File size: 3,976 Bytes
36b2bff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | #!/usr/bin/env python3
"""Compare embedding models on domain-specific retrieval queries.
Runs the evaluation query set against the loaded embedding model and
reports retrieval quality metrics. Used to validate model choice
before deployment (Gate 1).
Usage:
python scripts/evaluation/run_model_comparison.py
"""
import json
import logging
import sys
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)
def main():
# Load evaluation queries
eval_path = Path(__file__).parent.parent.parent / "tests" / "evaluation" / "eval_queries.json"
if not eval_path.exists():
logger.error(f"Evaluation queries not found: {eval_path}")
sys.exit(1)
queries = json.loads(eval_path.read_text())
logger.info(f"Loaded {len(queries)} evaluation queries")
# Initialize RAG service
try:
import asyncio
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
from app.core.config import get_settings
from app.services.rag import RAGService
settings = get_settings()
rag = RAGService(settings)
asyncio.run(rag.initialize())
logger.info(f"RAG service initialized with model: {settings.rag_embedding_model}")
except Exception as e:
logger.error(f"Failed to initialize RAG: {e}")
sys.exit(1)
# Run queries
results = []
correct_category = 0
total_with_results = 0
total_relevant = 0 # queries that should have results
for q in queries:
query_text = q["query"]
expected_cat = q.get("expected_category")
retrieved = rag.retrieve(query_text, n_results=5)
if expected_cat is None:
# Irrelevant query — should return empty or low-relevance results
if not retrieved or len(retrieved) == 0:
results.append({"query": query_text, "status": "CORRECT_EMPTY", "expected": None})
else:
results.append(
{
"query": query_text,
"status": "FALSE_POSITIVE",
"top_category": retrieved[0].get("metadata", {}).get("category"),
}
)
continue
total_relevant += 1
if not retrieved:
results.append({"query": query_text, "status": "NO_RESULTS", "expected": expected_cat})
continue
total_with_results += 1
top_cat = retrieved[0].get("metadata", {}).get("category", "")
if top_cat == expected_cat:
correct_category += 1
results.append({"query": query_text, "status": "CORRECT", "top_category": top_cat})
else:
results.append({"query": query_text, "status": "WRONG_CATEGORY", "expected": expected_cat, "got": top_cat})
# Report
logger.info("\n=== MODEL COMPARISON RESULTS ===")
logger.info(f"Model: {settings.rag_embedding_model}")
logger.info(f"Total queries: {len(queries)}")
logger.info(f"Relevant queries: {total_relevant}")
logger.info(f"Queries with results: {total_with_results}")
logger.info(
f"Correct top category: {correct_category}/{total_relevant} ({correct_category / max(total_relevant, 1) * 100:.1f}%)"
)
# Show failures
failures = [r for r in results if r["status"] in ("WRONG_CATEGORY", "NO_RESULTS")]
if failures:
logger.info(f"\nFailures ({len(failures)}):")
for f in failures:
logger.info(
f" [{f['status']}] {f['query'][:60]}... expected={f.get('expected')}, got={f.get('got', 'N/A')}"
)
# Save results
output_path = Path(__file__).parent / "model_comparison_results.json"
output_path.write_text(json.dumps({"model": settings.rag_embedding_model, "results": results}, indent=2))
logger.info(f"\nResults saved to {output_path}")
if __name__ == "__main__":
main()
|