| | |
| | """Isolate exactly what about D=8 causes NaN. |
| | |
| | Tests: |
| | 1. D=2 eagle head → forward → should be OK |
| | 2. D=8 eagle head (random, no ckpt) → forward → is NaN from VRAM pressure? |
| | 3. D=8 eagle head (random, NOT assigned to engine) → forward → is NaN from registration? |
| | 4. D=8 allocated but eagle_enabled=False → forward → is NaN from .to() side effect? |
| | """ |
| | import sys, os, torch, gc |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| | from fireecho_kernel import FireEchoEagleHead |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| | EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt") |
| |
|
| | PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n" |
| |
|
| |
|
| | @torch.no_grad() |
| | def check(engine, tokenizer, label): |
| | ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda() |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| | has_nan = logits.isnan().any().item() |
| | if has_nan: |
| | print(f" [{label}] NaN DETECTED") |
| | else: |
| | top = logits[:, -1, :].argmax(dim=-1).item() |
| | print(f" [{label}] OK — top={top} ('{tokenizer.decode([top])}')") |
| | return has_nan |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" D=8 NaN Isolation") |
| | print("=" * 60) |
| |
|
| | print("\n[1] Loading model...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.pack_all_experts() |
| | engine.kv_cache.enable_flat_decode() |
| | engine.eval() |
| |
|
| | |
| | print("\n[2] Warmup...") |
| | wids = tokenizer.encode("Hello", return_tensors='pt').cuda() |
| | for _ in range(3): |
| | engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0) |
| |
|
| | vram_base = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM baseline: {vram_base:.2f} GB") |
| |
|
| | |
| | print("\n[3] Baseline (no eagle)...") |
| | check(engine, tokenizer, "baseline") |
| |
|
| | |
| | print("\n[4] D=2 eagle head...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=2, checkpoint_path=EAGLE_CKPT) |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | check(engine, tokenizer, "D=2") |
| | |
| | del engine.eagle_head |
| | engine._eagle_enabled = False |
| | engine._eagle_hidden_states = {} |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | print("\n[5] D=8 eagle head (random init, no checkpoint)...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=8) |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | nan_d8_random = check(engine, tokenizer, "D=8 random") |
| | |
| | del engine.eagle_head |
| | engine._eagle_enabled = False |
| | engine._eagle_hidden_states = {} |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | print("\n[6] D=8 eagle head (with checkpoint)...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=8, checkpoint_path=EAGLE_CKPT) |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | nan_d8_ckpt = check(engine, tokenizer, "D=8 with ckpt") |
| | |
| | del engine.eagle_head |
| | engine._eagle_enabled = False |
| | engine._eagle_hidden_states = {} |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | print("\n[7] D=8 eagle head (allocated, NOT registered on engine)...") |
| | head_ext = FireEchoEagleHead( |
| | dim=config.dim, num_capture_layers=3, |
| | num_heads=16, ffn_mult=2, num_layers=8, |
| | ).to(dtype=torch.bfloat16, device='cuda') |
| | |
| | engine._eagle_enabled = True |
| | engine._eagle_capture_set = {8, 24, 47} |
| | engine._eagle_capture_layers = [8, 24, 47] |
| | engine._eagle_hidden_states = {} |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | nan_d8_unreg = check(engine, tokenizer, "D=8 unregistered") |
| | |
| | del head_ext |
| | engine._eagle_enabled = False |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | print("\n[8] D=4 eagle head (checkpoint)...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=4, checkpoint_path=EAGLE_CKPT) |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | nan_d4 = check(engine, tokenizer, "D=4") |
| | |
| | del engine.eagle_head |
| | engine._eagle_enabled = False |
| | engine._eagle_hidden_states = {} |
| | torch.cuda.empty_cache() |
| | gc.collect() |
| |
|
| | |
| | print("\n[9] D=8 eagle head, but _eagle_enabled=False...") |
| | engine.enable_eagle(capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2, |
| | num_head_layers=8, checkpoint_path=EAGLE_CKPT) |
| | engine._eagle_enabled = False |
| | vram = torch.cuda.memory_allocated() / 1e9 |
| | print(f" VRAM: {vram:.2f} GB (+{vram - vram_base:.2f})") |
| | nan_d8_flagoff = check(engine, tokenizer, "D=8 flag OFF") |
| |
|
| | |
| | print(f"\n{'='*60}") |
| | print(" RESULTS") |
| | print(f"{'='*60}") |
| | print(f" D=8 random: {'NaN' if nan_d8_random else 'OK'}") |
| | print(f" D=8 with ckpt: {'NaN' if nan_d8_ckpt else 'OK'}") |
| | print(f" D=8 unregistered: {'NaN' if nan_d8_unreg else 'OK'}") |
| | print(f" D=4: {'NaN' if nan_d4 else 'OK'}") |
| | print(f" D=8 flag OFF: {'NaN' if nan_d8_flagoff else 'OK'}") |
| |
|