File size: 7,810 Bytes
e317e25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
#!/usr/bin/env python3
"""Depth-sweep driver: pre-warm retina for HYDRA_SDR_TARGET_ACTIVE, then fan out
N parallel HF Jobs with different HYDRA_N_LAYER values, each running with full
per-layer diagnostics. Collects job IDs for downstream monitoring.

Usage:
    export HF_TOKEN=...
    # Optional overrides:
    export HYDRA_SDR_TARGET_ACTIVE=137
    export HYDRA_TIME_BUDGET=300       # 5 min training per job
    export HYDRA_MID_VAL_INTERVAL=250   # per-layer diag panel cadence
    export SWEEP_N_LAYERS=2,3,4,5,6,8
    export SWEEP_D_MODEL=768
    export SWEEP_SKIP_PREWARM=0        # set =1 if retina cache already populated
    python scripts/sweep_depth.py
"""
from __future__ import annotations

import os
import subprocess
import sys
import time
from pathlib import Path

REPO_ROOT = Path(__file__).resolve().parents[1]
LAUNCHER = REPO_ROOT / 'scripts' / 'launch_feather_hf_job.py'

SWEEP_N_LAYERS = [int(v) for v in os.environ.get('SWEEP_N_LAYERS', '2,3,4,5,6,8').split(',')]
SWEEP_D_MODEL  = os.environ.get('SWEEP_D_MODEL', '768')
SKIP_PREWARM   = os.environ.get('SWEEP_SKIP_PREWARM', '0') == '1'
TARGET_ACTIVE  = os.environ.get('HYDRA_SDR_TARGET_ACTIVE', '327')
# Short budget — we want diagnostic signal, not convergence.
TIME_BUDGET    = os.environ.get('HYDRA_TIME_BUDGET', '300')
MID_VAL        = os.environ.get('HYDRA_MID_VAL_INTERVAL', '250')
# Short timeout for pre-warm; sweep jobs get full 12h (no extension of wall).
PREWARM_TIMEOUT = os.environ.get('SWEEP_PREWARM_TIMEOUT', '30m')
SWEEP_TIMEOUT   = os.environ.get('SWEEP_TIMEOUT', '60m')


def launch(env_extra: dict, timeout: str) -> str | None:
    """Invoke launch_feather_hf_job.py with the given env overlay, parse job_id."""
    env = dict(os.environ)
    env.update(env_extra)
    env['FEATHER_HF_JOB_TIMEOUT'] = timeout
    # Always enable diagnostics + JSON emission for sweep jobs.
    env.setdefault('HYDRA_LAYER_DIAGNOSTICS', '1')
    env.setdefault('HYDRA_MID_VAL_INTERVAL', MID_VAL)
    env.setdefault('HYDRA_USE_NEMOTRON', '1')

    print(f'[sweep] launching with env overrides: {env_extra}', flush=True)
    proc = subprocess.run(
        [sys.executable, str(LAUNCHER)],
        env=env,
        capture_output=True,
        text=True,
    )
    sys.stdout.write(proc.stdout)
    sys.stderr.write(proc.stderr)
    if proc.returncode != 0:
        print(f'[sweep] launcher exited {proc.returncode}', flush=True)
        return None
    job_id = None
    for ln in proc.stdout.splitlines():
        if 'submitted job_id=' in ln:
            # format: [launch] submitted job_id=<id> status=<stage> url=...
            tail = ln.split('submitted job_id=', 1)[1]
            job_id = tail.split()[0].strip()
            break
    return job_id


def poll_until_done(job_id: str, poll_s: int = 30, max_wait_s: int = 1800) -> str:
    """Poll HF Jobs API until the job leaves the running/pending state or we
    exceed max_wait_s. Returns final stage string."""
    try:
        from huggingface_hub import HfApi  # type: ignore
    except Exception as e:
        print(f'[sweep] cannot poll (huggingface_hub missing: {e})', flush=True)
        return 'UNKNOWN'
    api = HfApi(token=os.environ.get('HF_TOKEN'))
    t0 = time.time()
    last_stage = None
    while True:
        try:
            j = api.inspect_job(job_id=job_id)
            stage = getattr(j.status, 'stage', None) if hasattr(j, 'status') else None
        except Exception as e:
            print(f'[sweep] poll error job={job_id} err={e}', flush=True)
            stage = None
        if stage != last_stage:
            print(f'[sweep] job={job_id} stage={stage}', flush=True)
            last_stage = stage
        if stage in {'COMPLETED', 'ERROR', 'CANCELED', 'FAILED'}:
            return stage or 'UNKNOWN'
        if time.time() - t0 > max_wait_s:
            print(f'[sweep] timed out waiting for job={job_id}', flush=True)
            return stage or 'TIMEOUT'
        time.sleep(poll_s)


