Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import csv | |
| import json | |
| import subprocess | |
| import sys | |
| from dataclasses import dataclass | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Dict, List, Sequence | |
| REPO_ROOT = Path(__file__).resolve().parents[1] | |
| TRAIN_SCRIPT = REPO_ROOT / "scripts" / "train_t5_rewrite.py" | |
| class SweepConfig: | |
| lr: float | |
| length_penalty: float | |
| beams: int | |
| def name(self) -> str: | |
| lr_txt = f"{self.lr:.0e}".replace("-", "m") | |
| lp_txt = str(self.length_penalty).replace(".", "p") | |
| return f"lr_{lr_txt}__lp_{lp_txt}__b_{self.beams}" | |
| def _parse_csv_floats(text: str) -> List[float]: | |
| out: List[float] = [] | |
| for raw in text.split(","): | |
| raw = raw.strip() | |
| if not raw: | |
| continue | |
| out.append(float(raw)) | |
| if not out: | |
| raise ValueError("expected at least one float value") | |
| return out | |
| def _read_metrics(path: Path) -> Dict[str, float]: | |
| obj = json.loads(path.read_text(encoding="utf-8")) | |
| test = obj.get("test", {}) | |
| val = obj.get("val", {}) | |
| train = obj.get("train", {}) | |
| return { | |
| "test_recall": float(test.get("test_set_recall", 0.0) or 0.0), | |
| "test_f1": float(test.get("test_set_f1", 0.0) or 0.0), | |
| "test_precision": float(test.get("test_set_precision", 0.0) or 0.0), | |
| "test_loss": float(test.get("test_loss", 0.0) or 0.0), | |
| "val_recall": float(val.get("eval_set_recall", 0.0) or 0.0), | |
| "val_f1": float(val.get("eval_set_f1", 0.0) or 0.0), | |
| "val_precision": float(val.get("eval_set_precision", 0.0) or 0.0), | |
| "val_loss": float(val.get("eval_loss", 0.0) or 0.0), | |
| "train_runtime_s": float(train.get("train_runtime", 0.0) or 0.0), | |
| "train_loss": float(train.get("train_loss", 0.0) or 0.0), | |
| } | |
| def _sort_rows(rows: Sequence[Dict[str, object]], primary: str) -> List[Dict[str, object]]: | |
| return sorted( | |
| rows, | |
| key=lambda r: ( | |
| float(r.get(primary, 0.0) or 0.0), | |
| float(r.get("test_f1", 0.0) or 0.0), | |
| float(r.get("test_precision", 0.0) or 0.0), | |
| ), | |
| reverse=True, | |
| ) | |
| def _write_csv(path: Path, rows: Sequence[Dict[str, object]]) -> None: | |
| if not rows: | |
| return | |
| keys = list(rows[0].keys()) | |
| with path.open("w", encoding="utf-8", newline="") as f: | |
| w = csv.DictWriter(f, fieldnames=keys) | |
| w.writeheader() | |
| for row in rows: | |
| w.writerow(row) | |
| def _run_one( | |
| cfg: SweepConfig, | |
| *, | |
| stage_name: str, | |
| max_steps: int, | |
| base_model_dir: Path, | |
| split_dir: Path, | |
| out_dir: Path, | |
| runtime_dir: Path, | |
| eval_steps: int, | |
| test_eval_every_steps: int, | |
| max_val_samples: int, | |
| max_test_samples: int, | |
| seed: int, | |
| resume_if_available: bool, | |
| force: bool, | |
| ) -> Dict[str, object]: | |
| run_dir = out_dir / stage_name / cfg.name | |
| run_dir.mkdir(parents=True, exist_ok=True) | |
| metrics_path = run_dir / "train_metrics.json" | |
| progress_file = runtime_dir / f"{stage_name}__{cfg.name}__progress.json" | |
| history_file = runtime_dir / f"{stage_name}__{cfg.name}__history.jsonl" | |
| if metrics_path.is_file() and not force: | |
| metrics = _read_metrics(metrics_path) | |
| return { | |
| "stage": stage_name, | |
| "config": cfg.name, | |
| "lr": cfg.lr, | |
| "length_penalty": cfg.length_penalty, | |
| "num_beams": cfg.beams, | |
| "max_steps": max_steps, | |
| "status": "cached", | |
| **metrics, | |
| "output_dir": str(run_dir), | |
| "metrics_path": str(metrics_path), | |
| } | |
| cmd = [ | |
| str(sys.executable), | |
| str(TRAIN_SCRIPT), | |
| "--split-dir", | |
| str(split_dir), | |
| "--base-model-dir", | |
| str(base_model_dir), | |
| "--output-dir", | |
| str(run_dir), | |
| "--max-steps", | |
| str(max_steps), | |
| "--eval-during-train", | |
| "--eval-steps", | |
| str(eval_steps), | |
| "--test-eval-every-steps", | |
| str(test_eval_every_steps), | |
| "--save-steps", | |
| str(eval_steps), | |
| "--best-model-metric", | |
| "recall", | |
| "--generation-length-penalty", | |
| str(cfg.length_penalty), | |
| "--lr", | |
| str(cfg.lr), | |
| "--num-beams", | |
| str(cfg.beams), | |
| "--max-val-samples", | |
| str(max_val_samples), | |
| "--max-test-samples", | |
| str(max_test_samples), | |
| "--seed", | |
| str(seed), | |
| "--require-cuda", | |
| "--fp16", | |
| "--report-to", | |
| "none", | |
| "--progress-file", | |
| str(progress_file), | |
| "--progress-history-file", | |
| str(history_file), | |
| ] | |
| if resume_if_available: | |
| cmd.append("--resume-if-available") | |
| subprocess.run(cmd, cwd=str(REPO_ROOT), check=True) | |
| metrics = _read_metrics(metrics_path) | |
| return { | |
| "stage": stage_name, | |
| "config": cfg.name, | |
| "lr": cfg.lr, | |
| "length_penalty": cfg.length_penalty, | |
| "num_beams": cfg.beams, | |
| "max_steps": max_steps, | |
| "status": "ran", | |
| **metrics, | |
| "output_dir": str(run_dir), | |
| "metrics_path": str(metrics_path), | |
| } | |
| def main() -> int: | |
| ap = argparse.ArgumentParser(description="Two-stage T5 sweep: fast screen then confirmation.") | |
| ap.add_argument("--split-dir", type=Path, default=REPO_ROOT / "data" / "external" / "caption_emporium" / "t5_rewrite_splits") | |
| ap.add_argument("--base-model-dir", type=Path, default=REPO_ROOT / "models" / "t5-small") | |
| ap.add_argument("--sweep-out-dir", type=Path, default=REPO_ROOT / "models" / "finetune" / "t5-sweep") | |
| ap.add_argument("--runtime-dir", type=Path, default=REPO_ROOT / "data" / "runtime_metrics" / "t5_sweep") | |
| ap.add_argument("--analysis-dir", type=Path, default=REPO_ROOT / "data" / "analysis") | |
| ap.add_argument("--lr-list", type=str, default="1e-4,2e-4") | |
| ap.add_argument("--length-penalty-list", type=str, default="0.7,0.8,0.9") | |
| ap.add_argument("--beams-list", type=str, default="4") | |
| ap.add_argument("--stage1-steps", type=int, default=1000) | |
| ap.add_argument("--stage2-steps", type=int, default=3750) | |
| ap.add_argument("--top-k", type=int, default=2) | |
| ap.add_argument("--eval-steps", type=int, default=500) | |
| ap.add_argument("--test-eval-every-steps", type=int, default=1000) | |
| ap.add_argument("--max-val-samples", type=int, default=128) | |
| ap.add_argument("--max-test-samples", type=int, default=128) | |
| ap.add_argument("--seed", type=int, default=42) | |
| ap.add_argument("--primary-metric", type=str, default="test_recall", choices=["test_recall", "test_f1", "val_recall", "val_f1"]) | |
| ap.add_argument( | |
| "--resume-if-available", | |
| dest="resume_if_available", | |
| action="store_true", | |
| default=True, | |
| help="Resume from latest checkpoint in each config output dir when available", | |
| ) | |
| ap.add_argument( | |
| "--no-resume-if-available", | |
| dest="resume_if_available", | |
| action="store_false", | |
| help="Disable checkpoint resume and always start each config from step 0", | |
| ) | |
| ap.add_argument("--force", action="store_true", default=False) | |
| args = ap.parse_args() | |
| split_dir = args.split_dir if args.split_dir.is_absolute() else (REPO_ROOT / args.split_dir).resolve() | |
| base_model_dir = args.base_model_dir if args.base_model_dir.is_absolute() else (REPO_ROOT / args.base_model_dir).resolve() | |
| sweep_out_dir = args.sweep_out_dir if args.sweep_out_dir.is_absolute() else (REPO_ROOT / args.sweep_out_dir).resolve() | |
| runtime_dir = args.runtime_dir if args.runtime_dir.is_absolute() else (REPO_ROOT / args.runtime_dir).resolve() | |
| analysis_dir = args.analysis_dir if args.analysis_dir.is_absolute() else (REPO_ROOT / args.analysis_dir).resolve() | |
| runtime_dir.mkdir(parents=True, exist_ok=True) | |
| analysis_dir.mkdir(parents=True, exist_ok=True) | |
| sweep_out_dir.mkdir(parents=True, exist_ok=True) | |
| lrs = _parse_csv_floats(args.lr_list) | |
| lps = _parse_csv_floats(args.length_penalty_list) | |
| beams_list = [int(x.strip()) for x in args.beams_list.split(",") if x.strip()] | |
| if not beams_list: | |
| raise ValueError("beams-list must contain at least one integer") | |
| stage1_configs = [SweepConfig(lr=lr, length_penalty=lp, beams=b) for lr in lrs for lp in lps for b in beams_list] | |
| print(f"Stage1 configs: {len(stage1_configs)}") | |
| for c in stage1_configs: | |
| print(" -", c.name) | |
| print() | |
| stage1_rows: List[Dict[str, object]] = [] | |
| for idx, cfg in enumerate(stage1_configs, start=1): | |
| print(f"[stage1 {idx}/{len(stage1_configs)}] {cfg.name}") | |
| row = _run_one( | |
| cfg, | |
| stage_name="stage1", | |
| max_steps=args.stage1_steps, | |
| base_model_dir=base_model_dir, | |
| split_dir=split_dir, | |
| out_dir=sweep_out_dir, | |
| runtime_dir=runtime_dir, | |
| eval_steps=args.eval_steps, | |
| test_eval_every_steps=args.test_eval_every_steps, | |
| max_val_samples=args.max_val_samples, | |
| max_test_samples=args.max_test_samples, | |
| seed=args.seed, | |
| resume_if_available=args.resume_if_available, | |
| force=args.force, | |
| ) | |
| stage1_rows.append(row) | |
| print( | |
| f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " | |
| f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}" | |
| ) | |
| print() | |
| stage1_sorted = _sort_rows(stage1_rows, args.primary_metric) | |
| top = stage1_sorted[: max(1, args.top_k)] | |
| print("Stage1 top configs:") | |
| for i, row in enumerate(top, start=1): | |
| print( | |
| f" {i}. {row['config']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " | |
| f"F1={float(row.get('test_f1', 0.0)):.4f}" | |
| ) | |
| print() | |
| top_configs = [ | |
| SweepConfig( | |
| lr=float(row["lr"]), | |
| length_penalty=float(row["length_penalty"]), | |
| beams=int(row["num_beams"]), | |
| ) | |
| for row in top | |
| ] | |
| stage2_rows: List[Dict[str, object]] = [] | |
| for idx, cfg in enumerate(top_configs, start=1): | |
| print(f"[stage2 {idx}/{len(top_configs)}] {cfg.name}") | |
| row = _run_one( | |
| cfg, | |
| stage_name="stage2", | |
| max_steps=args.stage2_steps, | |
| base_model_dir=base_model_dir, | |
| split_dir=split_dir, | |
| out_dir=sweep_out_dir, | |
| runtime_dir=runtime_dir, | |
| eval_steps=args.eval_steps, | |
| test_eval_every_steps=args.test_eval_every_steps, | |
| max_val_samples=args.max_val_samples, | |
| max_test_samples=args.max_test_samples, | |
| seed=args.seed, | |
| resume_if_available=args.resume_if_available, | |
| force=args.force, | |
| ) | |
| stage2_rows.append(row) | |
| print( | |
| f" -> {row['status']} {args.primary_metric}={float(row.get(args.primary_metric, 0.0)):.4f} " | |
| f"F1={float(row.get('test_f1', 0.0)):.4f} P={float(row.get('test_precision', 0.0)):.4f}" | |
| ) | |
| print() | |
| stage2_sorted = _sort_rows(stage2_rows, args.primary_metric) | |
| winner = stage2_sorted[0] | |
| payload = { | |
| "meta": { | |
| "timestamp": datetime.now().isoformat(), | |
| "python_executable": sys.executable, | |
| "split_dir": str(split_dir), | |
| "base_model_dir": str(base_model_dir), | |
| "sweep_out_dir": str(sweep_out_dir), | |
| "runtime_dir": str(runtime_dir), | |
| "stage1_steps": args.stage1_steps, | |
| "stage2_steps": args.stage2_steps, | |
| "top_k": args.top_k, | |
| "eval_steps": args.eval_steps, | |
| "test_eval_every_steps": args.test_eval_every_steps, | |
| "max_val_samples": args.max_val_samples, | |
| "max_test_samples": args.max_test_samples, | |
| "seed": args.seed, | |
| "primary_metric": args.primary_metric, | |
| "resume_if_available": args.resume_if_available, | |
| "lr_list": lrs, | |
| "length_penalty_list": lps, | |
| "beams_list": beams_list, | |
| "force": args.force, | |
| }, | |
| "stage1_rows": stage1_rows, | |
| "stage1_top": top, | |
| "stage2_rows": stage2_rows, | |
| "stage2_sorted": stage2_sorted, | |
| "winner": winner, | |
| } | |
| stamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| out_json = analysis_dir / f"t5_sweep_two_stage_{stamp}.json" | |
| out_csv_stage1 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage1.csv" | |
| out_csv_stage2 = analysis_dir / f"t5_sweep_two_stage_{stamp}_stage2.csv" | |
| out_json.write_text(json.dumps(payload, ensure_ascii=False, indent=2), encoding="utf-8") | |
| _write_csv(out_csv_stage1, stage1_rows) | |
| _write_csv(out_csv_stage2, stage2_rows) | |
| print("Winner:") | |
| print( | |
| f" {winner['config']} {args.primary_metric}={float(winner.get(args.primary_metric, 0.0)):.4f} " | |
| f"F1={float(winner.get('test_f1', 0.0)):.4f} P={float(winner.get('test_precision', 0.0)):.4f}" | |
| ) | |
| print(f"Results JSON: {out_json}") | |
| print(f"Stage1 CSV: {out_csv_stage1}") | |
| print(f"Stage2 CSV: {out_csv_stage2}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |