File size: 6,277 Bytes
b5bff9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | #!/usr/bin/env python3
"""Isolate exactly what about D=8 causes NaN.
Tests:
1. D=2 eagle head β forward β should be OK
2. D=8 eagle head (random, no ckpt) β forward β is NaN from VRAM pressure?
3. D=8 eagle head (random, NOT assigned to engine) β forward β is NaN from registration?
4. D=8 allocated but eagle_enabled=False β forward β is NaN from .to() side effect?
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check(engine, tokenizer, label):
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
if has_nan:
print(f" [{label}] NaN DETECTED")
else:
top = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] OK β top={top} ('{tokenizer.decode([top])}')")
return has_nan
if __name__ == "__main__":
print("=" * 60)
print(" D=8 NaN Isolation")
print("=" * 60)
print("\n[1] Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.pack_all_experts()
engine.kv_cache.enable_flat_decode()
engine.eval()
# Warmup
print("\n[2] Warmup...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
vram_base = torch.cuda.memory_allocated() / 1e9
print(f" VRAM baseline: {vram_base:.2f} GB")
# Test 1: Baseline (no eagle)
print("\n[3] Baseline (no eagle)...")
check(engine, tokenizer, "baseline")
# Test 2: D=2 eagle head (should work)
print("\n[4] D=2 eagle head...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=2, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
check(engine, tokenizer, "D=2")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 3: D=8 eagle head (NO checkpoint, random init)
print("\n[5] D=8 eagle head (random init, no checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8) # no checkpoint_path
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_random = check(engine, tokenizer, "D=8 random")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 4: D=8 eagle head WITH checkpoint
print("\n[6] D=8 eagle head (with checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_ckpt = check(engine, tokenizer, "D=8 with ckpt")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 5: D=8 eagle head allocated but NOT registered as submodule
print("\n[7] D=8 eagle head (allocated, NOT registered on engine)...")
head_ext = FireEchoEagleHead(
dim=config.dim, num_capture_layers=3,
num_heads=16, ffn_mult=2, num_layers=8,
).to(dtype=torch.bfloat16, device='cuda')
# Do NOT assign to engine β keep as local variable
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_unreg = check(engine, tokenizer, "D=8 unregistered")
# Cleanup
del head_ext
engine._eagle_enabled = False
torch.cuda.empty_cache()
gc.collect()
# Test 6: D=4 eagle head (between D=2 and D=8)
print("\n[8] D=4 eagle head (checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=4, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d4 = check(engine, tokenizer, "D=4")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 7: D=8 but eagle_enabled=False (head exists but flag off)
print("\n[9] D=8 eagle head, but _eagle_enabled=False...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8, checkpoint_path=EAGLE_CKPT)
engine._eagle_enabled = False # disable the flag
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_flagoff = check(engine, tokenizer, "D=8 flag OFF")
# Summary
print(f"\n{'='*60}")
print(" RESULTS")
print(f"{'='*60}")
print(f" D=8 random: {'NaN' if nan_d8_random else 'OK'}")
print(f" D=8 with ckpt: {'NaN' if nan_d8_ckpt else 'OK'}")
print(f" D=8 unregistered: {'NaN' if nan_d8_unreg else 'OK'}")
print(f" D=4: {'NaN' if nan_d4 else 'OK'}")
print(f" D=8 flag OFF: {'NaN' if nan_d8_flagoff else 'OK'}")
|