| | |
| | """Replicate the exact training eval flow to verify acceptance rate. |
| | |
| | Matches train_eagle_head.py: enable_eagle (no ckpt), load_checkpoint, evaluate. |
| | """ |
| | import sys, os, time, torch |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| |
|
| | EVAL_PROMPTS = [ |
| | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n", |
| | "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nExplain what a neural network is in simple terms.<|im_end|>\n<|im_start|>assistant\n", |
| | "<|im_start|>system\nYou are a helpful coding assistant.<|im_end|>\n<|im_start|>user\nWrite a binary search function in Python.<|im_end|>\n<|im_start|>assistant\n", |
| | ] |
| |
|
| |
|
| | @torch.no_grad() |
| | def evaluate_verbose(engine, tokenizer, max_new=60): |
| | """Run speculative_generate and print acceptance + output for each prompt.""" |
| | engine.eval() |
| | eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| | stop_tokens = [eos_id] if eos_id is not None else [151645] |
| |
|
| | for pi, prompt in enumerate(EVAL_PROMPTS): |
| | ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda") |
| | engine.reset_cache() |
| |
|
| | t0 = time.perf_counter() |
| | out = engine.speculative_generate( |
| | ids, max_new_tokens=max_new, temperature=0.0, |
| | stop_tokens=stop_tokens) |
| | torch.cuda.synchronize() |
| | t1 = time.perf_counter() |
| |
|
| | gen_len = out.shape[1] - ids.shape[1] |
| | text = tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True) |
| | tps = gen_len / max(t1 - t0, 1e-6) |
| | print(f"\n Prompt {pi}: {gen_len} tokens, {tps:.1f} tok/s") |
| | print(f" Output: {text[:150]}") |
| |
|
| | |
| | gen_ids = out[0, ids.shape[1]:].tolist() |
| | if len(set(gen_ids)) == 1 and len(gen_ids) > 5: |
| | print(f" WARNING: All tokens are the same ({gen_ids[0]}) — likely NaN bug!") |
| |
|
| |
|
| | @torch.no_grad() |
| | def test_manual_speculation(engine, tokenizer): |
| | """Manually run one round of draft+verify and check each step.""" |
| | print("\n--- Manual speculation test ---") |
| | engine.eval() |
| | prompt = EVAL_PROMPTS[0] |
| | ids = tokenizer.encode(prompt, return_tensors="pt").cuda() |
| | prompt_len = ids.shape[1] |
| |
|
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| |
|
| | |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | has_nan = logits.isnan().any().item() |
| | print(f" Prefill logits: has_nan={has_nan}") |
| | if has_nan: |
| | print(" FATAL: NaN in prefill! Cannot continue.") |
| | return |
| |
|
| | |
| | next_token = logits[:, -1:, :].argmax(dim=-1) |
| | print(f" First token: {next_token.item()} = '{tokenizer.decode([next_token.item()])}'") |
| |
|
| | |
| | logits = engine.forward(next_token, use_cache=True, position=prompt_len) |
| | has_nan = logits.isnan().any().item() |
| | print(f" Post-first-token logits: has_nan={has_nan}") |
| | if has_nan: |
| | print(" FATAL: NaN after first token forward!") |
| | return |
| |
|
| | main_pred = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" Target predicts next: {main_pred} = '{tokenizer.decode([main_pred])}'") |
| |
|
| | |
| | features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers] |
| | for li, f in zip(engine._eagle_capture_layers, features): |
| | print(f" Feature L{li}: has_nan={f.isnan().any().item()}, " |
| | f"shape={list(f.shape)}") |
| |
|
| | memory_ctx = engine._get_eagle_memory_context( |
| | engine._eagle_hidden_states[engine._eagle_capture_layers[-1]]) |
| |
|
| | dt, dl = engine.eagle_head.generate_draft( |
| | features, next_token, engine.embed, depth=5, memory_context=memory_ctx) |
| |
|
| | print(f"\n Draft tokens:") |
| | for i, t in enumerate(dt): |
| | print(f" [{i}] {t.item()} = '{tokenizer.decode([t.item()])}'") |
| |
|
| | |
| | draft_input = torch.cat(dt, dim=1) |
| | current_pos = prompt_len + 1 |
| | verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos) |
| | has_nan = verify_logits.isnan().any().item() |
| | print(f"\n Verify logits: has_nan={has_nan}") |
| |
|
| | accepted = 0 |
| | if dt[0].item() == main_pred: |
| | accepted = 1 |
| | for i in range(1, len(dt)): |
| | target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item() |
| | match = "MATCH" if dt[i].item() == target_pred else "MISS" |
| | print(f" [{i}] draft={dt[i].item()} target={target_pred} → {match}") |
| | if dt[i].item() == target_pred: |
| | accepted += 1 |
| | else: |
| | break |
| | else: |
| | print(f" [0] MISS: draft={dt[0].item()} target={main_pred}") |
| |
|
| | print(f" Accepted: {accepted}/{len(dt)}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" Eval Flow Test (replicates training eval)") |
| | print("=" * 60) |
| |
|
| | |
| | print("\n[1] Loading model...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=512, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | engine.pack_all_experts() |
| |
|
| | print("\n[2] Enabling EAGLE (no checkpoint)...") |
| | engine.enable_eagle( |
| | capture_layers=(8, 24, 47), |
| | num_heads=16, ffn_mult=2, |
| | draft_depth=5, num_head_layers=8) |
| |
|
| | print("\n[3] Loading checkpoint separately (like training script)...") |
| | if os.path.exists(EAGLE_CKPT): |
| | ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda') |
| | sd = ckpt.get('eagle_head', ckpt) |
| | is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) |
| | if is_legacy: |
| | engine.eagle_head.load_legacy_checkpoint(sd) |
| | else: |
| | engine.eagle_head.load_state_dict(sd, strict=False) |
| | print(f" Loaded checkpoint (step {ckpt.get('step', '?')})") |
| | else: |
| | print(f" No checkpoint found, using random init") |
| |
|
| | |
| | eagle_params = [p for n, p in engine.eagle_head.named_parameters() |
| | if 'lm_head' not in n and p.requires_grad] |
| | optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95)) |
| |
|
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB") |
| |
|
| | |
| | print("\n[4a] Running manual speculation test WITHOUT warmup...") |
| | test_manual_speculation(engine, tokenizer) |
| |
|
| | |
| | print("\n[4b] Warmup (3x generate)...") |
| | warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for _ in range(3): |
| | engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| | print(" Warmup done") |
| |
|
| | |
| | print("\n[4c] Running manual speculation test AFTER warmup...") |
| | test_manual_speculation(engine, tokenizer) |
| |
|
| | print("\n[5] Running full speculative_generate eval...") |
| | evaluate_verbose(engine, tokenizer) |
| |
|
| | print("\n" + "=" * 60) |
| | print(" Done") |
| | print("=" * 60) |
| |
|