File size: 3,608 Bytes
d992912
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import logging
import time
from dataclasses import dataclass
from typing import Dict, List, Set

import numpy as np
from tqdm.auto import tqdm

from backend.app.engine.search_engine import ASOSSearchEngine

logger = logging.getLogger("asos_search")

__all__ = ["EvalResult", "SearchEvaluator"]


@dataclass
class EvalResult:
    query: str
    recall_at_k: Dict[int, float]
    precision_at_k: Dict[int, float]
    mrr: float
    latency_ms: float


class SearchEvaluator:
    def __init__(self, engine: ASOSSearchEngine):
        self.engine = engine

    def evaluate_single(
        self, query: str, relevant_skus: Set[str], k_values: List[int] = [5, 10, 20]
    ) -> EvalResult:
        max_k = max(k_values)
        t0 = time.time()
        results = self.engine.search(query, top_n=max_k)
        latency = (time.time() - t0) * 1000

        retrieved = results["sku"].astype(str).tolist()
        relevant = set(str(s) for s in relevant_skus)

        recall_at, precision_at = {}, {}
        for k in k_values:
            top_k = retrieved[:k]
            found = len(set(top_k) & relevant)
            recall_at[k] = found / len(relevant) if relevant else 0.0
            precision_at[k] = found / k if k > 0 else 0.0

        mrr = 0.0
        for rank, sku in enumerate(retrieved, 1):
            if sku in relevant:
                mrr = 1.0 / rank
                break

        return EvalResult(
            query=query,
            recall_at_k=recall_at,
            precision_at_k=precision_at,
            mrr=mrr,
            latency_ms=latency,
        )

    def evaluate(
        self, test_queries: List[Dict], k_values: List[int] = [5, 10, 20]
    ) -> Dict:
        results = []
        for tq in tqdm(test_queries, desc="Evaluating"):
            try:
                res = self.evaluate_single(
                    tq["query"],
                    set(str(s) for s in tq["relevant_skus"]),
                    k_values,
                )
                results.append(res)
            except Exception as e:
                logger.warning(f"Eval failed for '{tq['query']}': {e}")

        if not results:
            return {"error": "No successful evaluations"}

        agg = {
            "n_queries": len(results),
            "avg_latency_ms": float(np.mean([r.latency_ms for r in results])),
            "median_latency_ms": float(np.median([r.latency_ms for r in results])),
            "mean_mrr": float(np.mean([r.mrr for r in results])),
        }
        for k in k_values:
            agg[f"mean_recall@{k}"] = float(
                np.mean([r.recall_at_k.get(k, 0) for r in results])
            )
            agg[f"mean_precision@{k}"] = float(
                np.mean([r.precision_at_k.get(k, 0) for r in results])
            )

        return {"aggregate": agg, "per_query": [
            {"query": r.query, "mrr": r.mrr, "latency_ms": r.latency_ms,
             "recall_at_k": r.recall_at_k, "precision_at_k": r.precision_at_k}
            for r in results
        ]}

    @staticmethod
    def print_report(report: Dict):
        agg = report.get("aggregate", {})
        print("\n" + "=" * 65)
        print("  SEARCH ENGINE EVALUATION REPORT")
        print("=" * 65)
        print(f"  Queries evaluated:   {agg.get('n_queries', 0)}")
        print(f"  Avg latency:         {agg.get('avg_latency_ms', 0):.1f} ms")
        print(f"  Mean MRR:            {agg.get('mean_mrr', 0):.4f}")
        for key, val in sorted(agg.items()):
            if "recall" in key or "precision" in key:
                print(f"  {key:25s}  {val:.4f}")
        print("=" * 65)