File size: 4,551 Bytes
870800f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Per-query breakdown of pipeline predictions vs gold labels.

For each unique query in train.csv:
  - Run the configured retriever (default: hybrid + RAG + train-knn LOO)
  - Compare predicted top-10 against gold URLs
  - Show gold split into:
      βœ“ HIT          predicted correctly
      βœ— MISS-IND     individual test we have indexed but missed
      βœ— MISS-PREPKG  pre-packaged label we don't index (unreachable)
"""
from __future__ import annotations

import argparse
import csv
import json
import time
from collections import OrderedDict
from pathlib import Path

from recsys.hybrid import (
    search as hybrid_search,
    search_with_rag as hybrid_rag_search,
)
from recsys.urls import slug

ROOT = Path(__file__).resolve().parent.parent
TRAIN = ROOT / "train.csv"
DOCS = ROOT / "data" / "documents.jsonl"


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=["hybrid", "hybrid_rag", "hybrid_rag_knn"],
                    default="hybrid_rag_knn")
    ap.add_argument("--k", type=int, default=10)
    ap.add_argument("--fanout", type=int, default=50)
    ap.add_argument("--sleep", type=float, default=2.0,
                    help="seconds between LLM-using queries")
    args = ap.parse_args()

    # Load train.csv β†’ query β†’ gold slugs
    queries: "OrderedDict[str, list[str]]" = OrderedDict()
    with TRAIN.open(encoding="utf-8") as f:
        r = csv.reader(f); next(r, None)
        for row in r:
            if len(row) < 2 or not row[0] or not row[1]:
                continue
            s = slug(row[1])
            if s:
                queries.setdefault(row[0], []).append(s)

    # Load known slugs (what's actually in our index)
    have = {json.loads(l)["id"] for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()}

    # Choose searcher
    use_knn = args.mode == "hybrid_rag_knn"
    if args.mode == "hybrid":
        run = lambda q: hybrid_search(q, k=args.k, fanout=args.fanout)
    elif args.mode == "hybrid_rag":
        run = lambda q: hybrid_rag_search(q, k=args.k, fanout=args.fanout)
    else:
        run = lambda q: hybrid_rag_search(
            q, k=args.k, fanout=args.fanout,
            use_train_knn=True, train_knn_loo=True,
        )

    print(f"mode={args.mode}  K={args.k}  fanout={args.fanout}\n")

    totals = {"gold": 0, "hits": 0, "missed_ind": 0, "missed_prepkg": 0}

    for i, (q, gold_slugs) in enumerate(queries.items(), 1):
        if i > 1 and args.mode in ("hybrid_rag", "hybrid_rag_knn"):
            time.sleep(args.sleep)

        gold_set = set(gold_slugs)
        gold_pre = {s for s in gold_set if s not in have}
        gold_ind = gold_set - gold_pre

        hits = run(q)
        pred = {h["id"] for h in hits}
        hit_set = pred & gold_set
        missed = gold_set - pred
        missed_pre = missed & gold_pre
        missed_ind = missed - gold_pre

        totals["gold"] += len(gold_set)
        totals["hits"] += len(hit_set)
        totals["missed_ind"] += len(missed_ind)
        totals["missed_prepkg"] += len(missed_pre)

        print(f"=== Q{i}: {q[:80].replace(chr(10),' / ')!r} ===")
        print(f"    gold: {len(gold_set):>2}  "
              f"(individual: {len(gold_ind)}, pre-packaged: {len(gold_pre)})")
        print(f"    HIT  : {len(hit_set):>2}  "
              f"({100*len(hit_set)/len(gold_set):.0f}% of gold, "
              f"{100*len(hit_set)/len(gold_ind) if gold_ind else 0:.0f}% of reachable)")
        print(f"    MISS : individual={len(missed_ind)}, prepackaged={len(missed_pre)} "
              f"(unreachable β€” not in our index)")
        for s in sorted(gold_set):
            mark = "βœ“" if s in hit_set else ("Β·" if s in gold_pre else "βœ—")
            tag = "PREPKG" if s in gold_pre else "IND"
            print(f"      [{mark}] [{tag:6s}] {s}")
        print()

    print("=" * 80)
    print("OVERALL")
    n = len(queries)
    print(f"  queries:                 {n}")
    print(f"  total gold labels:       {totals['gold']}")
    print(f"  hits (correct):          {totals['hits']}")
    print(f"  missed (individual):     {totals['missed_ind']}  (in index, retriever didn't surface)")
    print(f"  missed (pre-packaged):   {totals['missed_prepkg']}  (NOT in index β€” structurally unreachable)")
    print(f"  raw recall:              {totals['hits']/totals['gold']:.3f}")
    reachable_total = totals['gold'] - totals['missed_prepkg']
    print(f"  reachable recall:        {totals['hits']/reachable_total:.3f}")


if __name__ == "__main__":
    main()