#!/usr/bin/env python3 """Depth-sweep driver: pre-warm retina for HYDRA_SDR_TARGET_ACTIVE, then fan out N parallel HF Jobs with different HYDRA_N_LAYER values, each running with full per-layer diagnostics. Collects job IDs for downstream monitoring. Usage: export HF_TOKEN=... # Optional overrides: export HYDRA_SDR_TARGET_ACTIVE=137 export HYDRA_TIME_BUDGET=300 # 5 min training per job export HYDRA_MID_VAL_INTERVAL=250 # per-layer diag panel cadence export SWEEP_N_LAYERS=2,3,4,5,6,8 export SWEEP_D_MODEL=768 export SWEEP_SKIP_PREWARM=0 # set =1 if retina cache already populated python scripts/sweep_depth.py """ from __future__ import annotations import os import subprocess import sys import time from pathlib import Path REPO_ROOT = Path(__file__).resolve().parents[1] LAUNCHER = REPO_ROOT / 'scripts' / 'launch_feather_hf_job.py' SWEEP_N_LAYERS = [int(v) for v in os.environ.get('SWEEP_N_LAYERS', '2,3,4,5,6,8').split(',')] SWEEP_D_MODEL = os.environ.get('SWEEP_D_MODEL', '768') SKIP_PREWARM = os.environ.get('SWEEP_SKIP_PREWARM', '0') == '1' TARGET_ACTIVE = os.environ.get('HYDRA_SDR_TARGET_ACTIVE', '327') # Short budget — we want diagnostic signal, not convergence. TIME_BUDGET = os.environ.get('HYDRA_TIME_BUDGET', '300') MID_VAL = os.environ.get('HYDRA_MID_VAL_INTERVAL', '250') # Short timeout for pre-warm; sweep jobs get full 12h (no extension of wall). PREWARM_TIMEOUT = os.environ.get('SWEEP_PREWARM_TIMEOUT', '30m') SWEEP_TIMEOUT = os.environ.get('SWEEP_TIMEOUT', '60m') def launch(env_extra: dict, timeout: str) -> str | None: """Invoke launch_feather_hf_job.py with the given env overlay, parse job_id.""" env = dict(os.environ) env.update(env_extra) env['FEATHER_HF_JOB_TIMEOUT'] = timeout # Always enable diagnostics + JSON emission for sweep jobs. env.setdefault('HYDRA_LAYER_DIAGNOSTICS', '1') env.setdefault('HYDRA_MID_VAL_INTERVAL', MID_VAL) env.setdefault('HYDRA_USE_NEMOTRON', '1') print(f'[sweep] launching with env overrides: {env_extra}', flush=True) proc = subprocess.run( [sys.executable, str(LAUNCHER)], env=env, capture_output=True, text=True, ) sys.stdout.write(proc.stdout) sys.stderr.write(proc.stderr) if proc.returncode != 0: print(f'[sweep] launcher exited {proc.returncode}', flush=True) return None job_id = None for ln in proc.stdout.splitlines(): if 'submitted job_id=' in ln: # format: [launch] submitted job_id= status= url=... tail = ln.split('submitted job_id=', 1)[1] job_id = tail.split()[0].strip() break return job_id def poll_until_done(job_id: str, poll_s: int = 30, max_wait_s: int = 1800) -> str: """Poll HF Jobs API until the job leaves the running/pending state or we exceed max_wait_s. Returns final stage string.""" try: from huggingface_hub import HfApi # type: ignore except Exception as e: print(f'[sweep] cannot poll (huggingface_hub missing: {e})', flush=True) return 'UNKNOWN' api = HfApi(token=os.environ.get('HF_TOKEN')) t0 = time.time() last_stage = None while True: try: j = api.inspect_job(job_id=job_id) stage = getattr(j.status, 'stage', None) if hasattr(j, 'status') else None except Exception as e: print(f'[sweep] poll error job={job_id} err={e}', flush=True) stage = None if stage != last_stage: print(f'[sweep] job={job_id} stage={stage}', flush=True) last_stage = stage if stage in {'COMPLETED', 'ERROR', 'CANCELED', 'FAILED'}: return stage or 'UNKNOWN' if time.time() - t0 > max_wait_s: print(f'[sweep] timed out waiting for job={job_id}', flush=True) return stage or 'TIMEOUT' time.sleep(poll_s) def main() -> int: if not os.environ.get('HF_TOKEN'): print('ERROR: HF_TOKEN must be set', file=sys.stderr) return 2 print(f'[sweep] plan: n_layers={SWEEP_N_LAYERS} d_model={SWEEP_D_MODEL} ' f'target_active={TARGET_ACTIVE} time_budget={TIME_BUDGET}s mid_val={MID_VAL}', flush=True) # If using Space image, upload once now; all subsequent launches reuse it. use_space = os.environ.get('FEATHER_HF_USE_SPACE_IMAGE', '0') == '1' if use_space: print('[sweep] Space image mode: uploading overlay now, subsequent ' 'launches will skip upload', flush=True) # --- Pre-warm retina cache --- if not SKIP_PREWARM: print('[sweep] === PRE-WARM retina cache ===', flush=True) prewarm_env = { 'HYDRA_N_LAYER': '2', 'HYDRA_D_MODEL': SWEEP_D_MODEL, 'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE, # Minimal training — just enough to force retina build + upload. 'HYDRA_TIME_BUDGET': '30', 'HYDRA_CKPT_INTERVAL': '0', 'HYDRA_MID_VAL_INTERVAL': '0', 'HYDRA_LAYER_DIAGNOSTICS': '0', # no need during pre-warm 'HYDRA_METRICS_OUT': '/tmp/prewarm_metrics.json', } prewarm_id = launch(prewarm_env, PREWARM_TIMEOUT) # After the first launch, Space image (if used) is built — skip re-upload. if use_space: os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1' if not prewarm_id: print('[sweep] pre-warm failed to submit', flush=True) return 3 print(f'[sweep] pre-warm job={prewarm_id}, waiting for completion...', flush=True) stage = poll_until_done(prewarm_id, poll_s=20, max_wait_s=1800) print(f'[sweep] pre-warm finished stage={stage}', flush=True) if stage not in {'COMPLETED'}: print(f'[sweep] WARNING: pre-warm did not COMPLETE (stage={stage}); ' f'sweep jobs will each rebuild retina. Proceeding anyway.', flush=True) else: print('[sweep] SKIP_PREWARM=1; assuming retina cache already populated', flush=True) # --- Fan out sweep jobs (concurrent) --- print('[sweep] === FAN OUT n_layer sweep ===', flush=True) sweep_jobs = {} for idx, n_layer in enumerate(SWEEP_N_LAYERS): env_extra = { 'HYDRA_N_LAYER': str(n_layer), 'HYDRA_D_MODEL': SWEEP_D_MODEL, 'HYDRA_SDR_TARGET_ACTIVE': TARGET_ACTIVE, 'HYDRA_TIME_BUDGET': TIME_BUDGET, 'HYDRA_CKPT_INTERVAL': '0', 'HYDRA_LAYER_DIAGNOSTICS': '1', 'HYDRA_MID_VAL_INTERVAL': MID_VAL, 'HYDRA_METRICS_OUT': f'/tmp/sweep_n{n_layer}_metrics.json', } jid = launch(env_extra, SWEEP_TIMEOUT) # After the first launch in Space-image mode, mark skip-upload for the rest. if use_space and idx == 0: os.environ['FEATHER_HF_SKIP_UPLOAD'] = '1' if jid: sweep_jobs[n_layer] = jid print(f'[sweep] n_layer={n_layer} -> job_id={jid}', flush=True) else: print(f'[sweep] n_layer={n_layer} FAILED to submit', flush=True) print('[sweep] === SWEEP SUBMITTED ===', flush=True) print('[sweep] tracked jobs:', flush=True) for n, j in sweep_jobs.items(): print(f' n_layer={n:2d} job_id={j}', flush=True) # Write manifest so the aggregator can find them. manifest = Path('/tmp/sweep_depth_manifest.txt') manifest.write_text( 'n_layer\tjob_id\tmetrics_path\n' + '\n'.join( f'{n}\t{j}\t/tmp/sweep_n{n}_metrics.json' for n, j in sweep_jobs.items() ) + '\n' ) print(f'[sweep] manifest -> {manifest}', flush=True) return 0 if __name__ == '__main__': raise SystemExit(main())