| import logging |
| import time |
| from dataclasses import dataclass |
| from typing import Dict, List, Set |
|
|
| import numpy as np |
| from tqdm.auto import tqdm |
|
|
| from backend.app.engine.search_engine import ASOSSearchEngine |
|
|
| logger = logging.getLogger("asos_search") |
|
|
| __all__ = ["EvalResult", "SearchEvaluator"] |
|
|
|
|
| @dataclass |
| class EvalResult: |
| query: str |
| recall_at_k: Dict[int, float] |
| precision_at_k: Dict[int, float] |
| mrr: float |
| latency_ms: float |
|
|
|
|
| class SearchEvaluator: |
| def __init__(self, engine: ASOSSearchEngine): |
| self.engine = engine |
|
|
| def evaluate_single( |
| self, query: str, relevant_skus: Set[str], k_values: List[int] = [5, 10, 20] |
| ) -> EvalResult: |
| max_k = max(k_values) |
| t0 = time.time() |
| results = self.engine.search(query, top_n=max_k) |
| latency = (time.time() - t0) * 1000 |
|
|
| retrieved = results["sku"].astype(str).tolist() |
| relevant = set(str(s) for s in relevant_skus) |
|
|
| recall_at, precision_at = {}, {} |
| for k in k_values: |
| top_k = retrieved[:k] |
| found = len(set(top_k) & relevant) |
| recall_at[k] = found / len(relevant) if relevant else 0.0 |
| precision_at[k] = found / k if k > 0 else 0.0 |
|
|
| mrr = 0.0 |
| for rank, sku in enumerate(retrieved, 1): |
| if sku in relevant: |
| mrr = 1.0 / rank |
| break |
|
|
| return EvalResult( |
| query=query, |
| recall_at_k=recall_at, |
| precision_at_k=precision_at, |
| mrr=mrr, |
| latency_ms=latency, |
| ) |
|
|
| def evaluate( |
| self, test_queries: List[Dict], k_values: List[int] = [5, 10, 20] |
| ) -> Dict: |
| results = [] |
| for tq in tqdm(test_queries, desc="Evaluating"): |
| try: |
| res = self.evaluate_single( |
| tq["query"], |
| set(str(s) for s in tq["relevant_skus"]), |
| k_values, |
| ) |
| results.append(res) |
| except Exception as e: |
| logger.warning(f"Eval failed for '{tq['query']}': {e}") |
|
|
| if not results: |
| return {"error": "No successful evaluations"} |
|
|
| agg = { |
| "n_queries": len(results), |
| "avg_latency_ms": float(np.mean([r.latency_ms for r in results])), |
| "median_latency_ms": float(np.median([r.latency_ms for r in results])), |
| "mean_mrr": float(np.mean([r.mrr for r in results])), |
| } |
| for k in k_values: |
| agg[f"mean_recall@{k}"] = float( |
| np.mean([r.recall_at_k.get(k, 0) for r in results]) |
| ) |
| agg[f"mean_precision@{k}"] = float( |
| np.mean([r.precision_at_k.get(k, 0) for r in results]) |
| ) |
|
|
| return {"aggregate": agg, "per_query": [ |
| {"query": r.query, "mrr": r.mrr, "latency_ms": r.latency_ms, |
| "recall_at_k": r.recall_at_k, "precision_at_k": r.precision_at_k} |
| for r in results |
| ]} |
|
|
| @staticmethod |
| def print_report(report: Dict): |
| agg = report.get("aggregate", {}) |
| print("\n" + "=" * 65) |
| print(" SEARCH ENGINE EVALUATION REPORT") |
| print("=" * 65) |
| print(f" Queries evaluated: {agg.get('n_queries', 0)}") |
| print(f" Avg latency: {agg.get('avg_latency_ms', 0):.1f} ms") |
| print(f" Mean MRR: {agg.get('mean_mrr', 0):.4f}") |
| for key, val in sorted(agg.items()): |
| if "recall" in key or "precision" in key: |
| print(f" {key:25s} {val:.4f}") |
| print("=" * 65) |
|
|