#!/usr/bin/env python3 """Test: does max_seq_len=512 vs 4096 cause NaN?""" import sys, os, torch sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) from hebbian_finetune_demo import load_engine MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" @torch.no_grad() def check(engine, tokenizer, label): ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() engine.reset_cache() engine._current_seq_id = 0 if hasattr(engine.kv_cache, '_graph_mode'): engine.kv_cache._graph_mode = False logits = engine.forward(ids, use_cache=True, position=0) torch.cuda.synchronize() has_nan = logits.isnan().any().item() if has_nan: print(f" [{label}] NaN DETECTED") else: top = logits[:, -1, :].argmax(dim=-1).item() print(f" [{label}] OK — top={top} ('{tokenizer.decode([top])}')") return has_nan if __name__ == "__main__": print("=" * 60) print(" max_seq_len test") print("=" * 60) # Replicate EXACT training script flow: max_seq_len=512 print("\n[1] load_engine(max_seq_len=512)...") engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=512, device="cuda") engine.eval() engine.kv_cache.enable_flat_decode(4096) engine.pack_all_experts() vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB") # Warmup print("\n[2] Warmup...") wids = tokenizer.encode("Hello", return_tensors='pt').cuda() for _ in range(3): engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) # Test WITHOUT eagle (should work) print("\n[3] Forward without eagle (max_seq_len=512)...") check(engine, tokenizer, "no eagle, seq=512") # Test WITH D=8 eagle print("\n[4] Enable D=8 eagle + checkpoint...") engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, num_head_layers=8, checkpoint_path=EAGLE_CKPT) vram = torch.cuda.memory_allocated() / 1e9 print(f" VRAM: {vram:.2f} GB") nan_512 = check(engine, tokenizer, "D=8, seq=512") print(f"\n{'='*60}") print(f" max_seq_len=512 + D=8: {'NaN' if nan_512 else 'OK'}") print(f"{'='*60}")