shl-recommender-api / scripts /eval_recall.py
pankaj
SHL recommender — initial deploy
870800f
"""Evaluate retrieval on the train set.
For each unique query in train.csv, retrieve top-K, compute Recall@K against
the labeled URLs. Reports both raw recall and 'reachable' recall (treating
labels not present in our index as if they don't exist) so corpus coverage
doesn't get conflated with retrieval quality.
Three methods supported:
-- Method 1: Baseline — pure dense embedding retrieval
--hybrid --rag Method 3: Latest — hybrid (dense + BM25 + RRF) +
LLM rerank. Uses concept-enriched index when
data/products_with_concepts.jsonl is present.
--hybrid --rag --train-knn Method 2: LLM + RAG + train — same as Method 3
plus a 4th ranklist that borrows gold URLs from
the most semantically-similar train queries
(leave-one-out enabled here for honest eval).
"""
from __future__ import annotations
import argparse
import csv
import json
import time
from collections import defaultdict
from pathlib import Path
from recsys.retrieve import search
from recsys.hybrid import (
search as hybrid_search,
search_with_rag as hybrid_rag_search,
)
from recsys.urls import slug
ROOT = Path(__file__).resolve().parent.parent
TRAIN = ROOT / "train.csv"
DOCS = ROOT / "data" / "documents.jsonl"
def load_queries(path: Path) -> dict[str, list[str]]:
"""Group label URLs by query string. Preserves first-seen order of queries."""
out: dict[str, list[str]] = defaultdict(list)
with path.open(encoding="utf-8") as f:
reader = csv.reader(f)
next(reader, None) # header
for row in reader:
if len(row) < 2 or not row[0] or not row[1]:
continue
out[row[0]].append(row[1])
return out
def known_slugs() -> set[str]:
docs = [json.loads(l) for l in DOCS.read_text(encoding="utf-8").splitlines() if l.strip()]
return {d["id"] for d in docs}
def recall_at_k(predicted_slugs: list[str], gold_slugs: set[str], k: int) -> float:
if not gold_slugs:
return 0.0
top = set(predicted_slugs[:k])
return len(top & gold_slugs) / len(gold_slugs)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--k", type=int, default=10)
ap.add_argument("--rag", action="store_true", help="use LLM reranker")
ap.add_argument("--hybrid", action="store_true", help="dense + BM25 with RRF fusion")
ap.add_argument("--train-knn", action="store_true",
help="borrow gold URLs from similar train queries (LOO on train eval)")
ap.add_argument("--fanout", type=int, default=50, help="candidates fed to reranker")
ap.add_argument("--show-misses", action="store_true")
args = ap.parse_args()
queries = load_queries(TRAIN)
have = known_slugs()
if args.train_knn:
args.hybrid = True
if args.hybrid and args.rag:
mode = "hybrid+RAG" + ("+train-knn" if args.train_knn else "")
elif args.hybrid:
mode = "hybrid" + ("+train-knn" if args.train_knn else "")
elif args.rag:
mode = "RAG"
else:
mode = "baseline"
print(f"mode={mode} | queries: {len(queries)} | k={args.k} | corpus: {len(have)} slugs\n")
rows: list[tuple[str, int, int, int, float, float]] = []
first = True
for q, urls in queries.items():
if args.rag and not first:
time.sleep(8.0) # stay under LLM free-tier per-minute chat quota
first = False
gold = {slug(u) for u in urls if slug(u)}
reachable = {s for s in gold if s in have}
extras: dict = {}
if args.train_knn:
extras.update(use_train_knn=True, train_knn_loo=True) # LOO on train
if args.hybrid and args.rag:
hits = hybrid_rag_search(q, k=args.k, fanout=args.fanout, **extras)
elif args.hybrid:
hits = hybrid_search(q, k=args.k, fanout=args.fanout, **extras)
else:
hits = search(q, k=args.k)
pred = [h["id"] for h in hits]
raw = recall_at_k(pred, gold, args.k)
reach = recall_at_k(pred, reachable, args.k) if reachable else 0.0
rows.append((q, len(gold), len(reachable), len(set(pred[:args.k]) & gold), raw, reach))
if args.show_misses:
missed = gold - set(pred[: args.k])
if missed:
missed_str = ", ".join(sorted(missed))
print(f" [{q[:60]!r}]\n missed: {missed_str}")
print(f"{'#':>2} {'gold':>4} {'reach':>5} {'hits':>4} {'R@K(raw)':>8} {'R@K(reach)':>10} query")
print("-" * 110)
for i, (q, g, r, h, raw, reach) in enumerate(rows, 1):
print(f"{i:2d} {g:4d} {r:5d} {h:4d} {raw:8.3f} {reach:10.3f} {q[:60]}")
n = len(rows)
mean_raw = sum(r[4] for r in rows) / n if n else 0.0
mean_reach = sum(r[5] for r in rows) / n if n else 0.0
print("-" * 110)
print(f"mean Recall@{args.k} (raw): {mean_raw:.3f}")
print(f"mean Recall@{args.k} (reachable): {mean_reach:.3f}")
if __name__ == "__main__":
main()