icarus112's picture
Update Feather a10g-large training runtime image
e5cf7c3 verified
from __future__ import annotations
"""Diagnostic script for torch.compile deadlock after ~500 steps.
F17 investigation: validates that the _compiled_core / forward split
fixes the deadlock by running forward+backward loops with compile on.
Usage:
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
.venv/bin/python -u scripts/compile_debug.py [mode]
Modes:
eager - no compile (baseline)
model_only - compile model _compiled_core only
muon_only - compile muon step only
both - compile both (default)
"""
import gc
import os
import signal
import sys
import threading
import time
# Set CUDA env before torch import
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------------------------------
# Config
# -------------------------------------------------------------------------
MAX_STEPS = 800
WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
SEQ_LEN = 2048
VOCAB_SIZE = 8192
# -------------------------------------------------------------------------
# Watchdog thread: kills process if no progress
# -------------------------------------------------------------------------
_last_progress = time.time()
_watchdog_armed = True
def _watchdog_fn():
global _last_progress, _watchdog_armed
while _watchdog_armed:
time.sleep(1.0)
elapsed = time.time() - _last_progress
if elapsed > WATCHDOG_TIMEOUT_S:
print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s — DEADLOCK DETECTED ***",
flush=True)
_dump_diagnostics()
os.kill(os.getpid(), signal.SIGTERM)
return
def _dump_diagnostics():
"""Dump CUDA/dynamo state at deadlock time."""
try:
stats = torch.cuda.memory_stats()
print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
print(f" num_ooms: {stats.get('num_ooms', 0)}")
except Exception as e:
print(f" (memory_stats failed: {e})")
try:
import torch._dynamo.utils as du
print(f" dynamo counters: {dict(du.counters)}")
except Exception as e:
print(f" (dynamo counters failed: {e})")
def tick():
global _last_progress
_last_progress = time.time()
# -------------------------------------------------------------------------
# Test
# -------------------------------------------------------------------------
def run_test(mode: str) -> dict:
"""Run forward+backward loop with specified compile config."""
print(f"\n{'='*70}")
print(f"TEST MODE: {mode}")
print(f"{'='*70}", flush=True)
compile_model = mode in ("model_only", "both")
compile_muon = mode in ("muon_only", "both")
os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
# Clear cached modules for fresh env var reads
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("hydra."):
del sys.modules[mod_name]
torch._dynamo.reset()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
from hydra.model import PostSemClawModel
from hydra.config import PostSemClawConfig
device = torch.device("cuda")
config = PostSemClawConfig(
d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
)
with torch.device("meta"):
model = PostSemClawModel(config)
model.to_empty(device=device)
model.init_weights()
optimizer = model.setup_optimizer()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
result = {"mode": mode, "max_step": 0, "tps_samples": []}
alloc_retries_prev = 0
tick()
for step in range(MAX_STEPS):
t0 = time.time()
x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
with autocast_ctx:
loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
dt = time.time() - t0
tps = int(BATCH_SIZE * SEQ_LEN / dt)
tick()
stats = torch.cuda.memory_stats()
retries = stats.get("num_alloc_retries", 0)
retry_delta = retries - alloc_retries_prev
alloc_retries_prev = retries
result["max_step"] = step
if step % 50 == 0 or retry_delta > 0 or step < 3:
alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
print(
f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
f"alloc={alloc_mb:.0f}MB retries={retries}",
flush=True,
)
result["tps_samples"].append((step, tps))
result["completed"] = True
print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
return result
def main():
print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
wd = threading.Thread(target=_watchdog_fn, daemon=True)
wd.start()
modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
results = []
for mode in modes:
try:
r = run_test(mode)
except SystemExit:
print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
r = {"mode": mode, "completed": False, "max_step": "?"}
except Exception as e:
print(f"\n ERROR mode={mode}: {e}", flush=True)
r = {"mode": mode, "completed": False, "error": str(e)}
results.append(r)
print(f"\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
for r in results:
status = "PASS" if r.get("completed") else "FAIL"
print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
global _watchdog_armed
_watchdog_armed = False
if __name__ == "__main__":
main()