Spaces:
Build error
Build error
| """ | |
| Chunk Quality Evaluation Tool | |
| Evaluates RAG chunking quality using: | |
| - Recall@K: How many relevant chunks are retrieved | |
| - Precision@K: How many retrieved chunks are relevant | |
| - MRR (Mean Reciprocal Rank): Average rank of first relevant chunk | |
| """ | |
| import asyncio | |
| import json | |
| import sys | |
| from pathlib import Path | |
| # Add backend to path | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent)) | |
| import httpx # noqa: E402 | |
| class ChunkEvaluator: | |
| def __init__(self, base_url: str = "http://localhost:8000"): | |
| self.base_url = base_url | |
| self.questions_file = Path(__file__).parent / "evaluation_questions.json" | |
| def load_questions(self): | |
| """Load test questions""" | |
| with open(self.questions_file, encoding="utf-8") as f: | |
| return json.load(f) | |
| async def search_chunks(self, question: str, top_k: int = 10): | |
| """Search for relevant chunks using the RAG system""" | |
| async with httpx.AsyncClient(timeout=60.0) as client: | |
| response = await client.post( | |
| f"{self.base_url}/api/search/", | |
| json={ | |
| "query": question, | |
| "top_k": top_k, | |
| "similarity_threshold": 0.0, # Get all results | |
| }, | |
| ) | |
| if response.status_code == 200: | |
| return response.json()["results"] | |
| else: | |
| print(f"Error: {response.status_code} - {response.text}") | |
| return [] | |
| def calculate_relevance(self, chunk_text: str, expected_keywords: list[str]) -> float: | |
| """ | |
| Calculate relevance score based on keyword matching | |
| Returns: Float between 0 and 1 | |
| """ | |
| chunk_lower = chunk_text.lower() | |
| matches = sum(1 for keyword in expected_keywords if keyword.lower() in chunk_lower) | |
| return matches / len(expected_keywords) if expected_keywords else 0.0 | |
| def calculate_metrics( | |
| self, results: list, expected_keywords: list[str], relevance_threshold: float = 0.3 | |
| ): | |
| """ | |
| Calculate Recall@K, Precision@K, and MRR | |
| Args: | |
| results: List of search results with 'text' and 'distance' | |
| expected_keywords: Keywords that should be in relevant chunks | |
| relevance_threshold: Minimum relevance score to consider chunk as relevant | |
| """ | |
| # Calculate relevance for each result | |
| relevances = [ | |
| self.calculate_relevance(result["text"], expected_keywords) for result in results | |
| ] | |
| # Determine which chunks are relevant | |
| relevant_flags = [rel >= relevance_threshold for rel in relevances] | |
| # Calculate metrics for different K values | |
| metrics = {} | |
| for k in [1, 3, 5, 10]: | |
| if len(results) < k: | |
| continue | |
| top_k_relevant = relevant_flags[:k] | |
| # Recall@K: proportion of relevant chunks retrieved in top K | |
| # (In our case, we assume at least 1 chunk should be relevant per question) | |
| recall_at_k = ( | |
| sum(top_k_relevant) / max(1, sum(relevant_flags)) if any(relevant_flags) else 0.0 | |
| ) | |
| # Precision@K: proportion of retrieved chunks that are relevant | |
| precision_at_k = sum(top_k_relevant) / k | |
| metrics[f"recall@{k}"] = recall_at_k | |
| metrics[f"precision@{k}"] = precision_at_k | |
| # MRR: Mean Reciprocal Rank (1/rank of first relevant result) | |
| first_relevant_idx = next((i for i, rel in enumerate(relevant_flags) if rel), None) | |
| mrr = 1.0 / (first_relevant_idx + 1) if first_relevant_idx is not None else 0.0 | |
| metrics["mrr"] = mrr | |
| # Average relevance score | |
| metrics["avg_relevance"] = sum(relevances) / len(relevances) if relevances else 0.0 | |
| return metrics | |
| async def evaluate( | |
| self, top_k: int = 10, relevance_threshold: float = 0.3, limit_questions: int = None | |
| ): | |
| """Run full evaluation""" | |
| questions = self.load_questions() | |
| if limit_questions: | |
| questions = questions[:limit_questions] | |
| print(f"π Evaluating {len(questions)} questions with top_k={top_k}") | |
| print("=" * 80) | |
| all_metrics = [] | |
| for q in questions: | |
| print(f"\nβ Q{q['id']}: {q['question']}") | |
| print(f" Category: {q['category']}") | |
| keywords = ", ".join(q["expected_keywords"][:5]) | |
| suffix = "..." if len(q["expected_keywords"]) > 5 else "" | |
| print(f" Expected keywords: {keywords}{suffix}") | |
| # Search | |
| results = await self.search_chunks(q["question"], top_k=top_k) | |
| if not results: | |
| print(" β οΈ No results returned") | |
| continue | |
| # Calculate metrics | |
| metrics = self.calculate_metrics(results, q["expected_keywords"], relevance_threshold) | |
| all_metrics.append(metrics) | |
| # Display results | |
| print(" π Metrics:") | |
| for metric_name, value in metrics.items(): | |
| print(f" {metric_name}: {value:.3f}") | |
| # Show top 3 chunks | |
| print(" π Top 3 chunks:") | |
| for i, result in enumerate(results[:3]): | |
| relevance = self.calculate_relevance(result["text"], q["expected_keywords"]) | |
| preview = result["text"][:100].replace("\n", " ") | |
| print(f" {i+1}. [Rel: {relevance:.2f}] {preview}...") | |
| # Calculate average metrics | |
| print("\n" + "=" * 80) | |
| print("π OVERALL RESULTS") | |
| print("=" * 80) | |
| if all_metrics: | |
| avg_metrics = {} | |
| for key in all_metrics[0].keys(): | |
| avg_metrics[key] = sum(m[key] for m in all_metrics) / len(all_metrics) | |
| print("\nπ― Average Metrics:") | |
| for metric_name, value in avg_metrics.items(): | |
| print(f" {metric_name}: {value:.3f}") | |
| return avg_metrics | |
| else: | |
| print("β οΈ No metrics calculated") | |
| return {} | |
| async def main(): | |
| evaluator = ChunkEvaluator() | |
| print("π Starting Chunk Quality Evaluation") | |
| print("=" * 80) | |
| results = await evaluator.evaluate(top_k=10, relevance_threshold=0.3, limit_questions=3) | |
| print("\nβ Evaluation complete!") | |
| # Save results | |
| output_file = Path(__file__).parent / "evaluation_results.json" | |
| with open(output_file, "w", encoding="utf-8") as f: | |
| json.dump(results, f, indent=2, ensure_ascii=False) | |
| print(f"πΎ Results saved to: {output_file}") | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |