FireEcho / FireEcho Engine /debug_seqlen_threshold.py
Joysulem's picture
Upload 3258 files
b5bff9c verified
#!/usr/bin/env python3
"""Find exact sequence length threshold for NaN. Test with/without pack_all_experts."""
import sys, os, torch
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from hebbian_finetune_demo import load_engine
MODEL_PATH = "/run/media/echo/Echo/ECHO/training/Prototype Fireecho/model/Qwen3-Omni-30B-A3B-Instruct"
@torch.no_grad()
def test_len(engine, tokenizer, n_tokens, label=""):
"""Generate a prompt of approximately n tokens and test forward."""
# Use repeating text to control length
base = "word " * max(n_tokens, 1)
ids = tokenizer.encode(base, return_tensors='pt').cuda()
# Truncate to exact length
ids = ids[:, :n_tokens]
engine.reset_cache()
engine._current_seq_id = 0
if hasattr(engine.kv_cache, '_graph_mode'):
engine.kv_cache._graph_mode = False
logits = engine.forward(ids, use_cache=True, position=0)
torch.cuda.synchronize()
has_nan = logits.isnan().any().item()
status = "NaN" if has_nan else "OK"
print(f" len={n_tokens:4d} {label}: {status}")
return has_nan
if __name__ == "__main__":
print("=" * 60)
print(" Sequence Length NaN Threshold Finder")
print("=" * 60)
print("\n[1] Loading engine (WITH pack)...")
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
engine.pack_all_experts()
# Binary search for threshold
print("\n[2] Testing WITH pack_all_experts (coarse)...")
for n in [1, 5, 10, 15, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 40, 50, 64, 100]:
test_len(engine, tokenizer, n, "(packed)")
# Now test WITHOUT pack
print("\n[3] Reloading engine WITHOUT pack_all_experts...")
del engine
torch.cuda.empty_cache()
engine, tokenizer, config = load_engine(MODEL_PATH, max_seq_len=4096, device="cuda")
engine.eval()
engine.kv_cache.enable_flat_decode(4096)
# NO pack_all_experts!
print("\n[4] Testing WITHOUT pack_all_experts...")
for n in [1, 10, 20, 25, 30, 31, 32, 40, 50, 64, 100]:
test_len(engine, tokenizer, n, "(unpacked)")
print("\n" + "=" * 60)
print(" DONE")
print("=" * 60)