File size: 7,239 Bytes
b5bff9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | #!/usr/bin/env python3
"""Trace speculative_generate step by step to find exactly where NaN appears."""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a function to check primes.<|im_end|>\n<|im_start|>assistant\n"
def check_nan(label, tensor):
has_nan = tensor.isnan().any().item()
has_inf = tensor.isinf().any().item()
if has_nan or has_inf:
print(f" *** {label}: NaN={has_nan} Inf={has_inf} shape={list(tensor.shape)}")
# Check which positions have NaN
if tensor.dim() == 3: # [B, S, V]
for s in range(tensor.shape[1]):
if tensor[0, s].isnan().any():
print(f" Position {s}: NaN!")
return True
else:
top = tensor[:, -1, :].argmax(dim=-1).item()
print(f" {label}: OK (top={top}) shape={list(tensor.shape)}")
return False
@torch.no_grad()
def main():
print("=" * 60)
print(" Speculative Generate NaN Trace")
print("=" * 60)
# Load engine exactly like training
print("\n[SETUP] Loading engine...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
# Enable EAGLE D=8
engine.enable_eagle(
capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
draft_depth=5, num_head_layers=8, checkpoint_path=EAGLE_CKPT)
# Warmup
print("\n[SETUP] Warmup...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
del wids
# Now replicate speculative_generate manually
print("\n[TRACE] Starting manual speculation trace...")
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
prompt_len = ids.shape[1]
print(f" Prompt length: {prompt_len}")
# Step 1: Reset + prefill
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
print("\n[1] Prefill...")
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
nan1 = check_nan("Prefill logits", logits)
if nan1:
print(" FATAL: NaN in prefill!")
return
current_pos = prompt_len
first_token = logits[:, -1:, :].argmax(dim=-1)
print(f" First token: {first_token.item()} ('{tokenizer.decode([first_token.item()])}')")
# Step 2: Process first token through main model
print("\n[2] Process first token through main model...")
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(first_token, use_cache=True, position=current_pos)
torch.cuda.synchronize()
nan2 = check_nan("First-token logits", logits)
if nan2:
print(" FATAL: NaN at first token forward!")
return
current_pos += 1
main_pred = logits[:, -1, :].argmax(dim=-1).item()
print(f" main_pred: {main_pred} ('{tokenizer.decode([main_pred])}')")
# Step 3: Draft K tokens using EAGLE
print("\n[3] Draft K=5 tokens...")
features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers]
for idx, f in enumerate(features):
has_nan = f.isnan().any().item()
print(f" Feature {idx} (layer {engine._eagle_capture_layers[idx]}): "
f"shape={list(f.shape)}, NaN={has_nan}")
memory_ctx = engine._get_eagle_memory_context(
engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])
draft_tokens, draft_logits = engine.eagle_head.generate_draft(
features, first_token, engine.embed, depth=5, memory_context=memory_ctx)
print(f" Draft tokens: {[t.item() for t in draft_tokens]}")
print(f" Draft decoded: {[tokenizer.decode([t.item()]) for t in draft_tokens]}")
for i, dl in enumerate(draft_logits):
has_nan = dl.isnan().any().item()
if has_nan:
print(f" *** Draft logits[{i}]: NaN!")
# Step 4: Verify draft tokens through main model (this is the suspicious step)
print("\n[4] Verify K=5 draft tokens through main model...")
print(f" Verifying at position={current_pos} (prompt_len={prompt_len})")
draft_input = torch.cat(draft_tokens, dim=1)
print(f" draft_input shape: {list(draft_input.shape)}, tokens: {draft_input[0].tolist()}")
verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)
torch.cuda.synchronize()
nan4 = check_nan("Verify logits", verify_logits)
if nan4:
print("\n FOUND THE BUG: Verify forward (K>1 tokens at position>0) produces NaN!")
print(" This is likely a causal mask or KV cache issue in multi-token decode.")
# Additional test: verify ONE draft token at a time
print("\n[4b] Trying verify ONE token at a time...")
# Rollback the K tokens we just stored
engine.kv_cache.rollback_to(current_pos, 5)
for i, dt in enumerate(draft_tokens):
one_logit = engine.forward(dt, use_cache=True, position=current_pos + i)
torch.cuda.synchronize()
has_nan = one_logit.isnan().any().item()
top = one_logit[:, -1, :].argmax(dim=-1).item() if not has_nan else -1
print(f" Token {i} at pos {current_pos + i}: NaN={has_nan} top={top}")
if has_nan:
print(f" SINGLE token verify also fails at position {current_pos + i}!")
break
else:
print("\n Verify logits OK — checking acceptance logic...")
if draft_tokens[0].item() == main_pred:
print(f" Draft[0] matches main_pred ({main_pred}) ✓")
else:
print(f" Draft[0]={draft_tokens[0].item()} ≠ main_pred={main_pred} ✗")
for i in range(1, len(draft_tokens)):
target_pred = verify_logits[:, i-1, :].argmax(dim=-1).item()
match = "✓" if draft_tokens[i].item() == target_pred else "✗"
print(f" verify[{i-1}]={target_pred} vs draft[{i}]={draft_tokens[i].item()} {match}")
# Step 5: Also test a multi-token forward with RANDOM tokens at position>0
print("\n[5] Control test: multi-token forward with KNOWN-GOOD tokens...")
engine.reset_cache()
engine._current_seq_id = 0
# Prefill
logits = engine.forward(ids, use_cache=True, position=0)
# Now try 5 copies of a valid token at position=prompt_len
test_tokens = torch.full((1, 5), first_token.item(), dtype=torch.long, device='cuda')
test_logits = engine.forward(test_tokens, use_cache=True, position=prompt_len)
torch.cuda.synchronize()
nan5 = check_nan("Control multi-token logits", test_logits)
print("\n" + "=" * 60)
print(" TRACE COMPLETE")
print("=" * 60)
if __name__ == "__main__":
main()
|