| | |
| | """Isolate exactly which step of enable_eagle() causes NaN in target model. |
| | |
| | Tests each sub-step of enable_eagle() independently to find the culprit. |
| | Also checks per-layer output to find where NaN first appears. |
| | """ |
| | import sys, os, torch, gc |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| | from fireecho_kernel import FireEchoEagleHead |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| |
|
| | PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| |
|
| | @torch.no_grad() |
| | def check_forward(engine, tokenizer, label): |
| | """Run a forward pass and report NaN status.""" |
| | torch.cuda.synchronize() |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| |
|
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| |
|
| | has_nan = logits.isnan().any().item() |
| | last = logits[:, -1, :] |
| | if has_nan: |
| | print(f" [{label}] NaN DETECTED — logits all NaN") |
| | else: |
| | top_id = last.argmax(dim=-1).item() |
| | top_val = last.max().item() |
| | print(f" [{label}] OK — top token={top_id} " |
| | f"('{tokenizer.decode([top_id])}'), max={top_val:.2f}") |
| | return has_nan |
| |
|
| |
|
| | @torch.no_grad() |
| | def check_per_layer(engine, tokenizer, label): |
| | """Run forward pass manually layer-by-layer, check NaN at each layer.""" |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| |
|
| | x = engine.embed(ids) |
| | has_nan = x.isnan().any().item() |
| | print(f" [{label}] After embed: has_nan={has_nan}") |
| | if has_nan: |
| | return |
| |
|
| | first_nan_layer = None |
| | for i, layer in enumerate(engine.layers): |
| | x = layer(x, engine.kv_cache, engine._current_seq_id, 0, True) |
| | has_nan = x.isnan().any().item() |
| | if has_nan and first_nan_layer is None: |
| | first_nan_layer = i |
| | print(f" [{label}] FIRST NaN at layer {i} !!!") |
| | |
| | break |
| |
|
| | if first_nan_layer is None: |
| | |
| | x = engine.norm(x) |
| | has_nan = x.isnan().any().item() |
| | print(f" [{label}] After norm: has_nan={has_nan}") |
| | logits = engine.lm_head(x) |
| | has_nan = logits.isnan().any().item() |
| | print(f" [{label}] After lm_head: has_nan={has_nan}") |
| | if not has_nan: |
| | top_id = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" [{label}] Top token: {top_id} ('{tokenizer.decode([top_id])}')") |
| | else: |
| | print(f" [{label}] NaN starts at layer {first_nan_layer}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" NaN Isolation Test") |
| | print("=" * 60) |
| |
|
| | print("\n[1/6] Loading model...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.pack_all_experts() |
| | engine.kv_cache.enable_flat_decode() |
| | engine.eval() |
| |
|
| | |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM after load: {vram:.2f} GB") |
| |
|
| | print("\n[2/6] Warmup...") |
| | warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for _ in range(3): |
| | engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| |
|
| | print("\n[3/6] Test BEFORE enable_eagle()...") |
| | nan_before = check_forward(engine, tokenizer, "before eagle") |
| |
|
| | if nan_before: |
| | print("\n ERROR: NaN even before enable_eagle! Something wrong with model load.") |
| | sys.exit(1) |
| |
|
| | print("\n[4/6] Test: just set _eagle_enabled=True (no head creation)...") |
| | engine._eagle_enabled = True |
| | engine._eagle_capture_set = {8, 24, 47} |
| | engine._eagle_capture_layers = [8, 24, 47] |
| | engine._eagle_hidden_states = {} |
| | nan_flag_only = check_forward(engine, tokenizer, "flag only") |
| | engine._eagle_enabled = False |
| |
|
| | print("\n[5/6] Test: create eagle head + assign as submodule...") |
| | eagle_head = FireEchoEagleHead( |
| | dim=config.dim, num_capture_layers=3, |
| | num_heads=16, ffn_mult=2, num_layers=2, |
| | ).to(dtype=torch.bfloat16, device='cuda') |
| | eagle_head.lm_head = engine.lm_head |
| | engine.eagle_head = eagle_head |
| | vram2 = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM after eagle head: {vram2:.2f} GB (+{vram2 - vram:.2f} GB)") |
| | nan_with_head = check_forward(engine, tokenizer, "with head (no ckpt)") |
| |
|
| | print("\n[6/6] Test: load checkpoint into eagle head...") |
| | if os.path.exists(EAGLE_CKPT): |
| | ckpt = torch.load(EAGLE_CKPT, map_location='cuda', weights_only=True) |
| | sd = ckpt.get('eagle_head', ckpt) |
| | is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) |
| | if is_legacy: |
| | eagle_head.load_legacy_checkpoint(sd) |
| | else: |
| | eagle_head.load_state_dict(sd, strict=False) |
| | nan_with_ckpt = check_forward(engine, tokenizer, "with ckpt") |
| | else: |
| | print(f" No checkpoint at {EAGLE_CKPT}, skipping") |
| | nan_with_ckpt = nan_with_head |
| |
|
| | |
| | print(f"\n{'=' * 60}") |
| | print(" RESULTS") |
| | print(f"{'=' * 60}") |
| | print(f" Before eagle: {'NaN' if nan_before else 'OK'}") |
| | print(f" Flag only: {'NaN' if nan_flag_only else 'OK'}") |
| | print(f" With head (no ckpt): {'NaN' if nan_with_head else 'OK'}") |
| | print(f" With checkpoint: {'NaN' if nan_with_ckpt else 'OK'}") |
| |
|
| | |
| | if nan_flag_only or nan_with_head or nan_with_ckpt: |
| | print(f"\n--- Per-layer NaN analysis ---") |
| | if nan_flag_only: |
| | engine._eagle_enabled = True |
| | engine._eagle_capture_set = {8, 24, 47} |
| | engine._eagle_capture_layers = [8, 24, 47] |
| | engine._eagle_hidden_states = {} |
| | check_per_layer(engine, tokenizer, "flag-only per-layer") |
| | elif nan_with_head or nan_with_ckpt: |
| | |
| | engine._eagle_enabled = True |
| | engine._eagle_capture_set = {8, 24, 47} |
| | engine._eagle_capture_layers = [8, 24, 47] |
| | engine._eagle_hidden_states = {} |
| | check_per_layer(engine, tokenizer, "full-eagle per-layer") |
| |
|
| | |
| | print(f"\n--- Test: head assigned but _eagle_enabled=False ---") |
| | engine._eagle_enabled = False |
| | check_forward(engine, tokenizer, "head assigned, flag OFF") |
| | else: |
| | print(" All tests passed — no NaN detected!") |
| |
|