File size: 7,384 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
#!/usr/bin/env python3
"""Replicate the exact training eval flow to verify acceptance rate.

Matches train_eagle_head.py: enable_eagle (no ckpt), load_checkpoint, evaluate.
"""
import sys, os, time, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")

EVAL_PROMPTS = [
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n",
    "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nExplain what a neural network is in simple terms.<|im_end|>\n<|im_start|>assistant\n",
    "<|im_start|>system\nYou are a helpful coding assistant.<|im_end|>\n<|im_start|>user\nWrite a binary search function in Python.<|im_end|>\n<|im_start|>assistant\n",
]


@torch.no_grad()
def evaluate_verbose(engine, tokenizer, max_new=60):
    """Run speculative_generate and print acceptance + output for each prompt."""
    engine.eval()
    eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    stop_tokens = [eos_id] if eos_id is not None else [151645]

    for pi, prompt in enumerate(EVAL_PROMPTS):
        ids = tokenizer.encode(prompt, return_tensors="pt").to("cuda")
        engine.reset_cache()

        t0 = time.perf_counter()
        out = engine.speculative_generate(
            ids, max_new_tokens=max_new, temperature=0.0,
            stop_tokens=stop_tokens)
        torch.cuda.synchronize()
        t1 = time.perf_counter()

        gen_len = out.shape[1] - ids.shape[1]
        text = tokenizer.decode(out[0, ids.shape[1]:], skip_special_tokens=True)
        tps = gen_len / max(t1 - t0, 1e-6)
        print(f"\n  Prompt {pi}: {gen_len} tokens, {tps:.1f} tok/s")
        print(f"  Output: {text[:150]}")

        # Check for all-same-token output (sign of NaN)
        gen_ids = out[0, ids.shape[1]:].tolist()
        if len(set(gen_ids)) == 1 and len(gen_ids) > 5:
            print(f"  WARNING: All tokens are the same ({gen_ids[0]}) — likely NaN bug!")


@torch.no_grad()
def test_manual_speculation(engine, tokenizer):
    """Manually run one round of draft+verify and check each step."""
    print("\n--- Manual speculation test ---")
    engine.eval()
    prompt = EVAL_PROMPTS[0]
    ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
    prompt_len = ids.shape[1]

    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False

    # Prefill
    logits = engine.forward(ids, use_cache=True, position=0)
    has_nan = logits.isnan().any().item()
    print(f"  Prefill logits: has_nan={has_nan}")
    if has_nan:
        print("  FATAL: NaN in prefill! Cannot continue.")
        return

    # Decode first token
    next_token = logits[:, -1:, :].argmax(dim=-1)
    print(f"  First token: {next_token.item()} = '{tokenizer.decode([next_token.item()])}'")

    # Forward it
    logits = engine.forward(next_token, use_cache=True, position=prompt_len)
    has_nan = logits.isnan().any().item()
    print(f"  Post-first-token logits: has_nan={has_nan}")
    if has_nan:
        print("  FATAL: NaN after first token forward!")
        return

    main_pred = logits[:, -1, :].argmax(dim=-1).item()
    print(f"  Target predicts next: {main_pred} = '{tokenizer.decode([main_pred])}'")

    # Draft 5 tokens
    features = [engine._eagle_hidden_states[l] for l in engine._eagle_capture_layers]
    for li, f in zip(engine._eagle_capture_layers, features):
        print(f"  Feature L{li}: has_nan={f.isnan().any().item()}, "
              f"shape={list(f.shape)}")

    memory_ctx = engine._get_eagle_memory_context(
        engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])

    dt, dl = engine.eagle_head.generate_draft(
        features, next_token, engine.embed, depth=5, memory_context=memory_ctx)

    print(f"\n  Draft tokens:")
    for i, t in enumerate(dt):
        print(f"    [{i}] {t.item()} = '{tokenizer.decode([t.item()])}'")

    # Verify
    draft_input = torch.cat(dt, dim=1)
    current_pos = prompt_len + 1
    verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)
    has_nan = verify_logits.isnan().any().item()
    print(f"\n  Verify logits: has_nan={has_nan}")

    accepted = 0
    if dt[0].item() == main_pred:
        accepted = 1
        for i in range(1, len(dt)):
            target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item()
            match = "MATCH" if dt[i].item() == target_pred else "MISS"
            print(f"    [{i}] draft={dt[i].item()} target={target_pred}{match}")
            if dt[i].item() == target_pred:
                accepted += 1
            else:
                break
    else:
        print(f"    [0] MISS: draft={dt[0].item()} target={main_pred}")

    print(f"  Accepted: {accepted}/{len(dt)}")


if __name__ == "__main__":
    print("=" * 60)
    print("  Eval Flow Test (replicates training eval)")
    print("=" * 60)

    # === Match training script flow exactly ===
    print("\n[1] Loading model...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=512, device="cuda")
    engine.eval()
    engine.kv_cache.enable_flat_decode(4096)
    engine.pack_all_experts()

    print("\n[2] Enabling EAGLE (no checkpoint)...")
    engine.enable_eagle(
        capture_layers=(8, 24, 47),
        num_heads=16, ffn_mult=2,
        draft_depth=5, num_head_layers=8)

    print("\n[3] Loading checkpoint separately (like training script)...")
    if os.path.exists(EAGLE_CKPT):
        ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda')
        sd = ckpt.get('eagle_head', ckpt)
        is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
        if is_legacy:
            engine.eagle_head.load_legacy_checkpoint(sd)
        else:
            engine.eagle_head.load_state_dict(sd, strict=False)
        print(f"  Loaded checkpoint (step {ckpt.get('step', '?')})")
    else:
        print(f"  No checkpoint found, using random init")

    # Setup optimizer (like training script)
    eagle_params = [p for n, p in engine.eagle_head.named_parameters()
                    if 'lm_head' not in n and p.requires_grad]
    optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95))

    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB")

    # Test WITHOUT warmup first
    print("\n[4a] Running manual speculation test WITHOUT warmup...")
    test_manual_speculation(engine, tokenizer)

    # Now do warmup
    print("\n[4b] Warmup (3x generate)...")
    warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
    for _ in range(3):
        engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
    print("  Warmup done")

    # Test AFTER warmup
    print("\n[4c] Running manual speculation test AFTER warmup...")
    test_manual_speculation(engine, tokenizer)

    print("\n[5] Running full speculative_generate eval...")
    evaluate_verbose(engine, tokenizer)

    print("\n" + "=" * 60)
    print("  Done")
    print("=" * 60)