File size: 3,872 Bytes
5a3b322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from __future__ import annotations
import argparse
import json
import subprocess
import sys
from pathlib import Path
def run_cmd(cmd: list[str]) -> None:
print(">>", " ".join(cmd))
subprocess.run(cmd, check=True)
def main():
parser = argparse.ArgumentParser(description="Run ablation suite for retrieval settings.")
parser.add_argument("--catalog", default="data/catalog_docs.jsonl")
parser.add_argument("--train", default="data/Gen_AI Dataset.xlsx")
parser.add_argument("--vector-index", default="data/faiss_index/index.faiss")
parser.add_argument("--assessment-ids", default="data/embeddings/assessment_ids.json")
parser.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2")
parser.add_argument("--topn-list", default="50,100,200,400", help="Comma-separated topn candidates to test")
args = parser.parse_args()
topn_vals = [int(x) for x in args.topn_list.split(",") if x.strip()]
runs_dir = Path("runs/ablation")
runs_dir.mkdir(parents=True, exist_ok=True)
summary = []
for topn in topn_vals:
# BM25
run_cmd(
[
sys.executable,
"-m",
"eval.run_eval",
"--catalog",
args.catalog,
"--train",
args.train,
"--recommender",
"bm25",
"--topn-candidates",
str(topn),
"--out-dir",
str(runs_dir / f"bm25_top{topn}"),
]
)
# Vector
run_cmd(
[
sys.executable,
"-m",
"eval.run_eval",
"--catalog",
args.catalog,
"--train",
args.train,
"--recommender",
"vector",
"--vector-index",
args.vector_index,
"--assessment-ids",
args.assessment_ids,
"--model",
args.model,
"--topn-candidates",
str(topn),
"--out-dir",
str(runs_dir / f"vector_top{topn}"),
]
)
# Hybrid RRF
run_cmd(
[
sys.executable,
"-m",
"eval.run_eval",
"--catalog",
args.catalog,
"--train",
args.train,
"--recommender",
"hybrid_rrf",
"--vector-index",
args.vector_index,
"--assessment-ids",
args.assessment_ids,
"--model",
args.model,
"--topn-candidates",
str(topn),
"--rrf-k",
"60",
"--out-dir",
str(runs_dir / f"hybrid_rrf_top{topn}"),
]
)
# Collect metrics
for name in ["bm25", "vector", "hybrid_rrf"]:
mpath = runs_dir / f"{name}_top{topn}" / "metrics.json"
if mpath.exists():
with open(mpath) as f:
metrics = json.load(f)
summary.append(
{
"variant": f"{name}_top{topn}",
"train_recall@10": metrics["train"]["recall@10"],
"val_recall@10": metrics["val"]["recall@10"],
"train_mrr@10": metrics["train"]["mrr@10"],
"val_mrr@10": metrics["val"]["mrr@10"],
}
)
with open(runs_dir / "ablation_summary.json", "w") as f:
json.dump(summary, f, indent=2)
print("Ablation summary written to runs/ablation/ablation_summary.json")
if __name__ == "__main__":
main()
|