File size: 6,089 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
#!/usr/bin/env python3
"""Debug: Why does D=8 eagle head show 100% acceptance?
Compare draft tokens vs target predictions for D=2 and D=8.

ROOT CAUSE FOUND: Missing torch.no_grad() caused NaN logits (Goliath FP4
Triton kernels don't support autograd). argmax(NaN)=0 for both draft and
target → fake 100% acceptance. This version fixes that.
"""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")

@torch.no_grad()
def test_acceptance(engine, tokenizer, num_layers, label):
    """Enable eagle with given D, run one round of draft+verify, print details."""
    print(f"\n{'='*60}")
    print(f"  Testing D={num_layers} ({label})")
    print(f"{'='*60}")

    # Enable eagle
    engine.enable_eagle(
        capture_layers=(8, 24, 47),
        num_head_layers=num_layers,
        checkpoint_path=EAGLE_CKPT if os.path.exists(EAGLE_CKPT) else None)
    engine.eval()

    prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n"
    ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
    prompt_len = ids.shape[1]

    # Prefill
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False
    logits = engine.forward(ids, use_cache=True, position=0)
    current_pos = prompt_len

    # Check for NaN in target logits
    has_nan = logits.isnan().any().item()
    print(f"  Target prefill logits: has_nan={has_nan}, "
          f"min={logits[:,-1,:].min().item():.2f}, max={logits[:,-1,:].max().item():.2f}")

    # Decode first token
    next_token = logits[:, -1:, :].argmax(dim=-1)
    print(f"  First decoded token: {next_token.item()} = '{tokenizer.decode([next_token.item()])}'")

    # Forward it (stores KV, captures hidden states)
    logits = engine.forward(next_token, use_cache=True, position=current_pos)
    current_pos += 1

    # Target model's prediction
    main_pred = logits[:, -1, :].argmax(dim=-1).item()
    print(f"  Target predicts next: {main_pred} = '{tokenizer.decode([main_pred])}'")

    # Draft 5 tokens
    features = [engine._eagle_hidden_states[l]
                for l in engine._eagle_capture_layers]

    # Check features for NaN
    for li, f in zip(engine._eagle_capture_layers, features):
        print(f"  Feature layer {li}: has_nan={f.isnan().any().item()}, "
              f"min={f.min().item():.4f}, max={f.max().item():.4f}")

    memory_ctx = engine._get_eagle_memory_context(
        engine._eagle_hidden_states[engine._eagle_capture_layers[-1]])

    draft_tokens, draft_logits = engine.eagle_head.generate_draft(
        features, next_token, engine.embed, depth=5,
        memory_context=memory_ctx)

    print(f"  Draft tokens:")
    for i, dt in enumerate(draft_tokens):
        tok_id = dt.item()
        print(f"    [{i}] {tok_id} = '{tokenizer.decode([tok_id])}'")

    # Check draft logits for NaN
    dl0 = draft_logits[0][0, 0, :]
    print(f"  Draft logits[0]: has_nan={dl0.isnan().any().item()}, "
          f"min={dl0.min().item():.2f}, max={dl0.max().item():.2f}")

    # Verify: forward draft tokens through target
    draft_input = torch.cat(draft_tokens, dim=1)
    verify_logits = engine.forward(draft_input, use_cache=True, position=current_pos)

    print(f"  Target verify predictions:")
    accepted = 0
    if draft_tokens[0].item() == main_pred:
        accepted = 1
        for i in range(1, len(draft_tokens)):
            target_pred = verify_logits[:, i - 1, :].argmax(dim=-1).item()
            match = "MATCH" if draft_tokens[i].item() == target_pred else "MISS"
            print(f"    [{i}] target={target_pred} ('{tokenizer.decode([target_pred])}'), "
                  f"draft={draft_tokens[i].item()} ('{tokenizer.decode([draft_tokens[i].item()])}') → {match}")
            if draft_tokens[i].item() == target_pred:
                accepted += 1
            else:
                break
    else:
        print(f"    [0] MISS: draft[0]={draft_tokens[0].item()} "
              f"('{tokenizer.decode([draft_tokens[0].item()])}') "
              f"!= main_pred={main_pred} ('{tokenizer.decode([main_pred])}')")

    print(f"  Accepted: {accepted}/{len(draft_tokens)}")

    # Also run full speculative_generate to match training eval
    print(f"\n  --- Full speculative_generate (max_new=30) ---")
    engine.reset_cache()
    ids2 = tokenizer.encode(prompt, return_tensors='pt').cuda()
    out = engine.speculative_generate(
        ids2, max_new_tokens=30, temperature=0.0,
        stop_tokens=[199999, 200020])
    text = tokenizer.decode(out[0, ids2.shape[1]:], skip_special_tokens=True)
    print(f"  Output: {text[:120]}")

    # Cleanup eagle
    del engine.eagle_head
    engine._eagle_enabled = False

    return accepted


if __name__ == "__main__":
    print("Loading model...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.pack_all_experts()
    engine.kv_cache.enable_flat_decode()
    engine.eval()

    # Warmup
    print("Warmup...")
    warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
    for _ in range(3):
        engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)

    # Test D=2
    acc2 = test_acceptance(engine, tokenizer, 2, "D=2 baseline")

    # Test D=8
    acc8 = test_acceptance(engine, tokenizer, 8, "D=8 with random layers 2-7")

    print(f"\n{'='*60}")
    print(f"  D=2 accepted: {acc2}/5")
    print(f"  D=8 accepted: {acc8}/5")
    if acc8 > acc2 + 2:
        print(f"  WARNING: D=8 significantly better than D=2 — investigate!")
    elif acc2 <= 2 and acc8 <= 2:
        print(f"  EXPECTED: Both D=2 and D=8 have low acceptance (undertrained)")
    print(f"{'='*60}")