Spaces:
Running
Running
File size: 7,072 Bytes
099bec8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | #!/usr/bin/env python
"""Watch HF Jobs training jobs and fire vLLM evals automatically when each
finishes β with the right token + flavor per run.
Designed to run interactively (the assistant pokes it periodically) or as a
detached watchdog. We avoid mid-training evals because TRL's GRPO pushes the
top-level weights every save_steps, so a download mid-push could see partial
state. Once status=COMPLETED, the final weights are stable and the eval is
deterministic.
Usage:
HF_TOKEN_AGARWAL=hf_xxx HF_TOKEN_KANAN=hf_xxx HF_TOKEN_MNIT=hf_xxx \\
python scripts/watch_and_eval.py [--once] [--interval 120]
State file (``outputs/auto_eval_state.json``) records what evals have been
launched so a restart of this script doesn't double-fire.
Why this lives separately from monitor_training.py: this script is a *driver*
(reads job status β triggers shell action), whereas monitor_training.py is a
*reporter* (reads logs β prints metrics). Mixing the two makes failure modes
hard to debug.
"""
from __future__ import annotations
import argparse
import json
import os
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path
# Lazy import so --help works even if HF SDK isn't installed locally.
ROOT = Path(__file__).resolve().parent.parent
STATE_PATH = ROOT / "outputs" / "auto_eval_state.json"
LAUNCH_EVAL = ROOT / "scripts" / "launch_eval_job.sh"
@dataclass
class TrainingRun:
run_id: str
job_id: str
repo: str
token_env: str # which env-var holds the right HF write token
eval_flavor: str # GPU flavor for the eval job
eval_limit: int = 50
eval_label: str = "n50_v4"
# Edit this when launching new runs. Keep it sourced from outputs/runs.json
# manually so we don't accidentally fire evals against stale entries.
RUNS: list[TrainingRun] = [
TrainingRun(
# Run 3 v5 (after v3 OOM-died at step 1 with default num_gen=4)
run_id="run3",
job_id="69ed2569d2c8bd8662bce61a",
repo="Kanan2005/clarify-rl-grpo-qwen3-4b",
token_env="HF_TOKEN_KANAN",
eval_flavor="a100-large", # 4B fp16 β 8 GB weights, fits easily
eval_limit=50,
eval_label="n50_v4",
),
TrainingRun(
run_id="run4",
job_id="69ed1a3fd70108f37acdee5e",
repo="2022uec1542/clarify-rl-grpo-qwen3-1-7b",
token_env="HF_TOKEN_MNIT",
eval_flavor="a10g-small",
eval_limit=50,
eval_label="n50_v4",
),
]
TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "TIMED_OUT"}
def _load_state() -> dict:
if STATE_PATH.exists():
try:
return json.loads(STATE_PATH.read_text())
except Exception:
return {}
return {}
def _save_state(state: dict) -> None:
STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
STATE_PATH.write_text(json.dumps(state, indent=2))
def _check_run(run: TrainingRun, state: dict, interval: int) -> str:
"""Returns the latest stage observed; updates state in place."""
import truststore
truststore.inject_into_ssl()
from huggingface_hub import HfApi # noqa: WPS433
token = os.environ.get(run.token_env)
if not token:
print(f"[{run.run_id}] {run.token_env} not set β cannot watch")
return "UNKNOWN"
api = HfApi(token=token)
try:
job = api.inspect_job(job_id=run.job_id)
except Exception as exc:
print(f"[{run.run_id}] inspect_job failed: {exc}")
return "UNKNOWN"
stage = job.status.stage if job.status else "UNKNOWN"
msg = (job.status.message if job.status else "") or ""
rec = state.setdefault(run.run_id, {})
rec["last_stage"] = stage
rec["last_message"] = msg
print(f"[{run.run_id}] stage={stage} flavor={job.flavor} msg={msg[:80]}")
if stage in TERMINAL_STAGES and rec.get("eval_launched") != True:
# Double-confirm there's a model.safetensors before firing the eval
try:
files = api.list_repo_files(run.repo)
except Exception as exc:
print(f"[{run.run_id}] list_repo_files failed: {exc}")
return stage
if "model.safetensors" not in files and not any(f.endswith(".safetensors") for f in files):
print(f"[{run.run_id}] training ended in {stage} but no weights pushed β skipping eval")
rec["eval_launched"] = "skipped_no_weights"
return stage
if stage != "COMPLETED":
# FAILED but weights exist β still useful to eval the final checkpoint
print(f"[{run.run_id}] stage={stage} but weights present β evaluating partial run")
cmd = [
"bash", str(LAUNCH_EVAL),
run.repo, run.eval_flavor, str(run.eval_limit),
]
env = os.environ.copy()
env["HF_TOKEN"] = token
env["EVAL_LABEL"] = run.eval_label
env["DETACH"] = "1"
print(f"[{run.run_id}] launching eval: {' '.join(cmd)}")
proc = subprocess.run(cmd, env=env, capture_output=True, text=True)
rec["eval_launched"] = True
rec["eval_launched_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
rec["eval_stdout_tail"] = proc.stdout[-2000:]
rec["eval_stderr_tail"] = proc.stderr[-2000:]
rec["eval_returncode"] = proc.returncode
print(proc.stdout[-1000:])
if proc.returncode != 0:
print(f"[{run.run_id}] EVAL LAUNCH FAILED rc={proc.returncode}\n{proc.stderr[-1500:]}")
return stage
def main() -> None:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("--once", action="store_true", help="Single pass and exit")
p.add_argument("--interval", type=int, default=180, help="Seconds between polls (default 180)")
p.add_argument("--max-iterations", type=int, default=240,
help="Safety cap on poll iterations (default 240 β 12 hours @ 180s)")
args = p.parse_args()
if not LAUNCH_EVAL.is_file():
print(f"ERROR: {LAUNCH_EVAL} not found", file=sys.stderr)
sys.exit(1)
if shutil.which("bash") is None:
print("ERROR: bash not on PATH", file=sys.stderr)
sys.exit(1)
iter_count = 0
while True:
iter_count += 1
state = _load_state()
all_done = True
for run in RUNS:
stage = _check_run(run, state, args.interval)
rec = state.get(run.run_id, {})
if stage not in TERMINAL_STAGES or rec.get("eval_launched") is True:
# Still working: not terminal OR eval already kicked off (don't loop forever)
pass
if stage not in TERMINAL_STAGES:
all_done = False
_save_state(state)
if args.once or all_done or iter_count >= args.max_iterations:
break
print(f"[wait] sleeping {args.interval}s (iter {iter_count}/{args.max_iterations})")
time.sleep(args.interval)
if __name__ == "__main__":
main()
|