FireEcho / FireEcho Engine /debug_nan_isolate.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Isolate exactly which step of enable_eagle() causes NaN in target model.
Tests each sub-step of enable_eagle() independently to find the culprit.
Also checks per-layer output to find where NaN first appears.
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check_forward(engine, tokenizer, label):
"""Run a forward pass and report NaN status."""
torch.cuda.synchronize()
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
last = logits[:, -1, :]
if has_nan:
print(f" [{label}] NaN DETECTED — logits all NaN")
else:
top_id = last.argmax(dim=-1).item()
top_val = last.max().item()
print(f" [{label}] OK — top token={top_id} "
f"('{tokenizer.decode([top_id])}'), max={top_val:.2f}")
return has_nan
@torch.no_grad()
def check_per_layer(engine, tokenizer, label):
"""Run forward pass manually layer-by-layer, check NaN at each layer."""
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
x = engine.embed(ids)
has_nan = x.isnan().any().item()
print(f" [{label}] After embed: has_nan={has_nan}")
if has_nan:
return
first_nan_layer = None
for i, layer in enumerate(engine.layers):
x = layer(x, engine.kv_cache, engine._current_seq_id, 0, True)
has_nan = x.isnan().any().item()
if has_nan and first_nan_layer is None:
first_nan_layer = i
print(f" [{label}] FIRST NaN at layer {i} !!!")
# Check sub-components
break
if first_nan_layer is None:
# Check norm + lm_head
x = engine.norm(x)
has_nan = x.isnan().any().item()
print(f" [{label}] After norm: has_nan={has_nan}")
logits = engine.lm_head(x)
has_nan = logits.isnan().any().item()
print(f" [{label}] After lm_head: has_nan={has_nan}")
if not has_nan:
top_id = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] Top token: {top_id} ('{tokenizer.decode([top_id])}')")
else:
print(f" [{label}] NaN starts at layer {first_nan_layer}")
if __name__ == "__main__":
print("=" * 60)
print(" NaN Isolation Test")
print("=" * 60)
print("\n[1/6] Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.pack_all_experts()
engine.kv_cache.enable_flat_decode()
engine.eval()
# Check VRAM
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after load: {vram:.2f} GB")
print("\n[2/6] Warmup...")
warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
print("\n[3/6] Test BEFORE enable_eagle()...")
nan_before = check_forward(engine, tokenizer, "before eagle")
if nan_before:
print("\n ERROR: NaN even before enable_eagle! Something wrong with model load.")
sys.exit(1)
print("\n[4/6] Test: just set _eagle_enabled=True (no head creation)...")
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
nan_flag_only = check_forward(engine, tokenizer, "flag only")
engine._eagle_enabled = False # reset
print("\n[5/6] Test: create eagle head + assign as submodule...")
eagle_head = FireEchoEagleHead(
dim=config.dim, num_capture_layers=3,
num_heads=16, ffn_mult=2, num_layers=2,
).to(dtype=torch.bfloat16, device='cuda')
eagle_head.lm_head = engine.lm_head
engine.eagle_head = eagle_head # registers as nn.Module submodule
vram2 = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after eagle head: {vram2:.2f} GB (+{vram2 - vram:.2f} GB)")
nan_with_head = check_forward(engine, tokenizer, "with head (no ckpt)")
print("\n[6/6] Test: load checkpoint into eagle head...")
if os.path.exists(EAGLE_CKPT):
ckpt = torch.load(EAGLE_CKPT, map_location='cuda', weights_only=True)
sd = ckpt.get('eagle_head', ckpt)
is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
if is_legacy:
eagle_head.load_legacy_checkpoint(sd)
else:
eagle_head.load_state_dict(sd, strict=False)
nan_with_ckpt = check_forward(engine, tokenizer, "with ckpt")
else:
print(f" No checkpoint at {EAGLE_CKPT}, skipping")
nan_with_ckpt = nan_with_head
# Summary
print(f"\n{'=' * 60}")
print(" RESULTS")
print(f"{'=' * 60}")
print(f" Before eagle: {'NaN' if nan_before else 'OK'}")
print(f" Flag only: {'NaN' if nan_flag_only else 'OK'}")
print(f" With head (no ckpt): {'NaN' if nan_with_head else 'OK'}")
print(f" With checkpoint: {'NaN' if nan_with_ckpt else 'OK'}")
# If any NaN found, do per-layer analysis
if nan_flag_only or nan_with_head or nan_with_ckpt:
print(f"\n--- Per-layer NaN analysis ---")
if nan_flag_only:
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
check_per_layer(engine, tokenizer, "flag-only per-layer")
elif nan_with_head or nan_with_ckpt:
# eagle_head is still assigned
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
check_per_layer(engine, tokenizer, "full-eagle per-layer")
# Also test: head assigned but flag OFF
print(f"\n--- Test: head assigned but _eagle_enabled=False ---")
engine._eagle_enabled = False
check_forward(engine, tokenizer, "head assigned, flag OFF")
else:
print(" All tests passed — no NaN detected!")