Spaces:
Sleeping
Sleeping
| """Per-query breakdown of pipeline predictions vs gold labels. | |
| For each unique query in train.csv: | |
| - Run the configured retriever (default: hybrid + RAG + train-knn LOO) | |
| - Compare predicted top-10 against gold URLs | |
| - Show gold split into: | |
| β HIT predicted correctly | |
| β MISS-IND individual test we have indexed but missed | |
| β MISS-PREPKG pre-packaged label we don't index (unreachable) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import time | |
| from collections import OrderedDict | |
| from pathlib import Path | |
| from recsys.hybrid import ( | |
| search as hybrid_search, | |
| search_with_rag as hybrid_rag_search, | |
| ) | |
| from recsys.urls import slug | |
| ROOT = Path(__file__).resolve().parent.parent | |
| TRAIN = ROOT / "train.csv" | |
| DOCS = ROOT / "data" / "documents.jsonl" | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--mode", choices=["hybrid", "hybrid_rag", "hybrid_rag_knn"], | |
| default="hybrid_rag_knn") | |
| ap.add_argument("--k", type=int, default=10) | |
| ap.add_argument("--fanout", type=int, default=50) | |
| ap.add_argument("--sleep", type=float, default=2.0, | |
| help="seconds between LLM-using queries") | |
| args = ap.parse_args() | |
| # Load train.csv β query β gold slugs | |
| queries: "OrderedDict[str, list[str]]" = OrderedDict() | |
| with TRAIN.open(encoding="utf-8") as f: | |
| r = csv.reader(f); next(r, None) | |
| for row in r: | |
| if len(row) < 2 or not row[0] or not row[1]: | |
| continue | |
| s = slug(row[1]) | |
| if s: | |
| queries.setdefault(row[0], []).append(s) | |
| # Load known slugs (what's actually in our index) | |
| have = {json.loads(l)["id"] for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()} | |
| # Choose searcher | |
| use_knn = args.mode == "hybrid_rag_knn" | |
| if args.mode == "hybrid": | |
| run = lambda q: hybrid_search(q, k=args.k, fanout=args.fanout) | |
| elif args.mode == "hybrid_rag": | |
| run = lambda q: hybrid_rag_search(q, k=args.k, fanout=args.fanout) | |
| else: | |
| run = lambda q: hybrid_rag_search( | |
| q, k=args.k, fanout=args.fanout, | |
| use_train_knn=True, train_knn_loo=True, | |
| ) | |
| print(f"mode={args.mode} K={args.k} fanout={args.fanout}\n") | |
| totals = {"gold": 0, "hits": 0, "missed_ind": 0, "missed_prepkg": 0} | |
| for i, (q, gold_slugs) in enumerate(queries.items(), 1): | |
| if i > 1 and args.mode in ("hybrid_rag", "hybrid_rag_knn"): | |
| time.sleep(args.sleep) | |
| gold_set = set(gold_slugs) | |
| gold_pre = {s for s in gold_set if s not in have} | |
| gold_ind = gold_set - gold_pre | |
| hits = run(q) | |
| pred = {h["id"] for h in hits} | |
| hit_set = pred & gold_set | |
| missed = gold_set - pred | |
| missed_pre = missed & gold_pre | |
| missed_ind = missed - gold_pre | |
| totals["gold"] += len(gold_set) | |
| totals["hits"] += len(hit_set) | |
| totals["missed_ind"] += len(missed_ind) | |
| totals["missed_prepkg"] += len(missed_pre) | |
| print(f"=== Q{i}: {q[:80].replace(chr(10),' / ')!r} ===") | |
| print(f" gold: {len(gold_set):>2} " | |
| f"(individual: {len(gold_ind)}, pre-packaged: {len(gold_pre)})") | |
| print(f" HIT : {len(hit_set):>2} " | |
| f"({100*len(hit_set)/len(gold_set):.0f}% of gold, " | |
| f"{100*len(hit_set)/len(gold_ind) if gold_ind else 0:.0f}% of reachable)") | |
| print(f" MISS : individual={len(missed_ind)}, prepackaged={len(missed_pre)} " | |
| f"(unreachable β not in our index)") | |
| for s in sorted(gold_set): | |
| mark = "β" if s in hit_set else ("Β·" if s in gold_pre else "β") | |
| tag = "PREPKG" if s in gold_pre else "IND" | |
| print(f" [{mark}] [{tag:6s}] {s}") | |
| print() | |
| print("=" * 80) | |
| print("OVERALL") | |
| n = len(queries) | |
| print(f" queries: {n}") | |
| print(f" total gold labels: {totals['gold']}") | |
| print(f" hits (correct): {totals['hits']}") | |
| print(f" missed (individual): {totals['missed_ind']} (in index, retriever didn't surface)") | |
| print(f" missed (pre-packaged): {totals['missed_prepkg']} (NOT in index β structurally unreachable)") | |
| print(f" raw recall: {totals['hits']/totals['gold']:.3f}") | |
| reachable_total = totals['gold'] - totals['missed_prepkg'] | |
| print(f" reachable recall: {totals['hits']/reachable_total:.3f}") | |
| if __name__ == "__main__": | |
| main() | |