#!/usr/bin/env python """Watch HF Jobs training jobs and fire vLLM evals automatically when each finishes — with the right token + flavor per run. Designed to run interactively (the assistant pokes it periodically) or as a detached watchdog. We avoid mid-training evals because TRL's GRPO pushes the top-level weights every save_steps, so a download mid-push could see partial state. Once status=COMPLETED, the final weights are stable and the eval is deterministic. Usage: HF_TOKEN_AGARWAL=hf_xxx HF_TOKEN_KANAN=hf_xxx HF_TOKEN_MNIT=hf_xxx \\ python scripts/watch_and_eval.py [--once] [--interval 120] State file (``outputs/auto_eval_state.json``) records what evals have been launched so a restart of this script doesn't double-fire. Why this lives separately from monitor_training.py: this script is a *driver* (reads job status → triggers shell action), whereas monitor_training.py is a *reporter* (reads logs → prints metrics). Mixing the two makes failure modes hard to debug. """ from __future__ import annotations import argparse import json import os import shutil import subprocess import sys import time from dataclasses import dataclass from pathlib import Path # Lazy import so --help works even if HF SDK isn't installed locally. ROOT = Path(__file__).resolve().parent.parent STATE_PATH = ROOT / "outputs" / "auto_eval_state.json" LAUNCH_EVAL = ROOT / "scripts" / "launch_eval_job.sh" @dataclass class TrainingRun: run_id: str job_id: str repo: str token_env: str # which env-var holds the right HF write token eval_flavor: str # GPU flavor for the eval job eval_limit: int = 50 eval_label: str = "n50_v4" # Edit this when launching new runs. Keep it sourced from outputs/runs.json # manually so we don't accidentally fire evals against stale entries. RUNS: list[TrainingRun] = [ TrainingRun( # Run 3 v5 (after v3 OOM-died at step 1 with default num_gen=4) run_id="run3", job_id="69ed2569d2c8bd8662bce61a", repo="Kanan2005/clarify-rl-grpo-qwen3-4b", token_env="HF_TOKEN_KANAN", eval_flavor="a100-large", # 4B fp16 ≈ 8 GB weights, fits easily eval_limit=50, eval_label="n50_v4", ), TrainingRun( run_id="run4", job_id="69ed1a3fd70108f37acdee5e", repo="2022uec1542/clarify-rl-grpo-qwen3-1-7b", token_env="HF_TOKEN_MNIT", eval_flavor="a10g-small", eval_limit=50, eval_label="n50_v4", ), ] TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "TIMED_OUT"} def _load_state() -> dict: if STATE_PATH.exists(): try: return json.loads(STATE_PATH.read_text()) except Exception: return {} return {} def _save_state(state: dict) -> None: STATE_PATH.parent.mkdir(parents=True, exist_ok=True) STATE_PATH.write_text(json.dumps(state, indent=2)) def _check_run(run: TrainingRun, state: dict, interval: int) -> str: """Returns the latest stage observed; updates state in place.""" import truststore truststore.inject_into_ssl() from huggingface_hub import HfApi # noqa: WPS433 token = os.environ.get(run.token_env) if not token: print(f"[{run.run_id}] {run.token_env} not set — cannot watch") return "UNKNOWN" api = HfApi(token=token) try: job = api.inspect_job(job_id=run.job_id) except Exception as exc: print(f"[{run.run_id}] inspect_job failed: {exc}") return "UNKNOWN" stage = job.status.stage if job.status else "UNKNOWN" msg = (job.status.message if job.status else "") or "" rec = state.setdefault(run.run_id, {}) rec["last_stage"] = stage rec["last_message"] = msg print(f"[{run.run_id}] stage={stage} flavor={job.flavor} msg={msg[:80]}") if stage in TERMINAL_STAGES and rec.get("eval_launched") != True: # Double-confirm there's a model.safetensors before firing the eval try: files = api.list_repo_files(run.repo) except Exception as exc: print(f"[{run.run_id}] list_repo_files failed: {exc}") return stage if "model.safetensors" not in files and not any(f.endswith(".safetensors") for f in files): print(f"[{run.run_id}] training ended in {stage} but no weights pushed — skipping eval") rec["eval_launched"] = "skipped_no_weights" return stage if stage != "COMPLETED": # FAILED but weights exist — still useful to eval the final checkpoint print(f"[{run.run_id}] stage={stage} but weights present → evaluating partial run") cmd = [ "bash", str(LAUNCH_EVAL), run.repo, run.eval_flavor, str(run.eval_limit), ] env = os.environ.copy() env["HF_TOKEN"] = token env["EVAL_LABEL"] = run.eval_label env["DETACH"] = "1" print(f"[{run.run_id}] launching eval: {' '.join(cmd)}") proc = subprocess.run(cmd, env=env, capture_output=True, text=True) rec["eval_launched"] = True rec["eval_launched_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()) rec["eval_stdout_tail"] = proc.stdout[-2000:] rec["eval_stderr_tail"] = proc.stderr[-2000:] rec["eval_returncode"] = proc.returncode print(proc.stdout[-1000:]) if proc.returncode != 0: print(f"[{run.run_id}] EVAL LAUNCH FAILED rc={proc.returncode}\n{proc.stderr[-1500:]}") return stage def main() -> None: p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) p.add_argument("--once", action="store_true", help="Single pass and exit") p.add_argument("--interval", type=int, default=180, help="Seconds between polls (default 180)") p.add_argument("--max-iterations", type=int, default=240, help="Safety cap on poll iterations (default 240 ≈ 12 hours @ 180s)") args = p.parse_args() if not LAUNCH_EVAL.is_file(): print(f"ERROR: {LAUNCH_EVAL} not found", file=sys.stderr) sys.exit(1) if shutil.which("bash") is None: print("ERROR: bash not on PATH", file=sys.stderr) sys.exit(1) iter_count = 0 while True: iter_count += 1 state = _load_state() all_done = True for run in RUNS: stage = _check_run(run, state, args.interval) rec = state.get(run.run_id, {}) if stage not in TERMINAL_STAGES or rec.get("eval_launched") is True: # Still working: not terminal OR eval already kicked off (don't loop forever) pass if stage not in TERMINAL_STAGES: all_done = False _save_state(state) if args.once or all_done or iter_count >= args.max_iterations: break print(f"[wait] sleeping {args.interval}s (iter {iter_count}/{args.max_iterations})") time.sleep(args.interval) if __name__ == "__main__": main()