File size: 6,943 Bytes
c2bf4b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
"""Diagnostic script for torch.compile deadlock after ~500 steps.

F17 investigation: validates that the _compiled_core / forward split
fixes the deadlock by running forward+backward loops with compile on.

Usage:
    LD_LIBRARY_PATH=/usr/lib/wsl/lib:/usr/local/cuda/lib64 \
      HYDRA_TIME_BUDGET=30 HYDRA_BATCH_SIZE=8 HYDRA_TOTAL_BATCH=16384 \
      HYDRA_HTM_LEARN_EVERY=4 HYDRA_HESTIA_INTERVAL=9999 \
      .venv/bin/python -u scripts/compile_debug.py [mode]

Modes:
    eager       - no compile (baseline)
    model_only  - compile model _compiled_core only
    muon_only   - compile muon step only
    both        - compile both (default)
"""

from __future__ import annotations

import gc
import os
import signal
import sys
import threading
import time

# Set CUDA env before torch import
os.environ.setdefault("CUDA_HOME", "/usr/local/cuda")
os.environ.setdefault("PYTORCH_ALLOC_CONF", "expandable_segments:True")

import torch
import torch.nn as nn
import torch.nn.functional as F

# -------------------------------------------------------------------------
# Config
# -------------------------------------------------------------------------
MAX_STEPS = 800
WATCHDOG_TIMEOUT_S = 20  # kill if no progress for this many seconds
BATCH_SIZE = int(os.environ.get("HYDRA_BATCH_SIZE", "8"))
SEQ_LEN = 2048
VOCAB_SIZE = 8192


# -------------------------------------------------------------------------
# Watchdog thread: kills process if no progress
# -------------------------------------------------------------------------
_last_progress = time.time()
_watchdog_armed = True

def _watchdog_fn():
    global _last_progress, _watchdog_armed
    while _watchdog_armed:
        time.sleep(1.0)
        elapsed = time.time() - _last_progress
        if elapsed > WATCHDOG_TIMEOUT_S:
            print(f"\n*** WATCHDOG: no progress for {elapsed:.1f}s — DEADLOCK DETECTED ***",
                  flush=True)
            _dump_diagnostics()
            os.kill(os.getpid(), signal.SIGTERM)
            return

def _dump_diagnostics():
    """Dump CUDA/dynamo state at deadlock time."""
    try:
        stats = torch.cuda.memory_stats()
        print(f"  alloc_retries:     {stats.get('num_alloc_retries', 'N/A')}")
        print(f"  allocated_bytes:   {stats.get('allocated_bytes.all.current', 0) / 1e6:.1f} MB")
        print(f"  reserved_bytes:    {stats.get('reserved_bytes.all.current', 0) / 1e6:.1f} MB")
        print(f"  num_ooms:          {stats.get('num_ooms', 0)}")
    except Exception as e:
        print(f"  (memory_stats failed: {e})")

    try:
        import torch._dynamo.utils as du
        print(f"  dynamo counters:   {dict(du.counters)}")
    except Exception as e:
        print(f"  (dynamo counters failed: {e})")


def tick():
    global _last_progress
    _last_progress = time.time()


# -------------------------------------------------------------------------
# Test
# -------------------------------------------------------------------------
def run_test(mode: str) -> dict:
    """Run forward+backward loop with specified compile config."""
    print(f"\n{'='*70}")
    print(f"TEST MODE: {mode}")
    print(f"{'='*70}", flush=True)

    compile_model = mode in ("model_only", "both")
    compile_muon = mode in ("muon_only", "both")

    os.environ["HYDRA_MODEL_COMPILE"] = "1" if compile_model else "0"
    os.environ["HYDRA_MUON_COMPILE"] = "1" if compile_muon else "0"
    os.environ["HYDRA_ASYNC_POSTPROCESS"] = "0"
    os.environ["HYDRA_HESTIA_INTERVAL"] = "9999"
    os.environ["HYDRA_HTM_LEARN_EVERY"] = "4"

    # Clear cached modules for fresh env var reads
    for mod_name in list(sys.modules.keys()):
        if mod_name.startswith("hydra."):
            del sys.modules[mod_name]

    torch._dynamo.reset()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    gc.collect()

    from hydra.model import PostSemClawModel
    from hydra.config import PostSemClawConfig

    device = torch.device("cuda")
    config = PostSemClawConfig(
        d_model=256, n_layer=4, d_state=64, headdim=32, expand=2,
        vocab_size=VOCAB_SIZE, sequence_len=SEQ_LEN,
    )

    with torch.device("meta"):
        model = PostSemClawModel(config)
    model.to_empty(device=device)
    model.init_weights()

    optimizer = model.setup_optimizer()
    autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)

    result = {"mode": mode, "max_step": 0, "tps_samples": []}
    alloc_retries_prev = 0

    tick()

    for step in range(MAX_STEPS):
        t0 = time.time()

        x = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)
        y = torch.randint(0, VOCAB_SIZE, (BATCH_SIZE, SEQ_LEN), device=device)

        with autocast_ctx:
            loss = model(x, y)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        model.zero_grad(set_to_none=True)

        torch.cuda.synchronize()
        dt = time.time() - t0
        tps = int(BATCH_SIZE * SEQ_LEN / dt)

        tick()

        stats = torch.cuda.memory_stats()
        retries = stats.get("num_alloc_retries", 0)
        retry_delta = retries - alloc_retries_prev
        alloc_retries_prev = retries

        result["max_step"] = step

        if step % 50 == 0 or retry_delta > 0 or step < 3:
            alloc_mb = stats.get("allocated_bytes.all.current", 0) / 1e6
            print(
                f"  step={step:04d} tps={tps:6d} dt={dt*1000:.0f}ms "
                f"alloc={alloc_mb:.0f}MB retries={retries}",
                flush=True,
            )
            result["tps_samples"].append((step, tps))

    result["completed"] = True
    print(f"\n  COMPLETED: {MAX_STEPS} steps, mode={mode}", flush=True)
    return result


def main():
    print(f"torch: {torch.__version__}  CUDA: {torch.version.cuda}")
    print(f"GPU:   {torch.cuda.get_device_name()}")
    print(f"VRAM:  {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
    print(f"Steps: {MAX_STEPS}  Watchdog: {WATCHDOG_TIMEOUT_S}s")

    wd = threading.Thread(target=_watchdog_fn, daemon=True)
    wd.start()

    modes = sys.argv[1:] if len(sys.argv) > 1 else ["both"]
    results = []

    for mode in modes:
        try:
            r = run_test(mode)
        except SystemExit:
            print(f"\n  DEADLOCK/KILLED mode={mode}", flush=True)
            r = {"mode": mode, "completed": False, "max_step": "?"}
        except Exception as e:
            print(f"\n  ERROR mode={mode}: {e}", flush=True)
            r = {"mode": mode, "completed": False, "error": str(e)}
        results.append(r)

    print(f"\n{'='*70}")
    print("SUMMARY")
    print(f"{'='*70}")
    for r in results:
        status = "PASS" if r.get("completed") else "FAIL"
        print(f"  {r['mode']:20s}: {status} (step {r.get('max_step', '?')})")

    global _watchdog_armed
    _watchdog_armed = False


if __name__ == "__main__":
    main()