shl-recommender-api / scripts /per_query_breakdown.py
pankaj
SHL recommender β€” initial deploy
870800f
"""Per-query breakdown of pipeline predictions vs gold labels.
For each unique query in train.csv:
- Run the configured retriever (default: hybrid + RAG + train-knn LOO)
- Compare predicted top-10 against gold URLs
- Show gold split into:
βœ“ HIT predicted correctly
βœ— MISS-IND individual test we have indexed but missed
βœ— MISS-PREPKG pre-packaged label we don't index (unreachable)
"""
from __future__ import annotations
import argparse
import csv
import json
import time
from collections import OrderedDict
from pathlib import Path
from recsys.hybrid import (
search as hybrid_search,
search_with_rag as hybrid_rag_search,
)
from recsys.urls import slug
ROOT = Path(__file__).resolve().parent.parent
TRAIN = ROOT / "train.csv"
DOCS = ROOT / "data" / "documents.jsonl"
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--mode", choices=["hybrid", "hybrid_rag", "hybrid_rag_knn"],
default="hybrid_rag_knn")
ap.add_argument("--k", type=int, default=10)
ap.add_argument("--fanout", type=int, default=50)
ap.add_argument("--sleep", type=float, default=2.0,
help="seconds between LLM-using queries")
args = ap.parse_args()
# Load train.csv β†’ query β†’ gold slugs
queries: "OrderedDict[str, list[str]]" = OrderedDict()
with TRAIN.open(encoding="utf-8") as f:
r = csv.reader(f); next(r, None)
for row in r:
if len(row) < 2 or not row[0] or not row[1]:
continue
s = slug(row[1])
if s:
queries.setdefault(row[0], []).append(s)
# Load known slugs (what's actually in our index)
have = {json.loads(l)["id"] for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()}
# Choose searcher
use_knn = args.mode == "hybrid_rag_knn"
if args.mode == "hybrid":
run = lambda q: hybrid_search(q, k=args.k, fanout=args.fanout)
elif args.mode == "hybrid_rag":
run = lambda q: hybrid_rag_search(q, k=args.k, fanout=args.fanout)
else:
run = lambda q: hybrid_rag_search(
q, k=args.k, fanout=args.fanout,
use_train_knn=True, train_knn_loo=True,
)
print(f"mode={args.mode} K={args.k} fanout={args.fanout}\n")
totals = {"gold": 0, "hits": 0, "missed_ind": 0, "missed_prepkg": 0}
for i, (q, gold_slugs) in enumerate(queries.items(), 1):
if i > 1 and args.mode in ("hybrid_rag", "hybrid_rag_knn"):
time.sleep(args.sleep)
gold_set = set(gold_slugs)
gold_pre = {s for s in gold_set if s not in have}
gold_ind = gold_set - gold_pre
hits = run(q)
pred = {h["id"] for h in hits}
hit_set = pred & gold_set
missed = gold_set - pred
missed_pre = missed & gold_pre
missed_ind = missed - gold_pre
totals["gold"] += len(gold_set)
totals["hits"] += len(hit_set)
totals["missed_ind"] += len(missed_ind)
totals["missed_prepkg"] += len(missed_pre)
print(f"=== Q{i}: {q[:80].replace(chr(10),' / ')!r} ===")
print(f" gold: {len(gold_set):>2} "
f"(individual: {len(gold_ind)}, pre-packaged: {len(gold_pre)})")
print(f" HIT : {len(hit_set):>2} "
f"({100*len(hit_set)/len(gold_set):.0f}% of gold, "
f"{100*len(hit_set)/len(gold_ind) if gold_ind else 0:.0f}% of reachable)")
print(f" MISS : individual={len(missed_ind)}, prepackaged={len(missed_pre)} "
f"(unreachable β€” not in our index)")
for s in sorted(gold_set):
mark = "βœ“" if s in hit_set else ("Β·" if s in gold_pre else "βœ—")
tag = "PREPKG" if s in gold_pre else "IND"
print(f" [{mark}] [{tag:6s}] {s}")
print()
print("=" * 80)
print("OVERALL")
n = len(queries)
print(f" queries: {n}")
print(f" total gold labels: {totals['gold']}")
print(f" hits (correct): {totals['hits']}")
print(f" missed (individual): {totals['missed_ind']} (in index, retriever didn't surface)")
print(f" missed (pre-packaged): {totals['missed_prepkg']} (NOT in index β€” structurally unreachable)")
print(f" raw recall: {totals['hits']/totals['gold']:.3f}")
reachable_total = totals['gold'] - totals['missed_prepkg']
print(f" reachable recall: {totals['hits']/reachable_total:.3f}")
if __name__ == "__main__":
main()