File size: 6,277 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
#!/usr/bin/env python3
"""Isolate exactly what about D=8 causes NaN.

Tests:
1. D=2 eagle head β†’ forward β†’ should be OK
2. D=8 eagle head (random, no ckpt) β†’ forward β†’ is NaN from VRAM pressure?
3. D=8 eagle head (random, NOT assigned to engine) β†’ forward β†’ is NaN from registration?
4. D=8 allocated but eagle_enabled=False β†’ forward β†’ is NaN from .to() side effect?
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")

PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"


@torch.no_grad()
def check(engine, tokenizer, label):
    ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False
    logits = engine.forward(ids, use_cache=True, position=0)
    torch.cuda.synchronize()
    has_nan = logits.isnan().any().item()
    if has_nan:
        print(f"  [{label}] NaN DETECTED")
    else:
        top = logits[:, -1, :].argmax(dim=-1).item()
        print(f"  [{label}] OK β€” top={top} ('{tokenizer.decode([top])}')")
    return has_nan


if __name__ == "__main__":
    print("=" * 60)
    print("  D=8 NaN Isolation")
    print("=" * 60)

    print("\n[1] Loading model...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.pack_all_experts()
    engine.kv_cache.enable_flat_decode()
    engine.eval()

    # Warmup
    print("\n[2] Warmup...")
    wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
    for _ in range(3):
        engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)

    vram_base = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM baseline: {vram_base:.2f} GB")

    # Test 1: Baseline (no eagle)
    print("\n[3] Baseline (no eagle)...")
    check(engine, tokenizer, "baseline")

    # Test 2: D=2 eagle head (should work)
    print("\n[4] D=2 eagle head...")
    engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
                        num_head_layers=2, checkpoint_path=EAGLE_CKPT)
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    check(engine, tokenizer, "D=2")
    # Cleanup
    del engine.eagle_head
    engine._eagle_enabled = False
    engine._eagle_hidden_states = {}
    torch.cuda.empty_cache()
    gc.collect()

    # Test 3: D=8 eagle head (NO checkpoint, random init)
    print("\n[5] D=8 eagle head (random init, no checkpoint)...")
    engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
                        num_head_layers=8)  # no checkpoint_path
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    nan_d8_random = check(engine, tokenizer, "D=8 random")
    # Cleanup
    del engine.eagle_head
    engine._eagle_enabled = False
    engine._eagle_hidden_states = {}
    torch.cuda.empty_cache()
    gc.collect()

    # Test 4: D=8 eagle head WITH checkpoint
    print("\n[6] D=8 eagle head (with checkpoint)...")
    engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
                        num_head_layers=8, checkpoint_path=EAGLE_CKPT)
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    nan_d8_ckpt = check(engine, tokenizer, "D=8 with ckpt")
    # Cleanup
    del engine.eagle_head
    engine._eagle_enabled = False
    engine._eagle_hidden_states = {}
    torch.cuda.empty_cache()
    gc.collect()

    # Test 5: D=8 eagle head allocated but NOT registered as submodule
    print("\n[7] D=8 eagle head (allocated, NOT registered on engine)...")
    head_ext = FireEchoEagleHead(
        dim=config.dim, num_capture_layers=3,
        num_heads=16, ffn_mult=2, num_layers=8,
    ).to(dtype=torch.bfloat16, device='cuda')
    # Do NOT assign to engine β€” keep as local variable
    engine._eagle_enabled = True
    engine._eagle_capture_set = {8, 24, 47}
    engine._eagle_capture_layers = [8, 24, 47]
    engine._eagle_hidden_states = {}
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    nan_d8_unreg = check(engine, tokenizer, "D=8 unregistered")
    # Cleanup
    del head_ext
    engine._eagle_enabled = False
    torch.cuda.empty_cache()
    gc.collect()

    # Test 6: D=4 eagle head (between D=2 and D=8)
    print("\n[8] D=4 eagle head (checkpoint)...")
    engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
                        num_head_layers=4, checkpoint_path=EAGLE_CKPT)
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    nan_d4 = check(engine, tokenizer, "D=4")
    # Cleanup
    del engine.eagle_head
    engine._eagle_enabled = False
    engine._eagle_hidden_states = {}
    torch.cuda.empty_cache()
    gc.collect()

    # Test 7: D=8 but eagle_enabled=False (head exists but flag off)
    print("\n[9] D=8 eagle head, but _eagle_enabled=False...")
    engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
                        num_head_layers=8, checkpoint_path=EAGLE_CKPT)
    engine._eagle_enabled = False  # disable the flag
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})")
    nan_d8_flagoff = check(engine, tokenizer, "D=8 flag OFF")

    # Summary
    print(f"\n{'='*60}")
    print("  RESULTS")
    print(f"{'='*60}")
    print(f"  D=8 random:      {'NaN' if nan_d8_random else 'OK'}")
    print(f"  D=8 with ckpt:   {'NaN' if nan_d8_ckpt else 'OK'}")
    print(f"  D=8 unregistered: {'NaN' if nan_d8_unreg else 'OK'}")
    print(f"  D=4:             {'NaN' if nan_d4 else 'OK'}")
    print(f"  D=8 flag OFF:    {'NaN' if nan_d8_flagoff else 'OK'}")