Spaces:
Running
Running
| """ | |
| evaluation.py | |
| ============= | |
| Measures how well the retrieval and generation pipeline performs. | |
| Two categories of metrics: | |
| 1. RETRIEVAL METRICS β How good is the search? | |
| - Recall@K: Is the correct chunk in the top K results? | |
| - MRR: Mean Reciprocal Rank β how high is the correct chunk ranked? | |
| - NDCG@K: Normalized Discounted Cumulative Gain β weighted ranking quality | |
| 2. GENERATION METRICS β How good is the answer? | |
| - ROUGE-L: Longest Common Subsequence overlap with reference answer | |
| - BERTScore: Semantic similarity between generated and reference answer | |
| - Citation Accuracy: Does the answer cite the correct source? | |
| Usage: | |
| from src.evaluation import evaluate_retrieval, evaluate_generation | |
| """ | |
| import numpy as np | |
| from typing import List, Dict, Tuple, Optional | |
| from dataclasses import dataclass, field | |
| from src.utils import ChildChunk, UnifiedIndex, PaperResult | |
| from src.pipeline import hybrid_search, rerank_chunks | |
| # βββ Data structures βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class RetrievalMetrics: | |
| """Results from retrieval evaluation.""" | |
| recall_at_1: float = 0.0 | |
| recall_at_5: float = 0.0 | |
| recall_at_10: float = 0.0 | |
| mrr: float = 0.0 # Mean Reciprocal Rank | |
| ndcg_at_10: float = 0.0 | |
| num_queries: int = 0 | |
| def __str__(self): | |
| return ( | |
| f"Retrieval Metrics ({self.num_queries} queries):\n" | |
| f" Recall@1: {self.recall_at_1:.4f}\n" | |
| f" Recall@5: {self.recall_at_5:.4f}\n" | |
| f" Recall@10: {self.recall_at_10:.4f}\n" | |
| f" MRR: {self.mrr:.4f}\n" | |
| f" NDCG@10: {self.ndcg_at_10:.4f}" | |
| ) | |
| class GenerationMetrics: | |
| """Results from generation evaluation.""" | |
| rouge_l_precision: float = 0.0 | |
| rouge_l_recall: float = 0.0 | |
| rouge_l_f1: float = 0.0 | |
| bert_score_f1: float = 0.0 | |
| citation_accuracy: float = 0.0 # % of answers with correct citations | |
| num_examples: int = 0 | |
| def __str__(self): | |
| return ( | |
| f"Generation Metrics ({self.num_examples} examples):\n" | |
| f" ROUGE-L F1: {self.rouge_l_f1:.4f}\n" | |
| f" BERTScore F1: {self.bert_score_f1:.4f}\n" | |
| f" Citation Accuracy: {self.citation_accuracy:.4f}" | |
| ) | |
| class EvalExample: | |
| """A single evaluation example with query, expected evidence, and answer.""" | |
| query: str | |
| relevant_chunk_text: str # the ground-truth evidence | |
| expected_answer: str = "" # optional reference answer | |
| paper_title: str = "" | |
| # βββ 1. RETRIEVAL EVALUATION βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _dcg(relevances: List[int], k: int) -> float: | |
| """Discounted Cumulative Gain at K.""" | |
| relevances = relevances[:k] | |
| dcg = 0.0 | |
| for i, rel in enumerate(relevances): | |
| dcg += rel / np.log2(i + 2) # i+2 because log2(1) = 0 | |
| return dcg | |
| def _ndcg(relevances: List[int], k: int) -> float: | |
| """Normalized DCG at K.""" | |
| dcg = _dcg(relevances, k) | |
| # Ideal DCG: sort relevances descending | |
| ideal = _dcg(sorted(relevances, reverse=True), k) | |
| if ideal == 0: | |
| return 0.0 | |
| return dcg / ideal | |
| def evaluate_retrieval( | |
| eval_examples: List[EvalExample], | |
| unified_indices: List[UnifiedIndex], | |
| use_reranker: bool = True, | |
| top_k: int = 10 | |
| ) -> RetrievalMetrics: | |
| """ | |
| Evaluate the retrieval pipeline on a set of examples. | |
| For each query: | |
| 1. Run hybrid search (FAISS + BM25) | |
| 2. Optionally rerank with CrossEncoder | |
| 3. Check if the relevant chunk appears in the top K results | |
| 4. Compute Recall@K, MRR, NDCG@K | |
| Args: | |
| eval_examples: list of EvalExample with query + relevant_chunk_text | |
| unified_indices: the paper indices to search over | |
| use_reranker: whether to apply the CrossEncoder reranker | |
| top_k: evaluate at this K | |
| """ | |
| recalls_1 = [] | |
| recalls_5 = [] | |
| recalls_10 = [] | |
| reciprocal_ranks = [] | |
| ndcg_scores = [] | |
| for example in eval_examples: | |
| # Search across all indices | |
| all_candidates = [] | |
| for index in unified_indices: | |
| candidates = hybrid_search(example.query, index, top_k=20) | |
| all_candidates.extend(candidates) | |
| if not all_candidates: | |
| recalls_1.append(0) | |
| recalls_5.append(0) | |
| recalls_10.append(0) | |
| reciprocal_ranks.append(0) | |
| ndcg_scores.append(0) | |
| continue | |
| # Optionally rerank | |
| if use_reranker: | |
| ranked_chunks = rerank_chunks(example.query, all_candidates, top_n=top_k) | |
| else: | |
| ranked_chunks = all_candidates[:top_k] | |
| # Check where the relevant chunk appears | |
| # Use text overlap to determine match (fuzzy matching) | |
| relevances = [] | |
| found_rank = None | |
| for rank, chunk in enumerate(ranked_chunks): | |
| # A chunk is "relevant" if it contains significant overlap with the evidence | |
| overlap = _text_overlap(chunk.text, example.relevant_chunk_text) | |
| if overlap > 0.5: | |
| relevances.append(1) | |
| if found_rank is None: | |
| found_rank = rank + 1 # 1-indexed | |
| else: | |
| relevances.append(0) | |
| # Recall@K: did we find the relevant chunk in top K? | |
| recalls_1.append(1 if found_rank is not None and found_rank <= 1 else 0) | |
| recalls_5.append(1 if found_rank is not None and found_rank <= 5 else 0) | |
| recalls_10.append(1 if found_rank is not None and found_rank <= 10 else 0) | |
| # MRR: reciprocal of the rank where we found it | |
| reciprocal_ranks.append(1.0 / found_rank if found_rank else 0.0) | |
| # NDCG@10 | |
| ndcg_scores.append(_ndcg(relevances, 10)) | |
| return RetrievalMetrics( | |
| recall_at_1=np.mean(recalls_1) if recalls_1 else 0.0, | |
| recall_at_5=np.mean(recalls_5) if recalls_5 else 0.0, | |
| recall_at_10=np.mean(recalls_10) if recalls_10 else 0.0, | |
| mrr=np.mean(reciprocal_ranks) if reciprocal_ranks else 0.0, | |
| ndcg_at_10=np.mean(ndcg_scores) if ndcg_scores else 0.0, | |
| num_queries=len(eval_examples) | |
| ) | |
| def _text_overlap(text_a: str, text_b: str) -> float: | |
| """ | |
| Compute word-level Jaccard overlap between two texts. | |
| Returns a float between 0 and 1. | |
| """ | |
| words_a = set(text_a.lower().split()) | |
| words_b = set(text_b.lower().split()) | |
| if not words_a or not words_b: | |
| return 0.0 | |
| intersection = words_a & words_b | |
| union = words_a | words_b | |
| return len(intersection) / len(union) | |
| # βββ 2. GENERATION EVALUATION ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def evaluate_generation( | |
| predictions: List[str], | |
| references: List[str], | |
| source_papers: Optional[List[str]] = None | |
| ) -> GenerationMetrics: | |
| """ | |
| Evaluate the quality of generated answers against reference answers. | |
| Metrics: | |
| - ROUGE-L: Measures overlap of longest common subsequence. | |
| Good for checking factual coverage. | |
| - BERTScore: Uses BERT embeddings to measure semantic similarity. | |
| Catches paraphrases that ROUGE would miss. | |
| - Citation Accuracy: Checks if generated answer contains proper | |
| [SOURCE N: ...] citations. | |
| """ | |
| if not predictions or not references: | |
| return GenerationMetrics() | |
| # ROUGE-L | |
| rouge_scores = _compute_rouge_l(predictions, references) | |
| # BERTScore | |
| bert_scores = _compute_bert_score(predictions, references) | |
| # Citation accuracy | |
| citation_acc = _compute_citation_accuracy(predictions, source_papers) | |
| return GenerationMetrics( | |
| rouge_l_precision=rouge_scores["precision"], | |
| rouge_l_recall=rouge_scores["recall"], | |
| rouge_l_f1=rouge_scores["f1"], | |
| bert_score_f1=bert_scores, | |
| citation_accuracy=citation_acc, | |
| num_examples=len(predictions) | |
| ) | |
| def _compute_rouge_l(predictions: List[str], references: List[str]) -> Dict[str, float]: | |
| """ | |
| Compute ROUGE-L (Longest Common Subsequence) between predictions and references. | |
| """ | |
| try: | |
| from rouge_score import rouge_scorer | |
| scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True) | |
| precisions = [] | |
| recalls = [] | |
| f1s = [] | |
| for pred, ref in zip(predictions, references): | |
| scores = scorer.score(ref, pred) | |
| precisions.append(scores["rougeL"].precision) | |
| recalls.append(scores["rougeL"].recall) | |
| f1s.append(scores["rougeL"].fmeasure) | |
| return { | |
| "precision": np.mean(precisions), | |
| "recall": np.mean(recalls), | |
| "f1": np.mean(f1s) | |
| } | |
| except ImportError: | |
| print("Warning: rouge_score not installed. Skipping ROUGE-L.") | |
| return {"precision": 0.0, "recall": 0.0, "f1": 0.0} | |
| def _compute_bert_score(predictions: List[str], references: List[str]) -> float: | |
| """ | |
| Compute BERTScore F1 using the bert-score library. | |
| BERTScore computes token-level cosine similarity between | |
| contextual embeddings of prediction and reference tokens, | |
| then aggregates using greedy matching. | |
| """ | |
| try: | |
| from bert_score import score as bert_score_fn | |
| P, R, F1 = bert_score_fn( | |
| predictions, references, | |
| lang="en", | |
| verbose=False, | |
| rescale_with_baseline=True | |
| ) | |
| return float(F1.mean()) | |
| except ImportError: | |
| print("Warning: bert_score not installed. Skipping BERTScore.") | |
| return 0.0 | |
| except Exception as e: | |
| print(f"Warning: BERTScore failed: {e}") | |
| return 0.0 | |
| def _compute_citation_accuracy( | |
| predictions: List[str], | |
| source_papers: Optional[List[str]] = None | |
| ) -> float: | |
| """ | |
| Check if generated answers contain proper citations. | |
| Checks for: | |
| 1. Contains at least one [SOURCE N: ...] citation | |
| 2. If source_papers provided, checks if cited paper exists | |
| """ | |
| import re | |
| if not predictions: | |
| return 0.0 | |
| correct = 0 | |
| citation_pattern = re.compile(r'\[SOURCE\s+\d+:.*?\]', re.IGNORECASE) | |
| for pred in predictions: | |
| citations = citation_pattern.findall(pred) | |
| if citations: | |
| correct += 1 | |
| return correct / len(predictions) | |
| # βββ 3. QUICK EVALUATION REPORT ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_full_evaluation( | |
| eval_examples: List[EvalExample], | |
| unified_indices: List[UnifiedIndex], | |
| generate_fn=None | |
| ) -> Dict[str, any]: | |
| """ | |
| Run a complete evaluation of both retrieval and generation. | |
| Args: | |
| eval_examples: test examples with queries and ground truth | |
| unified_indices: paper indices to search | |
| generate_fn: function(query, indices) -> answer string | |
| Returns dict with retrieval_metrics, generation_metrics, and summary. | |
| """ | |
| # Retrieval evaluation | |
| print("Evaluating retrieval pipeline...") | |
| retrieval_without_reranker = evaluate_retrieval( | |
| eval_examples, unified_indices, use_reranker=False | |
| ) | |
| retrieval_with_reranker = evaluate_retrieval( | |
| eval_examples, unified_indices, use_reranker=True | |
| ) | |
| print("\n--- Without Reranker ---") | |
| print(retrieval_without_reranker) | |
| print("\n--- With Reranker ---") | |
| print(retrieval_with_reranker) | |
| # Reranker improvement | |
| recall_improvement = retrieval_with_reranker.recall_at_5 - retrieval_without_reranker.recall_at_5 | |
| print(f"\nReranker Recall@5 improvement: {recall_improvement:+.4f}") | |
| results = { | |
| "retrieval_no_reranker": retrieval_without_reranker, | |
| "retrieval_with_reranker": retrieval_with_reranker, | |
| "reranker_recall5_delta": recall_improvement, | |
| } | |
| # Generation evaluation (if generate_fn provided) | |
| if generate_fn: | |
| print("\nEvaluating generation pipeline...") | |
| predictions = [] | |
| references = [] | |
| for example in eval_examples: | |
| if example.expected_answer: | |
| answer = generate_fn(example.query, unified_indices) | |
| predictions.append(answer) | |
| references.append(example.expected_answer) | |
| if predictions: | |
| gen_metrics = evaluate_generation(predictions, references) | |
| print(gen_metrics) | |
| results["generation"] = gen_metrics | |
| return results | |