File size: 7,072 Bytes
099bec8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#!/usr/bin/env python
"""Watch HF Jobs training jobs and fire vLLM evals automatically when each
finishes β€” with the right token + flavor per run.

Designed to run interactively (the assistant pokes it periodically) or as a
detached watchdog. We avoid mid-training evals because TRL's GRPO pushes the
top-level weights every save_steps, so a download mid-push could see partial
state. Once status=COMPLETED, the final weights are stable and the eval is
deterministic.

Usage:
    HF_TOKEN_AGARWAL=hf_xxx HF_TOKEN_KANAN=hf_xxx HF_TOKEN_MNIT=hf_xxx \\
        python scripts/watch_and_eval.py [--once] [--interval 120]

State file (``outputs/auto_eval_state.json``) records what evals have been
launched so a restart of this script doesn't double-fire.

Why this lives separately from monitor_training.py: this script is a *driver*
(reads job status β†’ triggers shell action), whereas monitor_training.py is a
*reporter* (reads logs β†’ prints metrics). Mixing the two makes failure modes
hard to debug.
"""

from __future__ import annotations

import argparse
import json
import os
import shutil
import subprocess
import sys
import time
from dataclasses import dataclass
from pathlib import Path

# Lazy import so --help works even if HF SDK isn't installed locally.

ROOT = Path(__file__).resolve().parent.parent
STATE_PATH = ROOT / "outputs" / "auto_eval_state.json"
LAUNCH_EVAL = ROOT / "scripts" / "launch_eval_job.sh"


@dataclass
class TrainingRun:
    run_id: str
    job_id: str
    repo: str
    token_env: str            # which env-var holds the right HF write token
    eval_flavor: str          # GPU flavor for the eval job
    eval_limit: int = 50
    eval_label: str = "n50_v4"


# Edit this when launching new runs. Keep it sourced from outputs/runs.json
# manually so we don't accidentally fire evals against stale entries.
RUNS: list[TrainingRun] = [
    TrainingRun(
        # Run 3 v5 (after v3 OOM-died at step 1 with default num_gen=4)
        run_id="run3",
        job_id="69ed2569d2c8bd8662bce61a",
        repo="Kanan2005/clarify-rl-grpo-qwen3-4b",
        token_env="HF_TOKEN_KANAN",
        eval_flavor="a100-large",       # 4B fp16 β‰ˆ 8 GB weights, fits easily
        eval_limit=50,
        eval_label="n50_v4",
    ),
    TrainingRun(
        run_id="run4",
        job_id="69ed1a3fd70108f37acdee5e",
        repo="2022uec1542/clarify-rl-grpo-qwen3-1-7b",
        token_env="HF_TOKEN_MNIT",
        eval_flavor="a10g-small",
        eval_limit=50,
        eval_label="n50_v4",
    ),
]

TERMINAL_STAGES = {"COMPLETED", "FAILED", "CANCELLED", "ERROR", "TIMED_OUT"}


def _load_state() -> dict:
    if STATE_PATH.exists():
        try:
            return json.loads(STATE_PATH.read_text())
        except Exception:
            return {}
    return {}


def _save_state(state: dict) -> None:
    STATE_PATH.parent.mkdir(parents=True, exist_ok=True)
    STATE_PATH.write_text(json.dumps(state, indent=2))


def _check_run(run: TrainingRun, state: dict, interval: int) -> str:
    """Returns the latest stage observed; updates state in place."""
    import truststore
    truststore.inject_into_ssl()
    from huggingface_hub import HfApi  # noqa: WPS433

    token = os.environ.get(run.token_env)
    if not token:
        print(f"[{run.run_id}] {run.token_env} not set β€” cannot watch")
        return "UNKNOWN"

    api = HfApi(token=token)
    try:
        job = api.inspect_job(job_id=run.job_id)
    except Exception as exc:
        print(f"[{run.run_id}] inspect_job failed: {exc}")
        return "UNKNOWN"

    stage = job.status.stage if job.status else "UNKNOWN"
    msg = (job.status.message if job.status else "") or ""
    rec = state.setdefault(run.run_id, {})
    rec["last_stage"] = stage
    rec["last_message"] = msg

    print(f"[{run.run_id}] stage={stage}  flavor={job.flavor}  msg={msg[:80]}")

    if stage in TERMINAL_STAGES and rec.get("eval_launched") != True:
        # Double-confirm there's a model.safetensors before firing the eval
        try:
            files = api.list_repo_files(run.repo)
        except Exception as exc:
            print(f"[{run.run_id}] list_repo_files failed: {exc}")
            return stage
        if "model.safetensors" not in files and not any(f.endswith(".safetensors") for f in files):
            print(f"[{run.run_id}] training ended in {stage} but no weights pushed β€” skipping eval")
            rec["eval_launched"] = "skipped_no_weights"
            return stage

        if stage != "COMPLETED":
            # FAILED but weights exist β€” still useful to eval the final checkpoint
            print(f"[{run.run_id}] stage={stage} but weights present β†’ evaluating partial run")

        cmd = [
            "bash", str(LAUNCH_EVAL),
            run.repo, run.eval_flavor, str(run.eval_limit),
        ]
        env = os.environ.copy()
        env["HF_TOKEN"] = token
        env["EVAL_LABEL"] = run.eval_label
        env["DETACH"] = "1"

        print(f"[{run.run_id}] launching eval: {' '.join(cmd)}")
        proc = subprocess.run(cmd, env=env, capture_output=True, text=True)
        rec["eval_launched"] = True
        rec["eval_launched_at"] = time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime())
        rec["eval_stdout_tail"] = proc.stdout[-2000:]
        rec["eval_stderr_tail"] = proc.stderr[-2000:]
        rec["eval_returncode"] = proc.returncode
        print(proc.stdout[-1000:])
        if proc.returncode != 0:
            print(f"[{run.run_id}] EVAL LAUNCH FAILED rc={proc.returncode}\n{proc.stderr[-1500:]}")

    return stage


def main() -> None:
    p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
    p.add_argument("--once", action="store_true", help="Single pass and exit")
    p.add_argument("--interval", type=int, default=180, help="Seconds between polls (default 180)")
    p.add_argument("--max-iterations", type=int, default=240,
                   help="Safety cap on poll iterations (default 240 β‰ˆ 12 hours @ 180s)")
    args = p.parse_args()

    if not LAUNCH_EVAL.is_file():
        print(f"ERROR: {LAUNCH_EVAL} not found", file=sys.stderr)
        sys.exit(1)
    if shutil.which("bash") is None:
        print("ERROR: bash not on PATH", file=sys.stderr)
        sys.exit(1)

    iter_count = 0
    while True:
        iter_count += 1
        state = _load_state()
        all_done = True
        for run in RUNS:
            stage = _check_run(run, state, args.interval)
            rec = state.get(run.run_id, {})
            if stage not in TERMINAL_STAGES or rec.get("eval_launched") is True:
                # Still working: not terminal OR eval already kicked off (don't loop forever)
                pass
            if stage not in TERMINAL_STAGES:
                all_done = False
        _save_state(state)
        if args.once or all_done or iter_count >= args.max_iterations:
            break
        print(f"[wait] sleeping {args.interval}s (iter {iter_count}/{args.max_iterations})")
        time.sleep(args.interval)


if __name__ == "__main__":
    main()