| | |
| | """Find exact sequence length threshold for NaN. Test with/without pack_all_experts.""" |
| | import sys, os, torch |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| | from hebbian_finetune_demo import load_engine |
| |
|
| | MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct" |
| |
|
| |
|
| | @torch.no_grad() |
| | def test_len(engine, tokenizer, n_tokens, label=""): |
| | """Generate a prompt of approximately n tokens and test forward.""" |
| | |
| | base = "word " * max(n_tokens, 1) |
| | ids = tokenizer.encode(base, return_tensors='pt').cuda() |
| | |
| | ids = ids[:, :n_tokens] |
| | engine.reset_cache() |
| | engine._current_seq_id = 0 |
| | if hasattr(engine.kv_cache, '_graph_mode'): |
| | engine.kv_cache._graph_mode = False |
| | logits = engine.forward(ids, use_cache=True, position=0) |
| | torch.cuda.synchronize() |
| | has_nan = logits.isnan().any().item() |
| | status = "NaN" if has_nan else "OK" |
| | print(f" len={n_tokens:4d} {label}: {status}") |
| | return has_nan |
| |
|
| |
|
| | if __name__ == "__main__": |
| | print("=" * 60) |
| | print(" Sequence Length NaN Threshold Finder") |
| | print("=" * 60) |
| |
|
| | print("\n[1] Loading engine (WITH pack)...") |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | engine.pack_all_experts() |
| |
|
| | |
| | print("\n[2] Testing WITH pack_all_experts (coarse)...") |
| | for n in [1, 5, 10, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 40, 50, 64, 100]: |
| | test_len(engine, tokenizer, n, "(packed)") |
| |
|
| | |
| | print("\n[3] Reloading engine WITHOUT pack_all_experts...") |
| | del engine |
| | torch.cuda.empty_cache() |
| | engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda") |
| | engine.eval() |
| | engine.kv_cache.enable_flat_decode(4096) |
| | |
| |
|
| | print("\n[4] Testing WITHOUT pack_all_experts...") |
| | for n in [1, 10, 20, 25, 30, 31, 32, 40, 50, 64, 100]: |
| | test_len(engine, tokenizer, n, "(unpacked)") |
| |
|
| | print("\n" + "=" * 60) |
| | print(" DONE") |
| | print("=" * 60) |
| |
|