| | |
| | """Trace speculative_generate step by step to find exactly where NaN appears.""" |
| | import sys, os, torch |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| | PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a function to check primes.<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| |
|
| | def check_nan(label, tensor): |
| | has_nan = tensor.isnan().any().item() |
| | has_inf = tensor.isinf().any().item() |
| | if has_nan or has_inf: |
| | print(f" *** {label}: NaN={has_nan} Inf={has_inf} shape={list(tensor.shape)}") |
| | |
| | if tensor.dim() == 3: |
| | for s in range(tensor.shape[1]): |
| | if tensor[0, s].isnan().any(): |
| | print(f" Position {s}: NaN!") |
| | return True |
| | else: |
| | top = tensor[:, -1, :].argmax(dim=-1).item() |
| | print(f" {label}: OK (top={top}) shape={list(tensor.shape)}") |
| | return False |
| |
|
| |
|
| | @torch.no_grad() |
| | def main(): |
| | print("=" * 60) |
| | print(" Speculative Generate NaN Trace") |
| | print("=" * 60) |
| |
|
| | |
| | print("\n[SETUP] Loading engine...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | engine.pack_all_experts() |
| |
|
| | |
| | engine.enable_eagle( |
| | capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | draft_depth=5, num_head_layers=8, checkpoint_path=EAGLE_CKPT) |
| |
|
| | |
| | print("\n[SETUP] Warmup...") |
| | wids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for _ in range(3): |
| | engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| | del wids |
| |
|
| | |
| | print("\n[TRACE] Starting manual speculation trace...") |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | prompt_len = ids.shape[1] |
| | print(f" Prompt length: {prompt_len}") |
| |
|
| | |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| |
|
| | print("\n[1] Prefill...") |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| | nan1 = check_nan("Prefill logits", logits) |
| | if nan1: |
| | print(" FATAL: NaN in prefill!") |
| | return |
| |
|
| | current_pos = prompt_len |
| | first_token = logits[:, -1:, :].argmax(dim=-1) |
| | print(f" First token: {first_token.item()} ('{tokenizer.decode([first_token.item()])}')") |
| |
|
| | |
| | print("\n[2] Process first token through main model...") |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| | logits = engine.forward(first_token, use_cache=True, position=current_pos) |
| | torch.cuda.synchronize() |
| | nan2 = check_nan("First-token logits", logits) |
| | if nan2: |
| | print(" FATAL: NaN at first token forward!") |
| | return |
| | current_pos += 1 |
| | main_pred = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" main_pred: {main_pred} ('{tokenizer.decode([main_pred])}')") |
| |
|
| | |
| | print("\n[3] Draft K=5 tokens...") |
| | features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers] |
| | for idx, f in enumerate(features): |
| | has_nan = f.isnan().any().item() |
| | print(f" Feature {idx} (layer {engine._eagle_capture_layers[idx]}): " |
| | f"shape={list(f.shape)}, NaN={has_nan}") |
| |
|
| | memory_ctx = engine._get_eagle_memory_context( |
| | engine._eagle_hidden_states[engine._eagle_capture_layers[-1]]) |
| |
|
| | draft_tokens, draft_logits = engine.eagle_head.generate_draft( |
| | features, first_token, engine.embed, depth=5, memory_context=memory_ctx) |
| |
|
| | print(f" Draft tokens: {[t.item() for t in draft_tokens]}") |
| | print(f" Draft decoded: {[tokenizer.decode([t.item()]) for t in draft_tokens]}") |
| | for i, dl in enumerate(draft_logits): |
| | has_nan = dl.isnan().any().item() |
| | if has_nan: |
| | print(f" *** Draft logits[{i}]: NaN!") |
| |
|
| | |
| | print("\n[4] Verify K=5 draft tokens through main model...") |
| | print(f" Verifying at position={current_pos} (prompt_len={prompt_len})") |
| | draft_input = torch.cat(draft_tokens, dim=1) |
| | print(f" draft_input shape: {list(draft_input.shape)}, tokens: {draft_input[0].tolist()}") |
| |
|
| | verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos) |
| | torch.cuda.synchronize() |
| | nan4 = check_nan("Verify logits", verify_logits) |
| |
|
| | if nan4: |
| | print("\n FOUND THE BUG: Verify forward (K>1 tokens at position>0) produces NaN!") |
| | print(" This is likely a causal mask or KV cache issue in multi-token decode.") |
| |
|
| | |
| | print("\n[4b] Trying verify ONE token at a time...") |
| | |
| | engine.kv_cache.rollback_to(current_pos, 5) |
| |
|
| | for i, dt in enumerate(draft_tokens): |
| | one_logit = engine.forward(dt, use_cache=True, position=current_pos + i) |
| | torch.cuda.synchronize() |
| | has_nan = one_logit.isnan().any().item() |
| | top = one_logit[:, -1, :].argmax(dim=-1).item() if not has_nan else -1 |
| | print(f" Token {i} at pos {current_pos + i}: NaN={has_nan} top={top}") |
| | if has_nan: |
| | print(f" SINGLE token verify also fails at position {current_pos + i}!") |
| | break |
| | else: |
| | print("\n Verify logits OK — checking acceptance logic...") |
| | if draft_tokens[0].item() == main_pred: |
| | print(f" Draft[0] matches main_pred ({main_pred}) ✓") |
| | else: |
| | print(f" Draft[0]={draft_tokens[0].item()} ≠ main_pred={main_pred} ✗") |
| |
|
| | for i in range(1, len(draft_tokens)): |
| | target_pred = verify_logits[:, i-1, :].argmax(dim=-1).item() |
| | match = "✓" if draft_tokens[i].item() == target_pred else "✗" |
| | print(f" verify[{i-1}]={target_pred} vs draft[{i}]={draft_tokens[i].item()} {match}") |
| |
|
| | |
| | print("\n[5] Control test: multi-token forward with KNOWN-GOOD tokens...") |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | |
| | test_tokens = torch.full((1, 5), first_token.item(), dtype=torch.long, device='cuda') |
| | test_logits = engine.forward(test_tokens, use_cache=True, position=prompt_len) |
| | torch.cuda.synchronize() |
| | nan5 = check_nan("Control multi-token logits", test_logits) |
| |
|
| | print("\n" + "=" * 60) |
| | print(" TRACE COMPLETE") |
| | print("=" * 60) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|