| | |
| | """ |
| | Evaluation script for EyeWiki RAG system. |
| | |
| | Evaluates the system on a set of test questions and measures: |
| | - Retrieval recall (relevant sources retrieved) |
| | - Answer relevance (expected topics covered) |
| | - Source citation accuracy |
| | |
| | Usage: |
| | python scripts/evaluate.py |
| | python scripts/evaluate.py --questions tests/custom_questions.json |
| | python scripts/evaluate.py --output results/eval_results.json |
| | """ |
| |
|
| | import argparse |
| | import json |
| | import sys |
| | import time |
| | from pathlib import Path |
| | from typing import Dict, List, Any |
| |
|
| | from rich.console import Console |
| | from rich.progress import Progress, SpinnerColumn, TextColumn, BarColumn, TimeElapsedColumn |
| | from rich.table import Table |
| | from rich.panel import Panel |
| |
|
| | |
| | project_root = Path(__file__).parent.parent |
| | sys.path.insert(0, str(project_root)) |
| |
|
| | from config.settings import Settings |
| | from src.llm.ollama_client import OllamaClient |
| | from src.rag.query_engine import EyeWikiQueryEngine |
| | from src.rag.reranker import CrossEncoderReranker |
| | from src.rag.retriever import HybridRetriever |
| | from src.vectorstore.qdrant_store import QdrantStoreManager |
| |
|
| |
|
| | console = Console() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def calculate_retrieval_recall( |
| | retrieved_sources: List[str], |
| | expected_sources: List[str], |
| | ) -> float: |
| | """ |
| | Calculate retrieval recall. |
| | |
| | Recall = (# of expected sources retrieved) / (# of expected sources) |
| | |
| | Args: |
| | retrieved_sources: List of retrieved source titles |
| | expected_sources: List of expected source titles |
| | |
| | Returns: |
| | Recall score (0-1) |
| | """ |
| | if not expected_sources: |
| | return 1.0 |
| |
|
| | |
| | retrieved_lower = {s.lower() for s in retrieved_sources} |
| | expected_lower = {s.lower() for s in expected_sources} |
| |
|
| | |
| | matches = 0 |
| | for expected in expected_lower: |
| | for retrieved in retrieved_lower: |
| | |
| | if expected in retrieved or retrieved in expected: |
| | matches += 1 |
| | break |
| |
|
| | recall = matches / len(expected_sources) if expected_sources else 0.0 |
| | return recall |
| |
|
| |
|
| | def calculate_answer_relevance( |
| | answer: str, |
| | expected_topics: List[str], |
| | ) -> float: |
| | """ |
| | Calculate answer relevance based on topic coverage. |
| | |
| | Relevance = (# of expected topics found) / (# of expected topics) |
| | |
| | Args: |
| | answer: Generated answer text |
| | expected_topics: List of expected topic keywords |
| | |
| | Returns: |
| | Relevance score (0-1) |
| | """ |
| | if not expected_topics: |
| | return 1.0 |
| |
|
| | answer_lower = answer.lower() |
| |
|
| | |
| | topics_found = sum(1 for topic in expected_topics if topic.lower() in answer_lower) |
| |
|
| | relevance = topics_found / len(expected_topics) if expected_topics else 0.0 |
| | return relevance |
| |
|
| |
|
| | def calculate_citation_accuracy( |
| | answer: str, |
| | cited_sources: List[str], |
| | expected_sources: List[str], |
| | ) -> Dict[str, float]: |
| | """ |
| | Calculate citation accuracy metrics. |
| | |
| | Args: |
| | answer: Generated answer text |
| | cited_sources: Sources returned by system |
| | expected_sources: Expected sources |
| | |
| | Returns: |
| | Dictionary with citation metrics |
| | """ |
| | |
| | has_citations = "[Source:" in answer or "According to" in answer |
| |
|
| | |
| | if cited_sources and expected_sources: |
| | cited_set = {s.lower() for s in cited_sources} |
| | expected_set = {s.lower() for s in expected_sources} |
| |
|
| | |
| | true_positives = 0 |
| | for cited in cited_set: |
| | for expected in expected_set: |
| | if expected in cited or cited in expected: |
| | true_positives += 1 |
| | break |
| |
|
| | precision = true_positives / len(cited_sources) if cited_sources else 0.0 |
| | recall = true_positives / len(expected_sources) if expected_sources else 0.0 |
| |
|
| | |
| | f1 = ( |
| | 2 * (precision * recall) / (precision + recall) |
| | if (precision + recall) > 0 |
| | else 0.0 |
| | ) |
| | else: |
| | precision = 0.0 |
| | recall = 0.0 |
| | f1 = 0.0 |
| |
|
| | return { |
| | "has_explicit_citations": has_citations, |
| | "precision": precision, |
| | "recall": recall, |
| | "f1": f1, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def evaluate_question( |
| | question_data: Dict[str, Any], |
| | query_engine: EyeWikiQueryEngine, |
| | ) -> Dict[str, Any]: |
| | """ |
| | Evaluate a single question. |
| | |
| | Args: |
| | question_data: Question data with expected answers |
| | query_engine: Query engine instance |
| | |
| | Returns: |
| | Evaluation results |
| | """ |
| | question_id = question_data["id"] |
| | question = question_data["question"] |
| | expected_topics = question_data["expected_topics"] |
| | expected_sources = question_data["expected_sources"] |
| |
|
| | |
| | start_time = time.time() |
| | try: |
| | response = query_engine.query( |
| | question=question, |
| | include_sources=True, |
| | ) |
| | query_time = time.time() - start_time |
| |
|
| | |
| | retrieved_sources = [s.title for s in response.sources] |
| |
|
| | |
| | retrieval_recall = calculate_retrieval_recall( |
| | retrieved_sources, expected_sources |
| | ) |
| |
|
| | answer_relevance = calculate_answer_relevance( |
| | response.answer, expected_topics |
| | ) |
| |
|
| | citation_metrics = calculate_citation_accuracy( |
| | response.answer, retrieved_sources, expected_sources |
| | ) |
| |
|
| | |
| | topics_found = [ |
| | topic for topic in expected_topics if topic.lower() in response.answer.lower() |
| | ] |
| | topics_missing = [ |
| | topic |
| | for topic in expected_topics |
| | if topic.lower() not in response.answer.lower() |
| | ] |
| |
|
| | |
| | sources_retrieved = [] |
| | sources_missing = [] |
| |
|
| | for expected in expected_sources: |
| | found = False |
| | for retrieved in retrieved_sources: |
| | if expected.lower() in retrieved.lower() or retrieved.lower() in expected.lower(): |
| | sources_retrieved.append(expected) |
| | found = True |
| | break |
| | if not found: |
| | sources_missing.append(expected) |
| |
|
| | result = { |
| | "id": question_id, |
| | "question": question, |
| | "category": question_data.get("category", "unknown"), |
| | "answer": response.answer, |
| | "confidence": response.confidence, |
| | "query_time": query_time, |
| | "metrics": { |
| | "retrieval_recall": retrieval_recall, |
| | "answer_relevance": answer_relevance, |
| | "citation_precision": citation_metrics["precision"], |
| | "citation_recall": citation_metrics["recall"], |
| | "citation_f1": citation_metrics["f1"], |
| | }, |
| | "details": { |
| | "retrieved_sources": retrieved_sources, |
| | "expected_sources": expected_sources, |
| | "sources_retrieved": sources_retrieved, |
| | "sources_missing": sources_missing, |
| | "topics_found": topics_found, |
| | "topics_missing": topics_missing, |
| | "has_explicit_citations": citation_metrics["has_explicit_citations"], |
| | }, |
| | "success": True, |
| | } |
| |
|
| | except Exception as e: |
| | result = { |
| | "id": question_id, |
| | "question": question, |
| | "category": question_data.get("category", "unknown"), |
| | "error": str(e), |
| | "query_time": time.time() - start_time, |
| | "success": False, |
| | } |
| |
|
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def calculate_aggregate_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]: |
| | """ |
| | Calculate aggregate metrics across all questions. |
| | |
| | Args: |
| | results: List of evaluation results |
| | |
| | Returns: |
| | Aggregate metrics |
| | """ |
| | successful_results = [r for r in results if r["success"]] |
| |
|
| | if not successful_results: |
| | return {"error": "No successful evaluations"} |
| |
|
| | |
| | avg_retrieval_recall = sum( |
| | r["metrics"]["retrieval_recall"] for r in successful_results |
| | ) / len(successful_results) |
| |
|
| | avg_answer_relevance = sum( |
| | r["metrics"]["answer_relevance"] for r in successful_results |
| | ) / len(successful_results) |
| |
|
| | avg_citation_precision = sum( |
| | r["metrics"]["citation_precision"] for r in successful_results |
| | ) / len(successful_results) |
| |
|
| | avg_citation_recall = sum( |
| | r["metrics"]["citation_recall"] for r in successful_results |
| | ) / len(successful_results) |
| |
|
| | avg_citation_f1 = sum( |
| | r["metrics"]["citation_f1"] for r in successful_results |
| | ) / len(successful_results) |
| |
|
| | avg_confidence = sum(r["confidence"] for r in successful_results) / len( |
| | successful_results |
| | ) |
| |
|
| | avg_query_time = sum(r["query_time"] for r in successful_results) / len( |
| | successful_results |
| | ) |
| |
|
| | |
| | citations_present = sum( |
| | 1 for r in successful_results if r["details"]["has_explicit_citations"] |
| | ) |
| |
|
| | |
| | categories = {} |
| | for result in successful_results: |
| | category = result["category"] |
| | if category not in categories: |
| | categories[category] = { |
| | "count": 0, |
| | "retrieval_recall": 0, |
| | "answer_relevance": 0, |
| | } |
| | categories[category]["count"] += 1 |
| | categories[category]["retrieval_recall"] += result["metrics"]["retrieval_recall"] |
| | categories[category]["answer_relevance"] += result["metrics"]["answer_relevance"] |
| |
|
| | |
| | for category, data in categories.items(): |
| | count = data["count"] |
| | data["retrieval_recall"] /= count |
| | data["answer_relevance"] /= count |
| |
|
| | return { |
| | "total_questions": len(results), |
| | "successful": len(successful_results), |
| | "failed": len(results) - len(successful_results), |
| | "metrics": { |
| | "retrieval_recall": avg_retrieval_recall, |
| | "answer_relevance": avg_answer_relevance, |
| | "citation_precision": avg_citation_precision, |
| | "citation_recall": avg_citation_recall, |
| | "citation_f1": avg_citation_f1, |
| | "avg_confidence": avg_confidence, |
| | "avg_query_time": avg_query_time, |
| | "citation_rate": citations_present / len(successful_results), |
| | }, |
| | "by_category": categories, |
| | } |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def print_question_result(result: Dict[str, Any]): |
| | """Print result for a single question.""" |
| | if not result["success"]: |
| | console.print( |
| | f"\n[red]✗ {result['id']}: {result['question']}[/red]", |
| | f"[red]Error: {result['error']}[/red]", |
| | ) |
| | return |
| |
|
| | |
| | table = Table(show_header=False, box=None, padding=(0, 1)) |
| | table.add_column(style="cyan") |
| | table.add_column(style="yellow") |
| |
|
| | metrics = result["metrics"] |
| | table.add_row("Retrieval Recall", f"{metrics['retrieval_recall']:.2%}") |
| | table.add_row("Answer Relevance", f"{metrics['answer_relevance']:.2%}") |
| | table.add_row("Citation F1", f"{metrics['citation_f1']:.2%}") |
| | table.add_row("Confidence", f"{result['confidence']:.2%}") |
| | table.add_row("Query Time", f"{result['query_time']:.2f}s") |
| |
|
| | |
| | avg_score = (metrics["retrieval_recall"] + metrics["answer_relevance"]) / 2 |
| | if avg_score >= 0.8: |
| | status = "[green]✓ PASS[/green]" |
| | elif avg_score >= 0.6: |
| | status = "[yellow]~ PARTIAL[/yellow]" |
| | else: |
| | status = "[red]✗ FAIL[/red]" |
| |
|
| | console.print(f"\n{status} [bold]{result['id']}:[/bold] {result['question']}") |
| | console.print(table) |
| |
|
| | |
| | details = result["details"] |
| | if details["topics_missing"]: |
| | console.print( |
| | f" [dim]Missing topics: {', '.join(details['topics_missing'])}[/dim]" |
| | ) |
| | if details["sources_missing"]: |
| | console.print( |
| | f" [dim]Missing sources: {', '.join(details['sources_missing'])}[/dim]" |
| | ) |
| |
|
| |
|
| | def print_aggregate_results(aggregate: Dict[str, Any]): |
| | """Print aggregate results.""" |
| | console.print("\n") |
| | console.print( |
| | Panel.fit( |
| | "[bold cyan]Evaluation Summary[/bold cyan]", |
| | border_style="cyan", |
| | ) |
| | ) |
| |
|
| | |
| | table = Table(show_header=True, header_style="bold magenta") |
| | table.add_column("Metric", style="cyan") |
| | table.add_column("Score", style="yellow", justify="right") |
| | table.add_column("Grade", style="green", justify="center") |
| |
|
| | metrics = aggregate["metrics"] |
| |
|
| | def get_grade(score: float) -> str: |
| | if score >= 0.9: |
| | return "[green]A[/green]" |
| | elif score >= 0.8: |
| | return "[green]B[/green]" |
| | elif score >= 0.7: |
| | return "[yellow]C[/yellow]" |
| | elif score >= 0.6: |
| | return "[yellow]D[/yellow]" |
| | else: |
| | return "[red]F[/red]" |
| |
|
| | table.add_row( |
| | "Retrieval Recall", |
| | f"{metrics['retrieval_recall']:.2%}", |
| | get_grade(metrics["retrieval_recall"]), |
| | ) |
| | table.add_row( |
| | "Answer Relevance", |
| | f"{metrics['answer_relevance']:.2%}", |
| | get_grade(metrics["answer_relevance"]), |
| | ) |
| | table.add_row( |
| | "Citation Precision", |
| | f"{metrics['citation_precision']:.2%}", |
| | get_grade(metrics["citation_precision"]), |
| | ) |
| | table.add_row( |
| | "Citation Recall", |
| | f"{metrics['citation_recall']:.2%}", |
| | get_grade(metrics["citation_recall"]), |
| | ) |
| | table.add_row( |
| | "Citation F1", |
| | f"{metrics['citation_f1']:.2%}", |
| | get_grade(metrics["citation_f1"]), |
| | ) |
| |
|
| | console.print(table) |
| |
|
| | |
| | console.print(f"\n[bold]Statistics:[/bold]") |
| | console.print( |
| | f" Total Questions: {aggregate['total_questions']}", |
| | f" Successful: [green]{aggregate['successful']}[/green]", |
| | f" Failed: [red]{aggregate['failed']}[/red]", |
| | f" Avg Confidence: {metrics['avg_confidence']:.2%}", |
| | f" Avg Query Time: {metrics['avg_query_time']:.2f}s", |
| | f" Citation Rate: {metrics['citation_rate']:.2%}", |
| | ) |
| |
|
| | |
| | if aggregate["by_category"]: |
| | console.print(f"\n[bold]Performance by Category:[/bold]") |
| | cat_table = Table(show_header=True, header_style="bold magenta") |
| | cat_table.add_column("Category", style="cyan") |
| | cat_table.add_column("Count", justify="right") |
| | cat_table.add_column("Retrieval", justify="right") |
| | cat_table.add_column("Relevance", justify="right") |
| |
|
| | for category, data in sorted(aggregate["by_category"].items()): |
| | cat_table.add_row( |
| | category, |
| | str(data["count"]), |
| | f"{data['retrieval_recall']:.2%}", |
| | f"{data['answer_relevance']:.2%}", |
| | ) |
| |
|
| | console.print(cat_table) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def load_test_questions(questions_file: Path) -> List[Dict[str, Any]]: |
| | """Load test questions from JSON file.""" |
| | if not questions_file.exists(): |
| | console.print(f"[red]Error: Questions file not found: {questions_file}[/red]") |
| | sys.exit(1) |
| |
|
| | with open(questions_file, "r") as f: |
| | questions = json.load(f) |
| |
|
| | console.print(f"[green]✓[/green] Loaded {len(questions)} test questions") |
| | return questions |
| |
|
| |
|
| | def initialize_system() -> EyeWikiQueryEngine: |
| | """Initialize the RAG system.""" |
| | console.print("[bold]Initializing RAG system...[/bold]") |
| |
|
| | |
| | settings = Settings() |
| |
|
| | |
| | ollama_client = OllamaClient( |
| | base_url=settings.ollama_base_url, |
| | llm_model=settings.llm_model, |
| | embedding_model=settings.embedding_model, |
| | ) |
| |
|
| | qdrant_manager = QdrantStoreManager( |
| | collection_name=settings.qdrant_collection_name, |
| | qdrant_path=settings.qdrant_path, |
| | vector_size=settings.embedding_dim, |
| | ) |
| |
|
| | retriever = HybridRetriever( |
| | qdrant_manager=qdrant_manager, |
| | ollama_client=ollama_client, |
| | ) |
| |
|
| | reranker = CrossEncoderReranker( |
| | model_name=settings.reranker_model, |
| | ) |
| |
|
| | |
| | prompts_dir = project_root / "prompts" |
| | system_prompt_path = prompts_dir / "system_prompt.txt" |
| | query_prompt_path = prompts_dir / "query_prompt.txt" |
| | disclaimer_path = prompts_dir / "medical_disclaimer.txt" |
| |
|
| | query_engine = EyeWikiQueryEngine( |
| | retriever=retriever, |
| | reranker=reranker, |
| | llm_client=ollama_client, |
| | system_prompt_path=system_prompt_path if system_prompt_path.exists() else None, |
| | query_prompt_path=query_prompt_path if query_prompt_path.exists() else None, |
| | disclaimer_path=disclaimer_path if disclaimer_path.exists() else None, |
| | max_context_tokens=settings.max_context_tokens, |
| | retrieval_k=20, |
| | rerank_k=5, |
| | ) |
| |
|
| | console.print("[green]✓[/green] System initialized\n") |
| | return query_engine |
| |
|
| |
|
| | def run_evaluation( |
| | questions_file: Path, |
| | output_file: Path = None, |
| | verbose: bool = False, |
| | ): |
| | """ |
| | Run evaluation on test questions. |
| | |
| | Args: |
| | questions_file: Path to test questions JSON |
| | output_file: Optional path to save results |
| | verbose: Print detailed results |
| | """ |
| | console.print( |
| | Panel.fit( |
| | "[bold blue]EyeWiki RAG Evaluation[/bold blue]", |
| | border_style="blue", |
| | ) |
| | ) |
| |
|
| | |
| | questions = load_test_questions(questions_file) |
| |
|
| | |
| | query_engine = initialize_system() |
| |
|
| | |
| | results = [] |
| | console.print("[bold]Evaluating questions...[/bold]\n") |
| |
|
| | with Progress( |
| | SpinnerColumn(), |
| | TextColumn("[progress.description]{task.description}"), |
| | BarColumn(), |
| | TextColumn("[progress.percentage]{task.percentage:>3.0f}%"), |
| | TimeElapsedColumn(), |
| | console=console, |
| | ) as progress: |
| |
|
| | task = progress.add_task("Processing...", total=len(questions)) |
| |
|
| | for question_data in questions: |
| | result = evaluate_question(question_data, query_engine) |
| | results.append(result) |
| |
|
| | if verbose: |
| | print_question_result(result) |
| |
|
| | progress.update(task, advance=1) |
| |
|
| | |
| | aggregate = calculate_aggregate_metrics(results) |
| |
|
| | |
| | if not verbose: |
| | console.print("\n[bold]Per-Question Results:[/bold]") |
| | for result in results: |
| | print_question_result(result) |
| |
|
| | print_aggregate_results(aggregate) |
| |
|
| | |
| | if output_file: |
| | output_data = { |
| | "results": results, |
| | "aggregate": aggregate, |
| | "timestamp": time.time(), |
| | } |
| |
|
| | output_file.parent.mkdir(parents=True, exist_ok=True) |
| | with open(output_file, "w") as f: |
| | json.dump(output_data, f, indent=2) |
| |
|
| | console.print(f"\n[green]✓[/green] Results saved to {output_file}") |
| |
|
| |
|
| | def main(): |
| | """Main entry point.""" |
| | parser = argparse.ArgumentParser( |
| | description="Evaluate EyeWiki RAG system on test questions" |
| | ) |
| |
|
| | parser.add_argument( |
| | "--questions", |
| | type=Path, |
| | default=project_root / "tests" / "test_questions.json", |
| | help="Path to test questions JSON file", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--output", |
| | type=Path, |
| | default=None, |
| | help="Path to save evaluation results (JSON)", |
| | ) |
| |
|
| | parser.add_argument( |
| | "-v", |
| | "--verbose", |
| | action="store_true", |
| | help="Print detailed results for each question", |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | try: |
| | run_evaluation( |
| | questions_file=args.questions, |
| | output_file=args.output, |
| | verbose=args.verbose, |
| | ) |
| | except KeyboardInterrupt: |
| | console.print("\n[yellow]Evaluation interrupted by user[/yellow]") |
| | sys.exit(1) |
| | except Exception as e: |
| | console.print(f"\n[red]Error: {e}[/red]") |
| | import traceback |
| |
|
| | traceback.print_exc() |
| | sys.exit(1) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|