Long-context (>32K) crashes on ORT CUDA EP — fix proposal: slice logits to last token before LM head

#3
by NERDDISCO - opened

model_q4f16.onnx crashes on ORT CUDA EP for any prompt > 32 768 tokens

cudaErrorIllegalAddress (error 700) at /model/graph_outputs/logits/Cast. Hits everyone using this model with onnxruntime-gpu past 32K context, even though the model advertises 128 K.

Root cause (in ORT, not this model): cast_op.cu uses an int32_t element counter. Final logits tensor element count = B × S_q × vocab = 1 × 32768 × 65536 = 2³¹ exactly. One token over → counter wraps → illegal access. Filed upstream at microsoft/onnxruntime#28385 (twin of the recently-fixed Gather variant #28107 / #28108).

Fix on the model side: re-export with a Slice before the LM head matmul so it only projects the last token (logits shape [B, 1, V] instead of [B, S_q, V]). This is what HF transformers calls logits_to_keep=1 and what every production LLM serving framework already does (vLLM, TGI, llama.cpp, onnxruntime-genai's prune_lm_head). The bare optimum-cli export onnx disables it by default (optimum-onnx/convert.py:592) which is why so many HF ONNX exports share this defect.

The upstream ORT fix will take time to land + propagate to a stable release. Until then, every onnxruntime-gpu user (including transformers.js / wandler / onnxruntime-genai consumers) is stuck at 32 K. Re-exporting the model is the fastest path to fixing it for everyone.

I have a verified post-hoc graph patcher that does the slice and confirmed it on this exact model (works cleanly to 128 K, fp16 cos sim 1.00000): last_token_logits.py. Happy to PR a re-export to this repo if you share the export script you used.

Sign up or log in to comment