#!/usr/bin/env python3 from __future__ import annotations import json import os import subprocess import sys import tempfile import time from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from threading import Thread # ============================================================================= # EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports) # ============================================================================= # On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet # initialized" on first use, because nvidia-fabricmanager on the host # synchronizes with the container's first driver call. Once any NVML/CUDA # call succeeds once (even just nvidia-smi), the fabric is up for the rest # of the container lifetime. # # Our previous approach (wait in a subprocess before training) didn't work # because the "initialization failed" state persisted across calls in the # same container. The real fix: kick the driver exactly once with # nvidia-smi, which is what successfully-working baseline containers do # implicitly via their first torch.cuda call. # # Must happen BEFORE `import torch` (because any import that eagerly calls # cudaGetDeviceCount will cache the Error 802 state). def _early_cuda_kick() -> None: deadline = time.time() + 120.0 attempt = 0 while time.time() < deadline: attempt += 1 r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30) if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \ or 'A100' in (r.stdout or '') or r.returncode == 0: print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True) break print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}', flush=True) time.sleep(2) # After nvidia-smi, probe torch in a subprocess so any latent error state # doesn't leak into the main process's CUDA context. probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)' torch_deadline = time.time() + 120.0 t_attempt = 0 while time.time() < torch_deadline: t_attempt += 1 r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60) if r.returncode == 0: print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True) return if t_attempt == 1: print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True) time.sleep(2) print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True) _early_cuda_kick() # Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import. # triton_cache_setup.py is copied next to this file by the job bash command. try: import triton_cache_setup as _tcs _tcs.setup() except ImportError: print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True) from huggingface_hub import HfApi # noqa: E402 (import after cuda kick) if '/workspace/feather' not in sys.path: # noqa: E402 sys.path.insert(0, '/workspace/feather') from scripts.benchmark_assets import hydrate_benchmark_assets # noqa: E402 from subsystems.sdr_retina import build_retina # noqa: E402 REPO_ROOT = Path('/workspace/feather') CACHE_ROOT = Path.home() / '.cache' / 'autoresearch' LOG_FILE = REPO_ROOT / 'run_domain_expanded.log' JOB_ID = os.environ.get('JOB_ID', 'local-job') OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints') TOKEN = os.environ.get('HF_TOKEN') RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space') APP_PORT = int(os.environ.get('PORT', '7860')) class _HealthHandler(BaseHTTPRequestHandler): def do_GET(self): if self.path in ('/', '/health', '/healthz', '/ready'): payload = { 'status': 'ok', 'mode': RUNTIME_MODE, 'job_id': JOB_ID, } body = json.dumps(payload).encode('utf-8') self.send_response(200) self.send_header('Content-Type', 'application/json') self.send_header('Content-Length', str(len(body))) self.end_headers() self.wfile.write(body) return self.send_response(404) self.end_headers() def log_message(self, format, *args): return def _start_health_server() -> HTTPServer: server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler) thread = Thread(target=server.serve_forever, daemon=True) thread.start() print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True) return server def upload_artifact(api: HfApi, path: Path, dest: str) -> None: if not path.exists(): print(f'[upload] skip missing {path}', flush=True) return api.upload_file( path_or_fileobj=str(path), path_in_repo=dest, repo_id=OUTPUT_REPO, repo_type='model', ) print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True) def emit_runtime_state(state: str, **extra: object) -> None: payload: dict[str, object] = { 'job_id': JOB_ID, 'mode': RUNTIME_MODE, 'state': state, 'timestamp': int(time.time()), } payload.update(extra) print(f'[state] {json.dumps(payload, sort_keys=True)}', flush=True) if not TOKEN: return try: api = HfApi(token=TOKEN) api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True) prefix = f'jobs/{JOB_ID}/runtime_state' with tempfile.NamedTemporaryFile('w', encoding='utf-8', suffix='.json', delete=False) as handle: json.dump(payload, handle, indent=2, sort_keys=True) handle.write('\n') temp_path = Path(handle.name) try: upload_artifact(api, temp_path, f'{prefix}/{state}.json') finally: temp_path.unlink(missing_ok=True) except Exception as exc: print(f'[state] upload warning: {type(exc).__name__}: {exc}', flush=True) def rebuild_seeded_components_if_requested() -> None: if os.environ.get('FEATHER_REBUILD_SEEDED_COMPONENTS', '0') != '1': return targets = [ CACHE_ROOT / 'tokenizer', CACHE_ROOT / 'retina.npz', ] for target in targets: try: if target.is_dir(): import shutil shutil.rmtree(target, ignore_errors=True) print(f'[seeded-rebuild] removed directory {target}', flush=True) elif target.exists(): target.unlink() print(f'[seeded-rebuild] removed file {target}', flush=True) except Exception as exc: print(f'[seeded-rebuild] warning removing {target}: {type(exc).__name__}: {exc}', flush=True) def hydrate_seeded_components_for_throughput_gate() -> None: rebuild_seeded_components_if_requested() if not TOKEN: print('[seeded-rebuild] HF_TOKEN not set; skipping tokenizer/retina hydration', flush=True) return try: tokenizer_repo = os.environ.get('HYDRA_TOKENIZER_CACHE_REPO', OUTPUT_REPO) assets = hydrate_benchmark_assets( cache_dir=CACHE_ROOT, output_repo=OUTPUT_REPO, tokenizer_repo=tokenizer_repo, token=TOKEN, ) print(f'[seeded-rebuild] hydrated benchmark assets: {assets}', flush=True) except Exception as exc: print(f'[seeded-rebuild] tokenizer/checkpoint hydrate warning: {type(exc).__name__}: {exc}', flush=True) tokenizer_file = CACHE_ROOT / 'tokenizer' / 'tokenizer.pkl' if not tokenizer_file.exists(): shards = os.environ.get('HYDRA_TARGET_SHARDS', '1') workers = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '2') print( f'[seeded-rebuild] tokenizer missing at {tokenizer_file}; running prepare.py --num-shards {shards}', flush=True, ) prep = subprocess.run( [sys.executable, 'prepare.py', '--num-shards', shards, '--download-workers', workers], cwd=str(REPO_ROOT), check=False, ) print(f'[seeded-rebuild] tokenizer prepare exit={prep.returncode}', flush=True) if prep.returncode != 0: raise RuntimeError(f'tokenizer prepare failed with exit {prep.returncode}') try: report = build_retina() print( '[seeded-rebuild] retina ready ' f'(vocab_size={report.vocab_size} n_bits={report.n_bits} train_tokens={report.train_tokens} wall={report.wall_time_sec:.1f}s)', flush=True, ) except Exception as exc: print(f'[seeded-rebuild] retina build warning: {type(exc).__name__}: {exc}', flush=True) def _wait_for_cuda_ready(timeout_s: int = 120) -> None: """Block until CUDA is fully initialized or timeout. On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY (error 802) for the first few seconds, and any import that triggers @triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with "0 active drivers" if it happens during that window. We pre-init CUDA in a throwaway Python subprocess (so any error state does not leak into the main training process) and retry until torch.cuda reports ready. """ import time as _t probe = ( "import torch; " "import sys; " "avail = torch.cuda.is_available(); " "count = torch.cuda.device_count() if avail else 0; " "sys.exit(0 if (avail and count > 0) else 1)" ) deadline = _t.time() + timeout_s attempt = 0 while _t.time() < deadline: attempt += 1 r = subprocess.run(['python', '-c', probe], capture_output=True, text=True) if r.returncode == 0: print(f'[job] CUDA ready after {attempt} probe(s)', flush=True) return if attempt == 1: print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True) _t.sleep(2) print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True) def run_job_mode() -> int: os.chdir(REPO_ROOT) emit_runtime_state('job_boot', cwd=str(REPO_ROOT)) os.environ.setdefault('HYDRA_TIME_BUDGET', '43200') os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048') os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16') os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000') os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt')) # CUDA readiness was kicked at module import via _early_cuda_kick. Keep # the wait as a second safety net — no-op if CUDA already ready. _wait_for_cuda_ready() emit_runtime_state( 'job_cuda_ready', time_budget=os.environ['HYDRA_TIME_BUDGET'], target_shards=os.environ['HYDRA_TARGET_SHARDS'], ) cmd = [ 'bash', './scripts/run_domain_expanded_pretrain.sh', '--target-shards', os.environ['HYDRA_TARGET_SHARDS'], '--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'], ] print('[job] starting Feather domain-expanded pretrain', flush=True) print(f'[job] command={cmd}', flush=True) proc = subprocess.run(cmd, check=False) emit_runtime_state('job_finished', returncode=proc.returncode) # Push triton compilation cache back to HF Hub for next run. try: import triton_cache_setup as _tcs _tcs.teardown() except Exception as _tcs_err: print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True) if TOKEN: api = HfApi(token=TOKEN) try: api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True) except Exception as e: print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True) prefix = f'jobs/{JOB_ID}' try: upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log') upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt') upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt') except Exception as e: print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True) else: print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True) return proc.returncode def run_throughput_gate_mode() -> int: os.chdir(REPO_ROOT) emit_runtime_state('throughput_gate_boot', cwd=str(REPO_ROOT)) os.environ.setdefault('HYDRA_TIME_BUDGET', '300') _wait_for_cuda_ready() hydrate_seeded_components_for_throughput_gate() emit_runtime_state( 'throughput_gate_cuda_ready', bench_config=os.environ.get('FEATHER_BENCH_CONFIG', 'fullarch_a10'), time_budget=os.environ['HYDRA_TIME_BUDGET'], min_tps=os.environ.get('FEATHER_BENCH_MIN_TPS', '7000'), warmup_steps=os.environ.get('FEATHER_BENCH_WARMUP_STEPS', '5'), ) cmd = [ 'python', str(REPO_ROOT / 'scripts' / 'benchmark_hyena_stack.py'), '--config', os.environ.get('FEATHER_BENCH_CONFIG', 'fullarch_a10'), '--time', os.environ['HYDRA_TIME_BUDGET'], '--min-tps', os.environ.get('FEATHER_BENCH_MIN_TPS', '7000'), '--warmup-steps', os.environ.get('FEATHER_BENCH_WARMUP_STEPS', '5'), ] print('[job] starting throughput gate', flush=True) print(f'[job] command={cmd}', flush=True) proc = subprocess.run(cmd, check=False) emit_runtime_state('throughput_gate_finished', returncode=proc.returncode) return proc.returncode def run_space_mode() -> int: server = _start_health_server() print('[space] Feather runtime image ready', flush=True) try: while True: time.sleep(3600) finally: server.shutdown() server.server_close() def main() -> int: emit_runtime_state('entrypoint_started') if RUNTIME_MODE == 'job': return run_job_mode() if RUNTIME_MODE == 'throughput-gate': return run_throughput_gate_mode() return run_space_mode() if __name__ == '__main__': raise SystemExit(main())