FireEcho / FireEcho Engine /debug_acceptance.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Debug: Why does D=8 eagle head show 100% acceptance?
Compare draft tokens vs target predictions for D=2 and D=8.
ROOT CAUSE FOUND: Missing torch.no_grad() caused NaN logits (Goliath FP4
Triton kernels don't support autograd). argmax(NaN)=0 for both draft and
target → fake 100% acceptance. This version fixes that.
"""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
@torch.no_grad()
def test_acceptance(engine, tokenizer, num_layers, label):
"""Enable eagle with given D, run one round of draft+verify, print details."""
print(f"\n{'='*60}")
print(f" Testing D={num_layers} ({label})")
print(f"{'='*60}")
# Enable eagle
engine.enable_eagle(
capture_layers=(8, 24, 47),
num_head_layers=num_layers,
checkpoint_path=EAGLE_CKPT if os.path.exists(EAGLE_CKPT) else None)
engine.eval()
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n"
ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
prompt_len = ids.shape[1]
# Prefill
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
current_pos = prompt_len
# Check for NaN in target logits
has_nan = logits.isnan().any().item()
print(f" Target prefill logits: has_nan={has_nan}, "
f"min={logits[:,-1,:].min().item():.2f}, max={logits[:,-1,:].max().item():.2f}")
# Decode first token
next_token = logits[:, -1:, :].argmax(dim=-1)
print(f" First decoded token: {next_token.item()} = '{tokenizer.decode([next_token.item()])}'")
# Forward it (stores KV, captures hidden states)
logits = engine.forward(next_token, use_cache=True, position=current_pos)
current_pos += 1
# Target model's prediction
main_pred = logits[:, -1, :].argmax(dim=-1).item()
print(f" Target predicts next: {main_pred} = '{tokenizer.decode([main_pred])}'")
# Draft 5 tokens
features = [engine._eagle_hidden_states[l]
for l in engine._eagle_capture_layers]
# Check features for NaN
for li, f in zip(engine._eagle_capture_layers, features):
print(f" Feature layer {li}: has_nan={f.isnan().any().item()}, "
f"min={f.min().item():.4f}, max={f.max().item():.4f}")
memory_ctx = engine._get_eagle_memory_context(
engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])
draft_tokens, draft_logits = engine.eagle_head.generate_draft(
features, next_token, engine.embed, depth=5,
memory_context=memory_ctx)
print(f" Draft tokens:")
for i, dt in enumerate(draft_tokens):
tok_id = dt.item()
print(f" [{i}] {tok_id} = '{tokenizer.decode([tok_id])}'")
# Check draft logits for NaN
dl0 = draft_logits[0][0, 0, :]
print(f" Draft logits[0]: has_nan={dl0.isnan().any().item()}, "
f"min={dl0.min().item():.2f}, max={dl0.max().item():.2f}")
# Verify: forward draft tokens through target
draft_input = torch.cat(draft_tokens, dim=1)
verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)
print(f" Target verify predictions:")
accepted = 0
if draft_tokens[0].item() == main_pred:
accepted = 1
for i in range(1, len(draft_tokens)):
target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item()
match = "MATCH" if draft_tokens[i].item() == target_pred else "MISS"
print(f" [{i}] target={target_pred} ('{tokenizer.decode([target_pred])}'), "
f"draft={draft_tokens[i].item()} ('{tokenizer.decode([draft_tokens[i].item()])}') → {match}")
if draft_tokens[i].item() == target_pred:
accepted += 1
else:
break
else:
print(f" [0] MISS: draft[0]={draft_tokens[0].item()} "
f"('{tokenizer.decode([draft_tokens[0].item()])}') "
f"!= main_pred={main_pred} ('{tokenizer.decode([main_pred])}')")
print(f" Accepted: {accepted}/{len(draft_tokens)}")
# Also run full speculative_generate to match training eval
print(f"\n --- Full speculative_generate (max_new=30) ---")
engine.reset_cache()
ids2 = tokenizer.encode(prompt, return_tensors='pt').cuda()
out = engine.speculative_generate(
ids2, max_new_tokens=30, temperature=0.0,
stop_tokens=[199999, 200020])
text = tokenizer.decode(out[0, ids2.shape[1]:], skip_special_tokens=True)
print(f" Output: {text[:120]}")
# Cleanup eagle
del engine.eagle_head
engine._eagle_enabled = False
return accepted
if __name__ == "__main__":
print("Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.pack_all_experts()
engine.kv_cache.enable_flat_decode()
engine.eval()
# Warmup
print("Warmup...")
warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
# Test D=2
acc2 = test_acceptance(engine, tokenizer, 2, "D=2 baseline")
# Test D=8
acc8 = test_acceptance(engine, tokenizer, 8, "D=8 with random layers 2-7")
print(f"\n{'='*60}")
print(f" D=2 accepted: {acc2}/5")
print(f" D=8 accepted: {acc8}/5")
if acc8 > acc2 + 2:
print(f" WARNING: D=8 significantly better than D=2 — investigate!")
elif acc2 <= 2 and acc8 <= 2:
print(f" EXPECTED: Both D=2 and D=8 have low acceptance (undertrained)")
print(f"{'='*60}")