File size: 6,316 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/env python3
"""Bisect: exactly which step of the training flow causes NaN.

Replicates train_eagle_head.py main() step by step, checking forward() after each.
FIXED: pack_all_experts + enable_flat_decode BEFORE first forward() call.
"""
import sys, os, torch, gc, time
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
from fireecho_kernel import FireEchoEagleHead

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
EAGLE_CKPT = os.path.join(os.path.dirname(__file__), "eagle_checkpoints", "eagle_best.pt")
PROMPT = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n"


@torch.no_grad()
def check(engine, tokenizer, label):
    ids = tokenizer.encode(PROMPT, return_tensors='pt').cuda()
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False
    logits = engine.forward(ids, use_cache=True, position=0)
    torch.cuda.synchronize()
    has_nan = logits.isnan().any().item()
    vram = torch.cuda.memory_allocated() / 1e9
    if has_nan:
        print(f"  [{label}] *** NaN DETECTED *** VRAM={vram:.2f}GB")
    else:
        top = logits[:, -1, :].argmax(dim=-1).item()
        print(f"  [{label}] OK top={top} ('{tokenizer.decode([top])}') VRAM={vram:.2f}GB")
    return has_nan


@torch.no_grad()
def check_speculative(engine, tokenizer, label):
    """Test speculative_generate specifically."""
    prompt = "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n<|im_start|>user\nWrite a Python function to check if a number is prime.<|im_end|>\n<|im_start|>assistant\n"
    ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
    engine.reset_cache()
    engine.eval()
    eos_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
    stop = [eos_id] if eos_id else [151645]
    out = engine.speculative_generate(ids, max_new_tokens=20, temperature=0.0, stop_tokens=stop)
    gen_tokens = out[0, ids.shape[1]:].tolist()
    text = tokenizer.decode(gen_tokens, skip_special_tokens=True)
    all_same = len(set(gen_tokens)) <= 1 if gen_tokens else True
    if all_same and len(gen_tokens) > 3:
        print(f"  [{label}] *** ALL SAME TOKEN *** = NaN bug! tokens={gen_tokens[:5]}")
        return True
    else:
        print(f"  [{label}] OK: '{text[:80]}' ({len(gen_tokens)} tokens, {len(set(gen_tokens))} unique)")
        return False


if __name__ == "__main__":
    print("=" * 60)
    print("  Training Flow Bisection (v2 — fixed)")
    print("=" * 60)

    # === Step 1: load_engine (matches training exactly) ===
    print("\n[Step 1] load_engine(max_seq_len=4096) + eval + flat_decode + pack...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.eval()
    engine.kv_cache.enable_flat_decode(4096)
    engine.pack_all_experts()
    nan1 = check(engine, tokenizer, "after load+pack+flat")
    if nan1:
        print("  FATAL: NaN at baseline! Cannot continue.")
        sys.exit(1)

    # === Step 2: enable_eagle D=8 (NO checkpoint, matches training) ===
    print("\n[Step 2] enable_eagle(D=8, no checkpoint)...")
    engine.enable_eagle(
        capture_layers=(8, 24, 47), num_heads=16, ffn_mult=2,
        draft_depth=5, num_head_layers=8)
    nan2 = check(engine, tokenizer, "after eagle D=8 random")

    # === Step 3: create optimizer ===
    print("\n[Step 3] create AdamW optimizer...")
    eagle = engine.eagle_head
    eagle_params = [p for n, p in eagle.named_parameters()
                    if 'lm_head' not in n and p.requires_grad]
    optimizer = torch.optim.AdamW(eagle_params, lr=3e-4, betas=(0.9, 0.95), weight_decay=0.0)
    nan3 = check(engine, tokenizer, "after optimizer")

    # === Step 4: load_checkpoint (matches training: weights_only=False) ===
    print("\n[Step 4] load_checkpoint...")
    if os.path.exists(EAGLE_CKPT):
        ckpt = torch.load(EAGLE_CKPT, weights_only=False, map_location='cuda')
        sd = ckpt.get('eagle_head', ckpt)
        is_legacy = any(k.startswith('norm1.') or k.startswith('q_proj.') for k in sd)
        if is_legacy:
            eagle.load_legacy_checkpoint(sd)
            print("  Loaded legacy checkpoint")
        else:
            eagle.load_state_dict(sd, strict=False)
            print("  Loaded new-format checkpoint")
        if 'optimizer' in ckpt:
            try:
                optimizer.load_state_dict(ckpt['optimizer'])
                print("  Loaded optimizer state")
            except (ValueError, KeyError) as e:
                print(f"  Optimizer mismatch: {e}")
        step = ckpt.get('step', 0)
        print(f"  Step={step}")
        del ckpt
        torch.cuda.empty_cache()
    else:
        print("  No checkpoint found, using random weights")
    nan4 = check(engine, tokenizer, "after ckpt load")

    # === Step 5: warmup ===
    print("\n[Step 5] warmup 3x generate()...")
    wids = tokenizer.encode("Hello", return_tensors='pt').cuda()
    for i in range(3):
        out = engine.generate(wids, max_new_tokens=5, temperature=0.0, top_k=0, top_p=1.0)
        text = tokenizer.decode(out[0, wids.shape[1]:], skip_special_tokens=True)
        print(f"  Warmup {i}: '{text}'")
    del wids
    nan5 = check(engine, tokenizer, "after warmup")

    # === Step 6: speculative_generate (the actual eval path) ===
    print("\n[Step 6] speculative_generate()...")
    nan6 = check_speculative(engine, tokenizer, "speculative_generate")

    # === Summary ===
    print("\n" + "=" * 60)
    print("  BISECTION RESULTS")
    print("=" * 60)
    results = [
        ("Step 1: load+pack+flat", nan1),
        ("Step 2: enable_eagle D=8", nan2),
        ("Step 3: create optimizer", nan3),
        ("Step 4: load checkpoint", nan4),
        ("Step 5: warmup", nan5),
        ("Step 6: speculative_generate", nan6),
    ]
    for name, had_nan in results:
        status = "*** NaN ***" if had_nan else "OK"
        print(f"  {name}: {status}")

    first_fail = next((name for name, nan in results if nan), None)
    if first_fail:
        print(f"\n  FIRST FAILURE: {first_fail}")
    else:
        print(f"\n  ALL PASSED — no NaN detected!")
    print("=" * 60)