File size: 3,976 Bytes
36b2bff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
#!/usr/bin/env python3
"""Compare embedding models on domain-specific retrieval queries.

Runs the evaluation query set against the loaded embedding model and
reports retrieval quality metrics. Used to validate model choice
before deployment (Gate 1).

Usage:
    python scripts/evaluation/run_model_comparison.py
"""

import json
import logging
import sys
from pathlib import Path

logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
logger = logging.getLogger(__name__)


def main():
    # Load evaluation queries
    eval_path = Path(__file__).parent.parent.parent / "tests" / "evaluation" / "eval_queries.json"
    if not eval_path.exists():
        logger.error(f"Evaluation queries not found: {eval_path}")
        sys.exit(1)

    queries = json.loads(eval_path.read_text())
    logger.info(f"Loaded {len(queries)} evaluation queries")

    # Initialize RAG service
    try:
        import asyncio

        sys.path.insert(0, str(Path(__file__).parent.parent.parent))
        from app.core.config import get_settings
        from app.services.rag import RAGService

        settings = get_settings()
        rag = RAGService(settings)
        asyncio.run(rag.initialize())
        logger.info(f"RAG service initialized with model: {settings.rag_embedding_model}")
    except Exception as e:
        logger.error(f"Failed to initialize RAG: {e}")
        sys.exit(1)

    # Run queries
    results = []
    correct_category = 0
    total_with_results = 0
    total_relevant = 0  # queries that should have results

    for q in queries:
        query_text = q["query"]
        expected_cat = q.get("expected_category")

        retrieved = rag.retrieve(query_text, n_results=5)

        if expected_cat is None:
            # Irrelevant query — should return empty or low-relevance results
            if not retrieved or len(retrieved) == 0:
                results.append({"query": query_text, "status": "CORRECT_EMPTY", "expected": None})
            else:
                results.append(
                    {
                        "query": query_text,
                        "status": "FALSE_POSITIVE",
                        "top_category": retrieved[0].get("metadata", {}).get("category"),
                    }
                )
            continue

        total_relevant += 1

        if not retrieved:
            results.append({"query": query_text, "status": "NO_RESULTS", "expected": expected_cat})
            continue

        total_with_results += 1
        top_cat = retrieved[0].get("metadata", {}).get("category", "")

        if top_cat == expected_cat:
            correct_category += 1
            results.append({"query": query_text, "status": "CORRECT", "top_category": top_cat})
        else:
            results.append({"query": query_text, "status": "WRONG_CATEGORY", "expected": expected_cat, "got": top_cat})

    # Report
    logger.info("\n=== MODEL COMPARISON RESULTS ===")
    logger.info(f"Model: {settings.rag_embedding_model}")
    logger.info(f"Total queries: {len(queries)}")
    logger.info(f"Relevant queries: {total_relevant}")
    logger.info(f"Queries with results: {total_with_results}")
    logger.info(
        f"Correct top category: {correct_category}/{total_relevant} ({correct_category / max(total_relevant, 1) * 100:.1f}%)"
    )

    # Show failures
    failures = [r for r in results if r["status"] in ("WRONG_CATEGORY", "NO_RESULTS")]
    if failures:
        logger.info(f"\nFailures ({len(failures)}):")
        for f in failures:
            logger.info(
                f"  [{f['status']}] {f['query'][:60]}... expected={f.get('expected')}, got={f.get('got', 'N/A')}"
            )

    # Save results
    output_path = Path(__file__).parent / "model_comparison_results.json"
    output_path.write_text(json.dumps({"model": settings.rag_embedding_model, "results": results}, indent=2))
    logger.info(f"\nResults saved to {output_path}")


if __name__ == "__main__":
    main()