Prompt_Squirrel_RAG / scripts /run_t5_sweep.py
Food Desert
Roll out T5 rewrite updates, tooling, docs, and artifact ignore rules
34c53b5
from __future__ import annotations
import argparse
import csv
import json
import subprocess
import sys
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Sequence
REPO_ROOT = Path(__file__).resolve().parents[1]
TRAIN_SCRIPT = REPO_ROOT / "scripts" / "train_t5_rewrite.py"
@dataclass
class SweepConfig:
lr: float
length_penalty: float
beams: int
@property
def name(self) -> str:
lr_txt = f"{self.lr:.0e}".replace("-", "m")
lp_txt = str(self.length_penalty).replace(".", "p")
return f"lr_{lr_txt}__lp_{lp_txt}__b_{self.beams}"
def _parse_csv_floats(text: str) -> List[float]:
out: List[float] = []
for raw in text.split(","):
raw = raw.strip()
if not raw:
continue
out.append(float(raw))
if not out:
raise ValueError("expected at least one float value")
return out
def _read_metrics(path: Path) -> Dict[str, float]:
obj = json.loads(path.read_text(encoding="utf-8"))
test = obj.get("test", {})
val = obj.get("val", {})
train = obj.get("train", {})
return {
"test_recall": float(test.get("test_set_recall", 0.0) or 0.0),
"test_f1": float(test.get("test_set_f1", 0.0) or 0.0),
"test_precision": float(test.get("test_set_precision", 0.0) or 0.0),
"test_loss": float(test.get("test_loss", 0.0) or 0.0),
"val_recall": float(val.get("eval_set_recall", 0.0) or 0.0),
"val_f1": float(val.get("eval_set_f1", 0.0) or 0.0),
"val_precision": float(val.get("eval_set_precision", 0.0) or 0.0),
"val_loss": float(val.get("eval_loss", 0.0) or 0.0),
"train_runtime_s": float(train.get("train_runtime", 0.0) or 0.0),
"train_loss": float(train.get("train_loss", 0.0) or 0.0),
}
def _sort_rows(rows: Sequence[Dict[str, object]], primary: str) -> List[Dict[str, object]]:
return sorted(
rows,
key=lambda r: (
float(r.get(primary, 0.0) or 0.0),
float(r.get("test_f1", 0.0) or 0.0),
float(r.get("test_precision", 0.0) or 0.0),
),
reverse=True,
)
def _write_csv(path: Path, rows: Sequence[Dict[str, object]]) -> None:
if not rows:
return
keys = list(rows[0].keys())
with path.open("w", encoding="utf-8", newline="") as f:
w = csv.DictWriter(f, fieldnames=keys)
w.writeheader()
for row in rows:
w.writerow(row)
def _run_one(
cfg: SweepConfig,
*,
stage_name: str,
max_steps: int,
base_model_dir: Path,
split_dir: Path,
out_dir: Path,
runtime_dir: Path,
eval_steps: int,
test_eval_every_steps: int,
max_val_samples: int,
max_test_samples: int,
seed: int,
resume_if_available: bool,
force: bool,
) -> Dict[str, object]:
run_dir = out_dir / stage_name / cfg.name
run_dir.mkdir(parents=True, exist_ok=True)
metrics_path = run_dir / "train_metrics.json"
progress_file = runtime_dir / f"{stage_name}__{cfg.name}__progress.json"
history_file = runtime_dir / f"{stage_name}__{cfg.name}__history.jsonl"
if metrics_path.is_file() and not force:
metrics = _read_metrics(metrics_path)
return {
"stage": stage_name,
"config": cfg.name,
"lr": cfg.lr,
"length_penalty": cfg.length_penalty,
"num_beams": cfg.beams,
"max_steps": max_steps,
"status": "cached",
**metrics,
"output_dir": str(run_dir),
"metrics_path": str(metrics_path),
}
cmd = [
str(sys.executable),
str(TRAIN_SCRIPT),
"--split-dir",
str(split_dir),
"--base-model-dir",
str(base_model_dir),
"--output-dir",
str(run_dir),
"--max-steps",
str(max_steps),
"--eval-during-train",
"--eval-steps",
str(eval_steps),
"--test-eval-every-steps",
str(test_eval_every_steps),
"--save-steps",
str(eval_steps),
"--best-model-metric",
"recall",
"--generation-length-penalty",
str(cfg.length_penalty),
"--lr",
str(cfg.lr),
"--num-beams",
str(cfg.beams),
"--max-val-samples",
str(max_val_samples),
"--max-test-samples",
str(max_test_samples),
"--seed",
str(seed),
"--require-cuda",
"--fp16",
"--report-to",
"none",
"--progress-file",
str(progress_file),
"--progress-history-file",
str(history_file),
]
if resume_if_available:
cmd.append("--resume-if-available")
subprocess.run(cmd, cwd=str(REPO_ROOT), check=True)
metrics = _read_metrics(metrics_path)
return {
"stage": stage_name,
"config": cfg.name,
"lr": cfg.lr,
"length_penalty": cfg.length_penalty,
"num_beams": cfg.beams,
"max_steps": max_steps,
"status": "ran",
**metrics,
"output_dir": str(run_dir),
"metrics_path": str(metrics_path),
}
def main() -> int:
ap = argparse.ArgumentParser(description="Two-stage T5 sweep: fast screen then confirmation.")
ap.add_argument("--split-dir", type=Path, default=REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits")
ap.add_argument("--base-model-dir", type=Path, default=REPO_ROOT / "models" / "t5-small")
ap.add_argument("--sweep-out-dir", type=Path, default=REPO_ROOT / "models" / "finetune" / "t5-sweep")
ap.add_argument("--runtime-dir", type=Path, default=REPO_ROOT / "data" / "runtime_metrics" / "t5_sweep")
ap.add_argument("--analysis-dir", type=Path, default=REPO_ROOT / "data" / "analysis")
ap.add_argument("--lr-list", type=str, default="1e-4,2e-4")
ap.add_argument("--length-penalty-list", type=str, default="0.7,0.8,0.9")
ap.add_argument("--beams-list", type=str, default="4")
ap.add_argument("--stage1-steps", type=int, default=1000)
ap.add_argument("--stage2-steps", type=int, default=3750)
ap.add_argument("--top-k", type=int, default=2)
ap.add_argument("--eval-steps", type=int, default=500)
ap.add_argument("--test-eval-every-steps", type=int, default=1000)
ap.add_argument("--max-val-samples", type=int, default=128)
ap.add_argument("--max-test-samples", type=int, default=128)
ap.add_argument("--seed", type=int, default=42)
ap.add_argument("--primary-metric", type=str, default="test_recall", choices=["test_recall", "test_f1", "val_recall", "val_f1"])
ap.add_argument(
"--resume-if-available",
dest="resume_if_available",
action="store_true",
default=True,
help="Resume from latest checkpoint in each config output dir when available",
)
ap.add_argument(
"--no-resume-if-available",
dest="resume_if_available",
action="store_false",
help="Disable checkpoint resume and always start each config from step 0",
)
ap.add_argument("--force", action="store_true", default=False)
args = ap.parse_args()
split_dir = args.split_dir if args.split_dir.is_absolute() else (REPO_ROOT / args.split_dir).resolve()
base_model_dir = args.base_model_dir if args.base_model_dir.is_absolute() else (REPO_ROOT / args.base_model_dir).resolve()
sweep_out_dir = args.sweep_out_dir if args.sweep_out_dir.is_absolute() else (REPO_ROOT / args.sweep_out_dir).resolve()
runtime_dir = args.runtime_dir if args.runtime_dir.is_absolute() else (REPO_ROOT / args.runtime_dir).resolve()
analysis_dir = args.analysis_dir if args.analysis_dir.is_absolute() else (REPO_ROOT / args.analysis_dir).resolve()
runtime_dir.mkdir(parents=True, exist_ok=True)
analysis_dir.mkdir(parents=True, exist_ok=True)
sweep_out_dir.mkdir(parents=True, exist_ok=True)
lrs = _parse_csv_floats(args.lr_list)
lps = _parse_csv_floats(args.length_penalty_list)
beams_list = [int(x.strip()) for x in args.beams_list.split(",") if x.strip()]
if not beams_list:
raise ValueError("beams-list must contain at least one integer")
stage1_configs = [SweepConfig(lr=lr, length_penalty=lp, beams=b) for lr in lrs for lp in lps for b in beams_list]
print(f"Stage1 configs: {len(stage1_configs)}")
for c in stage1_configs:
print(" -", c.name)
print()
stage1_rows: List[Dict[str, object]] = []
for idx, cfg in enumerate(stage1_configs, start=1):
print(f"[stage1 {idx}/{len(stage1_configs)}] {cfg.name}")
row = _run_one(
cfg,
stage_name="stage1",
max_steps=args.stage1_steps,
base_model_dir=base_model_dir,
split_dir=split_dir,
out_dir=sweep_out_dir,
runtime_dir=runtime_dir,
eval_steps=args.eval_steps,
test_eval_every_steps=args.test_eval_every_steps,
max_val_samples=args.max_val_samples,
max_test_samples=args.max_test_samples,
seed=args.seed,
resume_if_available=args.resume_if_available,
force=args.force,
)
stage1_rows.append(row)
print(
f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} "
f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}"
)
print()
stage1_sorted = _sort_rows(stage1_rows, args.primary_metric)
top = stage1_sorted[: max(1, args.top_k)]
print("Stage1 top configs:")
for i, row in enumerate(top, start=1):
print(
f" {i}. {row['config']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} "
f"F1={float(row.get('test_f1', 0.0)):.4f}"
)
print()
top_configs = [
SweepConfig(
lr=float(row["lr"]),
length_penalty=float(row["length_penalty"]),
beams=int(row["num_beams"]),
)
for row in top
]
stage2_rows: List[Dict[str, object]] = []
for idx, cfg in enumerate(top_configs, start=1):
print(f"[stage2 {idx}/{len(top_configs)}] {cfg.name}")
row = _run_one(
cfg,
stage_name="stage2",
max_steps=args.stage2_steps,
base_model_dir=base_model_dir,
split_dir=split_dir,
out_dir=sweep_out_dir,
runtime_dir=runtime_dir,
eval_steps=args.eval_steps,
test_eval_every_steps=args.test_eval_every_steps,
max_val_samples=args.max_val_samples,
max_test_samples=args.max_test_samples,
seed=args.seed,
resume_if_available=args.resume_if_available,
force=args.force,
)
stage2_rows.append(row)
print(
f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} "
f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}"
)
print()
stage2_sorted = _sort_rows(stage2_rows, args.primary_metric)
winner = stage2_sorted[0]
payload = {
"meta": {
"timestamp": datetime.now().isoformat(),
"python_executable": sys.executable,
"split_dir": str(split_dir),
"base_model_dir": str(base_model_dir),
"sweep_out_dir": str(sweep_out_dir),
"runtime_dir": str(runtime_dir),
"stage1_steps": args.stage1_steps,
"stage2_steps": args.stage2_steps,
"top_k": args.top_k,
"eval_steps": args.eval_steps,
"test_eval_every_steps": args.test_eval_every_steps,
"max_val_samples": args.max_val_samples,
"max_test_samples": args.max_test_samples,
"seed": args.seed,
"primary_metric": args.primary_metric,
"resume_if_available": args.resume_if_available,
"lr_list": lrs,
"length_penalty_list": lps,
"beams_list": beams_list,
"force": args.force,
},
"stage1_rows": stage1_rows,
"stage1_top": top,
"stage2_rows": stage2_rows,
"stage2_sorted": stage2_sorted,
"winner": winner,
}
stamp = datetime.now().strftime("%Y%m%d_%H%M%S")
out_json = analysis_dir / f"t5_sweep_two_stage_{stamp}.json"
out_csv_stage1 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage1.csv"
out_csv_stage2 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage2.csv"
out_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8")
_write_csv(out_csv_stage1, stage1_rows)
_write_csv(out_csv_stage2, stage2_rows)
print("Winner:")
print(
f" {winner['config']} {args.primary_metric}={float(winner.get(args.primary_metric, 0.0)):.4f} "
f"F1={float(winner.get('test_f1', 0.0)):.4f} P={float(winner.get('test_precision', 0.0)):.4f}"
)
print(f"Results JSON: {out_json}")
print(f"Stage1 CSV: {out_csv_stage1}")
print(f"Stage2 CSV: {out_csv_stage2}")
return 0
if __name__ == "__main__":
raise SystemExit(main())