Spaces:
Sleeping
Sleeping
| """Generate the test-set CSV for SHL submission (Appendix 3 format). | |
| Reads test.csv (9 unlabeled queries), runs the chosen retriever, and writes | |
| predictions to data/predictions.csv: | |
| Query,Assessment_url | |
| Query 1,url1 | |
| Query 1,url2 | |
| ... | |
| Query 1,url10 | |
| Query 2,url1 | |
| ... | |
| Three modes match the three system methods: | |
| baseline Method 1: pure dense embedding retrieval | |
| hybrid_rag Method 3 (default, current best): hybrid + LLM rerank | |
| over the concept-enriched index | |
| hybrid_rag_knn Method 2: hybrid + LLM rerank + train-query similarity | |
| (no LOO — test queries are unseen anyway) | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import time | |
| from pathlib import Path | |
| from recsys.hybrid import search_with_rag as hybrid_rag_search | |
| from recsys.retrieve import search as dense_search | |
| ROOT = Path(__file__).resolve().parent.parent | |
| TEST = ROOT / "test.csv" | |
| OUT = ROOT / "data" / "predictions.csv" | |
| SEARCHERS = { | |
| "baseline": lambda q, k, fanout: dense_search(q, k=k), | |
| "hybrid_rag": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout), | |
| "hybrid_rag_knn": lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout, use_train_knn=True), | |
| } | |
| def main() -> None: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--mode", choices=list(SEARCHERS), default="hybrid_rag") | |
| ap.add_argument("--k", type=int, default=10) | |
| ap.add_argument("--fanout", type=int, default=50) | |
| ap.add_argument("--sleep", type=float, default=2.0, | |
| help="seconds between LLM-using queries (rate-limit pacing)") | |
| args = ap.parse_args() | |
| queries: list[str] = [] | |
| with TEST.open(encoding="utf-8") as f: | |
| r = csv.reader(f); next(r, None) | |
| for row in r: | |
| if not row or not row[0]: | |
| continue | |
| queries.append(row[0]) | |
| print(f"loaded {len(queries)} test queries; mode={args.mode}; k={args.k}") | |
| uses_llm = args.mode in {"hybrid_rag", "hybrid_rag_knn"} | |
| searcher = SEARCHERS[args.mode] | |
| OUT.parent.mkdir(parents=True, exist_ok=True) | |
| with OUT.open("w", newline="", encoding="utf-8") as f: | |
| w = csv.writer(f) | |
| w.writerow(["Query", "Assessment_url"]) | |
| for i, q in enumerate(queries, 1): | |
| if uses_llm and i > 1: | |
| time.sleep(args.sleep) | |
| print(f" [{i}/{len(queries)}] {q[:80]}") | |
| hits = searcher(q, args.k, args.fanout) | |
| if not hits: | |
| print(f" WARNING: no hits for query {i}") | |
| continue | |
| for h in hits: | |
| w.writerow([q, h["url"]]) | |
| print(f"\nwrote {OUT}") | |
| with OUT.open(encoding="utf-8") as f: | |
| rows = list(csv.reader(f)) | |
| print(f"rows: {len(rows) - 1} ({len(queries)} queries × ~{args.k} preds)") | |
| if __name__ == "__main__": | |
| main() | |