File size: 7,384 Bytes
b5bff9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 | #!/usr/bin/env python3
"""Replicate the exact training eval flow to verify acceptance rate.
Matches train_eagle_head.py: enable_eagle (no ckpt), load_checkpoint, evaluate.
"""
import sys, os, time, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
EVAL_PROMPTS = [
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nExplain what a neural network is in simple terms.<|im_end|>\n<|im_start|>assistant\n",
"<|im_start|>system\nYou are a helpful coding assistant.<|im_end|>\n<|im_start|>user\nWrite a binary search function in Python.<|im_end|>\n<|im_start|>assistant\n",
]
@torch.no_grad()
def evaluate_verbose(engine, tokenizer, max_new=60):
"""Run speculative_generate and print acceptance + output for each prompt."""
engine.eval()
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
stop_tokens = [eos_id] if eos_id is not None else [151645]
for pi, prompt in enumerate(EVAL_PROMPTS):
ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
engine.reset_cache()
t0 = time.perf_counter()
out = engine.speculative_generate(
ids, max_new_tokens=max_new, temperature=0.0,
stop_tokens=stop_tokens)
torch.cuda.synchronize()
t1 = time.perf_counter()
gen_len = out.shape[1] - ids.shape[1]
text = tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True)
tps = gen_len / max(t1 - t0, 1e-6)
print(f"\n Prompt {pi}: {gen_len} tokens, {tps:.1f} tok/s")
print(f" Output: {text[:150]}")
# Check for all-same-token output (sign of NaN)
gen_ids = out[0, ids.shape[1]:].tolist()
if len(set(gen_ids)) == 1 and len(gen_ids) > 5:
print(f" WARNING: All tokens are the same ({gen_ids[0]}) — likely NaN bug!")
@torch.no_grad()
def test_manual_speculation(engine, tokenizer):
"""Manually run one round of draft+verify and check each step."""
print("\n--- Manual speculation test ---")
engine.eval()
prompt = EVAL_PROMPTS[0]
ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
prompt_len = ids.shape[1]
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
# Prefill
logits = engine.forward(ids, use_cache=True, position=0)
has_nan = logits.isnan().any().item()
print(f" Prefill logits: has_nan={has_nan}")
if has_nan:
print(" FATAL: NaN in prefill! Cannot continue.")
return
# Decode first token
next_token = logits[:, -1:, :].argmax(dim=-1)
print(f" First token: {next_token.item()} = '{tokenizer.decode([next_token.item()])}'")
# Forward it
logits = engine.forward(next_token, use_cache=True, position=prompt_len)
has_nan = logits.isnan().any().item()
print(f" Post-first-token logits: has_nan={has_nan}")
if has_nan:
print(" FATAL: NaN after first token forward!")
return
main_pred = logits[:, -1, :].argmax(dim=-1).item()
print(f" Target predicts next: {main_pred} = '{tokenizer.decode([main_pred])}'")
# Draft 5 tokens
features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers]
for li, f in zip(engine._eagle_capture_layers, features):
print(f" Feature L{li}: has_nan={f.isnan().any().item()}, "
f"shape={list(f.shape)}")
memory_ctx = engine._get_eagle_memory_context(
engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])
dt, dl = engine.eagle_head.generate_draft(
features, next_token, engine.embed, depth=5, memory_context=memory_ctx)
print(f"\n Draft tokens:")
for i, t in enumerate(dt):
print(f" [{i}] {t.item()} = '{tokenizer.decode([t.item()])}'")
# Verify
draft_input = torch.cat(dt, dim=1)
current_pos = prompt_len + 1
verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)
has_nan = verify_logits.isnan().any().item()
print(f"\n Verify logits: has_nan={has_nan}")
accepted = 0
if dt[0].item() == main_pred:
accepted = 1
for i in range(1, len(dt)):
target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item()
match = "MATCH" if dt[i].item() == target_pred else "MISS"
print(f" [{i}] draft={dt[i].item()} target={target_pred} → {match}")
if dt[i].item() == target_pred:
accepted += 1
else:
break
else:
print(f" [0] MISS: draft={dt[0].item()} target={main_pred}")
print(f" Accepted: {accepted}/{len(dt)}")
if __name__ == "__main__":
print("=" * 60)
print(" Eval Flow Test (replicates training eval)")
print("=" * 60)
# === Match training script flow exactly ===
print("\n[1] Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=512, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
print("\n[2] Enabling EAGLE (no checkpoint)...")
engine.enable_eagle(
capture_layers=(8, 24, 47),
num_heads=16, ffn_mult=2,
draft_depth=5, num_head_layers=8)
print("\n[3] Loading checkpoint separately (like training script)...")
if os.path.exists(EAGLE_CKPT):
ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda')
sd = ckpt.get('eagle_head', ckpt)
is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
if is_legacy:
engine.eagle_head.load_legacy_checkpoint(sd)
else:
engine.eagle_head.load_state_dict(sd, strict=False)
print(f" Loaded checkpoint (step {ckpt.get('step', '?')})")
else:
print(f" No checkpoint found, using random init")
# Setup optimizer (like training script)
eagle_params = [p for n, p in engine.eagle_head.named_parameters()
if 'lm_head' not in n and p.requires_grad]
optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95))
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM: {vram:.2f} GB")
# Test WITHOUT warmup first
print("\n[4a] Running manual speculation test WITHOUT warmup...")
test_manual_speculation(engine, tokenizer)
# Now do warmup
print("\n[4b] Warmup (3x generate)...")
warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
print(" Warmup done")
# Test AFTER warmup
print("\n[4c] Running manual speculation test AFTER warmup...")
test_manual_speculation(engine, tokenizer)
print("\n[5] Running full speculative_generate eval...")
evaluate_verbose(engine, tokenizer)
print("\n" + "=" * 60)
print(" Done")
print("=" * 60)
|