Instructions to use kernels-community/punica-sgmv with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Kernels
How to use kernels-community/punica-sgmv with Kernels:
# !pip install kernels from kernels import get_kernel kernel = get_kernel("kernels-community/punica-sgmv") - Notebooks
- Google Colab
- Kaggle
F2: silent wrong-results from add_lora_sgmv_cutlass on single segments above 144 tokens
Summary
add_lora_sgmv_cutlass produces silently incorrect numerical results when invoked with a single segment whose token count T exceeds 144. The kernel returns without raising; the output tensor is well-formed but numerically diverges from the PEFT-factored reference by orders of magnitude beyond the bf16 GEMM accumulation tolerance.
The pattern is non-monotonic: most lengths above 144 are broken, with sporadic correct outputs at a small number of intermediate lengths. This is consistent with a CTA-tile / shared-memory boundary in the underlying CUTLASS implementation.
Environment
- Kernel repo snapshot:
kernels-community/punica-sgmvat commitbe89a97bbd04562e2834d5c8f0e9342dc7ae7715 - Build variant resolved:
build/torch29-cxx11-cu130-x86_64-linux/ - PyTorch: 2.9.1+cu130
- CUDA: 13.0
- Python: 3.11.2
kernelsruntime: >=0.5peft: 0.19.0- GPU A: NVIDIA RTX 5080, SM 12.0
- GPU B: NVIDIA RTX PRO 6000 Blackwell Server Edition, SM 10.0
- Both architectures listed in the kernel
metadata.jsonbackend.archs.
Minimal repro
import torch
from kernels import get_kernel
kernel = get_kernel("kernels-community/punica-sgmv")
device = "cuda"
dtype = torch.bfloat16
in_features, out_features, rank = 512, 512, 16
scaling = 32.0 / float(rank)
torch.manual_seed(0)
lora_A = (torch.randn(rank, in_features, dtype=torch.float32) * 0.02).to(device, dtype)
lora_B = (torch.randn(out_features, rank, dtype=torch.float32) * 0.02).to(device, dtype)
wa = lora_A.unsqueeze(0).contiguous()
wb = lora_B.transpose(0, 1).unsqueeze(0).contiguous()
def factored_ref(x):
return ((x @ lora_A.T) @ lora_B.T) * scaling
def sgmv(x):
n_tokens = x.shape[0]
y = torch.zeros(n_tokens, out_features, device=device, dtype=dtype)
wa_ptr = torch.tensor([wa.data_ptr()], dtype=torch.int64, device=device)
wb_ptr = torch.tensor([wb.data_ptr()], dtype=torch.int64, device=device)
s_start = torch.tensor([0], dtype=torch.int32, device=device)
s_end = torch.tensor([n_tokens], dtype=torch.int32, device=device)
kernel.add_lora_sgmv_cutlass(y, x, wa_ptr, wb_ptr, s_start, s_end, 0, rank)
return y * scaling
for n_tokens in [64, 128, 144, 176, 192, 256, 512]:
torch.manual_seed(n_tokens)
x = torch.randn(n_tokens, in_features, device=device, dtype=dtype).contiguous()
y_kernel = sgmv(x)
y_ref = factored_ref(x)
err = (y_kernel.float() - y_ref.float()).abs().max().item()
print(f"n_tokens={n_tokens:>4} max_abs_err={err:.3e}")
Expected behavior
The kernel result should match the PEFT factored reference within the bf16 GEMM accumulation tolerance (we use 1e-2). This must hold across any segment length up to whatever ceiling the kernel documents.
Actual behavior
Measured on the same RTX 5080 (SM 12.0) staging machine with the snapshot above:
n_tokens= 64 max_abs_err=9.77e-04 OK
n_tokens= 128 max_abs_err=9.77e-04 OK
n_tokens= 144 max_abs_err=1.98e-01 BROKEN
n_tokens= 176 max_abs_err=9.77e-04 OK
n_tokens= 192 max_abs_err=1.98e-01 BROKEN
n_tokens= 256 max_abs_err=1.98e-01 BROKEN
n_tokens= 512 max_abs_err=1.98e-01 BROKEN
The error is ~200x the bf16 tolerance band. The kernel does not raise. The break first appears at n_tokens = 144; we observed a non-monotonic pattern with occasional correct outputs at 176, 184, 191, 320, 480 and silent failures everywhere else above 144.
For comparison, calling the kernel with multiple segments of <=128 tokens each (all pointing at the same homogeneous LoRA factor pair) returns correct results at any aggregate n_tokens.
Workaround
In our wrapper we cap each segment at MAX_SAFE_SGMV_SEGMENT = 128 and chunk longer flat-batch inputs into multiple equivalent segments. This is the most conservative ceiling we found that still admits decode and prefill batches of arbitrary length without numerical drift. Code path: caller-side, no kernel modification.
Suggested fix direction
The non-monotonic pattern starting at 144 is consistent with a CTA-tile-count or shared-memory-block size boundary inside the SGMV CUTLASS dispatch. Likely sites to inspect:
- The threadblock
M-tile size used for the shrink and expand kernels; the safe ceiling 128 and the first failure at 144 suggest a tile boundary at or near 128 with a follow-on correctness issue when the residual block is partially populated. - Output-pointer arithmetic for segments whose token count is not a multiple of the tile
M. If the kernel writes past the end of the segment range when the residual partial tile is non-zero, the output could be silently overwriting a downstream slot. - Shared-memory layout for the rank-
Rreduction in the expand step; the rank is 16 in our use, well within typical limits, but the failure pattern would also be consistent with a shared-memory bank conflict that only manifests above a particularM*Rfootprint.
We can share intermediate dumps from the broken cases if helpful.
Snapshot pinning
Reproduced against be89a97bbd04562e2834d5c8f0e9342dc7ae7715 build variant torch29-cxx11-cu130-x86_64-linux. Confirmation that this is also reproducible on main would help us know when to remove the chunking workaround.