FireEcho / FireEcho Engine /debug_d8_isolate.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Isolate exactly what about D=8 causes NaN.
Tests:
1. D=2 eagle head → forward → should be OK
2. D=8 eagle head (random, no ckpt) → forward → is NaN from VRAM pressure?
3. D=8 eagle head (random, NOT assigned to engine) → forward → is NaN from registration?
4. D=8 allocated but eagle_enabled=False → forward → is NaN from .to() side effect?
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check(engine, tokenizer, label):
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
if has_nan:
print(f" [{label}] NaN DETECTED")
else:
top = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] OK — top={top} ('{tokenizer.decode([top])}')")
return has_nan
if __name__ == "__main__":
print("=" * 60)
print(" D=8 NaN Isolation")
print("=" * 60)
print("\n[1] Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.pack_all_experts()
engine.kv_cache.enable_flat_decode()
engine.eval()
# Warmup
print("\n[2] Warmup...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
vram_base = torch.cuda.memory_allocated() / 1e9
print(f" VRAM baseline: {vram_base:.2f} GB")
# Test 1: Baseline (no eagle)
print("\n[3] Baseline (no eagle)...")
check(engine, tokenizer, "baseline")
# Test 2: D=2 eagle head (should work)
print("\n[4] D=2 eagle head...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=2, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
check(engine, tokenizer, "D=2")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 3: D=8 eagle head (NO checkpoint, random init)
print("\n[5] D=8 eagle head (random init, no checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8) # no checkpoint_path
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_random = check(engine, tokenizer, "D=8 random")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 4: D=8 eagle head WITH checkpoint
print("\n[6] D=8 eagle head (with checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_ckpt = check(engine, tokenizer, "D=8 with ckpt")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 5: D=8 eagle head allocated but NOT registered as submodule
print("\n[7] D=8 eagle head (allocated, NOT registered on engine)...")
head_ext = FireEchoEagleHead(
dim=config.dim, num_capture_layers=3,
num_heads=16, ffn_mult=2, num_layers=8,
).to(dtype=torch.bfloat16, device='cuda')
# Do NOT assign to engine — keep as local variable
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_unreg = check(engine, tokenizer, "D=8 unregistered")
# Cleanup
del head_ext
engine._eagle_enabled = False
torch.cuda.empty_cache()
gc.collect()
# Test 6: D=4 eagle head (between D=2 and D=8)
print("\n[8] D=4 eagle head (checkpoint)...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=4, checkpoint_path=EAGLE_CKPT)
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d4 = check(engine, tokenizer, "D=4")
# Cleanup
del engine.eagle_head
engine._eagle_enabled = False
engine._eagle_hidden_states = {}
torch.cuda.empty_cache()
gc.collect()
# Test 7: D=8 but eagle_enabled=False (head exists but flag off)
print("\n[9] D=8 eagle head, but _eagle_enabled=False...")
engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
num_head_layers=8, checkpoint_path=EAGLE_CKPT)
engine._eagle_enabled = False # disable the flag
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
nan_d8_flagoff = check(engine, tokenizer, "D=8 flag OFF")
# Summary
print(f"\n{'='*60}")
print(" RESULTS")
print(f"{'='*60}")
print(f" D=8 random: {'NaN' if nan_d8_random else 'OK'}")
print(f" D=8 with ckpt: {'NaN' if nan_d8_ckpt else 'OK'}")
print(f" D=8 unregistered: {'NaN' if nan_d8_unreg else 'OK'}")
print(f" D=4: {'NaN' if nan_d4 else 'OK'}")
print(f" D=8 flag OFF: {'NaN' if nan_d8_flagoff else 'OK'}")