def main() -> int:
    if not os.environ.get('HF_TOKEN'):
        print('ERROR: HF_TOKEN must be set', file=sys.stderr)
        return 2

    print(f'[sweep] plan: n_layers={SWEEP_N_LAYERS} d_model={SWEEP_D_MODEL} '
          f'target_active={TARGET_ACTIVE} time_budget={TIME_BUDGET}s mid_val={MID_VAL}',
          flush=True)

    # If using Space image, upload once now; all subsequent launches reuse it.
    use_space = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
    if use_space:
        print('[sweep] Space image mode: uploading overlay now, subsequent '
              'launches will skip upload', flush=True)

    # --- Pre-warm retina cache ---
    if not SKIP_PREWARM:
        print('[sweep] === PRE-WARM retina cache ===', flush=True)
        prewarm_env = {
            'HYDRA_N_LAYER': '2',
            'HYDRA_D_MODEL': SWEEP_D_MODEL,
            'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE,
            # Minimal training — just enough to force retina build + upload.
            'HYDRA_TIME_BUDGET': '30',
            'HYDRA_CKPT_INTERVAL': '0',
            'HYDRA_MID_VAL_INTERVAL': '0',
            'HYDRA_LAYER_DIAGNOSTICS': '0',  # no need during pre-warm
            'HYDRA_METRICS_OUT': '/tmp/prewarm_metrics.json',
        }
        prewarm_id = launch(prewarm_env, PREWARM_TIMEOUT)
        # After the first launch, Space image (if used) is built — skip re-upload.
        if use_space:
            os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1'
        if not prewarm_id:
            print('[sweep] pre-warm failed to submit', flush=True)
            return 3
        print(f'[sweep] pre-warm job={prewarm_id}, waiting for completion...', flush=True)
        stage = poll_until_done(prewarm_id, poll_s=20, max_wait_s=1800)
        print(f'[sweep] pre-warm finished stage={stage}', flush=True)
        if stage not in {'COMPLETED'}:
            print(f'[sweep] WARNING: pre-warm did not COMPLETE (stage={stage}); '
                  f'sweep jobs will each rebuild retina. Proceeding anyway.',
                  flush=True)
    else:
        print('[sweep] SKIP_PREWARM=1; assuming retina cache already populated', flush=True)

    # --- Fan out sweep jobs (concurrent) ---
    print('[sweep] === FAN OUT n_layer sweep ===', flush=True)
    sweep_jobs = {}
    for idx, n_layer in enumerate(SWEEP_N_LAYERS):
        env_extra = {
            'HYDRA_N_LAYER': str(n_layer),
            'HYDRA_D_MODEL': SWEEP_D_MODEL,
            'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE,
            'HYDRA_TIME_BUDGET': TIME_BUDGET,
            'HYDRA_CKPT_INTERVAL': '0',
            'HYDRA_LAYER_DIAGNOSTICS': '1',
            'HYDRA_MID_VAL_INTERVAL': MID_VAL,
            'HYDRA_METRICS_OUT': f'/tmp/sweep_n{n_layer}_metrics.json',
        }
        jid = launch(env_extra, SWEEP_TIMEOUT)
        # After the first launch in Space-image mode, mark skip-upload for the rest.
        if use_space and idx == 0:
            os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1'
        if jid:
            sweep_jobs[n_layer] = jid
            print(f'[sweep]   n_layer={n_layer} -> job_id={jid}', flush=True)
        else:
            print(f'[sweep]   n_layer={n_layer} FAILED to submit', flush=True)

    print('[sweep] === SWEEP SUBMITTED ===', flush=True)
    print('[sweep] tracked jobs:', flush=True)
    for n, j in sweep_jobs.items():
        print(f'    n_layer={n:2d}  job_id={j}', flush=True)

    # Write manifest so the aggregator can find them.
    manifest = Path('/tmp/sweep_depth_manifest.txt')
    manifest.write_text(
        'n_layer\tjob_id\tmetrics_path\n' +
        '\n'.join(
            f'{n}\t{j}\t/tmp/sweep_n{n}_metrics.json'
            for n, j in sweep_jobs.items()
        ) + '\n'
    )
    print(f'[sweep] manifest -> {manifest}', flush=True)
    return 0


if __name__ == '__main__':
    raise SystemExit(main())