shl-recommender-api / scripts /predict_test.py
pankaj
SHL recommender — initial deploy
870800f
"""Generate the test-set CSV for SHL submission (Appendix 3 format).
Reads test.csv (9 unlabeled queries), runs the chosen retriever, and writes
predictions to data/predictions.csv:
Query,Assessment_url
Query 1,url1
Query 1,url2
...
Query 1,url10
Query 2,url1
...
Three modes match the three system methods:
baseline Method 1: pure dense embedding retrieval
hybrid_rag Method 3 (default, current best): hybrid + LLM rerank
over the concept-enriched index
hybrid_rag_knn Method 2: hybrid + LLM rerank + train-query similarity
(no LOO — test queries are unseen anyway)
"""
from __future__ import annotations
import argparse
import csv
import time
from pathlib import Path
from recsys.hybrid import search_with_rag as hybrid_rag_search
from recsys.retrieve import search as dense_search
ROOT = Path(__file__).resolve().parent.parent
TEST = ROOT / "test.csv"
OUT = ROOT / "data" / "predictions.csv"
SEARCHERS = {
"baseline": lambda q, k, fanout: dense_search(q, k=k),
"hybrid_rag": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout),
"hybrid_rag_knn": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout, use_train_knn=True),
}
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--mode", choices=list(SEARCHERS), default="hybrid_rag")
ap.add_argument("--k", type=int, default=10)
ap.add_argument("--fanout", type=int, default=50)
ap.add_argument("--sleep", type=float, default=2.0,
help="seconds between LLM-using queries (rate-limit pacing)")
args = ap.parse_args()
queries: list[str] = []
with TEST.open(encoding="utf-8") as f:
r = csv.reader(f); next(r, None)
for row in r:
if not row or not row[0]:
continue
queries.append(row[0])
print(f"loaded {len(queries)} test queries; mode={args.mode}; k={args.k}")
uses_llm = args.mode in {"hybrid_rag", "hybrid_rag_knn"}
searcher = SEARCHERS[args.mode]
OUT.parent.mkdir(parents=True, exist_ok=True)
with OUT.open("w", newline="", encoding="utf-8") as f:
w = csv.writer(f)
w.writerow(["Query", "Assessment_url"])
for i, q in enumerate(queries, 1):
if uses_llm and i > 1:
time.sleep(args.sleep)
print(f" [{i}/{len(queries)}] {q[:80]}")
hits = searcher(q, args.k, args.fanout)
if not hits:
print(f" WARNING: no hits for query {i}")
continue
for h in hits:
w.writerow([q, h["url"]])
print(f"\nwrote {OUT}")
with OUT.open(encoding="utf-8") as f:
rows = list(csv.reader(f))
print(f"rows: {len(rows) - 1} ({len(queries)} queries × ~{args.k} preds)")
if __name__ == "__main__":
main()