"""Evaluate retrieval on the train set. For each unique query in train.csv, retrieve top-K, compute Recall@K against the labeled URLs. Reports both raw recall and 'reachable' recall (treating labels not present in our index as if they don't exist) so corpus coverage doesn't get conflated with retrieval quality. Three methods supported: -- Method 1: Baseline — pure dense embedding retrieval --hybrid --rag Method 3: Latest — hybrid (dense + BM25 + RRF) + LLM rerank. Uses concept-enriched index when data/products_with_concepts.jsonl is present. --hybrid --rag --train-knn Method 2: LLM + RAG + train — same as Method 3 plus a 4th ranklist that borrows gold URLs from the most semantically-similar train queries (leave-one-out enabled here for honest eval). """ from __future__ import annotations import argparse import csv import json import time from collections import defaultdict from pathlib import Path from recsys.retrieve import search from recsys.hybrid import ( search as hybrid_search, search_with_rag as hybrid_rag_search, ) from recsys.urls import slug ROOT = Path(__file__).resolve().parent.parent TRAIN = ROOT / "train.csv" DOCS = ROOT / "data" / "documents.jsonl" def load_queries(path: Path) -> dict[str, list[str]]: """Group label URLs by query string. Preserves first-seen order of queries.""" out: dict[str, list[str]] = defaultdict(list) with path.open(encoding="utf-8") as f: reader = csv.reader(f) next(reader, None) # header for row in reader: if len(row) < 2 or not row[0] or not row[1]: continue out[row[0]].append(row[1]) return out def known_slugs() -> set[str]: docs = [json.loads(l) for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()] return {d["id"] for d in docs} def recall_at_k(predicted_slugs: list[str], gold_slugs: set[str], k: int) -> float: if not gold_slugs: return 0.0 top = set(predicted_slugs[:k]) return len(top & gold_slugs) / len(gold_slugs) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--k", type=int, default=10) ap.add_argument("--rag", action="store_true", help="use LLM reranker") ap.add_argument("--hybrid", action="store_true", help="dense + BM25 with RRF fusion") ap.add_argument("--train-knn", action="store_true", help="borrow gold URLs from similar train queries (LOO on train eval)") ap.add_argument("--fanout", type=int, default=50, help="candidates fed to reranker") ap.add_argument("--show-misses", action="store_true") args = ap.parse_args() queries = load_queries(TRAIN) have = known_slugs() if args.train_knn: args.hybrid = True if args.hybrid and args.rag: mode = "hybrid+RAG" + ("+train-knn" if args.train_knn else "") elif args.hybrid: mode = "hybrid" + ("+train-knn" if args.train_knn else "") elif args.rag: mode = "RAG" else: mode = "baseline" print(f"mode={mode} | queries: {len(queries)} | k={args.k} | corpus: {len(have)} slugs\n") rows: list[tuple[str, int, int, int, float, float]] = [] first = True for q, urls in queries.items(): if args.rag and not first: time.sleep(8.0) # stay under LLM free-tier per-minute chat quota first = False gold = {slug(u) for u in urls if slug(u)} reachable = {s for s in gold if s in have} extras: dict = {} if args.train_knn: extras.update(use_train_knn=True, train_knn_loo=True) # LOO on train if args.hybrid and args.rag: hits = hybrid_rag_search(q, k=args.k, fanout=args.fanout, **extras) elif args.hybrid: hits = hybrid_search(q, k=args.k, fanout=args.fanout, **extras) else: hits = search(q, k=args.k) pred = [h["id"] for h in hits] raw = recall_at_k(pred, gold, args.k) reach = recall_at_k(pred, reachable, args.k) if reachable else 0.0 rows.append((q, len(gold), len(reachable), len(set(pred[:args.k]) & gold), raw, reach)) if args.show_misses: missed = gold - set(pred[: args.k]) if missed: missed_str = ", ".join(sorted(missed)) print(f" [{q[:60]!r}]\n missed: {missed_str}") print(f"{'#':>2} {'gold':>4} {'reach':>5} {'hits':>4} {'R@K(raw)':>8} {'R@K(reach)':>10} query") print("-" * 110) for i, (q, g, r, h, raw, reach) in enumerate(rows, 1): print(f"{i:2d} {g:4d} {r:5d} {h:4d} {raw:8.3f} {reach:10.3f} {q[:60]}") n = len(rows) mean_raw = sum(r[4] for r in rows) / n if n else 0.0 mean_reach = sum(r[5] for r in rows) / n if n else 0.0 print("-" * 110) print(f"mean Recall@{args.k} (raw): {mean_raw:.3f}") print(f"mean Recall@{args.k} (reachable): {mean_reach:.3f}") if __name__ == "__main__": main()