FireEcho / FireEcho Engine /debug_bisect.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Bisect: exactly which step of the training flow causes NaN.
Replicates train_eagle_head.py main() step by step, checking forward() after each.
FIXED: pack_all_experts + enable_flat_decode BEFORE first forward() call.
"""
import sys, os, torch, gc, time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"
@torch.no_grad()
def check(engine, tokenizer, label):
ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
vram = torch.cuda.memory_allocated() / 1e9
if has_nan:
print(f" [{label}] *** NaN DETECTED *** VRAM={vram:.2f}GB")
else:
top = logits[:, -1, :].argmax(dim=-1).item()
print(f" [{label}] OK top={top} ('{tokenizer.decode([top])}') VRAM={vram:.2f}GB")
return has_nan
@torch.no_grad()
def check_speculative(engine, tokenizer, label):
"""Test speculative_generate specifically."""
prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n"
ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
engine.reset_cache()
engine.eval()
eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
stop = [eos_id] if eos_id else [151645]
out = engine.speculative_generate(ids, max_new_tokens=20, temperature=0.0, stop_tokens=stop)
gen_tokens = out[0, ids.shape[1]:].tolist()
text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
all_same = len(set(gen_tokens)) <= 1 if gen_tokens else True
if all_same and len(gen_tokens) > 3:
print(f" [{label}] *** ALL SAME TOKEN *** = NaN bug! tokens={gen_tokens[:5]}")
return True
else:
print(f" [{label}] OK: '{text[:80]}' ({len(gen_tokens)} tokens, {len(set(gen_tokens))} unique)")
return False
if __name__ == "__main__":
print("=" * 60)
print(" Training Flow Bisection (v2 — fixed)")
print("=" * 60)
# === Step 1: load_engine (matches training exactly) ===
print("\n[Step 1] load_engine(max_seq_len=4096) + eval + flat_decode + pack...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
nan1 = check(engine, tokenizer, "after load+pack+flat")
if nan1:
print(" FATAL: NaN at baseline! Cannot continue.")
sys.exit(1)
# === Step 2: enable_eagle D=8 (NO checkpoint, matches training) ===
print("\n[Step 2] enable_eagle(D=8, no checkpoint)...")
engine.enable_eagle(
capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
draft_depth=5, num_head_layers=8)
nan2 = check(engine, tokenizer, "after eagle D=8 random")
# === Step 3: create optimizer ===
print("\n[Step 3] create AdamW optimizer...")
eagle = engine.eagle_head
eagle_params = [p for n, p in eagle.named_parameters()
if 'lm_head' not in n and p.requires_grad]
optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.0)
nan3 = check(engine, tokenizer, "after optimizer")
# === Step 4: load_checkpoint (matches training: weights_only=False) ===
print("\n[Step 4] load_checkpoint...")
if os.path.exists(EAGLE_CKPT):
ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda')
sd = ckpt.get('eagle_head', ckpt)
is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
if is_legacy:
eagle.load_legacy_checkpoint(sd)
print(" Loaded legacy checkpoint")
else:
eagle.load_state_dict(sd, strict=False)
print(" Loaded new-format checkpoint")
if 'optimizer' in ckpt:
try:
optimizer.load_state_dict(ckpt['optimizer'])
print(" Loaded optimizer state")
except (ValueError, KeyError) as e:
print(f" Optimizer mismatch: {e}")
step = ckpt.get('step', 0)
print(f" Step={step}")
del ckpt
torch.cuda.empty_cache()
else:
print(" No checkpoint found, using random weights")
nan4 = check(engine, tokenizer, "after ckpt load")
# === Step 5: warmup ===
print("\n[Step 5] warmup 3x generate()...")
wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
for i in range(3):
out = engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
text = tokenizer.decode(out[0, wids.shape[1]:], skip_special_tokens=True)
print(f" Warmup {i}: '{text}'")
del wids
nan5 = check(engine, tokenizer, "after warmup")
# === Step 6: speculative_generate (the actual eval path) ===
print("\n[Step 6] speculative_generate()...")
nan6 = check_speculative(engine, tokenizer, "speculative_generate")
# === Summary ===
print("\n" + "=" * 60)
print(" BISECTION RESULTS")
print("=" * 60)
results = [
("Step 1: load+pack+flat", nan1),
("Step 2: enable_eagle D=8", nan2),
("Step 3: create optimizer", nan3),
("Step 4: load checkpoint", nan4),
("Step 5: warmup", nan5),
("Step 6: speculative_generate", nan6),
]
for name, had_nan in results:
status = "*** NaN ***" if had_nan else "OK"
print(f" {name}: {status}")
first_fail = next((name for name, nan in results if nan), None)
if first_fail:
print(f"\n FIRST FAILURE: {first_fail}")
else:
print(f"\n ALL PASSED — no NaN detected!")
print("=" * 60)