Kernels

F2: silent wrong-results from add_lora_sgmv_cutlass on single segments above 144 tokens

#2
by alkari33 - opened

Summary

add_lora_sgmv_cutlass produces silently incorrect numerical results when invoked with a single segment whose token count T exceeds 144. The kernel returns without raising; the output tensor is well-formed but numerically diverges from the PEFT-factored reference by orders of magnitude beyond the bf16 GEMM accumulation tolerance.

The pattern is non-monotonic: most lengths above 144 are broken, with sporadic correct outputs at a small number of intermediate lengths. This is consistent with a CTA-tile / shared-memory boundary in the underlying CUTLASS implementation.

Environment

  • Kernel repo snapshot: kernels-community/punica-sgmv at commit be89a97bbd04562e2834d5c8f0e9342dc7ae7715
  • Build variant resolved: build/torch29-cxx11-cu130-x86_64-linux/
  • PyTorch: 2.9.1+cu130
  • CUDA: 13.0
  • Python: 3.11.2
  • kernels runtime: >=0.5
  • peft: 0.19.0
  • GPU A: NVIDIA RTX 5080, SM 12.0
  • GPU B: NVIDIA RTX PRO 6000 Blackwell Server Edition, SM 10.0
  • Both architectures listed in the kernel metadata.json backend.archs.

Minimal repro

import torch
from kernels import get_kernel

kernel = get_kernel("kernels-community/punica-sgmv")

device = "cuda"
dtype = torch.bfloat16
in_features, out_features, rank = 512, 512, 16
scaling = 32.0 / float(rank)

torch.manual_seed(0)
lora_A = (torch.randn(rank, in_features, dtype=torch.float32) * 0.02).to(device, dtype)
lora_B = (torch.randn(out_features, rank, dtype=torch.float32) * 0.02).to(device, dtype)
wa = lora_A.unsqueeze(0).contiguous()
wb = lora_B.transpose(0, 1).unsqueeze(0).contiguous()

def factored_ref(x):
    return ((x @ lora_A.T) @ lora_B.T) * scaling

def sgmv(x):
    n_tokens = x.shape[0]
    y = torch.zeros(n_tokens, out_features, device=device, dtype=dtype)
    wa_ptr = torch.tensor([wa.data_ptr()], dtype=torch.int64, device=device)
    wb_ptr = torch.tensor([wb.data_ptr()], dtype=torch.int64, device=device)
    s_start = torch.tensor([0], dtype=torch.int32, device=device)
    s_end = torch.tensor([n_tokens], dtype=torch.int32, device=device)
    kernel.add_lora_sgmv_cutlass(y, x, wa_ptr, wb_ptr, s_start, s_end, 0, rank)
    return y * scaling

for n_tokens in [64, 128, 144, 176, 192, 256, 512]:
    torch.manual_seed(n_tokens)
    x = torch.randn(n_tokens, in_features, device=device, dtype=dtype).contiguous()
    y_kernel = sgmv(x)
    y_ref = factored_ref(x)
    err = (y_kernel.float() - y_ref.float()).abs().max().item()
    print(f"n_tokens={n_tokens:>4}  max_abs_err={err:.3e}")

Expected behavior

The kernel result should match the PEFT factored reference within the bf16 GEMM accumulation tolerance (we use 1e-2). This must hold across any segment length up to whatever ceiling the kernel documents.

Actual behavior

Measured on the same RTX 5080 (SM 12.0) staging machine with the snapshot above:

n_tokens=  64  max_abs_err=9.77e-04   OK
n_tokens= 128  max_abs_err=9.77e-04   OK
n_tokens= 144  max_abs_err=1.98e-01   BROKEN
n_tokens= 176  max_abs_err=9.77e-04   OK
n_tokens= 192  max_abs_err=1.98e-01   BROKEN
n_tokens= 256  max_abs_err=1.98e-01   BROKEN
n_tokens= 512  max_abs_err=1.98e-01   BROKEN

The error is ~200x the bf16 tolerance band. The kernel does not raise. The break first appears at n_tokens = 144; we observed a non-monotonic pattern with occasional correct outputs at 176, 184, 191, 320, 480 and silent failures everywhere else above 144.

For comparison, calling the kernel with multiple segments of <=128 tokens each (all pointing at the same homogeneous LoRA factor pair) returns correct results at any aggregate n_tokens.

Workaround

In our wrapper we cap each segment at MAX_SAFE_SGMV_SEGMENT = 128 and chunk longer flat-batch inputs into multiple equivalent segments. This is the most conservative ceiling we found that still admits decode and prefill batches of arbitrary length without numerical drift. Code path: caller-side, no kernel modification.

Suggested fix direction

The non-monotonic pattern starting at 144 is consistent with a CTA-tile-count or shared-memory-block size boundary inside the SGMV CUTLASS dispatch. Likely sites to inspect:

  1. The threadblock M-tile size used for the shrink and expand kernels; the safe ceiling 128 and the first failure at 144 suggest a tile boundary at or near 128 with a follow-on correctness issue when the residual block is partially populated.
  2. Output-pointer arithmetic for segments whose token count is not a multiple of the tile M. If the kernel writes past the end of the segment range when the residual partial tile is non-zero, the output could be silently overwriting a downstream slot.
  3. Shared-memory layout for the rank-R reduction in the expand step; the rank is 16 in our use, well within typical limits, but the failure pattern would also be consistent with a shared-memory bank conflict that only manifests above a particular M*R footprint.

We can share intermediate dumps from the broken cases if helpful.

Snapshot pinning

Reproduced against be89a97bbd04562e2834d5c8f0e9342dc7ae7715 build variant torch29-cxx11-cu130-x86_64-linux. Confirmation that this is also reproducible on main would help us know when to remove the chunking workaround.

Sign up or log in to comment