#!/usr/bin/env python3 """Isolate exactly what about D=8 causes NaN. Tests: 1. D=2 eagle head → forward → should be OK 2. D=8 eagle head (random, no ckpt) → forward → is NaN from VRAM pressure? 3. D=8 eagle head (random, NOT assigned to engine) → forward → is NaN from registration? 4. D=8 allocated but eagle_enabled=False → forward → is NaN from .to() side effect? """ import sys, os, torch, gc sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine from fireecho_kernel import FireEchoEagleHead MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" @torch.no_grad() def check(engine, tokenizer, label): ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() has_nan = logits.isnan().any().item() if has_nan: print(f" [{label}] NaN DETECTED") else: top = logits[:, -1, :].argmax(dim=-1).item() print(f" [{label}] OK — top={top} ('{tokenizer.decode([top])}')") return has_nan if __name__ == "__main__": print("=" * 60) print(" D=8 NaN Isolation") print("=" * 60) print("\n[1] Loading model...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.pack_all_experts() engine.kv_cache.enable_flat_decode() engine.eval() # Warmup print("\n[2] Warmup...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) vram_base = torch.cuda.memory_allocated() / 1e9 print(f" VRAM baseline: {vram_base:.2f} GB") # Test 1: Baseline (no eagle) print("\n[3] Baseline (no eagle)...") check(engine, tokenizer, "baseline") # Test 2: D=2 eagle head (should work) print("\n[4] D=2 eagle head...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=2, checkpoint_path=EAGLE_CKPT) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") check(engine, tokenizer, "D=2") # Cleanup del engine.eagle_head engine._eagle_enabled = False engine._eagle_hidden_states = {} torch.cuda.empty_cache() gc.collect() # Test 3: D=8 eagle head (NO checkpoint, random init) print("\n[5] D=8 eagle head (random init, no checkpoint)...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=8) # no checkpoint_path vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") nan_d8_random = check(engine, tokenizer, "D=8 random") # Cleanup del engine.eagle_head engine._eagle_enabled = False engine._eagle_hidden_states = {} torch.cuda.empty_cache() gc.collect() # Test 4: D=8 eagle head WITH checkpoint print("\n[6] D=8 eagle head (with checkpoint)...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=8, checkpoint_path=EAGLE_CKPT) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") nan_d8_ckpt = check(engine, tokenizer, "D=8 with ckpt") # Cleanup del engine.eagle_head engine._eagle_enabled = False engine._eagle_hidden_states = {} torch.cuda.empty_cache() gc.collect() # Test 5: D=8 eagle head allocated but NOT registered as submodule print("\n[7] D=8 eagle head (allocated, NOT registered on engine)...") head_ext = FireEchoEagleHead( dim=config.dim, num_capture_layers=3, num_heads=16, ffn_mult=2, num_layers=8, ).to(dtype=torch.bfloat16, device='cuda') # Do NOT assign to engine — keep as local variable engine._eagle_enabled = True engine._eagle_capture_set = {8, 24, 47} engine._eagle_capture_layers = [8, 24, 47] engine._eagle_hidden_states = {} vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") nan_d8_unreg = check(engine, tokenizer, "D=8 unregistered") # Cleanup del head_ext engine._eagle_enabled = False torch.cuda.empty_cache() gc.collect() # Test 6: D=4 eagle head (between D=2 and D=8) print("\n[8] D=4 eagle head (checkpoint)...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=4, checkpoint_path=EAGLE_CKPT) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") nan_d4 = check(engine, tokenizer, "D=4") # Cleanup del engine.eagle_head engine._eagle_enabled = False engine._eagle_hidden_states = {} torch.cuda.empty_cache() gc.collect() # Test 7: D=8 but eagle_enabled=False (head exists but flag off) print("\n[9] D=8 eagle head, but _eagle_enabled=False...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=8, checkpoint_path=EAGLE_CKPT) engine._eagle_enabled = False # disable the flag vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") nan_d8_flagoff = check(engine, tokenizer, "D=8 flag OFF") # Summary print(f"\n{'='*60}") print(" RESULTS") print(f"{'='*60}") print(f" D=8 random: {'NaN' if nan_d8_random else 'OK'}") print(f" D=8 with ckpt: {'NaN' if nan_d8_ckpt else 'OK'}") print(f" D=8 unregistered: {'NaN' if nan_d8_unreg else 'OK'}") print(f" D=4: {'NaN' if nan_d4 else 'OK'}") print(f" D=8 flag OFF: {'NaN' if nan_d8_flagoff else 'OK'}")