| | |
| | """Bisect: exactly which step of the training flow causes NaN. |
| | |
| | Replicates train_eagle_head.py main() step by step, checking forward() after each. |
| | FIXED: pack_all_experts + enable_flat_decode BEFORE first forward() call. |
| | """ |
| | import sys, os, torch, gc, time |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| | from fireecho_kernel import FireEchoEagleHead |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| | PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| |
|
| | @torch.no_grad() |
| | def check(engine, tokenizer, label): |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| | has_nan = logits.isnan().any().item() |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | if has_nan: |
| | print(f" [{label}] *** NaN DETECTED *** VRAM={vram:.2f}GB") |
| | else: |
| | top = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" [{label}] OK top={top} ('{tokenizer.decode([top])}') VRAM={vram:.2f}GB") |
| | return has_nan |
| |
|
| |
|
| | @torch.no_grad() |
| | def check_speculative(engine, tokenizer, label): |
| | """Test speculative_generate specifically.""" |
| | prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n" |
| | ids = tokenizer.encode(prompt, return_tensors="pt").cuda() |
| | engine.reset_cache() |
| | engine.eval() |
| | eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>") |
| | stop = [eos_id] if eos_id else [151645] |
| | out = engine.speculative_generate(ids, max_new_tokens=20, temperature=0.0, stop_tokens=stop) |
| | gen_tokens = out[0, ids.shape[1]:].tolist() |
| | text = tokenizer.decode(gen_tokens, skip_special_tokens=True) |
| | all_same = len(set(gen_tokens)) <= 1 if gen_tokens else True |
| | if all_same and len(gen_tokens) > 3: |
| | print(f" [{label}] *** ALL SAME TOKEN *** = NaN bug! tokens={gen_tokens[:5]}") |
| | return True |
| | else: |
| | print(f" [{label}] OK: '{text[:80]}' ({len(gen_tokens)} tokens, {len(set(gen_tokens))} unique)") |
| | return False |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" Training Flow Bisection (v2 — fixed)") |
| | print("=" * 60) |
| |
|
| | |
| | print("\n[Step 1] load_engine(max_seq_len=4096) + eval + flat_decode + pack...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | engine.pack_all_experts() |
| | nan1 = check(engine, tokenizer, "after load+pack+flat") |
| | if nan1: |
| | print(" FATAL: NaN at baseline! Cannot continue.") |
| | sys.exit(1) |
| |
|
| | |
| | print("\n[Step 2] enable_eagle(D=8, no checkpoint)...") |
| | engine.enable_eagle( |
| | capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | draft_depth=5, num_head_layers=8) |
| | nan2 = check(engine, tokenizer, "after eagle D=8 random") |
| |
|
| | |
| | print("\n[Step 3] create AdamW optimizer...") |
| | eagle = engine.eagle_head |
| | eagle_params = [p for n, p in eagle.named_parameters() |
| | if 'lm_head' not in n and p.requires_grad] |
| | optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.0) |
| | nan3 = check(engine, tokenizer, "after optimizer") |
| |
|
| | |
| | print("\n[Step 4] load_checkpoint...") |
| | if os.path.exists(EAGLE_CKPT): |
| | ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda') |
| | sd = ckpt.get('eagle_head', ckpt) |
| | is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd) |
| | if is_legacy: |
| | eagle.load_legacy_checkpoint(sd) |
| | print(" Loaded legacy checkpoint") |
| | else: |
| | eagle.load_state_dict(sd, strict=False) |
| | print(" Loaded new-format checkpoint") |
| | if 'optimizer' in ckpt: |
| | try: |
| | optimizer.load_state_dict(ckpt['optimizer']) |
| | print(" Loaded optimizer state") |
| | except (ValueError, KeyError) as e: |
| | print(f" Optimizer mismatch: {e}") |
| | step = ckpt.get('step', 0) |
| | print(f" Step={step}") |
| | del ckpt |
| | torch.cuda.empty_cache() |
| | else: |
| | print(" No checkpoint found, using random weights") |
| | nan4 = check(engine, tokenizer, "after ckpt load") |
| |
|
| | |
| | print("\n[Step 5] warmup 3x generate()...") |
| | wids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for i in range(3): |
| | out = engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| | text = tokenizer.decode(out[0, wids.shape[1]:], skip_special_tokens=True) |
| | print(f" Warmup {i}: '{text}'") |
| | del wids |
| | nan5 = check(engine, tokenizer, "after warmup") |
| |
|
| | |
| | print("\n[Step 6] speculative_generate()...") |
| | nan6 = check_speculative(engine, tokenizer, "speculative_generate") |
| |
|
| | |
| | print("\n" + "=" * 60) |
| | print(" BISECTION RESULTS") |
| | print("=" * 60) |
| | results = [ |
| | ("Step 1: load+pack+flat", nan1), |
| | ("Step 2: enable_eagle D=8", nan2), |
| | ("Step 3: create optimizer", nan3), |
| | ("Step 4: load checkpoint", nan4), |
| | ("Step 5: warmup", nan5), |
| | ("Step 6: speculative_generate", nan6), |
| | ] |
| | for name, had_nan in results: |
| | status = "*** NaN ***" if had_nan else "OK" |
| | print(f" {name}: {status}") |
| |
|
| | first_fail = next((name for name, nan in results if nan), None) |
| | if first_fail: |
| | print(f"\n FIRST FAILURE: {first_fail}") |
| | else: |
| | print(f"\n ALL PASSED — no NaN detected!") |
| | print("=" * 60) |
| |
|