File size: 3,872 Bytes
5a3b322
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from __future__ import annotations

import argparse
import json
import subprocess
import sys
from pathlib import Path


def run_cmd(cmd: list[str]) -> None:
    print(">>", " ".join(cmd))
    subprocess.run(cmd, check=True)


def main():
    parser = argparse.ArgumentParser(description="Run ablation suite for retrieval settings.")
    parser.add_argument("--catalog", default="data/catalog_docs.jsonl")
    parser.add_argument("--train", default="data/Gen_AI Dataset.xlsx")
    parser.add_argument("--vector-index", default="data/faiss_index/index.faiss")
    parser.add_argument("--assessment-ids", default="data/embeddings/assessment_ids.json")
    parser.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
    parser.add_argument("--topn-list", default="50,100,200,400", help="Comma-separated topn candidates to test")
    args = parser.parse_args()

    topn_vals = [int(x) for x in args.topn_list.split(",") if x.strip()]
    runs_dir = Path("runs/ablation")
    runs_dir.mkdir(parents=True, exist_ok=True)
    summary = []

    for topn in topn_vals:
        # BM25
        run_cmd(
            [
                sys.executable,
                "-m",
                "eval.run_eval",
                "--catalog",
                args.catalog,
                "--train",
                args.train,
                "--recommender",
                "bm25",
                "--topn-candidates",
                str(topn),
                "--out-dir",
                str(runs_dir / f"bm25_top{topn}"),
            ]
        )
        # Vector
        run_cmd(
            [
                sys.executable,
                "-m",
                "eval.run_eval",
                "--catalog",
                args.catalog,
                "--train",
                args.train,
                "--recommender",
                "vector",
                "--vector-index",
                args.vector_index,
                "--assessment-ids",
                args.assessment_ids,
                "--model",
                args.model,
                "--topn-candidates",
                str(topn),
                "--out-dir",
                str(runs_dir / f"vector_top{topn}"),
            ]
        )
        # Hybrid RRF
        run_cmd(
            [
                sys.executable,
                "-m",
                "eval.run_eval",
                "--catalog",
                args.catalog,
                "--train",
                args.train,
                "--recommender",
                "hybrid_rrf",
                "--vector-index",
                args.vector_index,
                "--assessment-ids",
                args.assessment_ids,
                "--model",
                args.model,
                "--topn-candidates",
                str(topn),
                "--rrf-k",
                "60",
                "--out-dir",
                str(runs_dir / f"hybrid_rrf_top{topn}"),
            ]
        )

        # Collect metrics
        for name in ["bm25", "vector", "hybrid_rrf"]:
            mpath = runs_dir / f"{name}_top{topn}" / "metrics.json"
            if mpath.exists():
                with open(mpath) as f:
                    metrics = json.load(f)
                summary.append(
                    {
                        "variant": f"{name}_top{topn}",
                        "train_recall@10": metrics["train"]["recall@10"],
                        "val_recall@10": metrics["val"]["recall@10"],
                        "train_mrr@10": metrics["train"]["mrr@10"],
                        "val_mrr@10": metrics["val"]["mrr@10"],
                    }
                )

    with open(runs_dir / "ablation_summary.json", "w") as f:
        json.dump(summary, f, indent=2)
    print("Ablation summary written to runs/ablation/ablation_summary.json")


if __name__ == "__main__":
    main()