File size: 2,896 Bytes
870800f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Generate the test-set CSV for SHL submission (Appendix 3 format).

Reads test.csv (9 unlabeled queries), runs the chosen retriever, and writes
predictions to data/predictions.csv:

    Query,Assessment_url
    Query 1,url1
    Query 1,url2
    ...
    Query 1,url10
    Query 2,url1
    ...

Three modes match the three system methods:

  baseline         Method 1: pure dense embedding retrieval
  hybrid_rag       Method 3 (default, current best): hybrid + LLM rerank
                   over the concept-enriched index
  hybrid_rag_knn   Method 2: hybrid + LLM rerank + train-query similarity
                   (no LOO — test queries are unseen anyway)
"""
from __future__ import annotations

import argparse
import csv
import time
from pathlib import Path

from recsys.hybrid import search_with_rag as hybrid_rag_search
from recsys.retrieve import search as dense_search

ROOT = Path(__file__).resolve().parent.parent
TEST = ROOT / "test.csv"
OUT = ROOT / "data" / "predictions.csv"


SEARCHERS = {
    "baseline":         lambda q, k, fanout: dense_search(q, k=k),
    "hybrid_rag":       lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout),
    "hybrid_rag_knn":   lambda q, k, fanout: hybrid_rag_search(q, k=k, fanout=fanout, use_train_knn=True),
}


def main() -> None:
    ap = argparse.ArgumentParser()
    ap.add_argument("--mode", choices=list(SEARCHERS), default="hybrid_rag")
    ap.add_argument("--k", type=int, default=10)
    ap.add_argument("--fanout", type=int, default=50)
    ap.add_argument("--sleep", type=float, default=2.0,
                    help="seconds between LLM-using queries (rate-limit pacing)")
    args = ap.parse_args()

    queries: list[str] = []
    with TEST.open(encoding="utf-8") as f:
        r = csv.reader(f); next(r, None)
        for row in r:
            if not row or not row[0]:
                continue
            queries.append(row[0])
    print(f"loaded {len(queries)} test queries; mode={args.mode}; k={args.k}")

    uses_llm = args.mode in {"hybrid_rag", "hybrid_rag_knn"}
    searcher = SEARCHERS[args.mode]

    OUT.parent.mkdir(parents=True, exist_ok=True)
    with OUT.open("w", newline="", encoding="utf-8") as f:
        w = csv.writer(f)
        w.writerow(["Query", "Assessment_url"])
        for i, q in enumerate(queries, 1):
            if uses_llm and i > 1:
                time.sleep(args.sleep)
            print(f"  [{i}/{len(queries)}] {q[:80]}")
            hits = searcher(q, args.k, args.fanout)
            if not hits:
                print(f"    WARNING: no hits for query {i}")
                continue
            for h in hits:
                w.writerow([q, h["url"]])

    print(f"\nwrote {OUT}")
    with OUT.open(encoding="utf-8") as f:
        rows = list(csv.reader(f))
    print(f"rows: {len(rows) - 1} ({len(queries)} queries × ~{args.k} preds)")


if __name__ == "__main__":
    main()