File size: 2,232 Bytes
b5bff9c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#!/usr/bin/env python3
"""Find exact sequence length threshold for NaN. Test with/without pack_all_experts."""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine

MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"


@torch.no_grad()
def test_len(engine, tokenizer, n_tokens, label=""):
    """Generate a prompt of approximately n tokens and test forward."""
    # Use repeating text to control length
    base = "word " * max(n_tokens, 1)
    ids = tokenizer.encode(base, return_tensors='pt').cuda()
    # Truncate to exact length
    ids = ids[:, :n_tokens]
    engine.reset_cache()
    engine._current_seq_id = 0
    if hasattr(engine.kv_cache, '_graph_mode'):
        engine.kv_cache._graph_mode = False
    logits = engine.forward(ids, use_cache=True, position=0)
    torch.cuda.synchronize()
    has_nan = logits.isnan().any().item()
    status = "NaN" if has_nan else "OK"
    print(f"  len={n_tokens:4d} {label}: {status}")
    return has_nan


if __name__ == "__main__":
    print("=" * 60)
    print("  Sequence Length NaN Threshold Finder")
    print("=" * 60)

    print("\n[1] Loading engine (WITH pack)...")
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.eval()
    engine.kv_cache.enable_flat_decode(4096)
    engine.pack_all_experts()

    # Binary search for threshold
    print("\n[2] Testing WITH pack_all_experts (coarse)...")
    for n in [1, 5, 10, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 40, 50, 64, 100]:
        test_len(engine, tokenizer, n, "(packed)")

    # Now test WITHOUT pack
    print("\n[3] Reloading engine WITHOUT pack_all_experts...")
    del engine
    torch.cuda.empty_cache()
    engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
    engine.eval()
    engine.kv_cache.enable_flat_decode(4096)
    # NO pack_all_experts!

    print("\n[4] Testing WITHOUT pack_all_experts...")
    for n in [1, 10, 20, 25, 30, 31, 32, 40, 50, 64, 100]:
        test_len(engine, tokenizer, n, "(unpacked)")

    print("\n" + "=" * 60)
    print("  DONE")
    print("=" * 60)