File size: 6,957 Bytes
b5bff9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 | #!/usr/bin/env python3
"""Isolate exactly which step of enable_eagle() causes NaN in target model.
Tests each sub-step of enable_eagle() independently to find the culprit.
Also checks per-layer output to find where NaN first appears.
"""
import sys, os, torch, gc
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check_forward(engine, tokenizer, label):
"""Run a forward pass and report NaN status."""
torch.cuda.synchronize()
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
last = logits[:, -1, :]
if has_nan:
print(f" [{label}] NaN DETECTED — logits all NaN")
else:
top_id = last.argmax(dim=-1).item()
top_val = last.max().item()
print(f" [{label}] OK — top token={top_id} "
f"('{tokenizer.decode([top_id])}'), max={top_val:.2f}")
return has_nan
@torch.no_grad()
def check_per_layer(engine, tokenizer, label):
"""Run forward pass manually layer-by-layer, check NaN at each layer."""
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
x = engine.embed(ids)
has_nan = x.isnan().any().item()
print(f" [{label}] After embed: has_nan={has_nan}")
if has_nan:
return
first_nan_layer = None
for i, layer in enumerate(engine.layers):
x = layer(x, engine.kv_cache, engine._current_seq_id, 0, True)
has_nan = x.isnan().any().item()
if has_nan and first_nan_layer is None:
first_nan_layer = i
print(f" [{label}] FIRST NaN at layer {i} !!!")
# Check sub-components
break
if first_nan_layer is None:
# Check norm + lm_head
x = engine.norm(x)
has_nan = x.isnan().any().item()
print(f" [{label}] After norm: has_nan={has_nan}")
logits = engine.lm_head(x)
has_nan = logits.isnan().any().item()
print(f" [{label}] After lm_head: has_nan={has_nan}")
if not has_nan:
top_id = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] Top token: {top_id} ('{tokenizer.decode([top_id])}')")
else:
print(f" [{label}] NaN starts at layer {first_nan_layer}")
if __name__ == "__main__":
print("=" * 60)
print(" NaN Isolation Test")
print("=" * 60)
print("\n[1/6] Loading model...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.pack_all_experts()
engine.kv_cache.enable_flat_decode()
engine.eval()
# Check VRAM
vram = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after load: {vram:.2f} GB")
print("\n[2/6] Warmup...")
warmup_ids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for _ in range(3):
engine.generate(warmup_ids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
print("\n[3/6] Test BEFORE enable_eagle()...")
nan_before = check_forward(engine, tokenizer, "before eagle")
if nan_before:
print("\n ERROR: NaN even before enable_eagle! Something wrong with model load.")
sys.exit(1)
print("\n[4/6] Test: just set _eagle_enabled=True (no head creation)...")
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
nan_flag_only = check_forward(engine, tokenizer, "flag only")
engine._eagle_enabled = False # reset
print("\n[5/6] Test: create eagle head + assign as submodule...")
eagle_head = FireEchoEagleHead(
dim=config.dim, num_capture_layers=3,
num_heads=16, ffn_mult=2, num_layers=2,
).to(dtype=torch.bfloat16, device='cuda')
eagle_head.lm_head = engine.lm_head
engine.eagle_head = eagle_head # registers as nn.Module submodule
vram2 = torch.cuda.memory_allocated() / 1e9
print(f" VRAM after eagle head: {vram2:.2f} GB (+{vram2 - vram:.2f} GB)")
nan_with_head = check_forward(engine, tokenizer, "with head (no ckpt)")
print("\n[6/6] Test: load checkpoint into eagle head...")
if os.path.exists(EAGLE_CKPT):
ckpt = torch.load(EAGLE_CKPT, map_location='cuda', weights_only=True)
sd = ckpt.get('eagle_head', ckpt)
is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
if is_legacy:
eagle_head.load_legacy_checkpoint(sd)
else:
eagle_head.load_state_dict(sd, strict=False)
nan_with_ckpt = check_forward(engine, tokenizer, "with ckpt")
else:
print(f" No checkpoint at {EAGLE_CKPT}, skipping")
nan_with_ckpt = nan_with_head
# Summary
print(f"\n{'=' * 60}")
print(" RESULTS")
print(f"{'=' * 60}")
print(f" Before eagle: {'NaN' if nan_before else 'OK'}")
print(f" Flag only: {'NaN' if nan_flag_only else 'OK'}")
print(f" With head (no ckpt): {'NaN' if nan_with_head else 'OK'}")
print(f" With checkpoint: {'NaN' if nan_with_ckpt else 'OK'}")
# If any NaN found, do per-layer analysis
if nan_flag_only or nan_with_head or nan_with_ckpt:
print(f"\n--- Per-layer NaN analysis ---")
if nan_flag_only:
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
check_per_layer(engine, tokenizer, "flag-only per-layer")
elif nan_with_head or nan_with_ckpt:
# eagle_head is still assigned
engine._eagle_enabled = True
engine._eagle_capture_set = {8, 24, 47}
engine._eagle_capture_layers = [8, 24, 47]
engine._eagle_hidden_states = {}
check_per_layer(engine, tokenizer, "full-eagle per-layer")
# Also test: head assigned but flag OFF
print(f"\n--- Test: head assigned but _eagle_enabled=False ---")
engine._eagle_enabled = False
check_forward(engine, tokenizer, "head assigned, flag OFF")
else:
print(" All tests passed — no NaN detected!")
|