Spaces:
Sleeping
Sleeping
| """Evaluate retrieval on the train set. | |
| For each unique query in train.csv, retrieve top-K, compute Recall@K against | |
| the labeled URLs. Reports both raw recall and 'reachable' recall (treating | |
| labels not present in our index as if they don't exist) so corpus coverage | |
| doesn't get conflated with retrieval quality. | |
| Three methods supported: | |
| -- Method 1: Baseline — pure dense embedding retrieval | |
| --hybrid --rag Method 3: Latest — hybrid (dense + BM25 + RRF) + | |
| LLM rerank. Uses concept-enriched index when | |
| data/products_with_concepts.jsonl is present. | |
| --hybrid --rag --train-knn Method 2: LLM + RAG + train — same as Method 3 | |
| plus a 4th ranklist that borrows gold URLs from | |
| the most semantically-similar train queries | |
| (leave-one-out enabled here for honest eval). | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import time | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from recsys.retrieve import search | |
| from recsys.hybrid import ( | |
| search as hybrid_search, | |
| search_with_rag as hybrid_rag_search, | |
| ) | |
| from recsys.urls import slug | |
| ROOT = Path(__file__).resolve().parent.parent | |
| TRAIN = ROOT / "train.csv" | |
| DOCS = ROOT / "data" / "documents.jsonl" | |
| def load_queries(path: Path) -> dict[str, list[str]]: | |
| """Group label URLs by query string. Preserves first-seen order of queries.""" | |
| out: dict[str, list[str]] = defaultdict(list) | |
| with path.open(encoding="utf-8") as f: | |
| reader = csv.reader(f) | |
| next(reader, None) # header | |
| for row in reader: | |
| if len(row) < 2 or not row[0] or not row[1]: | |
| continue | |
| out[row[0]].append(row[1]) | |
| return out | |
| def known_slugs() -> set[str]: | |
| docs = [json.loads(l) for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()] | |
| return {d["id"] for d in docs} | |
| def recall_at_k(predicted_slugs: list[str], gold_slugs: set[str], k: int) -> float: | |
| if not gold_slugs: | |
| return 0.0 | |
| top = set(predicted_slugs[:k]) | |
| return len(top & gold_slugs) / len(gold_slugs) | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--k", type=int, default=10) | |
| ap.add_argument("--rag", action="store_true", help="use LLM reranker") | |
| ap.add_argument("--hybrid", action="store_true", help="dense + BM25 with RRF fusion") | |
| ap.add_argument("--train-knn", action="store_true", | |
| help="borrow gold URLs from similar train queries (LOO on train eval)") | |
| ap.add_argument("--fanout", type=int, default=50, help="candidates fed to reranker") | |
| ap.add_argument("--show-misses", action="store_true") | |
| args = ap.parse_args() | |
| queries = load_queries(TRAIN) | |
| have = known_slugs() | |
| if args.train_knn: | |
| args.hybrid = True | |
| if args.hybrid and args.rag: | |
| mode = "hybrid+RAG" + ("+train-knn" if args.train_knn else "") | |
| elif args.hybrid: | |
| mode = "hybrid" + ("+train-knn" if args.train_knn else "") | |
| elif args.rag: | |
| mode = "RAG" | |
| else: | |
| mode = "baseline" | |
| print(f"mode={mode} | queries: {len(queries)} | k={args.k} | corpus: {len(have)} slugs\n") | |
| rows: list[tuple[str, int, int, int, float, float]] = [] | |
| first = True | |
| for q, urls in queries.items(): | |
| if args.rag and not first: | |
| time.sleep(8.0) # stay under LLM free-tier per-minute chat quota | |
| first = False | |
| gold = {slug(u) for u in urls if slug(u)} | |
| reachable = {s for s in gold if s in have} | |
| extras: dict = {} | |
| if args.train_knn: | |
| extras.update(use_train_knn=True, train_knn_loo=True) # LOO on train | |
| if args.hybrid and args.rag: | |
| hits = hybrid_rag_search(q, k=args.k, fanout=args.fanout, **extras) | |
| elif args.hybrid: | |
| hits = hybrid_search(q, k=args.k, fanout=args.fanout, **extras) | |
| else: | |
| hits = search(q, k=args.k) | |
| pred = [h["id"] for h in hits] | |
| raw = recall_at_k(pred, gold, args.k) | |
| reach = recall_at_k(pred, reachable, args.k) if reachable else 0.0 | |
| rows.append((q, len(gold), len(reachable), len(set(pred[:args.k]) & gold), raw, reach)) | |
| if args.show_misses: | |
| missed = gold - set(pred[: args.k]) | |
| if missed: | |
| missed_str = ", ".join(sorted(missed)) | |
| print(f" [{q[:60]!r}]\n missed: {missed_str}") | |
| print(f"{'#':>2} {'gold':>4} {'reach':>5} {'hits':>4} {'R@K(raw)':>8} {'R@K(reach)':>10} query") | |
| print("-" * 110) | |
| for i, (q, g, r, h, raw, reach) in enumerate(rows, 1): | |
| print(f"{i:2d} {g:4d} {r:5d} {h:4d} {raw:8.3f} {reach:10.3f} {q[:60]}") | |
| n = len(rows) | |
| mean_raw = sum(r[4] for r in rows) / n if n else 0.0 | |
| mean_reach = sum(r[5] for r in rows) / n if n else 0.0 | |
| print("-" * 110) | |
| print(f"mean Recall@{args.k} (raw): {mean_raw:.3f}") | |
| print(f"mean Recall@{args.k} (reachable): {mean_reach:.3f}") | |
| if __name__ == "__main__": | |
| main() | |