"""Per-query breakdown of pipeline predictions vs gold labels. For each unique query in train.csv: - Run the configured retriever (default: hybrid + RAG + train-knn LOO) - Compare predicted top-10 against gold URLs - Show gold split into: ✓ HIT predicted correctly ✗ MISS-IND individual test we have indexed but missed ✗ MISS-PREPKG pre-packaged label we don't index (unreachable) """ from __future__ import annotations import argparse import csv import json import time from collections import OrderedDict from pathlib import Path from recsys.hybrid import ( search as hybrid_search, search_with_rag as hybrid_rag_search, ) from recsys.urls import slug ROOT = Path(__file__).resolve().parent.parent TRAIN = ROOT / "train.csv" DOCS = ROOT / "data" / "documents.jsonl" def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--mode", choices=["hybrid", "hybrid_rag", "hybrid_rag_knn"], default="hybrid_rag_knn") ap.add_argument("--k", type=int, default=10) ap.add_argument("--fanout", type=int, default=50) ap.add_argument("--sleep", type=float, default=2.0, help="seconds between LLM-using queries") args = ap.parse_args() # Load train.csv → query → gold slugs queries: "OrderedDict[str, list[str]]" = OrderedDict() with TRAIN.open(encoding="utf-8") as f: r = csv.reader(f); next(r, None) for row in r: if len(row) < 2 or not row[0] or not row[1]: continue s = slug(row[1]) if s: queries.setdefault(row[0], []).append(s) # Load known slugs (what's actually in our index) have = {json.loads(l)["id"] for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()} # Choose searcher use_knn = args.mode == "hybrid_rag_knn" if args.mode == "hybrid": run = lambda q: hybrid_search(q, k=args.k, fanout=args.fanout) elif args.mode == "hybrid_rag": run = lambda q: hybrid_rag_search(q, k=args.k, fanout=args.fanout) else: run = lambda q: hybrid_rag_search( q, k=args.k, fanout=args.fanout, use_train_knn=True, train_knn_loo=True, ) print(f"mode={args.mode} K={args.k} fanout={args.fanout}\n") totals = {"gold": 0, "hits": 0, "missed_ind": 0, "missed_prepkg": 0} for i, (q, gold_slugs) in enumerate(queries.items(), 1): if i > 1 and args.mode in ("hybrid_rag", "hybrid_rag_knn"): time.sleep(args.sleep) gold_set = set(gold_slugs) gold_pre = {s for s in gold_set if s not in have} gold_ind = gold_set - gold_pre hits = run(q) pred = {h["id"] for h in hits} hit_set = pred & gold_set missed = gold_set - pred missed_pre = missed & gold_pre missed_ind = missed - gold_pre totals["gold"] += len(gold_set) totals["hits"] += len(hit_set) totals["missed_ind"] += len(missed_ind) totals["missed_prepkg"] += len(missed_pre) print(f"=== Q{i}: {q[:80].replace(chr(10),' / ')!r} ===") print(f" gold: {len(gold_set):>2} " f"(individual: {len(gold_ind)}, pre-packaged: {len(gold_pre)})") print(f" HIT : {len(hit_set):>2} " f"({100*len(hit_set)/len(gold_set):.0f}% of gold, " f"{100*len(hit_set)/len(gold_ind) if gold_ind else 0:.0f}% of reachable)") print(f" MISS : individual={len(missed_ind)}, prepackaged={len(missed_pre)} " f"(unreachable — not in our index)") for s in sorted(gold_set): mark = "✓" if s in hit_set else ("·" if s in gold_pre else "✗") tag = "PREPKG" if s in gold_pre else "IND" print(f" [{mark}] [{tag:6s}] {s}") print() print("=" * 80) print("OVERALL") n = len(queries) print(f" queries: {n}") print(f" total gold labels: {totals['gold']}") print(f" hits (correct): {totals['hits']}") print(f" missed (individual): {totals['missed_ind']} (in index, retriever didn't surface)") print(f" missed (pre-packaged): {totals['missed_prepkg']} (NOT in index — structurally unreachable)") print(f" raw recall: {totals['hits']/totals['gold']:.3f}") reachable_total = totals['gold'] - totals['missed_prepkg'] print(f" reachable recall: {totals['hits']/reachable_total:.3f}") if __name__ == "__main__": main()