File size: 6,957 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
"""Isolate exactly which step of enable_eagle() causes NaN in target model.

Tests each sub-step of enable_eagle() independently to find the culprit.
Also checks per-layer output to find where NaN first appears.
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")

PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"


@torch.no_grad()
def check_forward(engine, tokenizer, label):
    """Run a forward pass and report NaN status."""
    torch.cuda.synchronize()
    ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False

    logits = engine.forward(ids, use_cache=True, position=0)
    torch.cuda.synchronize()

    has_nan = logits.isnan().any().item()
    last = logits[:, -1, :]
    if has_nan:
        print(f"  [{label}] NaN DETECTED — logits all NaN")
    else:
        top_id = last.argmax(dim=-1).item()
        top_val = last.max().item()
        print(f"  [{label}] OK — top token={top_id} "
              f"('{tokenizer.decode([top_id])}'), max={top_val:.2f}")
    return has_nan


@torch.no_grad()
def check_per_layer(engine, tokenizer, label):
    """Run forward pass manually layer-by-layer, check NaN at each layer."""
    ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False

    x = engine.embed(ids)
    has_nan = x.isnan().any().item()
    print(f"  [{label}] After embed: has_nan={has_nan}")
    if has_nan:
        return

    first_nan_layer = None
    for i, layer in enumerate(engine.layers):
        x = layer(x, engine.kv_cache, engine._current_seq_id, 0, True)
        has_nan = x.isnan().any().item()
        if has_nan and first_nan_layer is None:
            first_nan_layer = i
            print(f"  [{label}] FIRST NaN at layer {i} !!!")
            # Check sub-components
            break

    if first_nan_layer is None:
        # Check norm + lm_head
        x = engine.norm(x)
        has_nan = x.isnan().any().item()
        print(f"  [{label}] After norm: has_nan={has_nan}")
        logits = engine.lm_head(x)
        has_nan = logits.isnan().any().item()
        print(f"  [{label}] After lm_head: has_nan={has_nan}")
        if not has_nan:
            top_id = logits[:, -1, :].argmax(dim=-1).item()
            print(f"  [{label}] Top token: {top_id} ('{tokenizer.decode([top_id])}')")
    else:
        print(f"  [{label}] NaN starts at layer {first_nan_layer}")


if __name__ == "__main__":
    print("=" * 60)
    print("  NaN Isolation Test")
    print("=" * 60)

    print("\n[1/6] Loading model...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.pack_all_experts()
    engine.kv_cache.enable_flat_decode()
    engine.eval()

    # Check VRAM
    vram = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM after load: {vram:.2f} GB")

    print("\n[2/6] Warmup...")
    warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
    for _ in range(3):
        engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)

    print("\n[3/6] Test BEFORE enable_eagle()...")
    nan_before = check_forward(engine, tokenizer, "before eagle")

    if nan_before:
        print("\n  ERROR: NaN even before enable_eagle! Something wrong with model load.")
        sys.exit(1)

    print("\n[4/6] Test: just set _eagle_enabled=True (no head creation)...")
    engine._eagle_enabled = True
    engine._eagle_capture_set = {8, 24, 47}
    engine._eagle_capture_layers = [8, 24, 47]
    engine._eagle_hidden_states = {}
    nan_flag_only = check_forward(engine, tokenizer, "flag only")
    engine._eagle_enabled = False  # reset

    print("\n[5/6] Test: create eagle head + assign as submodule...")
    eagle_head = FireEchoEagleHead(
        dim=config.dim, num_capture_layers=3,
        num_heads=16, ffn_mult=2, num_layers=2,
    ).to(dtype=torch.bfloat16, device='cuda')
    eagle_head.lm_head = engine.lm_head
    engine.eagle_head = eagle_head  # registers as nn.Module submodule
    vram2 = torch.cuda.memory_allocated() / 1e9
    print(f"  VRAM after eagle head: {vram2:.2f} GB (+{vram2 - vram:.2f} GB)")
    nan_with_head = check_forward(engine, tokenizer, "with head (no ckpt)")

    print("\n[6/6] Test: load checkpoint into eagle head...")
    if os.path.exists(EAGLE_CKPT):
        ckpt = torch.load(EAGLE_CKPT, map_location='cuda', weights_only=True)
        sd = ckpt.get('eagle_head', ckpt)
        is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
        if is_legacy:
            eagle_head.load_legacy_checkpoint(sd)
        else:
            eagle_head.load_state_dict(sd, strict=False)
        nan_with_ckpt = check_forward(engine, tokenizer, "with ckpt")
    else:
        print(f"  No checkpoint at {EAGLE_CKPT}, skipping")
        nan_with_ckpt = nan_with_head

    # Summary
    print(f"\n{'=' * 60}")
    print("  RESULTS")
    print(f"{'=' * 60}")
    print(f"  Before eagle:      {'NaN' if nan_before else 'OK'}")
    print(f"  Flag only:         {'NaN' if nan_flag_only else 'OK'}")
    print(f"  With head (no ckpt): {'NaN' if nan_with_head else 'OK'}")
    print(f"  With checkpoint:   {'NaN' if nan_with_ckpt else 'OK'}")

    # If any NaN found, do per-layer analysis
    if nan_flag_only or nan_with_head or nan_with_ckpt:
        print(f"\n--- Per-layer NaN analysis ---")
        if nan_flag_only:
            engine._eagle_enabled = True
            engine._eagle_capture_set = {8, 24, 47}
            engine._eagle_capture_layers = [8, 24, 47]
            engine._eagle_hidden_states = {}
            check_per_layer(engine, tokenizer, "flag-only per-layer")
        elif nan_with_head or nan_with_ckpt:
            # eagle_head is still assigned
            engine._eagle_enabled = True
            engine._eagle_capture_set = {8, 24, 47}
            engine._eagle_capture_layers = [8, 24, 47]
            engine._eagle_hidden_states = {}
            check_per_layer(engine, tokenizer, "full-eagle per-layer")

            # Also test: head assigned but flag OFF
            print(f"\n--- Test: head assigned but _eagle_enabled=False ---")
            engine._eagle_enabled = False
            check_forward(engine, tokenizer, "head assigned, flag OFF")
    else:
        print("  All tests passed — no NaN detected!")