| | |
| | """Test: does max_seq_len=512 vs 4096 cause NaN?""" |
| | import sys, os, torch |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| | PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| |
|
| | @torch.no_grad() |
| | def check(engine, tokenizer, label): |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| | has_nan = logits.isnan().any().item() |
| | if has_nan: |
| | print(f" [{label}] NaN DETECTED") |
| | else: |
| | top = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" [{label}] OK — top={top} ('{tokenizer.decode([top])}')") |
| | return has_nan |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" max_seq_len test") |
| | print("=" * 60) |
| |
|
| | |
| | print("\n[1] load_engine(max_seq_len=512)...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=512, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | engine.pack_all_experts() |
| |
|
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB") |
| |
|
| | |
| | print("\n[2] Warmup...") |
| | wids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for _ in range(3): |
| | engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| |
|
| | |
| | print("\n[3] Forward without eagle (max_seq_len=512)...") |
| | check(engine, tokenizer, "no eagle, seq=512") |
| |
|
| | |
| | print("\n[4] Enable D=8 eagle + checkpoint...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=8, checkpoint_path=EAGLE_CKPT) |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB") |
| | nan_512 = check(engine, tokenizer, "D=8, seq=512") |
| |
|
| | print(f"\n{'='*60}") |
| | print(f" max_seq_len=512 + D=8: {'NaN' if nan_512 else 'OK'}") |
| | print(f"{'='*60}") |
| |
|