#!/usr/bin/env python3 """Bisect: exactly which step of the training flow causes NaN. Replicates train_eagle_head.py main() step by step, checking forward() after each. FIXED: pack_all_experts + enable_flat_decode BEFORE first forward() call. """ import sys, os, torch, gc, time sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine from fireecho_kernel import FireEchoEagleHead MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" @torch.no_grad() def check(engine, tokenizer, label): ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() has_nan = logits.isnan().any().item() vram = torch.cuda.memory_allocated() / 1e9 if has_nan: print(f" [{label}] *** NaN DETECTED *** VRAM={vram:.2f}GB") else: top = logits[:, -1, :].argmax(dim=-1).item() print(f" [{label}] OK top={top} ('{tokenizer.decode([top])}') VRAM={vram:.2f}GB") return has_nan @torch.no_grad() def check_speculative(engine, tokenizer, label): """Test speculative_generate specifically.""" prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n" ids = tokenizer.encode(prompt, return_tensors="pt").cuda() engine.reset_cache() engine.eval() eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>") stop = [eos_id] if eos_id else [151645] out = engine.speculative_generate(ids, max_new_tokens=20, temperature=0.0, stop_tokens=stop) gen_tokens = out[0, ids.shape[1]:].tolist() text = tokenizer.decode(gen_tokens, skip_special_tokens=True) all_same = len(set(gen_tokens)) <= 1 if gen_tokens else True if all_same and len(gen_tokens) > 3: print(f" [{label}] *** ALL SAME TOKEN *** = NaN bug! tokens={gen_tokens[:5]}") return True else: print(f" [{label}] OK: '{text[:80]}' ({len(gen_tokens)} tokens, {len(set(gen_tokens))} unique)") return False if __name__ == "__main__": print("=" * 60) print(" Training Flow Bisection (v2 — fixed)") print("=" * 60) # === Step 1: load_engine (matches training exactly) === print("\n[Step 1] load_engine(max_seq_len=4096) + eval + flat_decode + pack...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.eval() engine.kv_cache.enable_flat_decode(4096) engine.pack_all_experts() nan1 = check(engine, tokenizer, "after load+pack+flat") if nan1: print(" FATAL: NaN at baseline! Cannot continue.") sys.exit(1) # === Step 2: enable_eagle D=8 (NO checkpoint, matches training) === print("\n[Step 2] enable_eagle(D=8, no checkpoint)...") engine.enable_eagle( capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, draft_depth=5, num_head_layers=8) nan2 = check(engine, tokenizer, "after eagle D=8 random") # === Step 3: create optimizer === print("\n[Step 3] create AdamW optimizer...") eagle = engine.eagle_head eagle_params = [p for n, p in eagle.named_parameters() if 'lm_head' not in n and p.requires_grad] optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.0) nan3 = check(engine, tokenizer, "after optimizer") # === Step 4: load_checkpoint (matches training: weights_only=False) === print("\n[Step 4] load_checkpoint...") if os.path.exists(EAGLE_CKPT): ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda') sd = ckpt.get('eagle_head', ckpt) is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) if is_legacy: eagle.load_legacy_checkpoint(sd) print(" Loaded legacy checkpoint") else: eagle.load_state_dict(sd, strict=False) print(" Loaded new-format checkpoint") if 'optimizer' in ckpt: try: optimizer.load_state_dict(ckpt['optimizer']) print(" Loaded optimizer state") except (ValueError, KeyError) as e: print(f" Optimizer mismatch: {e}") step = ckpt.get('step', 0) print(f" Step={step}") del ckpt torch.cuda.empty_cache() else: print(" No checkpoint found, using random weights") nan4 = check(engine, tokenizer, "after ckpt load") # === Step 5: warmup === print("\n[Step 5] warmup 3x generate()...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for i in range(3): out = engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) text = tokenizer.decode(out[0, wids.shape[1]:], skip_special_tokens=True) print(f" Warmup {i}: '{text}'") del wids nan5 = check(engine, tokenizer, "after warmup") # === Step 6: speculative_generate (the actual eval path) === print("\n[Step 6] speculative_generate()...") nan6 = check_speculative(engine, tokenizer, "speculative_generate") # === Summary === print("\n" + "=" * 60) print(" BISECTION RESULTS") print("=" * 60) results = [ ("Step 1: load+pack+flat", nan1), ("Step 2: enable_eagle D=8", nan2), ("Step 3: create optimizer", nan3), ("Step 4: load checkpoint", nan4), ("Step 5: warmup", nan5), ("Step 6: speculative_generate", nan6), ] for name, had_nan in results: status = "*** NaN ***" if had_nan else "OK" print(f" {name}: {status}") first_fail = next((name for name, nan in results if nan), None) if first_fail: print(f"\n FIRST FAILURE: {first_fail}") else: print(f"\n ALL PASSED — no NaN detected!") print("=" * 60)