icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
#!/usr/bin/env python3
from __future__ import annotations
import json
import os
import subprocess
import sys
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from threading import Thread
# =============================================================================
# EARLY CUDA FABRIC MANAGER KICK (before ANY CUDA-touching imports)
# =============================================================================
# On H200 hosts, cudaGetDeviceCount can return Error 802 "system not yet
# initialized" on first use, because nvidia-fabricmanager on the host
# synchronizes with the container's first driver call. Once any NVML/CUDA
# call succeeds once (even just nvidia-smi), the fabric is up for the rest
# of the container lifetime.
#
# Our previous approach (wait in a subprocess before training) didn't work
# because the "initialization failed" state persisted across calls in the
# same container. The real fix: kick the driver exactly once with
# nvidia-smi, which is what successfully-working baseline containers do
# implicitly via their first torch.cuda call.
#
# Must happen BEFORE `import torch` (because any import that eagerly calls
# cudaGetDeviceCount will cache the Error 802 state).
def _early_cuda_kick() -> None:
deadline = time.time() + 120.0
attempt = 0
while time.time() < deadline:
attempt += 1
r = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=30)
if r.returncode == 0 and 'H200' in (r.stdout or '') or 'H100' in (r.stdout or '') \
or 'A100' in (r.stdout or '') or r.returncode == 0:
print(f'[boot] nvidia-smi OK on attempt {attempt}', flush=True)
break
print(f'[boot] nvidia-smi attempt {attempt} rc={r.returncode} stderr={(r.stderr or "")[:120]}',
flush=True)
time.sleep(2)
# After nvidia-smi, probe torch in a subprocess so any latent error state
# doesn't leak into the main process's CUDA context.
probe = 'import torch; import sys; sys.exit(0 if torch.cuda.is_available() else 1)'
torch_deadline = time.time() + 120.0
t_attempt = 0
while time.time() < torch_deadline:
t_attempt += 1
r = subprocess.run([sys.executable, '-c', probe], capture_output=True, text=True, timeout=60)
if r.returncode == 0:
print(f'[boot] torch.cuda.is_available() = True after {t_attempt} probe(s)', flush=True)
return
if t_attempt == 1:
print(f'[boot] torch cuda probe {t_attempt}: {(r.stderr or "")[:200]}', flush=True)
time.sleep(2)
print('[boot] WARNING: torch.cuda never became ready — training will likely fail', flush=True)
_early_cuda_kick()
# Hydrate triton compilation cache from HF Hub before any triton/mamba_ssm import.
# triton_cache_setup.py is copied next to this file by the job bash command.
try:
import triton_cache_setup as _tcs
_tcs.setup()
except ImportError:
print('[boot] triton_cache_setup not found; skipping cache hydrate', flush=True)
from huggingface_hub import HfApi # noqa: E402 (import after cuda kick)
REPO_ROOT = Path('/workspace/feather')
CACHE_ROOT = Path.home() / '.cache' / 'autoresearch'
LOG_FILE = REPO_ROOT / 'run_domain_expanded.log'
JOB_ID = os.environ.get('JOB_ID', 'local-job')
OUTPUT_REPO = os.environ.get('HF_REPO_ID', 'icarus112/feather-pretrain-checkpoints')
TOKEN = os.environ.get('HF_TOKEN')
RUNTIME_MODE = os.environ.get('FEATHER_RUNTIME_MODE', 'space')
APP_PORT = int(os.environ.get('PORT', '7860'))
class _HealthHandler(BaseHTTPRequestHandler):
def do_GET(self):
if self.path in ('/', '/health', '/healthz', '/ready'):
payload = {
'status': 'ok',
'mode': RUNTIME_MODE,
'job_id': JOB_ID,
}
body = json.dumps(payload).encode('utf-8')
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(body)))
self.end_headers()
self.wfile.write(body)
return
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
return
def _start_health_server() -> HTTPServer:
server = HTTPServer(('0.0.0.0', APP_PORT), _HealthHandler)
thread = Thread(target=server.serve_forever, daemon=True)
thread.start()
print(f'[space] health server listening on 0.0.0.0:{APP_PORT}', flush=True)
return server
def upload_artifact(api: HfApi, path: Path, dest: str) -> None:
if not path.exists():
print(f'[upload] skip missing {path}', flush=True)
return
api.upload_file(
path_or_fileobj=str(path),
path_in_repo=dest,
repo_id=OUTPUT_REPO,
repo_type='model',
)
print(f'[upload] uploaded {path} -> {OUTPUT_REPO}/{dest}', flush=True)
def _wait_for_cuda_ready(timeout_s: int = 120) -> None:
"""Block until CUDA is fully initialized or timeout.
On H200 hosts with NVSwitch/fabric manager, nvidia driver setup can race
with container start. cudaGetDeviceCount can return CUDA_ERROR_SYSTEM_NOT_READY
(error 802) for the first few seconds, and any import that triggers
@triton.autotune (e.g. mamba_ssm, torch amp utilities) blows up with
"0 active drivers" if it happens during that window.
We pre-init CUDA in a throwaway Python subprocess (so any error state does
not leak into the main training process) and retry until torch.cuda
reports ready.
"""
import time as _t
probe = (
"import torch; "
"import sys; "
"avail = torch.cuda.is_available(); "
"count = torch.cuda.device_count() if avail else 0; "
"sys.exit(0 if (avail and count > 0) else 1)"
)
deadline = _t.time() + timeout_s
attempt = 0
while _t.time() < deadline:
attempt += 1
r = subprocess.run(['python', '-c', probe], capture_output=True, text=True)
if r.returncode == 0:
print(f'[job] CUDA ready after {attempt} probe(s)', flush=True)
return
if attempt == 1:
print(f'[job] CUDA not ready yet (will retry up to {timeout_s}s): {r.stderr.strip()[:200]}', flush=True)
_t.sleep(2)
print(f'[job] CUDA still not ready after {timeout_s}s — continuing anyway (training will likely fail)', flush=True)
def run_job_mode() -> int:
os.chdir(REPO_ROOT)
# Dynamic live patch from GitHub to bypass Space build errors
GIT_REF = os.environ.get('FEATHER_GIT_REF')
if GIT_REF and (REPO_ROOT / '.git').exists():
print(f'[bootstrap] dynamic sync to {GIT_REF}...', flush=True)
subprocess.run(['git', 'fetch', 'origin'], cwd=REPO_ROOT, check=False)
subprocess.run(['git', 'checkout', GIT_REF], cwd=REPO_ROOT, check=False)
elif GIT_REF:
print(f'[bootstrap] skipping dynamic sync (no .git in {REPO_ROOT})', flush=True)
os.environ.setdefault('HYDRA_TIME_BUDGET', '43200')
os.environ.setdefault('HYDRA_TARGET_SHARDS', '2048')
os.environ.setdefault('HYDRA_DOWNLOAD_WORKERS', '16')
os.environ.setdefault('HYDRA_CKPT_INTERVAL', '1000')
os.environ.setdefault('HYDRA_RESUME_CKPT', str(CACHE_ROOT / 'latest.pt'))
# CUDA readiness was kicked at module import via _early_cuda_kick. Keep
# the wait as a second safety net — no-op if CUDA already ready.
_wait_for_cuda_ready()
cmd = [
'bash',
'./scripts/run_domain_expanded_pretrain.sh',
'--target-shards', os.environ['HYDRA_TARGET_SHARDS'],
'--download-workers', os.environ['HYDRA_DOWNLOAD_WORKERS'],
]
print('[job] starting Feather domain-expanded pretrain', flush=True)
print(f'[job] command={cmd}', flush=True)
proc = subprocess.run(cmd, check=False)
# Push triton compilation cache back to HF Hub for next run.
try:
import triton_cache_setup as _tcs
_tcs.teardown()
except Exception as _tcs_err:
print(f'[triton_cache] teardown error (non-fatal): {_tcs_err}', flush=True)
if TOKEN:
api = HfApi(token=TOKEN)
try:
api.create_repo(repo_id=OUTPUT_REPO, repo_type='model', private=True, exist_ok=True)
except Exception as e:
print(f'[upload] create_repo warning: {type(e).__name__}: {e}', flush=True)
prefix = f'jobs/{JOB_ID}'
try:
upload_artifact(api, LOG_FILE, f'{prefix}/run_domain_expanded.log')
upload_artifact(api, CACHE_ROOT / 'latest.pt', f'{prefix}/latest.pt')
upload_artifact(api, CACHE_ROOT / 'pretrain_final.pt', f'{prefix}/pretrain_final.pt')
except Exception as e:
print(f'[upload] upload warning: {type(e).__name__}: {e}', flush=True)
else:
print('[upload] HF_TOKEN not set; skipping artifact upload', flush=True)
return proc.returncode
def run_space_mode() -> int:
server = _start_health_server()
print('[space] Feather runtime image ready', flush=True)
try:
while True:
time.sleep(3600)
finally:
server.shutdown()
server.server_close()
def main() -> int:
if RUNTIME_MODE == 'job':
return run_job_mode()
return run_space_mode()
if __name__ == '__main__':
raise SystemExit(main())