#!/usr/bin/env python3 """Trace speculative_generate step by step to find exactly where NaN appears.""" import sys, os, torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a function to check primes.<|im_end|>\n<|im_start|>assistant\n" def check_nan(label, tensor): has_nan = tensor.isnan().any().item() has_inf = tensor.isinf().any().item() if has_nan or has_inf: print(f" *** {label}: NaN={has_nan} Inf={has_inf} shape={list(tensor.shape)}") # Check which positions have NaN if tensor.dim() == 3: # [B, S, V] for s in range(tensor.shape[1]): if tensor[0, s].isnan().any(): print(f" Position {s}: NaN!") return True else: top = tensor[:, -1, :].argmax(dim=-1).item() print(f" {label}: OK (top={top}) shape={list(tensor.shape)}") return False @torch.no_grad() def main(): print("=" * 60) print(" Speculative Generate NaN Trace") print("=" * 60) # Load engine exactly like training print("\n[SETUP] Loading engine...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.eval() engine.kv_cache.enable_flat_decode(4096) engine.pack_all_experts() # Enable EAGLE D=8 engine.enable_eagle( capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, draft_depth=5, num_head_layers=8, checkpoint_path=EAGLE_CKPT) # Warmup print("\n[SETUP] Warmup...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) del wids # Now replicate speculative_generate manually print("\n[TRACE] Starting manual speculation trace...") ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() prompt_len = ids.shape[1] print(f" Prompt length: {prompt_len}") # Step 1: Reset + prefill engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False print("\n[1] Prefill...") logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() nan1 = check_nan("Prefill logits", logits) if nan1: print(" FATAL: NaN in prefill!") return current_pos = prompt_len first_token = logits[:, -1:, :].argmax(dim=-1) print(f" First token: {first_token.item()} ('{tokenizer.decode([first_token.item()])}')") # Step 2: Process first token through main model print("\n[2] Process first token through main model...") if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(first_token, use_cache=True, position=current_pos) torch.cuda.synchronize() nan2 = check_nan("First-token logits", logits) if nan2: print(" FATAL: NaN at first token forward!") return current_pos += 1 main_pred = logits[:, -1, :].argmax(dim=-1).item() print(f" main_pred: {main_pred} ('{tokenizer.decode([main_pred])}')") # Step 3: Draft K tokens using EAGLE print("\n[3] Draft K=5 tokens...") features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers] for idx, f in enumerate(features): has_nan = f.isnan().any().item() print(f" Feature {idx} (layer {engine._eagle_capture_layers[idx]}): " f"shape={list(f.shape)}, NaN={has_nan}") memory_ctx = engine._get_eagle_memory_context( engine._eagle_hidden_states[engine._eagle_capture_layers[-1]]) draft_tokens, draft_logits = engine.eagle_head.generate_draft( features, first_token, engine.embed, depth=5, memory_context=memory_ctx) print(f" Draft tokens: {[t.item() for t in draft_tokens]}") print(f" Draft decoded: {[tokenizer.decode([t.item()]) for t in draft_tokens]}") for i, dl in enumerate(draft_logits): has_nan = dl.isnan().any().item() if has_nan: print(f" *** Draft logits[{i}]: NaN!") # Step 4: Verify draft tokens through main model (this is the suspicious step) print("\n[4] Verify K=5 draft tokens through main model...") print(f" Verifying at position={current_pos} (prompt_len={prompt_len})") draft_input = torch.cat(draft_tokens, dim=1) print(f" draft_input shape: {list(draft_input.shape)}, tokens: {draft_input[0].tolist()}") verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos) torch.cuda.synchronize() nan4 = check_nan("Verify logits", verify_logits) if nan4: print("\n FOUND THE BUG: Verify forward (K>1 tokens at position>0) produces NaN!") print(" This is likely a causal mask or KV cache issue in multi-token decode.") # Additional test: verify ONE draft token at a time print("\n[4b] Trying verify ONE token at a time...") # Rollback the K tokens we just stored engine.kv_cache.rollback_to(current_pos, 5) for i, dt in enumerate(draft_tokens): one_logit = engine.forward(dt, use_cache=True, position=current_pos + i) torch.cuda.synchronize() has_nan = one_logit.isnan().any().item() top = one_logit[:, -1, :].argmax(dim=-1).item() if not has_nan else -1 print(f" Token {i} at pos {current_pos + i}: NaN={has_nan} top={top}") if has_nan: print(f" SINGLE token verify also fails at position {current_pos + i}!") break else: print("\n Verify logits OK — checking acceptance logic...") if draft_tokens[0].item() == main_pred: print(f" Draft[0] matches main_pred ({main_pred}) ✓") else: print(f" Draft[0]={draft_tokens[0].item()} ≠ main_pred={main_pred} ✗") for i in range(1, len(draft_tokens)): target_pred = verify_logits[:, i-1, :].argmax(dim=-1).item() match = "✓" if draft_tokens[i].item() == target_pred else "✗" print(f" verify[{i-1}]={target_pred} vs draft[{i}]={draft_tokens[i].item()} {match}") # Step 5: Also test a multi-token forward with RANDOM tokens at position>0 print("\n[5] Control test: multi-token forward with KNOWN-GOOD tokens...") engine.reset_cache() engine._current_seq_id = 0 # Prefill logits = engine.forward(ids, use_cache=True, position=0) # Now try 5 copies of a valid token at position=prompt_len test_tokens = torch.full((1, 5), first_token.item(), dtype=torch.long, device='cuda') test_logits = engine.forward(test_tokens, use_cache=True, position=prompt_len) torch.cuda.synchronize() nan5 = check_nan("Control multi-token logits", test_logits) print("\n" + "=" * 60) print(" TRACE COMPLETE") print("=" * 60) if __name__ == "__main__": main()