feather-a10-runtime / entrypoint.py
Jackoatmon's picture
Update Feather H200 training runtime image
2bddef6 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import subprocess
import sys
import tempfile
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from threading import Thread
# =============================================================================
# EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
# =============================================================================
# On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
# initialized" on first use, because nvidia-fabricmanager on the host
# synchronizes with the container's first driver call. Once any NVML/CUDA
# call succeeds once (even just nvidia-smi), the fabric is up for the rest
# of the container lifetime.
#
# Our previous approach (wait in a subprocess before training) didn't work
# because the "initialization failed" state persisted across calls in the
# same container. The real fix: kick the driver exactly once with
# nvidia-smi, which is what successfully-working baseline containers do
# implicitly via their first torch.cuda call.
#
# Must happen BEFORE `import torch` (because any import that eagerly calls
# cudaGetDeviceCount will cache the Error 802 state).
def _early_cuda_kick() -> None:
deadline = time.time() + 120.0
attempt = 0
while time.time() < deadline:
attempt += 1
r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
or 'A100' in (r.stdout or '') or r.returncode == 0:
print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
break
print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
flush=True)
time.sleep(2)
# After nvidia-smi, probe torch in a subprocess so any latent error state
# doesn't leak into the main process's CUDA context.
probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
torch_deadline = time.time() + 120.0
t_attempt = 0
while time.time() < torch_deadline:
t_attempt += 1
r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
if r.returncode == 0:
print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
return
if t_attempt == 1:
print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
time.sleep(2)
print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
_early_cuda_kick()
# Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
# triton_cache_setup.py is copied next to this file by the job bash command.
try:
import triton_cache_setup as _tcs
_tcs.setup()
except ImportError:
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
if '/workspace/feather' not in sys.path: # noqa: E402
sys.path.insert(0, '/workspace/feather')
from scripts.benchmark_assets import hydrate_benchmark_assets # noqa: E402
from subsystems.sdr_retina import build_retina # noqa: E402
REPO_ROOT = Path('/workspace/feather')
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
JOB_ID = os.environ.get('JOB_ID', 'local-job')
OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
TOKEN = os.environ.get('HF_TOKEN')
RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
APP_PORT = int(os.environ.get('PORT', '7860'))
class _HealthHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path in ('/', '/health', '/healthz', '/ready'):
payload = {
'status': 'ok',
'mode': RUNTIME_MODE,
'job_id': JOB_ID,
}
body = json.dumps(payload).encode('utf-8')
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(body)))
self.end_headers()
self.wfile.write(body)
return
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
return
def _start_health_server() -> HTTPServer:
server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
thread = Thread(target=server.serve_forever, daemon=True)
thread.start()
print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
return server
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
if not path.exists():
print(f'[upload] skip missing {path}', flush=True)
return
api.upload_file(
path_or_fileobj=str(path),
path_in_repo=dest,
repo_id=OUTPUT_REPO,
repo_type='model',
)
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
def emit_runtime_state(state: str, **extra: object) -> None:
payload: dict[str, object] = {
'job_id': JOB_ID,
'mode': RUNTIME_MODE,
'state': state,
'timestamp': int(time.time()),
}
payload.update(extra)
print(f'[state] {json.dumps(payload, sort_keys=True)}', flush=True)
if not TOKEN:
return
try:
api = HfApi(token=TOKEN)
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
prefix = f'jobs/{JOB_ID}/runtime_state'
with tempfile.NamedTemporaryFile('w', encoding='utf-8', suffix='.json', delete=False) as handle:
json.dump(payload, handle, indent=2, sort_keys=True)
handle.write('\n')
temp_path = Path(handle.name)
try:
upload_artifact(api, temp_path, f'{prefix}/{state}.json')
finally:
temp_path.unlink(missing_ok=True)
except Exception as exc:
print(f'[state] upload warning: {type(exc).__name__}: {exc}', flush=True)
def rebuild_seeded_components_if_requested() -> None:
if os.environ.get('FEATHER_REBUILD_SEEDED_COMPONENTS', '0') != '1':
return
targets = [
CACHE_ROOT / 'tokenizer',
CACHE_ROOT / 'retina.npz',
]
for target in targets:
try:
if target.is_dir():
import shutil
shutil.rmtree(target, ignore_errors=True)
print(f'[seeded-rebuild] removed directory {target}', flush=True)
elif target.exists():
target.unlink()
print(f'[seeded-rebuild] removed file {target}', flush=True)
except Exception as exc:
print(f'[seeded-rebuild] warning removing {target}: {type(exc).__name__}: {exc}', flush=True)
def hydrate_seeded_components_for_throughput_gate() -> None:
rebuild_seeded_components_if_requested()
if not TOKEN:
print('[seeded-rebuild] HF_TOKEN not set; skipping tokenizer/retina hydration', flush=True)
return
try:
tokenizer_repo = os.environ.get('HYDRA_TOKENIZER_CACHE_REPO', OUTPUT_REPO)
assets = hydrate_benchmark_assets(
cache_dir=CACHE_ROOT,
output_repo=OUTPUT_REPO,
tokenizer_repo=tokenizer_repo,
token=TOKEN,
)
print(f'[seeded-rebuild] hydrated benchmark assets: {assets}', flush=True)
except Exception as exc:
print(f'[seeded-rebuild] tokenizer/checkpoint hydrate warning: {type(exc).__name__}: {exc}', flush=True)
tokenizer_file = CACHE_ROOT / 'tokenizer' / 'tokenizer.pkl'
if not tokenizer_file.exists():
shards = os.environ.get('HYDRA_TARGET_SHARDS', '1')
workers = os.environ.get('HYDRA_DOWNLOAD_WORKERS', '2')
print(
f'[seeded-rebuild] tokenizer missing at {tokenizer_file}; running prepare.py --num-shards {shards}',
flush=True,
)
prep = subprocess.run(
[sys.executable, 'prepare.py', '--num-shards', shards, '--download-workers', workers],
cwd=str(REPO_ROOT),
check=False,
)
print(f'[seeded-rebuild] tokenizer prepare exit={prep.returncode}', flush=True)
if prep.returncode != 0:
raise RuntimeError(f'tokenizer prepare failed with exit {prep.returncode}')
try:
report = build_retina()
print(
'[seeded-rebuild] retina ready '
f'(vocab_size={report.vocab_size} n_bits={report.n_bits} train_tokens={report.train_tokens} wall={report.wall_time_sec:.1f}s)',
flush=True,
)
except Exception as exc:
print(f'[seeded-rebuild] retina build warning: {type(exc).__name__}: {exc}', flush=True)
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
"""Block until CUDA is fully initialized or timeout.
On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
(error 802) for the first few seconds, and any import that triggers
@triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
"0 active drivers" if it happens during that window.
We pre-init CUDA in a throwaway Python subprocess (so any error state does
not leak into the main training process) and retry until torch.cuda
reports ready.
"""
import time as _t
probe = (
"import torch; "
"import sys; "
"avail = torch.cuda.is_available(); "
"count = torch.cuda.device_count() if avail else 0; "
"sys.exit(0 if (avail and count > 0) else 1)"
)
deadline = _t.time() + timeout_s
attempt = 0
while _t.time() < deadline:
attempt += 1
r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
if r.returncode == 0:
print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
return
if attempt == 1:
print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
_t.sleep(2)
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
def run_job_mode() -> int:
os.chdir(REPO_ROOT)
emit_runtime_state('job_boot', cwd=str(REPO_ROOT))
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
# CUDA readiness was kicked at module import via _early_cuda_kick. Keep
# the wait as a second safety net — no-op if CUDA already ready.
_wait_for_cuda_ready()
emit_runtime_state(
'job_cuda_ready',
time_budget=os.environ['HYDRA_TIME_BUDGET'],
target_shards=os.environ['HYDRA_TARGET_SHARDS'],
)
cmd = [
'bash',
'./scripts/run_domain_expanded_pretrain.sh',
'--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
'--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
]
print('[job] starting Feather domain-expanded pretrain', flush=True)
print(f'[job] command={cmd}', flush=True)
proc = subprocess.run(cmd, check=False)
emit_runtime_state('job_finished', returncode=proc.returncode)
# Push triton compilation cache back to HF Hub for next run.
try:
import triton_cache_setup as _tcs
_tcs.teardown()
except Exception as _tcs_err:
print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
if TOKEN:
api = HfApi(token=TOKEN)
try:
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
except Exception as e:
print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
prefix = f'jobs/{JOB_ID}'
try:
upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
except Exception as e:
print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
else:
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
return proc.returncode
def run_throughput_gate_mode() -> int:
os.chdir(REPO_ROOT)
emit_runtime_state('throughput_gate_boot', cwd=str(REPO_ROOT))
os.environ.setdefault('HYDRA_TIME_BUDGET', '300')
_wait_for_cuda_ready()
hydrate_seeded_components_for_throughput_gate()
emit_runtime_state(
'throughput_gate_cuda_ready',
bench_config=os.environ.get('FEATHER_BENCH_CONFIG', 'fullarch_a10'),
time_budget=os.environ['HYDRA_TIME_BUDGET'],
min_tps=os.environ.get('FEATHER_BENCH_MIN_TPS', '7000'),
warmup_steps=os.environ.get('FEATHER_BENCH_WARMUP_STEPS', '5'),
)
cmd = [
'python',
str(REPO_ROOT / 'scripts' / 'benchmark_hyena_stack.py'),
'--config', os.environ.get('FEATHER_BENCH_CONFIG', 'fullarch_a10'),
'--time', os.environ['HYDRA_TIME_BUDGET'],
'--min-tps', os.environ.get('FEATHER_BENCH_MIN_TPS', '7000'),
'--warmup-steps', os.environ.get('FEATHER_BENCH_WARMUP_STEPS', '5'),
]
print('[job] starting throughput gate', flush=True)
print(f'[job] command={cmd}', flush=True)
proc = subprocess.run(cmd, check=False)
emit_runtime_state('throughput_gate_finished', returncode=proc.returncode)
return proc.returncode
def run_space_mode() -> int:
server = _start_health_server()
print('[space] Feather runtime image ready', flush=True)
try:
while True:
time.sleep(3600)
finally:
server.shutdown()
server.server_close()
def main() -> int:
emit_runtime_state('entrypoint_started')
if RUNTIME_MODE == 'job':
return run_job_mode()
if RUNTIME_MODE == 'throughput-gate':
return run_throughput_gate_mode()
return run_space_mode()
if __name__ == '__main__':
raise SystemExit(main())