feather-runtime / overlay /scripts /sweep_depth.py
Jackoatmon's picture
Update Feather h200 training runtime image
e317e25 verified
#!/usr/bin/env python3
"""Depth-sweep driver: pre-warm retina for HYDRA_SDR_TARGET_ACTIVE, then fan out
N parallel HF Jobs with different HYDRA_N_LAYER values, each running with full
per-layer diagnostics. Collects job IDs for downstream monitoring.
Usage:
export HF_TOKEN=...
# Optional overrides:
export HYDRA_SDR_TARGET_ACTIVE=137
export HYDRA_TIME_BUDGET=300 # 5 min training per job
export HYDRA_MID_VAL_INTERVAL=250 # per-layer diag panel cadence
export SWEEP_N_LAYERS=2,3,4,5,6,8
export SWEEP_D_MODEL=768
export SWEEP_SKIP_PREWARM=0 # set =1 if retina cache already populated
python scripts/sweep_depth.py
"""
from __future__ import annotations
import os
import subprocess
import sys
import time
from pathlib import Path
REPO_ROOT = Path(__file__).resolve().parents[1]
LAUNCHER = REPO_ROOT / 'scripts' / 'launch_feather_hf_job.py'
SWEEP_N_LAYERS = [int(v) for v in os.environ.get('SWEEP_N_LAYERS', '2,3,4,5,6,8').split(',')]
SWEEP_D_MODEL = os.environ.get('SWEEP_D_MODEL', '768')
SKIP_PREWARM = os.environ.get('SWEEP_SKIP_PREWARM', '0') == '1'
TARGET_ACTIVE = os.environ.get('HYDRA_SDR_TARGET_ACTIVE', '327')
# Short budget — we want diagnostic signal, not convergence.
TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '300')
MID_VAL = os.environ.get('HYDRA_MID_VAL_INTERVAL', '250')
# Short timeout for pre-warm; sweep jobs get full 12h (no extension of wall).
PREWARM_TIMEOUT = os.environ.get('SWEEP_PREWARM_TIMEOUT', '30m')
SWEEP_TIMEOUT = os.environ.get('SWEEP_TIMEOUT', '60m')
def launch(env_extra: dict, timeout: str) -> str | None:
"""Invoke launch_feather_hf_job.py with the given env overlay, parse job_id."""
env = dict(os.environ)
env.update(env_extra)
env['FEATHER_HF_JOB_TIMEOUT'] = timeout
# Always enable diagnostics + JSON emission for sweep jobs.
env.setdefault('HYDRA_LAYER_DIAGNOSTICS', '1')
env.setdefault('HYDRA_MID_VAL_INTERVAL', MID_VAL)
env.setdefault('HYDRA_USE_NEMOTRON', '1')
print(f'[sweep] launching with env overrides: {env_extra}', flush=True)
proc = subprocess.run(
[sys.executable, str(LAUNCHER)],
env=env,
capture_output=True,
text=True,
)
sys.stdout.write(proc.stdout)
sys.stderr.write(proc.stderr)
if proc.returncode != 0:
print(f'[sweep] launcher exited {proc.returncode}', flush=True)
return None
job_id = None
for ln in proc.stdout.splitlines():
if 'submitted job_id=' in ln:
# format: [launch] submitted job_id=<id> status=<stage> url=...
tail = ln.split('submitted job_id=', 1)[1]
job_id = tail.split()[0].strip()
break
return job_id
def poll_until_done(job_id: str, poll_s: int = 30, max_wait_s: int = 1800) -> str:
"""Poll HF Jobs API until the job leaves the running/pending state or we
exceed max_wait_s. Returns final stage string."""
try:
from huggingface_hub import HfApi # type: ignore
except Exception as e:
print(f'[sweep] cannot poll (huggingface_hub missing: {e})', flush=True)
return 'UNKNOWN'
api = HfApi(token=os.environ.get('HF_TOKEN'))
t0 = time.time()
last_stage = None
while True:
try:
j = api.inspect_job(job_id=job_id)
stage = getattr(j.status, 'stage', None) if hasattr(j, 'status') else None
except Exception as e:
print(f'[sweep] poll error job={job_id} err={e}', flush=True)
stage = None
if stage != last_stage:
print(f'[sweep] job={job_id} stage={stage}', flush=True)
last_stage = stage
if stage in {'COMPLETED', 'ERROR', 'CANCELED', 'FAILED'}:
return stage or 'UNKNOWN'
if time.time() - t0 > max_wait_s:
print(f'[sweep] timed out waiting for job={job_id}', flush=True)
return stage or 'TIMEOUT'
time.sleep(poll_s)
def main() -> int:
if not os.environ.get('HF_TOKEN'):
print('ERROR: HF_TOKEN must be set', file=sys.stderr)
return 2
print(f'[sweep] plan: n_layers={SWEEP_N_LAYERS} d_model={SWEEP_D_MODEL} '
f'target_active={TARGET_ACTIVE} time_budget={TIME_BUDGET}s mid_val={MID_VAL}',
flush=True)
# If using Space image, upload once now; all subsequent launches reuse it.
use_space = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1'
if use_space:
print('[sweep] Space image mode: uploading overlay now, subsequent '
'launches will skip upload', flush=True)
# --- Pre-warm retina cache ---
if not SKIP_PREWARM:
print('[sweep] === PRE-WARM retina cache ===', flush=True)
prewarm_env = {
'HYDRA_N_LAYER': '2',
'HYDRA_D_MODEL': SWEEP_D_MODEL,
'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE,
# Minimal training — just enough to force retina build + upload.
'HYDRA_TIME_BUDGET': '30',
'HYDRA_CKPT_INTERVAL': '0',
'HYDRA_MID_VAL_INTERVAL': '0',
'HYDRA_LAYER_DIAGNOSTICS': '0', # no need during pre-warm
'HYDRA_METRICS_OUT': '/tmp/prewarm_metrics.json',
}
prewarm_id = launch(prewarm_env, PREWARM_TIMEOUT)
# After the first launch, Space image (if used) is built — skip re-upload.
if use_space:
os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1'
if not prewarm_id:
print('[sweep] pre-warm failed to submit', flush=True)
return 3
print(f'[sweep] pre-warm job={prewarm_id}, waiting for completion...', flush=True)
stage = poll_until_done(prewarm_id, poll_s=20, max_wait_s=1800)
print(f'[sweep] pre-warm finished stage={stage}', flush=True)
if stage not in {'COMPLETED'}:
print(f'[sweep] WARNING: pre-warm did not COMPLETE (stage={stage}); '
f'sweep jobs will each rebuild retina. Proceeding anyway.',
flush=True)
else:
print('[sweep] SKIP_PREWARM=1; assuming retina cache already populated', flush=True)
# --- Fan out sweep jobs (concurrent) ---
print('[sweep] === FAN OUT n_layer sweep ===', flush=True)
sweep_jobs = {}
for idx, n_layer in enumerate(SWEEP_N_LAYERS):
env_extra = {
'HYDRA_N_LAYER': str(n_layer),
'HYDRA_D_MODEL': SWEEP_D_MODEL,
'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE,
'HYDRA_TIME_BUDGET': TIME_BUDGET,
'HYDRA_CKPT_INTERVAL': '0',
'HYDRA_LAYER_DIAGNOSTICS': '1',
'HYDRA_MID_VAL_INTERVAL': MID_VAL,
'HYDRA_METRICS_OUT': f'/tmp/sweep_n{n_layer}_metrics.json',
}
jid = launch(env_extra, SWEEP_TIMEOUT)
# After the first launch in Space-image mode, mark skip-upload for the rest.
if use_space and idx == 0:
os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1'
if jid:
sweep_jobs[n_layer] = jid
print(f'[sweep] n_layer={n_layer} -> job_id={jid}', flush=True)
else:
print(f'[sweep] n_layer={n_layer} FAILED to submit', flush=True)
print('[sweep] === SWEEP SUBMITTED ===', flush=True)
print('[sweep] tracked jobs:', flush=True)
for n, j in sweep_jobs.items():
print(f' n_layer={n:2d} job_id={j}', flush=True)
# Write manifest so the aggregator can find them.
manifest = Path('/tmp/sweep_depth_manifest.txt')
manifest.write_text(
'n_layer\tjob_id\tmetrics_path\n' +
'\n'.join(
f'{n}\t{j}\t/tmp/sweep_n{n}_metrics.json'
for n, j in sweep_jobs.items()
) + '\n'
)
print(f'[sweep] manifest -> {manifest}', flush=True)
return 0
if __name__ == '__main__':
raise SystemExit(main())