#!/usr/bin/env python3 """Test: does prompt length cause NaN? Test with/without eagle.""" import sys, os, torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") SHORT = "Hello" MEDIUM = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" LONG = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n" @torch.no_grad() def test_forward(engine, tokenizer, label, prompt): ids = tokenizer.encode(prompt, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() has_nan = logits.isnan().any().item() if has_nan: # Count NaN positions nan_count = sum(1 for s in range(logits.shape[1]) if logits[0, s].isnan().any()) print(f" [{label}] NaN! ({nan_count}/{logits.shape[1]} positions) len={ids.shape[1]}") else: top = logits[:, -1, :].argmax(dim=-1).item() print(f" [{label}] OK top={top} ('{tokenizer.decode([top])}') len={ids.shape[1]}") return has_nan if __name__ == "__main__": print("=" * 60) print(" Prompt Length NaN Test") print("=" * 60) print("\n[SETUP] Loading engine...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.eval() engine.kv_cache.enable_flat_decode(4096) engine.pack_all_experts() # Test WITHOUT eagle print("\n[Phase 1] No eagle — varying prompt lengths...") test_forward(engine, tokenizer, "short (no eagle)", SHORT) test_forward(engine, tokenizer, "medium (no eagle)", MEDIUM) test_forward(engine, tokenizer, "long (no eagle)", LONG) # Warmup print("\n[Warmup]...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) del wids # Test again after warmup print("\n[Phase 2] No eagle, after warmup...") test_forward(engine, tokenizer, "short (warmed)", SHORT) test_forward(engine, tokenizer, "medium (warmed)", MEDIUM) test_forward(engine, tokenizer, "long (warmed)", LONG) # Enable eagle WITH checkpoint print("\n[Phase 3] Enable eagle D=8 with checkpoint...") engine.enable_eagle( capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, draft_depth=5, num_head_layers=8, checkpoint_path=EAGLE_CKPT) test_forward(engine, tokenizer, "short (eagle+ckpt)", SHORT) test_forward(engine, tokenizer, "medium (eagle+ckpt)", MEDIUM) test_forward(engine, tokenizer, "long (eagle+ckpt)", LONG) # Warmup again after eagle print("\n[Warmup after eagle]...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) del wids print("\n[Phase 4] Eagle + ckpt, after warmup...") test_forward(engine, tokenizer, "short (eagle warmed)", SHORT) test_forward(engine, tokenizer, "medium (eagle warmed)", MEDIUM) test_forward(engine, tokenizer, "long (eagle warmed)", LONG) # Test: enable_eagle WITHOUT checkpoint print("\n[Phase 5] Fresh engine, eagle D=8 NO checkpoint...") del engine torch.cuda.empty_cache() engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") engine.eval() engine.kv_cache.enable_flat_decode(4096) engine.pack_all_experts() engine.enable_eagle( capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, draft_depth=5, num_head_layers=8) # NO checkpoint # Warmup wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) del wids test_forward(engine, tokenizer, "short (no ckpt)", SHORT) test_forward(engine, tokenizer, "medium (no ckpt)", MEDIUM) test_forward(engine, tokenizer, "long (no ckpt)", LONG) print("\n" + "=" * 60) print(" DONE") print("=" * 60)