#!/usr/bin/env python3 """Isolate exactly which step of enable_eagle() causes NaN in target model. Tests each sub-step of enable_eagle() independently to find the culprit. Also checks per-layer output to find where NaN first appears. """ import sys, os, torch, gc sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine from fireecho_kernel import FireEchoEagleHead MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" @torch.no_grad() def check_forward(engine, tokenizer, label): """Run a forward pass and report NaN status.""" torch.cuda.synchronize() ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() has_nan = logits.isnan().any().item() last = logits[:, -1, :] if has_nan: print(f" [{label}] NaN DETECTED — logits all NaN") else: top_id = last.argmax(dim=-1).item() top_val = last.max().item() print(f" [{label}] OK — top token={top_id} " f"('{tokenizer.decode([top_id])}'), max={top_val:.2f}") return has_nan @torch.no_grad() def check_per_layer(engine, tokenizer, label): """Run forward pass manually layer-by-layer, check NaN at each layer.""" ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False x = engine.embed(ids) has_nan = x.isnan().any().item() print(f" [{label}] After embed: has_nan={has_nan}") if has_nan: return first_nan_layer = None for i, layer in enumerate(engine.layers): x = layer(x, engine.kv_cache, engine._current_seq_id, 0, True) has_nan = x.isnan().any().item() if has_nan and first_nan_layer is None: first_nan_layer = i print(f" [{label}] FIRST NaN at layer {i} !!!") # Check sub-components break if first_nan_layer is None: # Check norm + lm_head x = engine.norm(x) has_nan = x.isnan().any().item() print(f" [{label}] After norm: has_nan={has_nan}") logits = engine.lm_head(x) has_nan = logits.isnan().any().item() print(f" [{label}] After lm_head: has_nan={has_nan}") if not has_nan: top_id = logits[:, -1, :].argmax(dim=-1).item() print(f" [{label}] Top token: {top_id} ('{tokenizer.decode([top_id])}')") else: print(f" [{label}] NaN starts at layer {first_nan_layer}") if __name__ == "__main__": print("=" * 60) print(" NaN Isolation Test") print("=" * 60) print("\n[1/6] Loading model...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.pack_all_experts() engine.kv_cache.enable_flat_decode() engine.eval() # Check VRAM vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM after load: {vram:.2f} GB") print("\n[2/6] Warmup...") warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) print("\n[3/6] Test BEFORE enable_eagle()...") nan_before = check_forward(engine, tokenizer, "before eagle") if nan_before: print("\n ERROR: NaN even before enable_eagle! Something wrong with model load.") sys.exit(1) print("\n[4/6] Test: just set _eagle_enabled=True (no head creation)...") engine._eagle_enabled = True engine._eagle_capture_set = {8, 24, 47} engine._eagle_capture_layers = [8, 24, 47] engine._eagle_hidden_states = {} nan_flag_only = check_forward(engine, tokenizer, "flag only") engine._eagle_enabled = False # reset print("\n[5/6] Test: create eagle head + assign as submodule...") eagle_head = FireEchoEagleHead( dim=config.dim, num_capture_layers=3, num_heads=16, ffn_mult=2, num_layers=2, ).to(dtype=torch.bfloat16, device='cuda') eagle_head.lm_head = engine.lm_head engine.eagle_head = eagle_head # registers as nn.Module submodule vram2 = torch.cuda.memory_allocated() / 1e9 print(f" VRAM after eagle head: {vram2:.2f} GB (+{vram2 - vram:.2f} GB)") nan_with_head = check_forward(engine, tokenizer, "with head (no ckpt)") print("\n[6/6] Test: load checkpoint into eagle head...") if os.path.exists(EAGLE_CKPT): ckpt = torch.load(EAGLE_CKPT, map_location='cuda', weights_only=True) sd = ckpt.get('eagle_head', ckpt) is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) if is_legacy: eagle_head.load_legacy_checkpoint(sd) else: eagle_head.load_state_dict(sd, strict=False) nan_with_ckpt = check_forward(engine, tokenizer, "with ckpt") else: print(f" No checkpoint at {EAGLE_CKPT}, skipping") nan_with_ckpt = nan_with_head # Summary print(f"\n{'=' * 60}") print(" RESULTS") print(f"{'=' * 60}") print(f" Before eagle: {'NaN' if nan_before else 'OK'}") print(f" Flag only: {'NaN' if nan_flag_only else 'OK'}") print(f" With head (no ckpt): {'NaN' if nan_with_head else 'OK'}") print(f" With checkpoint: {'NaN' if nan_with_ckpt else 'OK'}") # If any NaN found, do per-layer analysis if nan_flag_only or nan_with_head or nan_with_ckpt: print(f"\n--- Per-layer NaN analysis ---") if nan_flag_only: engine._eagle_enabled = True engine._eagle_capture_set = {8, 24, 47} engine._eagle_capture_layers = [8, 24, 47] engine._eagle_hidden_states = {} check_per_layer(engine, tokenizer, "flag-only per-layer") elif nan_with_head or nan_with_ckpt: # eagle_head is still assigned engine._eagle_enabled = True engine._eagle_capture_set = {8, 24, 47} engine._eagle_capture_layers = [8, 24, 47] engine._eagle_hidden_states = {} check_per_layer(engine, tokenizer, "full-eagle per-layer") # Also test: head assigned but flag OFF print(f"\n--- Test: head assigned but _eagle_enabled=False ---") engine._eagle_enabled = False check_forward(engine, tokenizer, "head assigned, flag OFF") else: print(" All tests passed — no NaN detected!")