File size: 6,316 Bytes
b5bff9c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | #!/usr/bin/env python3
"""Bisect: exactly which step of the training flow causes NaN.
Replicates train_eagle_head.py main() step by step, checking forward() after each.
FIXED: pack_all_experts + enable_flat_decode BEFORE first forward() call.
"""
import sys, os, torch, gc, time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check(engine, tokenizer, label):
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
vram = torch.cuda.memory_allocated() / 1e9
if has_nan:
print(f" [{label}] *** NaN DETECTED *** VRAM={vram:.2f}GB")
else:
top = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] OK top={top} ('{tokenizer.decode([top])}') VRAM={vram:.2f}GB")
return has_nan
@torch.no_grad()
def check_speculative(engine, tokenizer, label):
"""Test speculative_generate specifically."""
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n"
ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
engine.reset_cache()
engine.eval()
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
stop = [eos_id] if eos_id else [151645]
out = engine.speculative_generate(ids, max_new_tokens=20, temperature=0.0, stop_tokens=stop)
gen_tokens = out[0, ids.shape[1]:].tolist()
text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
all_same = len(set(gen_tokens)) <= 1 if gen_tokens else True
if all_same and len(gen_tokens) > 3:
print(f" [{label}] *** ALL SAME TOKEN *** = NaN bug! tokens={gen_tokens[:5]}")
return True
else:
print(f" [{label}] OK: '{text[:80]}' ({len(gen_tokens)} tokens, {len(set(gen_tokens))} unique)")
return False
if __name__ == "__main__":
print("=" * 60)
print(" Training Flow Bisection (v2 — fixed)")
print("=" * 60)
# === Step 1: load_engine (matches training exactly) ===
print("\n[Step 1] load_engine(max_seq_len=4096) + eval + flat_decode + pack...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
nan1 = check(engine, tokenizer, "after load+pack+flat")
if nan1:
print(" FATAL: NaN at baseline! Cannot continue.")
sys.exit(1)
# === Step 2: enable_eagle D=8 (NO checkpoint, matches training) ===
print("\n[Step 2] enable_eagle(D=8, no checkpoint)...")
engine.enable_eagle(
capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
draft_depth=5, num_head_layers=8)
nan2 = check(engine, tokenizer, "after eagle D=8 random")
# === Step 3: create optimizer ===
print("\n[Step 3] create AdamW optimizer...")
eagle = engine.eagle_head
eagle_params = [p for n, p in eagle.named_parameters()
if 'lm_head' not in n and p.requires_grad]
optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.0)
nan3 = check(engine, tokenizer, "after optimizer")
# === Step 4: load_checkpoint (matches training: weights_only=False) ===
print("\n[Step 4] load_checkpoint...")
if os.path.exists(EAGLE_CKPT):
ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda')
sd = ckpt.get('eagle_head', ckpt)
is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
if is_legacy:
eagle.load_legacy_checkpoint(sd)
print(" Loaded legacy checkpoint")
else:
eagle.load_state_dict(sd, strict=False)
print(" Loaded new-format checkpoint")
if 'optimizer' in ckpt:
try:
optimizer.load_state_dict(ckpt['optimizer'])
print(" Loaded optimizer state")
except (ValueError, KeyError) as e:
print(f" Optimizer mismatch: {e}")
step = ckpt.get('step', 0)
print(f" Step={step}")
del ckpt
torch.cuda.empty_cache()
else:
print(" No checkpoint found, using random weights")
nan4 = check(engine, tokenizer, "after ckpt load")
# === Step 5: warmup ===
print("\n[Step 5] warmup 3x generate()...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for i in range(3):
out = engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
text = tokenizer.decode(out[0, wids.shape[1]:], skip_special_tokens=True)
print(f" Warmup {i}: '{text}'")
del wids
nan5 = check(engine, tokenizer, "after warmup")
# === Step 6: speculative_generate (the actual eval path) ===
print("\n[Step 6] speculative_generate()...")
nan6 = check_speculative(engine, tokenizer, "speculative_generate")
# === Summary ===
print("\n" + "=" * 60)
print(" BISECTION RESULTS")
print("=" * 60)
results = [
("Step 1: load+pack+flat", nan1),
("Step 2: enable_eagle D=8", nan2),
("Step 3: create optimizer", nan3),
("Step 4: load checkpoint", nan4),
("Step 5: warmup", nan5),
("Step 6: speculative_generate", nan6),
]
for name, had_nan in results:
status = "*** NaN ***" if had_nan else "OK"
print(f" {name}: {status}")
first_fail = next((name for name, nan in results if nan), None)
if first_fail:
print(f"\n FIRST FAILURE: {first_fail}")
else:
print(f"\n ALL PASSED — no NaN detected!")
print("=" * 60)
|