#!/usr/bin/env python3 """ Retrieval quality evaluation. Usage: python scripts/eval_retrieval.py tests/eval_data/queries.json Measures: - Precision@k - Recall@k - Mean Reciprocal Rank (MRR) """ import sys import json from pathlib import Path from dataclasses import dataclass from typing import List, Dict, Set, Optional sys.path.insert(0, str(Path(__file__).parent.parent)) @dataclass class RetrievalMetrics: """Metrics for a single query.""" query_id: str query: str precision_at_k: float recall_at_k: float reciprocal_rank: float retrieved_ids: List[str] relevant_found: List[str] relevant_missed: List[str] @dataclass class AggregateMetrics: """Aggregate metrics across all queries.""" total_queries: int mean_precision: float mean_recall: float mrr: float # Mean Reciprocal Rank queries_with_hits: int def evaluate_single_query( query_id: str, query: str, relevant_chunks: Set[str], retrieved_chunks: List[str], k: int = 5 ) -> RetrievalMetrics: """Evaluate retrieval for a single query.""" top_k = retrieved_chunks[:k] top_k_set = set(top_k) # Precision@k: relevant in top-k / k relevant_in_top_k = top_k_set & relevant_chunks precision = len(relevant_in_top_k) / k if k > 0 else 0.0 # Recall@k: relevant in top-k / total relevant recall = len(relevant_in_top_k) / len(relevant_chunks) if relevant_chunks else 0.0 # Reciprocal Rank: 1 / rank of first relevant reciprocal_rank = 0.0 for i, chunk_id in enumerate(top_k): if chunk_id in relevant_chunks: reciprocal_rank = 1.0 / (i + 1) break return RetrievalMetrics( query_id=query_id, query=query, precision_at_k=precision, recall_at_k=recall, reciprocal_rank=reciprocal_rank, retrieved_ids=top_k, relevant_found=list(relevant_in_top_k), relevant_missed=list(relevant_chunks - top_k_set) ) def run_retrieval_eval( queries_file: str, k: int = 5, use_mock: bool = False ) -> AggregateMetrics: """Run retrieval evaluation from queries file.""" with open(queries_file, 'r') as f: data = json.load(f) queries = data.get("queries", []) if not queries: print("No queries found in file") return None # Import retrieval function if not use_mock: try: from src.retrieval.hybrid import hybrid_search except ImportError: print("Warning: Could not import hybrid_search, using mock") use_mock = True all_metrics = [] print("\n" + "=" * 60) print(" RETRIEVAL QUALITY EVALUATION") print("=" * 60) for q in queries: query_id = q.get("id", "unknown") query_text = q.get("query", "") relevant = set(q.get("relevant_chunks", [])) if not relevant: print(f"\nāš ļø Query {query_id}: No relevant chunks defined, skipping") continue print(f"\nšŸ“ Query {query_id}: {query_text[:50]}...") # Get retrieval results if use_mock: # Mock results for testing without Pinecone retrieved = list(relevant)[:k] + ["mock::0", "mock::1"] else: try: results = hybrid_search(query_text, top_k=k) retrieved = [r.get("id", "") for r in results] except Exception as e: print(f" Error: {e}") retrieved = [] # Evaluate metrics = evaluate_single_query( query_id=query_id, query=query_text, relevant_chunks=relevant, retrieved_chunks=retrieved, k=k ) all_metrics.append(metrics) # Print results print(f" Precision@{k}: {metrics.precision_at_k:.2f}") print(f" Recall@{k}: {metrics.recall_at_k:.2f}") print(f" Reciprocal Rank: {metrics.reciprocal_rank:.2f}") if metrics.relevant_found: print(f" āœ… Found: {metrics.relevant_found}") if metrics.relevant_missed: print(f" āŒ Missed: {metrics.relevant_missed}") # Aggregate if not all_metrics: print("\nNo queries evaluated") return None aggregate = AggregateMetrics( total_queries=len(all_metrics), mean_precision=sum(m.precision_at_k for m in all_metrics) / len(all_metrics), mean_recall=sum(m.recall_at_k for m in all_metrics) / len(all_metrics), mrr=sum(m.reciprocal_rank for m in all_metrics) / len(all_metrics), queries_with_hits=sum(1 for m in all_metrics if m.reciprocal_rank > 0) ) # Print summary print("\n" + "-" * 60) print(" SUMMARY") print("-" * 60) print(f" Total queries: {aggregate.total_queries}") print(f" Mean Precision@{k}: {aggregate.mean_precision:.2f}") print(f" Mean Recall@{k}: {aggregate.mean_recall:.2f}") print(f" MRR: {aggregate.mrr:.2f}") print(f" Queries with hits: {aggregate.queries_with_hits}/{aggregate.total_queries}") # Quality assessment print("\nšŸ“Š Quality Assessment") if aggregate.mean_precision >= 0.6: print(" āœ… Precision: GOOD (≄60%)") elif aggregate.mean_precision >= 0.4: print(" āš ļø Precision: FAIR (40-60%)") else: print(" āŒ Precision: POOR (<40%)") if aggregate.mrr >= 0.5: print(" āœ… MRR: GOOD (≄0.5)") elif aggregate.mrr >= 0.3: print(" āš ļø MRR: FAIR (0.3-0.5)") else: print(" āŒ MRR: POOR (<0.3)") return aggregate if __name__ == "__main__": if len(sys.argv) < 2: print("Usage: python scripts/eval_retrieval.py queries.json [--mock]") print("\nExample:") print(" python scripts/eval_retrieval.py tests/eval_data/queries.json") print(" python scripts/eval_retrieval.py tests/eval_data/queries.json --mock") sys.exit(1) queries_file = sys.argv[1] use_mock = "--mock" in sys.argv k = 5 # Parse k value if provided for arg in sys.argv: if arg.startswith("--k="): k = int(arg.split("=")[1]) if not Path(queries_file).exists(): print(f"Error: File not found: {queries_file}") sys.exit(1) metrics = run_retrieval_eval(queries_file, k=k, use_mock=use_mock) if metrics and metrics.mean_precision < 0.4: sys.exit(1)