"""Generate the test-set CSV for SHL submission (Appendix 3 format). Reads test.csv (9 unlabeled queries), runs the chosen retriever, and writes predictions to data/predictions.csv: Query,Assessment_url Query 1,url1 Query 1,url2 ... Query 1,url10 Query 2,url1 ... Three modes match the three system methods: baseline Method 1: pure dense embedding retrieval hybrid_rag Method 3 (default, current best): hybrid + LLM rerank over the concept-enriched index hybrid_rag_knn Method 2: hybrid + LLM rerank + train-query similarity (no LOO — test queries are unseen anyway) """ from __future__ import annotations import argparse import csv import time from pathlib import Path from recsys.hybrid import search_with_rag as hybrid_rag_search from recsys.retrieve import search as dense_search ROOT = Path(__file__).resolve().parent.parent TEST = ROOT / "test.csv" OUT = ROOT / "data" / "predictions.csv" SEARCHERS = { "baseline": lambda q, k, fanout: dense_search(q, k=k), "hybrid_rag": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout), "hybrid_rag_knn": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout, use_train_knn=True), } def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--mode", choices=list(SEARCHERS), default="hybrid_rag") ap.add_argument("--k", type=int, default=10) ap.add_argument("--fanout", type=int, default=50) ap.add_argument("--sleep", type=float, default=2.0, help="seconds between LLM-using queries (rate-limit pacing)") args = ap.parse_args() queries: list[str] = [] with TEST.open(encoding="utf-8") as f: r = csv.reader(f); next(r, None) for row in r: if not row or not row[0]: continue queries.append(row[0]) print(f"loaded {len(queries)} test queries; mode={args.mode}; k={args.k}") uses_llm = args.mode in {"hybrid_rag", "hybrid_rag_knn"} searcher = SEARCHERS[args.mode] OUT.parent.mkdir(parents=True, exist_ok=True) with OUT.open("w", newline="", encoding="utf-8") as f: w = csv.writer(f) w.writerow(["Query", "Assessment_url"]) for i, q in enumerate(queries, 1): if uses_llm and i > 1: time.sleep(args.sleep) print(f" [{i}/{len(queries)}] {q[:80]}") hits = searcher(q, args.k, args.fanout) if not hits: print(f" WARNING: no hits for query {i}") continue for h in hits: w.writerow([q, h["url"]]) print(f"\nwrote {OUT}") with OUT.open(encoding="utf-8") as f: rows = list(csv.reader(f)) print(f"rows: {len(rows) - 1} ({len(queries)} queries × ~{args.k} preds)") if __name__ == "__main__": main()