Spaces:
Runtime error
Runtime error
File size: 6,943 Bytes
e5cf7c3 3319b2a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | from __future__ import annotations
"""Diagnostic script for torch.compile deadlock after ~500 steps.
F17 investigation: validates that the _compiled_core / forward split
fixes the deadlock by running forward+backward loops with compile on.
Usage:
LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
.venv/bin/python -u scripts/compile_debug.py [mode]
Modes:
eager - no compile (baseline)
model_only - compile model _compiled_core only
muon_only - compile muon step only
both - compile both (default)
"""
import gc
import os
import signal
import sys
import threading
import time
# Set CUDA env before torch import
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")
import torch
import torch.nn as nn
import torch.nn.functional as F
# -------------------------------------------------------------------------
# Config
# -------------------------------------------------------------------------
MAX_STEPS = 800
WATCHDOG_TIMEOUT_S = 20 # kill if no progress for this many seconds
BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
SEQ_LEN = 2048
VOCAB_SIZE = 8192
# -------------------------------------------------------------------------
# Watchdog thread: kills process if no progress
# -------------------------------------------------------------------------
_last_progress = time.time()
_watchdog_armed = True
def _watchdog_fn():
global _last_progress, _watchdog_armed
while _watchdog_armed:
time.sleep(1.0)
elapsed = time.time() - _last_progress
if elapsed > WATCHDOG_TIMEOUT_S:
print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s — DEADLOCK DETECTED ***",
flush=True)
_dump_diagnostics()
os.kill(os.getpid(), signal.SIGTERM)
return
def _dump_diagnostics():
"""Dump CUDA/dynamo state at deadlock time."""
try:
stats = torch.cuda.memory_stats()
print(f" alloc_retries: {stats.get('num_alloc_retries', 'N/A')}")
print(f" allocated_bytes: {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
print(f" reserved_bytes: {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
print(f" num_ooms: {stats.get('num_ooms', 0)}")
except Exception as e:
print(f" (memory_stats failed: {e})")
try:
import torch._dynamo.utils as du
print(f" dynamo counters: {dict(du.counters)}")
except Exception as e:
print(f" (dynamo counters failed: {e})")
def tick():
global _last_progress
_last_progress = time.time()
# -------------------------------------------------------------------------
# Test
# -------------------------------------------------------------------------
def run_test(mode: str) -> dict:
"""Run forward+backward loop with specified compile config."""
print(f"\n{'='*70}")
print(f"TEST MODE: {mode}")
print(f"{'='*70}", flush=True)
compile_model = mode in ("model_only", "both")
compile_muon = mode in ("muon_only", "both")
os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"
# Clear cached modules for fresh env var reads
for mod_name in list(sys.modules.keys()):
if mod_name.startswith("hydra."):
del sys.modules[mod_name]
torch._dynamo.reset()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
gc.collect()
from hydra.model import PostSemClawModel
from hydra.config import PostSemClawConfig
device = torch.device("cuda")
config = PostSemClawConfig(
d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
)
with torch.device("meta"):
model = PostSemClawModel(config)
model.to_empty(device=device)
model.init_weights()
optimizer = model.setup_optimizer()
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
result = {"mode": mode, "max_step": 0, "tps_samples": []}
alloc_retries_prev = 0
tick()
for step in range(MAX_STEPS):
t0 = time.time()
x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
with autocast_ctx:
loss = model(x, y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
model.zero_grad(set_to_none=True)
torch.cuda.synchronize()
dt = time.time() - t0
tps = int(BATCH_SIZE * SEQ_LEN / dt)
tick()
stats = torch.cuda.memory_stats()
retries = stats.get("num_alloc_retries", 0)
retry_delta = retries - alloc_retries_prev
alloc_retries_prev = retries
result["max_step"] = step
if step % 50 == 0 or retry_delta > 0 or step < 3:
alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
print(
f" step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
f"alloc={alloc_mb:.0f}MB retries={retries}",
flush=True,
)
result["tps_samples"].append((step, tps))
result["completed"] = True
print(f"\n COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
return result
def main():
print(f"torch: {torch.__version__} CUDA: {torch.version.cuda}")
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
print(f"Steps: {MAX_STEPS} Watchdog: {WATCHDOG_TIMEOUT_S}s")
wd = threading.Thread(target=_watchdog_fn, daemon=True)
wd.start()
modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
results = []
for mode in modes:
try:
r = run_test(mode)
except SystemExit:
print(f"\n DEADLOCK/KILLED mode={mode}", flush=True)
r = {"mode": mode, "completed": False, "max_step": "?"}
except Exception as e:
print(f"\n ERROR mode={mode}: {e}", flush=True)
r = {"mode": mode, "completed": False, "error": str(e)}
results.append(r)
print(f"\n{'='*70}")
print("SUMMARY")
print(f"{'='*70}")
for r in results:
status = "PASS" if r.get("completed") else "FAIL"
print(f" {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")
global _watchdog_armed
_watchdog_armed = False
if __name__ == "__main__":
main()
|