from __future__ import annotations import argparse import csv import json import subprocess import sys from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import Dict, List, Sequence REPO_ROOT = Path(__file__).resolve().parents[1] TRAIN_SCRIPT = REPO_ROOT / "scripts" / "train_t5_rewrite.py" @dataclass class SweepConfig: lr: float length_penalty: float beams: int @property def name(self) -> str: lr_txt = f"{self.lr:.0e}".replace("-", "m") lp_txt = str(self.length_penalty).replace(".", "p") return f"lr_{lr_txt}__lp_{lp_txt}__b_{self.beams}" def _parse_csv_floats(text: str) -> List[float]: out: List[float] = [] for raw in text.split(","): raw = raw.strip() if not raw: continue out.append(float(raw)) if not out: raise ValueError("expected at least one float value") return out def _read_metrics(path: Path) -> Dict[str, float]: obj = json.loads(path.read_text(encoding="utf-8")) test = obj.get("test", {}) val = obj.get("val", {}) train = obj.get("train", {}) return { "test_recall": float(test.get("test_set_recall", 0.0) or 0.0), "test_f1": float(test.get("test_set_f1", 0.0) or 0.0), "test_precision": float(test.get("test_set_precision", 0.0) or 0.0), "test_loss": float(test.get("test_loss", 0.0) or 0.0), "val_recall": float(val.get("eval_set_recall", 0.0) or 0.0), "val_f1": float(val.get("eval_set_f1", 0.0) or 0.0), "val_precision": float(val.get("eval_set_precision", 0.0) or 0.0), "val_loss": float(val.get("eval_loss", 0.0) or 0.0), "train_runtime_s": float(train.get("train_runtime", 0.0) or 0.0), "train_loss": float(train.get("train_loss", 0.0) or 0.0), } def _sort_rows(rows: Sequence[Dict[str, object]], primary: str) -> List[Dict[str, object]]: return sorted( rows, key=lambda r: ( float(r.get(primary, 0.0) or 0.0), float(r.get("test_f1", 0.0) or 0.0), float(r.get("test_precision", 0.0) or 0.0), ), reverse=True, ) def _write_csv(path: Path, rows: Sequence[Dict[str, object]]) -> None: if not rows: return keys = list(rows[0].keys()) with path.open("w", encoding="utf-8", newline="") as f: w = csv.DictWriter(f, fieldnames=keys) w.writeheader() for row in rows: w.writerow(row) def _run_one( cfg: SweepConfig, *, stage_name: str, max_steps: int, base_model_dir: Path, split_dir: Path, out_dir: Path, runtime_dir: Path, eval_steps: int, test_eval_every_steps: int, max_val_samples: int, max_test_samples: int, seed: int, resume_if_available: bool, force: bool, ) -> Dict[str, object]: run_dir = out_dir / stage_name / cfg.name run_dir.mkdir(parents=True, exist_ok=True) metrics_path = run_dir / "train_metrics.json" progress_file = runtime_dir / f"{stage_name}__{cfg.name}__progress.json" history_file = runtime_dir / f"{stage_name}__{cfg.name}__history.jsonl" if metrics_path.is_file() and not force: metrics = _read_metrics(metrics_path) return { "stage": stage_name, "config": cfg.name, "lr": cfg.lr, "length_penalty": cfg.length_penalty, "num_beams": cfg.beams, "max_steps": max_steps, "status": "cached", **metrics, "output_dir": str(run_dir), "metrics_path": str(metrics_path), } cmd = [ str(sys.executable), str(TRAIN_SCRIPT), "--split-dir", str(split_dir), "--base-model-dir", str(base_model_dir), "--output-dir", str(run_dir), "--max-steps", str(max_steps), "--eval-during-train", "--eval-steps", str(eval_steps), "--test-eval-every-steps", str(test_eval_every_steps), "--save-steps", str(eval_steps), "--best-model-metric", "recall", "--generation-length-penalty", str(cfg.length_penalty), "--lr", str(cfg.lr), "--num-beams", str(cfg.beams), "--max-val-samples", str(max_val_samples), "--max-test-samples", str(max_test_samples), "--seed", str(seed), "--require-cuda", "--fp16", "--report-to", "none", "--progress-file", str(progress_file), "--progress-history-file", str(history_file), ] if resume_if_available: cmd.append("--resume-if-available") subprocess.run(cmd, cwd=str(REPO_ROOT), check=True) metrics = _read_metrics(metrics_path) return { "stage": stage_name, "config": cfg.name, "lr": cfg.lr, "length_penalty": cfg.length_penalty, "num_beams": cfg.beams, "max_steps": max_steps, "status": "ran", **metrics, "output_dir": str(run_dir), "metrics_path": str(metrics_path), } def main() -> int: ap = argparse.ArgumentParser(description="Two-stage T5 sweep: fast screen then confirmation.") ap.add_argument("--split-dir", type=Path, default=REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits") ap.add_argument("--base-model-dir", type=Path, default=REPO_ROOT / "models" / "t5-small") ap.add_argument("--sweep-out-dir", type=Path, default=REPO_ROOT / "models" / "finetune" / "t5-sweep") ap.add_argument("--runtime-dir", type=Path, default=REPO_ROOT / "data" / "runtime_metrics" / "t5_sweep") ap.add_argument("--analysis-dir", type=Path, default=REPO_ROOT / "data" / "analysis") ap.add_argument("--lr-list", type=str, default="1e-4,2e-4") ap.add_argument("--length-penalty-list", type=str, default="0.7,0.8,0.9") ap.add_argument("--beams-list", type=str, default="4") ap.add_argument("--stage1-steps", type=int, default=1000) ap.add_argument("--stage2-steps", type=int, default=3750) ap.add_argument("--top-k", type=int, default=2) ap.add_argument("--eval-steps", type=int, default=500) ap.add_argument("--test-eval-every-steps", type=int, default=1000) ap.add_argument("--max-val-samples", type=int, default=128) ap.add_argument("--max-test-samples", type=int, default=128) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--primary-metric", type=str, default="test_recall", choices=["test_recall", "test_f1", "val_recall", "val_f1"]) ap.add_argument( "--resume-if-available", dest="resume_if_available", action="store_true", default=True, help="Resume from latest checkpoint in each config output dir when available", ) ap.add_argument( "--no-resume-if-available", dest="resume_if_available", action="store_false", help="Disable checkpoint resume and always start each config from step 0", ) ap.add_argument("--force", action="store_true", default=False) args = ap.parse_args() split_dir = args.split_dir if args.split_dir.is_absolute() else (REPO_ROOT / args.split_dir).resolve() base_model_dir = args.base_model_dir if args.base_model_dir.is_absolute() else (REPO_ROOT / args.base_model_dir).resolve() sweep_out_dir = args.sweep_out_dir if args.sweep_out_dir.is_absolute() else (REPO_ROOT / args.sweep_out_dir).resolve() runtime_dir = args.runtime_dir if args.runtime_dir.is_absolute() else (REPO_ROOT / args.runtime_dir).resolve() analysis_dir = args.analysis_dir if args.analysis_dir.is_absolute() else (REPO_ROOT / args.analysis_dir).resolve() runtime_dir.mkdir(parents=True, exist_ok=True) analysis_dir.mkdir(parents=True, exist_ok=True) sweep_out_dir.mkdir(parents=True, exist_ok=True) lrs = _parse_csv_floats(args.lr_list) lps = _parse_csv_floats(args.length_penalty_list) beams_list = [int(x.strip()) for x in args.beams_list.split(",") if x.strip()] if not beams_list: raise ValueError("beams-list must contain at least one integer") stage1_configs = [SweepConfig(lr=lr, length_penalty=lp, beams=b) for lr in lrs for lp in lps for b in beams_list] print(f"Stage1 configs: {len(stage1_configs)}") for c in stage1_configs: print(" -", c.name) print() stage1_rows: List[Dict[str, object]] = [] for idx, cfg in enumerate(stage1_configs, start=1): print(f"[stage1 {idx}/{len(stage1_configs)}] {cfg.name}") row = _run_one( cfg, stage_name="stage1", max_steps=args.stage1_steps, base_model_dir=base_model_dir, split_dir=split_dir, out_dir=sweep_out_dir, runtime_dir=runtime_dir, eval_steps=args.eval_steps, test_eval_every_steps=args.test_eval_every_steps, max_val_samples=args.max_val_samples, max_test_samples=args.max_test_samples, seed=args.seed, resume_if_available=args.resume_if_available, force=args.force, ) stage1_rows.append(row) print( f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}" ) print() stage1_sorted = _sort_rows(stage1_rows, args.primary_metric) top = stage1_sorted[: max(1, args.top_k)] print("Stage1 top configs:") for i, row in enumerate(top, start=1): print( f" {i}. {row['config']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " f"F1={float(row.get('test_f1', 0.0)):.4f}" ) print() top_configs = [ SweepConfig( lr=float(row["lr"]), length_penalty=float(row["length_penalty"]), beams=int(row["num_beams"]), ) for row in top ] stage2_rows: List[Dict[str, object]] = [] for idx, cfg in enumerate(top_configs, start=1): print(f"[stage2 {idx}/{len(top_configs)}] {cfg.name}") row = _run_one( cfg, stage_name="stage2", max_steps=args.stage2_steps, base_model_dir=base_model_dir, split_dir=split_dir, out_dir=sweep_out_dir, runtime_dir=runtime_dir, eval_steps=args.eval_steps, test_eval_every_steps=args.test_eval_every_steps, max_val_samples=args.max_val_samples, max_test_samples=args.max_test_samples, seed=args.seed, resume_if_available=args.resume_if_available, force=args.force, ) stage2_rows.append(row) print( f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}" ) print() stage2_sorted = _sort_rows(stage2_rows, args.primary_metric) winner = stage2_sorted[0] payload = { "meta": { "timestamp": datetime.now().isoformat(), "python_executable": sys.executable, "split_dir": str(split_dir), "base_model_dir": str(base_model_dir), "sweep_out_dir": str(sweep_out_dir), "runtime_dir": str(runtime_dir), "stage1_steps": args.stage1_steps, "stage2_steps": args.stage2_steps, "top_k": args.top_k, "eval_steps": args.eval_steps, "test_eval_every_steps": args.test_eval_every_steps, "max_val_samples": args.max_val_samples, "max_test_samples": args.max_test_samples, "seed": args.seed, "primary_metric": args.primary_metric, "resume_if_available": args.resume_if_available, "lr_list": lrs, "length_penalty_list": lps, "beams_list": beams_list, "force": args.force, }, "stage1_rows": stage1_rows, "stage1_top": top, "stage2_rows": stage2_rows, "stage2_sorted": stage2_sorted, "winner": winner, } stamp = datetime.now().strftime("%Y%m%d_%H%M%S") out_json = analysis_dir / f"t5_sweep_two_stage_{stamp}.json" out_csv_stage1 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage1.csv" out_csv_stage2 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage2.csv" out_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") _write_csv(out_csv_stage1, stage1_rows) _write_csv(out_csv_stage2, stage2_rows) print("Winner:") print( f" {winner['config']} {args.primary_metric}={float(winner.get(args.primary_metric, 0.0)):.4f} " f"F1={float(winner.get('test_f1', 0.0)):.4f} P={float(winner.get('test_precision', 0.0)):.4f}" ) print(f"Results JSON: {out_json}") print(f"Stage1 CSV: {out_csv_stage1}") print(f"Stage2 CSV: {out_csv_stage2}") return 0 if __name__ == "__main__": raise SystemExit(main())