"""Diagnostic script for torch.compile deadlock after ~500 steps. F17 investigation: validates that the _compiled_core / forward split fixes the deadlock by running forward+backward loops with compile on. Usage: LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \ HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \ HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \ .venv/bin/python -u scripts/compile_debug.py [mode] Modes: eager - no compile (baseline) model_only - compile model _compiled_core only muon_only - compile muon step only both - compile both (default) """ from __future__ import annotations import gc import os import signal import sys import threading import time # Set CUDA env before torch import os.environ.setdefault("CUDA_HOME", "/usr/local/cuda") os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True") import torch import torch.nn as nn import torch.nn.functional as F # ------------------------------------------------------------------------- # Config # ------------------------------------------------------------------------- MAX_STEPS = 800 WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8")) SEQ_LEN = 2048 VOCAB_SIZE = 8192 # ------------------------------------------------------------------------- # Watchdog thread: kills process if no progress # ------------------------------------------------------------------------- _last_progress = time.time() _watchdog_armed = True def _watchdog_fn(): global _last_progress, _watchdog_armed while _watchdog_armed: time.sleep(1.0) elapsed = time.time() - _last_progress if elapsed > WATCHDOG_TIMEOUT_S: print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s — DEADLOCK DETECTED ***", flush=True) _dump_diagnostics() os.kill(os.getpid(), signal.SIGTERM) return def _dump_diagnostics(): """Dump CUDA/dynamo state at deadlock time.""" try: stats = torch.cuda.memory_stats() print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}") print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB") print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB") print(f" num_ooms: {stats.get('num_ooms', 0)}") except Exception as e: print(f" (memory_stats failed: {e})") try: import torch._dynamo.utils as du print(f" dynamo counters: {dict(du.counters)}") except Exception as e: print(f" (dynamo counters failed: {e})") def tick(): global _last_progress _last_progress = time.time() # ------------------------------------------------------------------------- # Test # ------------------------------------------------------------------------- def run_test(mode: str) -> dict: """Run forward+backward loop with specified compile config.""" print(f"\n{'='*70}") print(f"TEST MODE: {mode}") print(f"{'='*70}", flush=True) compile_model = mode in ("model_only", "both") compile_muon = mode in ("muon_only", "both") os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0" os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0" os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0" os.environ["HYDRA_HESTIA_INTERVAL"] = "9999" os.environ["HYDRA_HTM_LEARN_EVERY"] = "4" # Clear cached modules for fresh env var reads for mod_name in list(sys.modules.keys()): if mod_name.startswith("hydra."): del sys.modules[mod_name] torch._dynamo.reset() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats() gc.collect() from hydra.model import PostSemClawModel from hydra.config import PostSemClawConfig device = torch.device("cuda") config = PostSemClawConfig( d_model=256, n_layer=4, d_state=64, headdim=32, expand=2, vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN, ) with torch.device("meta"): model = PostSemClawModel(config) model.to_empty(device=device) model.init_weights() optimizer = model.setup_optimizer() autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16) result = {"mode": mode, "max_step": 0, "tps_samples": []} alloc_retries_prev = 0 tick() for step in range(MAX_STEPS): t0 = time.time() x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device) y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device) with autocast_ctx: loss = model(x, y) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) optimizer.step() model.zero_grad(set_to_none=True) torch.cuda.synchronize() dt = time.time() - t0 tps = int(BATCH_SIZE * SEQ_LEN / dt) tick() stats = torch.cuda.memory_stats() retries = stats.get("num_alloc_retries", 0) retry_delta = retries - alloc_retries_prev alloc_retries_prev = retries result["max_step"] = step if step % 50 == 0 or retry_delta > 0 or step < 3: alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6 print( f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms " f"alloc={alloc_mb:.0f}MB retries={retries}", flush=True, ) result["tps_samples"].append((step, tps)) result["completed"] = True print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True) return result def main(): print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}") print(f"GPU: {torch.cuda.get_device_name()}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s") wd = threading.Thread(target=_watchdog_fn, daemon=True) wd.start() modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"] results = [] for mode in modes: try: r = run_test(mode) except SystemExit: print(f"\n DEADLOCK/KILLED mode={mode}", flush=True) r = {"mode": mode, "completed": False, "max_step": "?"} except Exception as e: print(f"\n ERROR mode={mode}: {e}", flush=True) r = {"mode": mode, "completed": False, "error": str(e)} results.append(r) print(f"\n{'='*70}") print("SUMMARY") print(f"{'='*70}") for r in results: status = "PASS" if r.get("completed") else "FAIL" print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})") global _watchdog_armed _watchdog_armed = False if __name__ == "__main__": main()