FireEcho / FireEcho Engine /debug_specgen_trace.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Trace speculative_generate step by step to find exactly where NaN appears."""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a function to check primes.<|im_end|>\n<|im_start|>assistant\n"
def check_nan(label, tensor):
has_nan = tensor.isnan().any().item()
has_inf = tensor.isinf().any().item()
if has_nan or has_inf:
print(f" *** {label}: NaN={has_nan} Inf={has_inf} shape={list(tensor.shape)}")
# Check which positions have NaN
if tensor.dim() == 3: # [B, S, V]
for s in range(tensor.shape[1]):
if tensor[0, s].isnan().any():
print(f" Position {s}: NaN!")
return True
else:
top = tensor[:, -1, :].argmax(dim=-1).item()
print(f" {label}: OK (top={top}) shape={list(tensor.shape)}")
return False
@torch.no_grad()
def main():
print("=" * 60)
print(" Speculative Generate NaN Trace")
print("=" * 60)
# Load engine exactly like training
print("\n[SETUP] Loading engine...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
# Enable EAGLE D=8
engine.enable_eagle(
capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
draft_depth=5, num_head_layers=8, checkpoint_path=EAGLE_CKPT)
# Warmup
print("\n[SETUP] Warmup...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
del wids
# Now replicate speculative_generate manually
print("\n[TRACE] Starting manual speculation trace...")
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
prompt_len = ids.shape[1]
print(f" Prompt length: {prompt_len}")
# Step 1: Reset + prefill
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
print("\n[1] Prefill...")
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
nan1 = check_nan("Prefill logits", logits)
if nan1:
print(" FATAL: NaN in prefill!")
return
current_pos = prompt_len
first_token = logits[:, -1:, :].argmax(dim=-1)
print(f" First token: {first_token.item()} ('{tokenizer.decode([first_token.item()])}')")
# Step 2: Process first token through main model
print("\n[2] Process first token through main model...")
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(first_token, use_cache=True, position=current_pos)
torch.cuda.synchronize()
nan2 = check_nan("First-token logits", logits)
if nan2:
print(" FATAL: NaN at first token forward!")
return
current_pos += 1
main_pred = logits[:, -1, :].argmax(dim=-1).item()
print(f" main_pred: {main_pred} ('{tokenizer.decode([main_pred])}')")
# Step 3: Draft K tokens using EAGLE
print("\n[3] Draft K=5 tokens...")
features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers]
for idx, f in enumerate(features):
has_nan = f.isnan().any().item()
print(f" Feature {idx} (layer {engine._eagle_capture_layers[idx]}): "
f"shape={list(f.shape)}, NaN={has_nan}")
memory_ctx = engine._get_eagle_memory_context(
engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])
draft_tokens, draft_logits = engine.eagle_head.generate_draft(
features, first_token, engine.embed, depth=5, memory_context=memory_ctx)
print(f" Draft tokens: {[t.item() for t in draft_tokens]}")
print(f" Draft decoded: {[tokenizer.decode([t.item()]) for t in draft_tokens]}")
for i, dl in enumerate(draft_logits):
has_nan = dl.isnan().any().item()
if has_nan:
print(f" *** Draft logits[{i}]: NaN!")
# Step 4: Verify draft tokens through main model (this is the suspicious step)
print("\n[4] Verify K=5 draft tokens through main model...")
print(f" Verifying at position={current_pos} (prompt_len={prompt_len})")
draft_input = torch.cat(draft_tokens, dim=1)
print(f" draft_input shape: {list(draft_input.shape)}, tokens: {draft_input[0].tolist()}")
verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)
torch.cuda.synchronize()
nan4 = check_nan("Verify logits", verify_logits)
if nan4:
print("\n FOUND THE BUG: Verify forward (K>1 tokens at position>0) produces NaN!")
print(" This is likely a causal mask or KV cache issue in multi-token decode.")
# Additional test: verify ONE draft token at a time
print("\n[4b] Trying verify ONE token at a time...")
# Rollback the K tokens we just stored
engine.kv_cache.rollback_to(current_pos, 5)
for i, dt in enumerate(draft_tokens):
one_logit = engine.forward(dt, use_cache=True, position=current_pos + i)
torch.cuda.synchronize()
has_nan = one_logit.isnan().any().item()
top = one_logit[:, -1, :].argmax(dim=-1).item() if not has_nan else -1
print(f" Token {i} at pos {current_pos + i}: NaN={has_nan} top={top}")
if has_nan:
print(f" SINGLE token verify also fails at position {current_pos + i}!")
break
else:
print("\n Verify logits OK — checking acceptance logic...")
if draft_tokens[0].item() == main_pred:
print(f" Draft[0] matches main_pred ({main_pred}) ✓")
else:
print(f" Draft[0]={draft_tokens[0].item()} ≠ main_pred={main_pred} ✗")
for i in range(1, len(draft_tokens)):
target_pred = verify_logits[:, i-1, :].argmax(dim=-1).item()
match = "✓" if draft_tokens[i].item() == target_pred else "✗"
print(f" verify[{i-1}]={target_pred} vs draft[{i}]={draft_tokens[i].item()} {match}")
# Step 5: Also test a multi-token forward with RANDOM tokens at position>0
print("\n[5] Control test: multi-token forward with KNOWN-GOOD tokens...")
engine.reset_cache()
engine._current_seq_id = 0
# Prefill
logits = engine.forward(ids, use_cache=True, position=0)
# Now try 5 copies of a valid token at position=prompt_len
test_tokens = torch.full((1, 5), first_token.item(), dtype=torch.long, device='cuda')
test_logits = engine.forward(test_tokens, use_cache=True, position=prompt_len)
torch.cuda.synchronize()
nan5 = check_nan("Control multi-token logits", test_logits)
print("\n" + "=" * 60)
print(" TRACE COMPLETE")
print("=" * 60)
if __name__ == "__main__":
main()