| {"output": "analysisWe must write a Python code snippet containing a Triton kernel and wrapper function per the given signatures. We need an AMD compatible version (ROCm). Use Triton for AMD GPUs; kernel must be decorated with @triton.jit. Must not use any CUDA specific features; Triton abstracts over hardware but should be fine.\n\nThe request expects to adapt the provided description to implement a specialized SGMV expand slice kernel. We have to implement the kernel but code may be placeholder? Should be functional.\n\nWe have to precisely define kernel signature arguments: input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, slice_offset, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, CAST_TYPE.\n\nCheck the list: The function signature includes \"BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, EVEN_K: tl.constexpr, ADD_INPUTS: tl.constexpr, CAST_TYPE: tl.constexpr\". So we must use those constants.\n\nThe wrapper has arguments: inputs: torch.Tensor, lora_b_weights: torch.Tensor, output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, lora_indices_tensor: torch.Tensor, batches: int, max_seq_length: int, token_nums: int, slice_offset: int, slice_size: int, add_inputs: bool = False.\n\nWe need to set up grid and launch kernel. Must perform shape checks, configure block sizes.\n\nThe description states specialized matrix multiplication, possibly performing SGMV with LoRA weights slices. The param names: input_ptr (input tensor), lora_ptr (LoRA weights), out_ptr (output), N, K (dimensions). Possibly N is number of columns per matrix, K is rank size? We need to compute for each batch sequence slice. This is a custom kernel. We can implement a version like:\n\n- For each output element in a block of size BLOCK_M x BLOCK_N (M rows, N columns). The kernel loads a block of input of shape BLOCK_M x BLOCK_K (or BLOCK_M x K?), loads LoRA weight block (BLOCK_K x BLOCK_N?), does dot, possibly adding original input.\n\nBut the snippet we produce must be fully syntactically correct but not necessarily match hidden algorithm if unknown. Could be a simplified but functional SGMV kernel:\n\nSGMV: sparse generalized matrix-vector multiplication: multiply a dense vector (inputs) with a sparse matrix (LoRA weight?). The LoRA weight likely is a low-rank adaptation: weight = A * B where A shape (hidden_size, rank) and B shape (rank, hidden_size). The kernel may compute for each token a slice from LoRA weight matrix for the slice_offset maybe indicates starting column in A or B.\n\nThe inputs: inputs (shape: token_nums x hidden_size?) The LoRA B weights: shape (max_rank, hidden_size?) maybe.\n\nb_seq_start_loc: start location for each batch's sequence within token flat array.\n\nseq_len_tensor: per batch sequence length.\n\nlora_indices_tensor: indices that map each token to a LoRA weight slice (like rank selection per token). So each token selects a particular LoRA slice.\n\nslice_offset: offset within LoRA weights to select slice? Possibly starting token offset.\n\nSlicing: The kernel takes slice_offset to offset into loRA B weight array.\n\nThe wrapper must call kernel with proper strides.\n\nWe need to formulate block sizes, maybe choose BLOCK_M=64, BLOCK_N=64, BLOCK_K=32 as compile-time. For even K determine if K % BLOCK_K ==0.\n\nADD_INPUTS possibly indicates whether to add original input to output.\n\nCAST_TYPE: maybe choose 0 for float16, 1 for float32? Could use tl.float16 vs tl.float32.\n\nWe'll produce a generic kernel where we compute:\n\nFor each M (row) and N (col) block, we compute:\n\nacc = 0\n\nFor k in [0, K) step BLOCK_K:\n\n- Load input block (size BLOCK_M x BLOCK_K) from input_ptr (assuming row-major: we need strides). The input pointer stride parameters are xm_stride (stride for rows?), xk_stride (stride for columns?). Actually xk_stride is comment \"# 1\" maybe it's stride across K dimension? The names: xm_stride and xk_stride (maybe stride for M dimension and K dimension). Without proper documentation we can assume input is shape (M, K) and output shape (M, N). LoRA matrix is shape (K, N). We'll get lora_ptr with strides l0_stride (maybe something else?), lora_k_stride, lora_n_stride.\n\nBut the signature includes l0_stride, also includes cm_stride, cn_stride, which perhaps are for the output (c matrix) at dimensions M and N.\n\nTherefore:\n\nInput matrix is A: size M x K, pointer input_ptr, strides xm_stride (row stride), xk_stride (col stride). LoRA weights are B: size K x N, pointer lora_ptr, strides l0_stride (maybe for batch dimension?), lora_k_stride (row stride), lora_n_stride (col stride). Output C: size M x N, pointer out_ptr, strides cm_stride (row stride), cn_stride (col stride).\n\nWe also have slice_offset: maybe an offset to apply to output or LoRA pointer indexing: each slice to compute per token.\n\nBut also we have b_seq_start_loc (batch seq start location) and seq_lens, lora_indices to fetch LoRA slice for each row (maybe each row corresponds to token). Possibly each row in input corresponds to a token, and we need to get LoRA slice based on lora_indices (maybe per token) to use specific LoRA weight slices.\n\nSimplify: For each row i (block row), we compute:\n\n- row_index = program_id(0) * BLOCK_M + offset_in_block\n\n- batch_id = maybe program_id(1)? Actually we can flatten across rows and so select program 0 = row block, program 1 = column block.\n\nGeneral pattern: kernel ids = [pid0, pid1] where pid0 = block index in M dimension, pid1 = block index in N dimension.\n\nBut we also need to incorporate per-token offsets: possibly we have multiple sequences per batch: we need to compute global row index = b_seq_start_loc[batch] + token_offset. but we have seq_lens and lora_indices for each token.\n\nSimplify: We'll treat each row i as a global row index. We'll compute base offset using b_seq_start_loc and seq_lens. Use lora_indices perhaps to index LoRA weight rows.\n\nWe'll produce a kernel where each thread loads the input element for row i and col k, multiply with LoRA weight element for k and column j, sum.\n\nPseudo-code:\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N, K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n # Compute block id\n pid_m = tl.program_id(0) # block row\n pid_n = tl.program_id(1) # block col\n\n # offsets for each element\n offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n # Compute mask within matrix dimensions\n mask_m = offs_m < N # ??? Actually M dimension size is perhaps number of rows (token count). We have token_nums param on wrapper side.\n mask_n = offs_n < K # ??? Wait M is rows (maybe token count). N is columns of input? Actually signature includes N and K; unclear which correspond to output shape. Possibly:\n - input shape: (M, K) thus N is M? Or maybe N is the number of rows (output rows) and K is number of columns (shared dimension).\n Let's treat: M dimension = N (since N likely rows count?), K is inner dimension.\n\nBetter to read signature: def _sgmv_expand_slice_kernel(..., N, K, ...). So N likely \"M\" (output rows) and K is inner dimension (rank). Then output shape is (N, ???) Actually output shape from wrapper presumably slice size? The wrapper passes slice_size maybe to N? The wrapper arguments: slice_offset and slice_size parameters. There's also token_nums parameter for total tokens. In the wrapper we need to compute N and K: Possibly N = token_nums (or slice_size) and K = hidden_size. Wait LoRA expansions: We have b_seq_start_loc and lora_indices etc.\n\nWe need to treat generic: the kernel multiplies input (N x K) with LoRA weights (K x slice_size?) producing output (N x slice_size). We need to pass output stride cm and cn; we need slice_offset to compute offset in output.\n\nProbably the wrapper passes N = token_nums (or slice_size?), K = hidden_size (or max_rank?). We cannot be sure but we can still write generic kernel that uses N as number of rows, K as shared dim, and slice_offset is used to offset output columns.\n\nBut to keep kernel simple, we can ignore slice_offset (just add to each column offset when storing output). Actually we can treat output pointer offset: out_ptr + slice_offset*cn_stride + ... So we can incorporate slice_offset in column offset: column index + slice_offset.\n\nPseudo:\n\ncol_idx = offs_n + slice_offset\n\nCompute input pointer address:\n\ninput_ptr_base = input_ptr + offs_m * xm_stride + tl.arange(0, BLOCK_K)[None,:] * xk_stride ???\n\nBut we need a block K dimension offset per iteration.\n\nWe'll do a loop over k inner dimension: for k0 in range(0, K, BLOCK_K): (if EVEN_K then handle full loops)\n\nInside loop:\n\nk_offset = k0 + tl.arange(0, BLOCK_K) # shape (BLOCK_K)\n\nWe need to mask for k_offset < K.\n\nLoad a block of A: A_block = tl.load(input_ptr + offs_m[:,None] * xm_stride + k_offset[None,:] * xk_stride, mask=mask_m[:,None] & (k_offset[None,:] < K), other=0.0)\n\nLoad block of B: B_block = tl.load(lora_ptr + l0_stride*? wait lora_ptr maybe also batch index? Actually lora_ptr perhaps includes indexing with lora_indices. There is lora_indices tensor index separate; we need to use lora_indices value per row to offset into lora weights.\n\nThe loRA index for each row i is lora_indices[global_row], maybe shape token_nums. So we need to look up index for each row and compute offset accordingly.\n\nWe have lora_indices: shape token_nums. Use tl.load(lora_indices + global_row_idx). global_row_idx = b_seq_start_loc[batch] + (offs_m - ???). Actually b_seq_start_loc holds start offset per batch.\n\nSimplify: In kernel we can compute row_global = b_seq_start_loc[batch] + offs_m, where batch is deduced from a 2D grid. However, we can also flatten across rows only (batch dimension considered part of row ID). Simpler: treat each row as linear index across token_nums (max_seq_length * batches). In wrapper we can compute pointer to input and lora_weights accordingly. So we can avoid b_seq_start_loc inside kernel.\n\nAlternatively we can treat b_seq_start_loc and seq_lens to clamp rows. But because we don't have to fully implement exact algorithm, we can just ignore these parameters in computation; just treat them as possible early returns.\n\nBut we need to include them as args to avoid unused warnings?\n\nSimplify: At start of kernel:\n\nbatch_id = tl.program_id(2) maybe? Actually we only have 2 dimensions (pid_m, pid_n). The b_seq_start_loc and seq_lens likely are 1D arrays of size batches.\n\nIf we wanted to incorporate them, we could compute token_offset by reading them. But to keep simple, we can just ignore them.\n\nHowever, the function signature includes them so we must accept. Not using them is okay but not recommended; but I think it's fine.\n\nImplementation details:\n\n- Use tl.dot for matrix multiplication of small subblocks? Actually we can compute C += tl.dot(A_block, B_block); where A shape (BLOCK_M, BLOCK_K), B shape (BLOCK_K, BLOCK_N). Yes.\n\nThus we can have:\n\nacc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32 or CAST_TYPE ?\n\nBut we could accumulate in float32 for accuracy.\n\nIf ADD_INPUTS: at the end add original input slice to output (maybe loading input slice again? Actually input itself is a matrix; we can load input slice as \"input_block\" and add.\n\nBut ADD_INPUTS indicates addition of input to output perhaps for residual.\n\nSimplify: compute out = dot(A, B). If ADD_INPUTS, we also add A (maybe not correct). Or maybe add an extra term.\n\nGiven limited information, we can treat ADD_INPUTS as a flag: if true, before storing we add A (input) to out. But dims might not match (if B maps K->N but input shape is N->K?). Hard.\n\nAlternative: we can ignore ADD_INPUTS to produce output = dot.\n\nBut we can implement logic for both: if ADD_INPUTS, we compute out += original input (maybe the slice of LOFA weight? Unclear). We can call it out += tl.load(input_ptr offset for same rows/cols). That would require same shape (BLOCK_M, BLOCK_N): input slice maybe derived from input_ptr using same row offsets and column offsets? Input_ptr has K dimension not N. But perhaps we could add another term from another tensor: not defined.\n\nBetter to just ignore ADD_INPUTS and not add anything. But the flag is required param; we will use if ADD_INPUTS: after computing out, we load additional data from maybe out_ptr (or input ???). For safety we can simply store out as is.\n\nAlternatively, we could add a comment that it's not used.\n\nNow about CAST_TYPE: we can treat CAST_TYPE argument as constexpr specifying target type of load and store: if CAST_TYPE == 0, use tl.float16; if CAST_TYPE == 1, use tl.float32. But we can't easily branch on constexpr inside kernel? Yes can use if parallel? Actually using if statements on compile-time constexpr influences generation. So we can write:\n\nDATA_TYPE = tl.float16 if CAST_TYPE == 0 else tl.float32.\n\nBut need to ensure proper type.\n\nDefine at top:\n\nif CAST_TYPE == tl.constexpr(0):\n acc_dtype = tl.float16\nelse:\n acc_dtype = tl.float32\n\nBut we need to check if we can use Python-level if with constexpr values? Yes.\n\nAlternatively, use `if CAST_TYPE:` treat as bool: if it equals 1 produce float32 else float16.\n\nSimplify: set acc_type = tl.float32; we ignore CAST_TYPE.\n\nBut signature expects CAST_TYPE: tl.constexpr. We can keep param but not use.\n\nBetter to keep logic: `ACC_TYPE = tl.float32 if CAST_TYPE else tl.float16`.\n\nNow mask: compute proper mask shapes for loads.\n\nWe need to compute global offsets:\n\n- Input pointer offset: input_ptr + (offs_m[:, None] * xm_stride + (k_offset[None, :] * xk_stride))\n\nwhere xm_stride is row stride (maybe number of elements per row). Note: In Triton pointer arithmetic: pointer arithmetic on a pointer is done by adding integer offsets multiplied by element size. For a pointer to contiguous memory of type T, pointer + offset yields pointer to element offset * sizeof(T). So we can do pointer arithmetic using strides.\n\nDefine:\n\noffsets_a = offs_m[:, None] * xm_stride + (k_offset[None, :] * xk_stride)\n\nBut careful: stride must be multiplied by element size (by default they are element strides). Since we pass integer stride values derived from torch.stride (in elements). That works.\n\n- For LoRA pointer (B matrix), we have stride for K dimension (lora_k_stride) and stride for N dimension (lora_n_stride). Also there is l0_stride maybe for batch dimension? Actually if LoRA weight global (like each LoRA rank slice list). If we treat that l0_stride indexes slice offset (slice_offset?), we can compute base offset for each row using lora_index per row: each row i uses a specific slice (maybe lora_index) to offset B pointer. So we need to compute index = lora_indices[global_row] (int). Then compute offset = index * l0_stride + k_offset * lora_k_stride + offs_n * lora_n_stride.\n\nThus we need load from lora_indices.\n\nThus we need global row id: We need to compute from offs_m (row offset) + maybe base row offset per batch?\n\nWe can get token index = offs_m (since they represent row index within the flattened token space). The wrapper will ensure that kernel receives N = token_nums, K = hidden_size and stride values accordingly; we can treat row index = offs_m. So we can load index as:\n\nrow_id = offs_m\n\nBut row_id may exceed total tokens, so we mask accordingly.\n\nThus lora_index = tl.load(lora_indices + row_id, mask=row_id < N, other=0)\n\nBut lora_indices shape maybe token_nums; So fine.\n\nThus compute base_offset = lora_index * l0_stride\n\nThen load B block with:\n\noffsets_b = base_offset + (k_offset[:, None] * lora_k_stride) ??? Wait B shape: K x N slice (maybe hidden size x slice_size). Actually we need B rows = K, columns = slice_size. So offsets_b = base_offset + (k_offset[:, None] * lora_k_stride) + (offs_n[None, :] * lora_n_stride). Actually need to make shapes align: k_offset of shape BLOCK_K, offs_n shape BLOCK_N.\n\nThus:\n\nB_block = tl.load(lora_ptr + base_offset + k_offset[:, None] * lora_k_stride + offs_n[None, :] * lora_n_stride, mask=k_mask & n_mask, other=0.0, eviction_policy='evict_last')\n\nBut we need to ensure shape: base_offset is scalar per row (depending on row). Need to broadcast across K and N dims.\n\nBetter: compute base_offset_per_row = base_offset (scalar). Then offsets = base_offset_per_row + k_offset[:, None] * lora_k_stride + offs_n[None, :] * lora_n_stride.\n\nNow, for each row i in block, we need base offset per row i: base_offset_i = lora_indices[i_global] * l0_stride. This yields shape (BLOCK_M) after broadcasting.\n\nBut pointer arithmetic expects pointer + offset; we can add offset per element. To broadcast correctly, we can compute offset matrix as:\n\noffs_A = offs_m[:, None] * xm_stride + k_offset[None, :] * xk_stride\n# shape (BLOCK_M, BLOCK_K)\nIn case we need per-row base offset for B, we need to add per-row offset: lora_index_i * l0_stride yields shape (BLOCK_M), then we broadcast across K dimension: lora_index_i[:, None] * l0_stride ???\n\nBut l0_stride presumably stride for slice dimension for LoRA? Actually l0_stride maybe stride for the zero-th dimension which indexes each LoRA slice (maybe dimension 0 is slice/per token?). For LoRA B weights, maybe shape (max_rank, hidden, hidden) or something. But we will assume l0_stride is stride to move between slices (or to next set of LoRA weights). Use that to compute pointer offset for each row.\n\nThus: base_offset_i = lora_indices_i * l0_stride (scalar). offsets_b = base_offset_i[:, None] + k_offset[None,:] * lora_k_stride + offs_n[None,:] * lora_n_stride\n\nBut need shape (BLOCK_M, BLOCK_N). Wait B is matrix K x N, typical stride: rows K stride, columns N stride. So we need offset: k_offset * lora_k_stride + offs_n * lora_n_stride. The l0_stride is not typical; maybe it's an extra dimension representing LoRA index dimension that should be added onto the offset for each row. So we include it.\n\nThus B_block = tl.load(lora_ptr + base_offset_i[:, None] + k_offset[None, :] * lora_k_stride + offs_n[None, :] * lora_n_stride, mask=mask_B, other=0.0). However the dimensions of each term must match shape (BLOCK_M, BLOCK_N). base_offset_i[:, None] is shape (BLOCK_M, 1) and broadcast across N.\n\nNow compute dot:\n\nacc = tl.dot(A_block, B_block) # shape (BLOCK_M, BLOCK_N)\n\nBut tl.dot expects shape (M, K) and (K, N) and returns (M,N). Good.\n\nAccumulate across k loops: if EVEN_K then can use for loop with static strides.\n\nPseudo:\n\nacc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\nfor k in range(0, K, BLOCK_K):\n cur_k = k + tl.arange(0, BLOCK_K) \n k_mask = cur_k < K\n # A block\n a = tl.load(input_ptr + offs_m[:, None] * xm_stride + cur_k[None, :] * xk_stride, mask=mask_m[:, None] & k_mask[None, :], other=0.0)\n # B block, need lora index per row\n lora_idx = tl.load(lora_indices + offs_m, mask=mask_m, other=0) # shape (BLOCK_M)\n base = lora_idx * l0_stride\n b = tl.load(lora_ptr + base[:, None] + cur_k[:, None] * lora_k_stride + offs_n[None, :] * lora_n_stride, mask=k_mask[:, None] & mask_n[None, :], other=0.0)\n\n acc += tl.dot(a, b)\n\nAfter loop, if ADD_INPUTS: maybe load input slice for same rows and columns as offse_n+slice_offset? But we need to add original inputs? Could be original input slice (like residual). Let's ignore but we can apply addition from out pointer previously? To keep code valid, we can simply store acc.\n\nStore output:\n\n# compute output offset\noutput_offset = offs_m[:, None] * cm_stride + (offs_n + slice_offset) * cn_stride ???\n\nBut we need to correctly offset:\n\nout_base_ptr = out_ptr + slice_offset * cn_stride\n\nProbably output pointer (out_ptr) is base of output for a given slice; we can store directly at out_ptr + offsets.\n\nThus:\n\ntl.store(out_ptr + offs_m[:, None] * cm_stride + offs_n * cn_stride, acc, mask=mask_m[:, None] & mask_n[None, :])\n\nBut we need to incorporate slice_offset: we can compute out_ptr + offs_m[...,]*cm_stride + (offs_n + slice_offset)*cn_stride.\n\nSo store pointer: out_ptr + offs_m[:, None] * cm_stride + (offs_n + slice_offset) * cn_stride\n\nError: tl.store expects a pointer, not dynamic offset. But we can compute pointer add using offset addition to pointer: out_ptr + (offs_m[:, None] * cm_stride + (offs_n + slice_offset) * cn_stride). This yields pointer shape (BLOCK_M, BLOCK_N). Should be fine.\n\nNow the grid: In wrapper, we need to pass BLOCK_M, BLOCK_N, BLOCK_K (runtime compile-time constants). We can define constants like BLOCK_M = 128, BLOCK_N = 64, BLOCK_K = 32. Could adapt based on hardware.\n\nThe wrapper must accept inputs: inputs (torch.Tensor), lora_b_weights (torch.Tensor), output_tensor (torch.Tensor), b_seq_start_loc (torch.Tensor), seq_len_tensor (torch.Tensor), lora_indices_tensor (torch.Tensor), batches (int), max_seq_length (int), token_nums (int), slice_offset (int), slice_size (int), add_inputs (bool = False). Note that slice_size: maybe number of columns output. Since we have output_tensor shape appropriate.\n\nWe need to ensure contiguity: maybe call .contiguous().\n\nWe must compute N and K values for kernel. N: number of rows in inputs? That is token_nums? Or slice size? Actually typical matrix multiplication: inputs shape (token_nums, hidden_dim). LoRA B weights shape (rank, hidden_dim) or (hidden_dim, rank). But we want expansion slice: compute for each token a slice of output of size slice_size, maybe slice_size = hidden_dim? Not sure.\n\nSimplify: Let N = token_nums (the rows). Let K = inputs.shape[1] (hidden dim). The LoRA weights B slice shape expects (K, slice_size). Then output shape = (token_nums, slice_size). So output_tensor shape should be (token_nums, slice_size). The wrapper will then set N = token_nums, K = inputs.shape[1], and also pass slice_offset.\n\nThus wrapper can compute:\n\nN = token_nums\nK = inputs.shape[1]\n\nAssuming that lora_b_weights shape: (max_rank, K, slice_size?) Actually LoRA weights may be a 3D tensor: indexed by lora_indices (rank) -> shape (max_rank, K, slice_size). Then we can have l0_stride = lora_b_weights.stride(0) (stride across rank), lora_k_stride = lora_b_weights.stride(1), lora_n_stride = lora_b_weights.stride(2).\n\nWe also need to compute input strides: xm_stride = inputs.stride(0) (row stride), xk_stride = inputs.stride(1) (col stride). For output: cm_stride = output_tensor.stride(0), cn_stride = output_tensor.stride(1).\n\nb_seq_start_loc and seq_len needed for maybe early termination but not used.\n\nlora_indices_tensor shape token_nums (int64). Provide to kernel.\n\nEnsure all tensors are contiguous: use .contiguous().\n\nNow kernel launch. Use program_ids:\n\ngrid = (triton.cdiv(N, BLOCK_M), triton.cdiv(slice_size, BLOCK_N))\n\nBut we may need also a third dimension for batches? Not needed since token_nums includes all.\n\nNow pass kernel args:\n\n_sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_K=BLOCK_K,\n EVEN_K=K % BLOCK_K == 0,\n ADD_INPUTS=add_inputs,\n CAST_TYPE=0 # maybe 0 for float16 else 1 for fp32? We can decide based on inputs.dtype: if inputs.dtype == torch.float32, then CAST_TYPE=1 else 0.\n\n) # compile constants defined after parameters as named arguments.\n\nBut careful: The parameters order after slice_offset: we have xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, slice_offset, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, CAST_TYPE.\n\nThus the kernel call must include those values in order before block size constants.\n\nThe order is: input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, slice_offset, BLOCK_M=..., BLOCK_N=..., BLOCK_K=..., EVEN_K=..., ADD_INPUTS=..., CAST_TYPE=...\n\nThus call matches.\n\nNow inside kernel we need likely to handle type of pointer loading: if inputs.dtype is float16, we load as tl.float16; else tl.float32.\n\nSimplify: Use cast type perhaps: define load_type = tl.float16 if CAST_TYPE == 0 else tl.float32\n\nAt top: if CAST_TYPE == 0: dtype = tl.float16 else: dtype = tl.float32\n\nBut we also use dtype to allocate accumulators, perhaps convert to float32 to reduce error.\n\nThus:\n\nif CAST_TYPE:\n # cast type is float32\n acc_dtype = tl.float32\nelse:\n acc_dtype = tl.float16\n\nBut we could allocate accumulator as tl.float32 for everything and then store with appropriate dtype.\n\nSimplify: Use accumulator dtype = tl.float32, store as original dtype.\n\nNow implement kernel logic.\n\nImplementation details:\n\n- Compute global row indices: row_idx = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n\n- Compute global col indices: col_idx = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N)\n\n- Compute masks: mask_m = row_idx < N, mask_n = col_idx < slice_size (which is determined as out stride? We don't pass slice_size, but we pass N and K only; We need slice_size maybe derived from output column count: we can compute from shape of out_ptr. Not passed to kernel. We can compute using `cn_stride`? That is stride for columns, not dimension count.\n\nBut we need the number of columns (output width). The wrapper could compute slice_size and pass as an extra compile-time constant; but not in signature. Wait the kernel signature includes `slice_offset` but not slice_size; but `slice_offset` is offset into output columns. Perhaps the raw pointer to output for each slice; maybe the output pointer points to beginning of slice of overall output, so the kernel's N dimension for columns is effectively slice_size? Actually the output shape may be (token_nums, slice_size). The pointer out_ptr points to a contiguous slice into a larger output matrix. The slice_offset is offset within the larger output where we write. The kernel's columns dimension is slice_size (maybe we compute from width of output_tensor). However width is not passed; but we can get width from cn_stride? Not exactly: stride is not length.\n\nHence we may need to get slice_size from token_nums or shape. But we cannot pass that as integer if not in signature. However, we can compute N is rows, but columns we need to know a bound for mask. Could use K? No.\n\nBut we can compute number of columns as perhaps `cm_stride` (row stride) provides dimension distance between rows; but not length. The shape must be known from out tensor shape; we can do out_tensor.shape[1] in wrapper, pass it perhaps via parameter like`N` or `K`? But signature we cannot change; maybe slice_size is the K argument? But K is inner dimension; not output columns.\n\nLet's examine the signature: N, K are presumably dimensions for computation: N number of rows (like token count), K dimension of inner multiplication (like hidden size). The matrix multiplication multiplies N x K (input) by K x ??? (LoRA weight slice). Output shape maybe N x (slice_size). So slice_size not passed. The kernel can derive output columns dimension from pointer stride length? Possibly by using a sentinel: we can compute number of columns = cn_stride * ??? Not possible.\n\nBut perhaps the kernel processes a fixed output slice size equal to BLOCK_N (kernel tile) and we launch grid across columns based on slice size (grid dimension). In wrapper we will compute grid dimension based on slice_size.\n\nThus mask_n = col_idx < slice_size (passed as token maybe). But slice size is not passed as param; we need to compute it in wrapper and embed as compile-time constant? We could embed as `BLOCK_N`? No.\n\nOption: Use an additional compile-time constexpr for slice size but not allowed; signature does not include it. But we could compute number of columns based on `cn_stride` which is stride in output, but stride is stride (distance in elements) to next column. Usually stride[1] = 1 for contiguous row-major. So here's a trick: If output is contiguous and row-major, then cn_stride = 1. Could compute slice_size as out_tensor.shape[1] / cn_stride? That is out_tensor.shape[1] which we know. In kernel we can incorporate it by using something like `N_COLS = tl.load(NCOLS_PTR)`\nBut there is no such pointer.\n\nAlternative: Use the extra argument N? But N already represents number of rows; maybe they swapped semantics: maybe N stands for number of columns (output width) and K for rows? Let's think: The kernel parameters are (input_ptr, lora_ptr, out_ptr, N, K, ...). Could be they treat N as number of columns of output, K as number of rows of input? Let's examine typical naming: In matrix multiplication, often you have (M, N, K). M rows of A, N columns of B, K inner dimension. In the problem description: \"blocks of data defined by the parameters BLOCK_M, BLOCK_N, and BLOCK_K.\" So the kernel likely follows M, N, K naming: M rows of input (or output?), N columns of output (or B's columns), K inner dimension. But the signature uses N and K only, not M. The M dimension likely derived from number of rows (maybe token count) inside input. Since they have N and K, perhaps M is token count? Wait typical gemm signature: a (M, K), b (K, N), c (M, N). In Kernels, they often pass M, N, K (or use the shape to determine grid). Here they removed M because it's derived from input dimensions? Possibly they pass N and K to kernel (N columns of output, K inner dimension); M can be derived by scanning input stride or perhaps total rows derived from token count token_nums. Actually rank# token_num appears as separate param in wrapper: token_nums.\n\nThus M = token_nums (passed separately). In kernel, they may not have M param but compute rows via program_id(0) and N? Actually they could compute rows i as program_id(0)*BLOCK_M + offset; but they need to guard against token_nums bound. So they need token_nums bound; maybe M is derived via N shape for column dimension?! No.\n\nLooking at typical Triton gemm examples: signature includes M, N, K (or they compute M from shape of A (i.e., pointer shape). Example: define kernel arguments: A, B, C, M, N, K, strideX etc. Example from Triton docs: @triton.jit def matmul_kernel(a_ptr, b_ptr, c_ptr, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr): They have shape dimensions M,N,K. That's typical.\n\nBut here they have N and K only. Might be they omitted M because they use token_nums as dimension for rows (like M), but they didn't pass M explicitly. Yet signature includes N and K (two dims). Might be they used N for number of rows (i.e., token_nums) and K for inner dimension; the output columns maybe slice_size passed via slice_offset? Eh.\n\nLet's re-express: Suppose N is rows (i.e., token count). K is hidden dimension. Then output columns = slice_size; but slice_size not passed. Perhaps slice_offset and slice_size are used to index into a bigger weight matrix (like lora weight matrix slices across hidden dimension). The kernel may compute output slice of a larger matrix product: out = input * weight_slice where weight_slice (K x hidden) maybe is s = weight[ :, slice_offset: slice_offset+slice_size]. So slice_size is size of slice. Thus same kernel code can compute only that slice; it does not need to know slice_size at compile time: we can derive slice size at runtime via grid dimension. The kernel's N dimension is rows; M may be rows? Actually we have only N, K per signature: N is rows dimension (maybe token_nums) and the output columns dimension is determined by grid dimension (program_id(1) * BLOCK_N + offsets). So no need to pass slice_size.\n\nThus we can treat token count as N param, hidden dimension as K.\n\nNow mask for output columns: We need to compute column bounds: We need to know the number of columns for this slice. Could use `cn_stride` to compute column length? Typically for contiguous output in row-major, stride[1] = 1. So number of columns is out_tensor.shape[1], but we can pass that as maybe `slice_size` implicitly via how many blocks we launch. In kernel, we only need to mask columns beyond slice size because we launch grid accordingly. We can compute column mask using a bound size passed as variable: We can compute it using `tl.load` from a pointer? Not needed if grid covers exact size, but sometimes last block may exceed bound, so we need column mask based on total columns. The total columns parameter not given; but we can compute `col_idx` max with slices: maybe we can compute via `tl.program_id(1) * BLOCK_N + tl.arange`. The mask can be `col_idx < N?` But N is rows not columns.\n\nAlternative: we can set mask_n = col_idx < slice_size where slice_size we can compute as out_tensor.shape[1]: we have slice_size passed to wrapper but not to kernel. We cannot refer to it unless we pass as argument. However we could repurpose `K` as slice size if we think K is inner dimension? Unclear.\n\nLet's re-evaluate naming: The kernel signature includes N and K: maybe N is dimension of output columns (size of slice). K is inner dimension. Or perhaps N is columns of output (slice size), K is inner dimension of weight (K). M (rows) is token count and passed via `b_seq_start_loc` maybe used for indexing into input. Yet we have second param `b_seq_start_loc` which is a tensor giving start offset for each batch. That could be used to compute row index but not needed.\n\nThus M could be derived via token indices inside each batch: N could be slice size, K is inner dimension. Or N could be number of output columns for each slice (slice_size). Let's check typical naming: In original Triton SPMMs, they often use `N` for output columns and `K` for inner dimension; `M` for rows. So it's plausible they pass N (= slice size) and K (= hidden dim). So the kernel expects N = slice_size, K = hidden_dim.\n\nThus we need to pass slice_size as N; K is hidden_dim. Then token count (rows) is derived from other inputs (maybe `b_seq_start_loc` and `seq_lens`). But then we need to compute M dimension: maybe each batch has a sequence length, and we compute for each token individually; the kernel may iterate across tokens using program_id(0) indexing over rows for each batch sequence.\n\nBetter to stick to standard gemm signature: Treat N as columns of output slice (slice_size). K as inner dimension. Then M dimension we derive from token count (e.g., compute global row index from token index via program_id(0) and offsets).\n\nThus we need M dimension param as something else: maybe N argument is actually M (row count). But they didn't have M param; perhaps they treat M via `seq_len_tensor` and `b_seq_start_loc` to compute row offset within input and output pointer dynamically. In that case N maybe number of rows? Honestly ambiguous.\n\nNevertheless, we must output code that matches signature, but we can pick any meaning as long as we use the values we compute.\n\nGiven we have flexibility (no tests?), but we likely need to compile and run with actual data. So better align with typical gemm pattern: use M param (rows) as token count, N as slice size, K as hidden. But signature only contains N and K; we can treat N = slice_size, K = hidden_dim. The token count (M) can be computed as the sum of seq_lens for each batch? But we can get token_nums from input shape (inputs.shape[0]) and treat that as M. We can also pass that as `b_seq_start_loc` perhaps not needed.\n\nThus in kernel we compute offsets using `tl.program_id(0)` for rows (block row dimension). We need to compute row index: row_idx = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M). Then mask rows using row_idx < total_rows (which we can set as `b_seq_start_loc` length? Might need to know total_rows; we can compute it as input pointer shape length, but not passed; we can pass as a runtime constant: we can repurpose `seq_lens` but that's a pointer. We could compute sum(seq_lens) but not within kernel.\n\nSimplify: So we can just compute mask rows condition row_idx < seq_lens[0] maybe not correct.\n\nGiven the complexity, we might accept that our kernel works for contiguous tokens and ignore b_seq_start_loc and seq_lens. For mask we can simply clamp to pointer loop where we load with mask based on row_idx < N_ROWS which we could treat as `N` because maybe they set N to total rows. But if N is slice size, that would be wrong.\n\nThus we need to decide. Could treat N as total rows and K as hidden dimension, then we compute slice columns as BLOCK_N dimension without bound. In wrapper we launch grid for columns based on slice_size. That would be mismatched: we would compute mask for rows using N, leaving columns unbounded; but storing with out pointer will not have correct mask for columns: we need to mask on col idx < slice_size. We could compute column mask using `col_idx < (out_ptr.shape[1]?)` but not accessible.\n\nSimplify: we could store without mask; if last block overshoots, we will store some out-of-bounds memory, which is dangerous. But maybe we accept risk because grid shape is set to exact multiple of BLOCK_N >= slice_size, and we can clamp by mask using something like `mask_n = col_idx < slice_size` where slice_size is known via kernel constant we can compute as `N` if we pass slice_size as N, but the signature expects N param after out_ptr. So we can repurpose N as slice_size. In that case we also need row bound mask: we need M bound. Could use token_nums via maybe something else: we have b_seq_start_loc and seq_lens. Could compute max row index as token_nums. But not passed.\n\nThus we must have both row and column dimensions: use N as slice_dim (output columns) and we need additional dimension for rows. Possibly rows bound is computed from `b_seq_start_loc` and `seq_lens` combined: maybe we could compute cumulative sum until we exceed bound? Too complex.\n\nAlternative: we could ignore row mask and assume total rows (number of tokens) exactly divides by BLOCK_M and we launch exact grid. Then we don't need to mask rows either. So we can set grid size for rows as token_nums // BLOCK_M (rounded up) using `triton.cdiv(token_nums, BLOCK_M)`. If token_nums is divisible by BLOCK_M, we fine; else we need mask. We can compute row mask using `row_idx < token_nums` and we need token_nums passed as param. But we can pass token_nums as N param instead, then mask rows correctly and column mask based on slice size is not needed as we have no access.\n\nOk, we could treat N as token_nums (rows). Then column mask must be based on slice_size which we have as a constant, we could pass slice_size as `K` param if K is hidden dim then we lose hidden dimension. That can't.\n\nThus the kernel needs both row count and column count. But signature only contains one dimension (N) plus K (inner). So maybe they intend N to be row count, K to be column count (slice size), while inner dimension perhaps determined from `inputs.shape[1]`? But that would be same as N? Not make sense.\n\nLet's examine the signature order: input_ptr, lora_ptr, out_ptr, N, K, ... There is \"N\" before \"K\". In typical GEMM: (A_ptr, B_ptr, C_ptr, M, N, K). So \"M\" (rows) then \"N\" (columns) then \"K\" (inner). If they omitted M, but they include N and K, perhaps they treat M as something else: maybe M derived from original input shape (like token dimension). Indeed typical implementation uses \"M\" param as number of rows of A (or C). But they've omitted M possibly because it's derived from `seq_len` and `b_seq_start_loc`. So N may be columns of output (like hidden dimension?). K inner dimension. So M is derived.\n\nThus we should treat N as columns of output slice (slice_size). K as inner dimension (hidden dimension). M rows are derived from the per-batch sequence start and length.\n\nThus we need to compute the row index for each token using b_seq_start_loc (starting offset of each batch's token in flatten input) and seq_lens (length). But this may be too complex for this answer; can simplified: assume that rows are contiguous from 0 to token_nums-1, and that b_seq_start_loc and seq_lens are not needed; just calculate row index directly from program_id(0) * BLOCK_M + offset.\n\nThus we need to know token_nums: we could compute that as total number of rows = sum(seq_len_tensor). But we are not passing token count to kernel; maybe the wrapper knows token_nums and passes this as N? No N is slice_size.\n\nBut we could pass token_nums by using param N? No it's slice size.\n\nLet's reinterpret: maybe N is token_nums (rows) and K is input hidden dimension (inner), and the output column dimension slice size is passed as `slice_offset` + some known dimension? Hmm.\n\nBetter to look at typical LoRA multiplication: Suppose we have input: [batch*seq_len, hidden] shape (num_tokens, hidden_dim). LoRA weight B: shape (hidden_dim, rank) and LoRA weight A: shape (rank, hidden_dim). Sometimes we compute LoRA output: input @ (B @ A) maybe. Or more generally use LoRA B weight for expansion slice.\n\nBut the described kernel \"SGMV expand slice\" maybe does: output = input * (LoRA weight slice) for each batch. The slice is defined by slice_offset (starting column) and slice_size (maybe not needed inside kernel but used for launching grid). The LoRA weight matrix maybe of shape (max_rank, hidden_dim, hidden_dim) where each rank slice yields a slice for each token.\n\nNevertheless, due to unknown specifics, we can design a generic kernel:\n\n- Input matrix: P shape (M, K). We'll treat M = token_nums (passed as maybe from b_seq_start_loc shape? We'll compute from `b_seq_start_loc` and `seq_lens` inside wrapper to pass token_nums as N argument? Wait param N is used to bound rows in kernel. Let's set N = token_nums. Then we must pass K = hidden_dim. The slice size is determined by `cn_stride` perhaps.\n\nBut we have no param for output column size. We need to compute column mask using out shape. We could compute column size as `out_ptr.shape[1]` but shape not accessible inside kernel.\n\nWe could use `X` parameter `cm_stride` maybe corresponds to stride for rows in output. `cn_stride` is stride for columns. For a contiguous output (row-major), cn_stride = 1. Sometimes we can compute column length as out_ptr.numel() / (row_count * cn_stride) ??? Not available.\n\nThus we need to pass column size via a constant. The wrapper could embed it into compile-time constant by using `BLOCK_N` and launching appropriate number of program IDs; if we also compute mask column using something like `col_idx < BLOCK_N * n_blocks` etc. But we still require column bound condition, else last block may overshoot.\n\nWe can decide to use `col_idx < seq_len_tensor`? Not appropriate.\n\nOk perhaps the kernel expects to produce output for a specific slice, not a full matrix, and writes out in place using the provided `slice_offset`. The column index `col_idx` is within slice size, which is known as `BLOCK_N` only? In that case output may be flush only for 1 block per launch? But they likely iterate across column blocks.\n\nLet's check footnotes: The wrapper function sets up input parameters, ensuring contiguity and dimension correctness, and configures the Triton launch grid to execute the kernel. It validates tensor shapes and data types, adjusts weights if needed, and launches the kernel with a grid configured to cover the input data based on the maximum sequence length and batch size.\n\nThus wrapper will compute grid based on max sequence length and total token count. The kernel must accept each block of rows and columns.\n\nGiven limited spec, we can implement a generic kernel for GEMM of input (M,K) * LoRA weight (K,N) producing (M,N). That is typical. We'll interpret N as output columns dimension (slice size). We'll treat N = slice_size, K = hidden_dim, and M derived from inputs shape -> token_nums.\n\nThus we will reinterpret 'N' param as slice size (output columns). The kernel signature includes N and K; we will pass slice_size as N, and K as hidden_dim. That's more plausible.\n\nNow we need to provide a mask for rows: can compute token count by reading a pointer: maybe `b_seq_start_loc` contains total token count per batch? Or we can pass token count via 'N'? Not used.\n\nWe can pass token count via shape of input_ptr? Perhaps we can compute row count as `tl.load(b_seq_start_loc + batches)`? No.\n\nAlternatively, we can treat M as derived from `seq_lens` and `b_seq_start_loc`. For each block row index `pid0`, we can compute:\n\nbatch_id = tl.program_id(2) ???\n\nBut signature not including program_id for batch maybe.\n\nPerhaps they use 2D grid: program_id(0) for block row (M), program_id(1) for block col (N). Batch dimension is included via `b_seq_start_loc` and `seq_lens` functions to index into rows; for each token row we compute global row index = first token of batch + offset. So we need a way to map block row offsets (0...max_seq_length-1) to each batch. The kernel's block row dimension may be max_seq_length: they launch grid over max_seq_length and batch size separately. It may compute row offset within a batch using `b_seq_start_loc` and `seq_lens`. For each batch and every token offset (0..max_seq_length-1) we compute using offset = b_seq_start_loc[batch] + token_offset. If token_offset >= seq_len then mask off. That ensures we skip padding.\n\nThus kernel likely uses pid0 = token offset within sequence (maybe up to max_seq_length). pid1 = block column (for slice). pid2 = batch index. Check signature: they have b_seq_start_loc and seq_lens but not batch dimension in kernel args; they probably launch grid with 3 dimensions: (max_seq_len, N/N?), and batch dimension via `b_seq_start_loc` using separate kernel param maybe? Actually standard pattern: with batch dimension: use program_id(0) for batch, program_id(1) for M block row, program_id(2) for N block col. They might use 3D grid.\n\nGiven the signature includes b_seq_start_loc (1d tensor of length batches), seq_lens (1d token lengths per batch). So we can map batch dimension to program ID for block row dimension? They may have made batch axis program id 0 (batches). And program id 1 for rows (seq length). Program id 2 for columns (slice). But they only have 2 compile-time parameters for block sizes; typical matrix multiplication uses 2D grid: block rows (M) and block columns (N). For batched GEMM, you might use program ID 0 for batch dimension and program ID 1 for rows, program ID 2 for columns. They can combine 2 axes: block rows (max_seq_length) and batch dimension maybe collapsed? They could offset using b_seq_start_loc inside kernel: compute global row index = tl.program_id(0) + b_seq_start_loc[<batch>]? Not sure.\n\nBut we can still implement simple version: treat each row as global index sequentially across all tokens, ignoring batch.\n\nThus wrapper will flatten inputs across tokens accordingly.\n\nThus function signature for wrapper using batches and max_seq_length: maybe they want to compute per batch token start using b_seq_start_loc, which is shape (batches), and seq_len_tensor shape (batches). The wrapper passes these to kernel. Kernel could compute row index offset using program_id(0) for each batch, while each program ID per token offset. But we can ignore and just run over token_nums.\n\nGiven the question does not provide test harness, we can provide a plausible generic kernel and wrapper that matches signature and works for common shapes.\n\nThus plan:\n\n- Define block sizes: BLOCK_M = 128, BLOCK_N = 64, BLOCK_K = 32.\n\n- Kernel: compute acc = dot(A_block, B_block). Use loop over K dimension.\n\n- Use lora_indices to select slice offset in B (LoRA). For each row, we compute start = lora_indices[row_idx] * l0_stride.\n\n- Use slice_offset to offset output columns: out_ptr + (col_idx + slice_offset) * cn_stride\n\n- Provide mask for rows/cols based on N_rows and N_cols.\n\n- Compute N_rows = token_nums (provided as N param maybe). For column mask we can compute using `N_COL` param perhaps `N`? But N is columns? Hard.\n\nLet's decide to treat `N` as number of rows (token count) and `K` as hidden_dim; we also need column count (slice_size) available via `cn_stride` or `slice_offset`? Actually we can compute column count as out_ptr shape; but we can't know shape. So we treat `slice_offset` as start column; we compute last column index = slice_offset + col_idx, but need to mask based on output total columns which we can't know. But we can maybe compute mask using the target slice size passed as `N`? Could reassign: Let N be slice size; then row bound = token_nums. But we need token count as another param; maybe b_seq_start_loc length? Actually b_seq_start_loc shape = batches, seq_len shape = batches. We could compute total_rows = sum(seq_len). But kernel can't compute that easily.\n\nSimplify: we can pass token count as \"N\" and slice columns as \"K\"? But K used for inner dimension; no.\n\nTo avoid mask complications, we could allocate all rows and columns permutation but overshoot out-of-bounds risk.\n\nPotential approach: Use `tl.where` to mask loads and stores using logical operations; out-of-bounds memory may be read/written incorrectly if mask not applied. But we can restrict store mask based on both row and column bounds. For columns, we can compute mask using `col_idx < slice_size`, but we need slice_size; we can treat slice_size as `K`? Not correct.\n\nAlright maybe the signature is wrong: Actually \"N\" might be the number of output columns (slice_size) and \"K\" might be the inner dimension hidden size. So row bound = ???\n\nWe may compute row_bound = b_seq_start_loc shape maybe known as token_nums. But we could compute a runtime bound param: maybe `seq_len_tensor` contains per-batch seq lengths; we can compute `sum(seq_len_tensor)` on GPU? Not trivial.\n\nAlternative: For mask for rows we could fetch `b_seq_start_loc` to get start offset and `seq_lens` to check if row index < seq_len for each batch. Implementation could compute a mask for each row based on whether it is within its sequence length. Could be like:\n\nbatch_idx = tl.program_id(0) # (batch dimension)\nseq_len = tl.load(seq_lens + batch_idx)\n\nrow_offset = tl.arange(0, BLOCK_M)\nglobal_row_idx = tl.load(b_seq_start_loc + batch_idx) + row_offset\n\nmask_row = row_offset < seq_len\n\nBut we still need column bound: out columns is slice_size maybe not needed as we have for each row compute block column similarly.\n\nThus grid: (batches, max_seq_length, num_blocks_n) maybe.\n\nGiven signature includes b_seq_start_loc (tensor length batches) and seq_lens (tensor length batches), but not batch stride. So we design kernel with program ids:\n\npid_batch = tl.program_id(0)\npid_row = tl.program_id(1)\npid_col = tl.program_id(2)\n\nset:\n\nrow_offset = pid_row * BLOCK_M + tl.arange(0, BLOCK_M)\ncol_offset = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n\nmask_row = row_offset < seq_lens[pid_batch] // using seq_len for each batch\n\nNow compute global row index for load from input and output:\n\nglobal_row = tl.load(b_seq_start_loc + pid_batch) + row_offset\n\nNow compute input pointer offset: input_ptr + global_row * xm_stride + k_offset * xk_stride.\n\nSimilarly output pointer offset: out_ptr + global_row * cm_stride + (col_offset + slice_offset) * cn_stride\n\nNow B pointer: lora_ptr + base_offset per row (based on lora_indices) * l0_stride + k_offset*lora_k_stride + (col_offset)*lora_n_stride\n\nThus we can achieve proper mask.\n\nThus we require 3 program IDs.\n\nSo grid shape = (batches, triton.cdiv(max_seq_length, BLOCK_M), triton.cdiv(slice_size, BLOCK_N))\n\nDefine max_seq_length passed to wrapper.\n\nThus wrapper must compute grid as above.\n\nImplementation steps in wrapper:\n\n- Ensure inputs shape: inputs shape (total_tokens, hidden_dim). But we need to know arrangement: they are presumably flattened (batch x max_seq_length). Should be size (batches * max_seq_length, hidden_dim). For each batch, valid token count determined by seq_len_tensor.\n\n- Ensure lora_indices_tensor shape (total_tokens) (one per token). This picks LoRA slice index for each token.\n\n- Compute N = slice_size (provided by argument slice_size). But signature expects N param; maybe pass slice_size to kernel as N.\n\n- Compute K = hidden_dim.\n\n- Set stride arguments: xm_stride = inputs.stride(0), xk_stride = inputs.stride(1). Typically inputs are (token_nums, hidden_dim) row-major: stride0 = hidden_dim, stride1 = 1.\n\n- l0_stride = lora_b_weights.stride(0), lora_k_stride = lora_b_weights.stride(1), lora_n_stride = lora_b_weights.stride(2).\n\n- cm_stride = output_tensor.stride(0), cn_stride = output_tensor.stride(1).\n\n- slice_offset is passed to kernel.\n\n- Determine block sizes: Choose BLOCK_M = 128, BLOCK_N = 64, BLOCK_K = 32.\n\n- Determine EVEN_K = (K % BLOCK_K == 0).\n\n- Determine ADD_INPUTS = add_inputs (bool). Pass as 1 or 0 (constexpr bool). In Triton, boolean constexpr can be passed as integer 0/1.\n\n- Determine CAST_TYPE based on dtype: If inputs.dtype == torch.float16, then CAST_TYPE = 0 (float16). If torch.float32, CAST_TYPE = 1 (float32). else fallback to float16.\n\n- Launch kernel:\n\ngrid = (batches, triton.cdiv(max_seq_length, BLOCK_M), triton.cdiv(slice_size, BLOCK_N))\n\n_sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n slice_size, # N param: I think slice_size used as output columns count (N)\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_K=BLOCK_K,\n EVEN_K=K % BLOCK_K == 0,\n ADD_INPUTS=1 if add_inputs else 0,\n CAST_TYPE=CAST_TYPE,\n)\n\nThe kernel must accept the above arguments ordering.\n\nImplementation of kernel:\n\nDefine block row offset:\n\npid_batch = tl.program_id(0)\npid_row = tl.program_id(1) # block row\npid_col = tl.program_id(2) # block col\n\nrow_offsets = pid_row * BLOCK_M + tl.arange(0, BLOCK_M)\ncol_offsets = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n\nseq_len = tl.load(seq_lens + pid_batch)\nstart_loc = tl.load(b_seq_start_loc + pid_batch)\n\nvalid_row = row_offsets < seq_len\n\nglobal_row = start_loc + row_offsets\n\nNow compute K loop.\n\nDefine acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\nDefine k_offsets loop.\n\nCompute K dimension length from K param.\n\nImplementation details:\n\n# Load lora_indices for each valid row.\n\nlora_idx = tl.load(lora_indices + global_row, mask=valid_row, other=0)\n\nSince lora_idx scalar per row, shape (BLOCK_M). We'll broadcast.\n\nNow loop:\n\nif EVEN_K:\n # static loop for K dimension step BLOCK_K\n for k in range(0, K, BLOCK_K):\n k_offsets = k + tl.arange(0, BLOCK_K)\n k_mask = k_offsets < K\n # load A block\n a_ptr = input_ptr + global_row[:, None] * xm_stride + k_offsets[None, :] * xk_stride\n a = tl.load(a_ptr, mask=valid_row[:, None] & k_mask[None, :], other=0.0)\n # compute B pointer base offset using lora indices:\n # base = lora_idx * l0_stride\n base_offset = lora_idx * l0_stride\n # B offsets: base_offset[:, None] + k_offsets[:, None] * lora_k_stride + col_offsets[None, :] * lora_n_stride\n b_ptr = lora_ptr + base_offset[:, None] + k_offsets[:, None] * lora_k_stride + col_offsets[None, :] * lora_n_stride\n b = tl.load(b_ptr, mask=k_mask[:, None] & (col_offsets[None, :] < N), other=0.0) # need N = slice_size\n # accumulate\n acc += tl.dot(a, b)\nelse:\n # handle tail\n for k in range(0, K, BLOCK_K):\n # same except k_mask = ...\n ...\n\nWe'll need to access N (slice_size) maybe used for mask: col_offsets < N. N param is slice_size.\n\nThus we need to pass N to kernel (slice_size). That's actual signature: N param is provided. So we can use it for column mask.\n\nThus we treat signature's N as slice_size (output columns count). K is inner dimension.\n\nThus row bound is derived using seq_lens and b_seq_start_loc. As we compute row offsets up to seq_len. Good.\n\nNow fill in masks.\n\nAfter loop, if ADD_INPUTS:\n\nif ADD_INPUTS:\n # load input block (maybe a part of original input?), but shape?? Could add the initial input values (like residual). Might be adding the slice of input (maybe part of hidden?), but since input shape (token_nums, hidden) cannot match output slice shape (slice_size) unless slice_offset aligns columns within hidden dimension. Could treat ADD_INPUTS as adding the original input values of the slice, i.e., the same slice columns from input (like part of hidden vector). So we can load a slice of input at the same column location (col_offsets + slice_offset) from input (since input shape hidden_dim, and slice_offset within hidden dimension). We can compute add_block = tl.load(input_ptr + global_row[:, None] * xm_stride + (col_offsets + slice_offset)[None,:] * xk_stride). This uses same xk_stride for columns dimension. Add to acc.\n\nThus we compute extra_input = tl.load(input_ptr + global_row[:, None] * xm_stride + (col_offsets + slice_offset)[None,:] * xk_stride, mask=valid_row[:, None] & col_mask[None,:], other=0.0) and add to acc.\n\nNow store acc to out.\n\nstore_ptr = out_ptr + global_row[:, None] * cm_stride + (col_offsets + slice_offset)[None,:] * cn_stride\n\ntl.store(store_ptr, acc, mask=valid_row[:, None] & col_mask[None,:])\n\nNote: We must ensure dtype matches out type: we can cast acc to out_ptr type via `acc.to(tl.float32).to(... )`? Actually tl.store expects correct type. If out dtype maybe float16, we need to cast accordingly. We can use `acc.to(tl.float16)` if out dtype is fp16. Not straightforward to get dtype of out_ptr. But we can just store as float32; Triton will truncate if output dtype is fp16? Probably need matching dtype. So we need to convert based on CAST_TYPE: if CAST_TYPE == 0 (fp16), convert acc to tl.float16: acc = acc.to(tl.float16). else keep as float32.\n\nThus:\n\nif CAST_TYPE == 0:\n acc_out = acc.to(tl.float16)\nelse:\n acc_out = acc\n\nNow set.\n\nNote: at compile time we can branch.\n\nNow define col_mask = col_offsets < N.\n\nNow implementation details for loop.\n\nNow we need to make sure that data types of load matches pointer dtype. For pointer loads, we can use `tl.load(..., dtype=... )`. But if we load into float16, propagate accordingly.\n\nWe'll keep loads as default (deduce from pointer). For accumulations we cast to float32.\n\nNow code.\n\nNow final wrapper:\n\nDefine:\n\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n token_nums: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n):\n # Ensure contiguous\n inputs = inputs.contiguous()\n lora_b_weights = lora_b_weights.contiguous()\n output_tensor = output_tensor.contiguous()\n b_seq_start_loc = b_seq_start_loc.contiguous()\n seq_len_tensor = seq_len_tensor.contiguous()\n lora_indices_tensor = lora_indices_tensor.contiguous()\n\n # Validate dimensions\n assert inputs.shape[0] == token_nums, \"inputs token dimension mismatch\"\n assert output_tensor.shape[0] == token_nums, \"output token dimension mismatch\"\n assert output_tensor.shape[1] == slice_size, \"output slice size mismatch\"\n assert b_seq_start_loc.shape[0] == batches, \"b_seq_start_loc batch dimension mismatch\"\n assert seq_len_tensor.shape[0] == batches, \"seq_len_tensor batch dimension mismatch\"\n assert lora_indices_tensor.shape[0] == token_nums, \"lora_indices length mismatch\"\n\n # Determine data type casting flag\n CAST_TYPE = 0 # default float16\n if inputs.dtype == torch.float32:\n CAST_TYPE = 1\n elif inputs.dtype == torch.float16:\n CAST_TYPE = 0\n else:\n # fallback: treat as float16\n CAST_TYPE = 0\n\n # Block sizes - can be tuned\n BLOCK_M = 128\n BLOCK_N = 64\n BLOCK_K = 32\n\n EVEN_K = (inputs.shape[1] % BLOCK_K) == 0\n\n # Compute grid: batch dimension, rows dimension, columns dimension\n grid = (\n batches,\n triton.cdiv(max_seq_length, BLOCK_M),\n triton.cdiv(slice_size, BLOCK_N),\n )\n\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n slice_size, # N\n inputs.shape[1], # K (hidden dimension)\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_K=BLOCK_K,\n EVEN_K=1 if EVEN_K else 0,\n ADD_INPUTS=1 if add_inputs else 0,\n CAST_TYPE=CAST_TYPE,\n )\n return output_tensor\n\nPotential issues:\n\n- The kernel expects arguments N,K as ints; we pass slice_size, hidden_dim.\n\n- The kernel uses `xk_stride` for input's column stride and `lora_n_stride` for output column stride inside LoRA weight.\n\n- We need to confirm pointer shapes: lora_b_weights shape (max_rank, hidden_dim, slice_size). Then lora_k_stride = stride over hidden_dim dimension (second dimension), lora_n_stride = stride over slice_size dimension (third dimension). l0_stride = stride over max_rank dimension (first dimension). Good.\n\n- For B load, we need to broadcast base_offset per row; shape will be (BLOCK_M,). We can use `base_offset[:, None]` to broadcast across columns.\n\n- In loading B, we need to also broadcast across K dimension (block_k). We wrote b_ptr = lora_ptr + base_offset[:, None] + k_offsets[:, None] * lora_k_stride + col_offsets[None, :] * lora_n_stride. That's elementwise offset shape (BLOCK_M, BLOCK_N)?? Actually using broadcasting we have base_offset shape (BLOCK_M,1), k_offsets shape (BLOCK_K,1). Wait we need shape (BLOCK_M, BLOCK_K) and (BLOCK_K, BLOCK_N). Actually we need a pointer to B of shape (BLOCK_K, BLOCK_N). B block shape should be (BLOCK_K, BLOCK_N); it does not depend on row (i). However we have a row-dependent slice index (lora_idx) selecting which LoRA slice to use. So B pointer offset includes base_offset_i (selected slice) but B block for each row i uses same k offsets and same column offsets. That yields per-row base offset offset_i. So B block shape is (BLOCK_M, BLOCK_K, BLOCK_N) but we want to use dot that can handle per-row base? Actually `tl.dot` expects A shape (M,K) and B shape (K,N). It does not support per-row B slice. However we can compute dot per row i individually? You can compute a block dot using broadcasting: tl.dot(A, B) where A shape (M, K) and B shape (K, N) yields (M,N). This works if B is same across all rows. But if we have per-row different B due to different lora indices, we cannot use `tl.dot` directly. Instead we need to compute dot using manual accumulation across K loops: For each k, multiply a[i,k] * b[i,k,j] and accumulate across k. This works because we can load A and B blocks that differ per row.\n\nThus we must avoid tl.dot. Instead implement manual multiplication using `tl.broadcasted` approach:\n\nWe can compute for each k block:\n\n- Load a_block: shape (BLOCK_M, BLOCK_K)\n\n- Load b_block: shape (BLOCK_M, BLOCK_K, BLOCK_N) ? To multiply each row's a with its own b slice. Actually we need to multiply a[i,k] by b[i,k,j] and sum across k. That is like per-row matrix calculation with per-row B slice. Equivalent to using `tl.dot` if B is same across rows, but not if B varies per row.\n\nSimplify: Since we have per-row different LoRA slice (index), we can compute per row dot manually by iterating across K dimension and use a vectorized outer product: For each row we have a vector a_i (size BLOCK_K) and B_i matrix (BLOCK_K x BLOCK_N). Multiplying yields row_i_out = a_i @ B_i. This can be done by `tl.dot(a_i, B_i)` where a_i is shape (BLOCK_K) and B_i shape (BLOCK_K, BLOCK_N). But we need vectorization across rows: Could use tl.dot on per-row slices after broadcasting? Options:\n\n- Expand a to shape (BLOCK_M, BLOCK_K, 1) and b to shape (BLOCK_M, BLOCK_K, BLOCK_N) then perform elementwise multiplication and sum across K dim.\n\nWe can compute:\n\na_exp = a[:, :, None] # shape (M, K, 1)\nb_exp = b # shape (M, K, N)\nprod = a_exp * b_exp # broadcast a across N dimension -> (M, K, N)\nacc += tl.sum(prod, axis=1) # sum over K to get (M,N). In Triton we can use `tl.sum(prod, axis=1)`.\n\nBut `tl.sum` can be used to reduce along an axis? I think `tl.sum` can reduce vectorized axis but for block we might need to use `tl.sum` on 1D vector. However we can also use manual loop across k offset dimension and accumulate using `tl.dot`.\n\nBetter to use `tl.dot` per row using `tl.dot(a[i], B[i])` is not vectorized; would need a for-loop over rows. But we can vectorize using `tl.dot` with fused broadcast if we restructure data.\n\nAlternative: Use `tl.dot` if lora_indices is same across a block, our block size may be such that each block contains rows from same slice index. But not guaranteed; but we could assume due to grouping perhaps they are same. In many LoRA usage, each token may have its own LoRA index; but if we group by BLOCK_M rows and they might be different, the kernel might still work if we load B individually per row and compute dot via `tl.dot` after broadcasting? Not possible.\n\nThus simplest: assume lora_indices are constant across a block. We could load lora_idx via first row of block and broadcast to other rows. This will produce wrong results if indices differ, but we can accept as placeholder.\n\nAlternatively, we can implement per-row dot manually with elementwise multiplication and reduction across K. Let's try to implement that.\n\nAssume block_k = BLOCK_K, block_n = BLOCK_N, block_m = BLOCK_M.\n\nGiven A shape (M, K). For each row i in block and each col j in block, we compute sum_{k} A[i,k] * B[i,k,j].\n\nImplementation:\n\nacc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\nLoop over k:\n\nDefine k_indices = k_base + tl.arange(0, BLOCK_K)\n\nmask_k = k_indices < K\n\n# Load A elements shape (BLOCK_M, BLOCK_K) with mask on rows and K.\n\na = tl.load(input_ptr + global_row[:, None] * xm_stride + k_indices[None,:] * xk_stride, mask=valid_row[:, None] & mask_k[None,:], other=0.0)\n\n# Load B elements shape (BLOCK_M, BLOCK_K, BLOCK_N). That's three dimensions (maybe need to flatten to 2D? We can compute pointer for each row and each k and each n). But we can load B as 2D block per row? With tl.load we can load multi-dimensional arrays using pointer offsets and shape determined by broadcasting. For multi-dim blocks, we can pass offsets arrays of shape (BLOCK_M, BLOCK_K, BLOCK_N). But Triton supports up to 2d loads? Actually tl.load can accept an address computed as pointer + offsets with shape being broadcasted; they are elementwise pointer offsets for each element in returned block. It can handle arbitrary shapes like 2D; to get 3D you need to flatten index. However you can use `tl.reshape` to shape after loading.\n\nSimplify: We can loop over k and compute product of a[i, k] with all n values (vector). For each k, we load B subblock for that k across all rows? B shape (M, K, N). For each k offset we need to load B entries for all rows i and all columns n.\n\nWe can compute pointer offsets: offset = (global_row + ????) Actually LoRA B weight shape (max_rank, K, N) but we have lora_idx per row. For each row i, we need to load B_i block for all n columns and for current k index. For a given k_offset, B[i, k_offset, col] = lora_weights[lora_idx[i], k_offset, col].\n\nThus we need to load B_i block for each row at given k and all columns. That is a 2D block per row of shape (1, N). Could gather across rows producing shape (BLOCK_M, N). Using pointer offsets we can compute per row base offset + k_offset*lora_k_stride + col_offset*lora_n_stride. This yields shape (BLOCK_M, BLOCK_N). So B block per k is 2D.\n\nThus for each k_offset (scalar offset), we can load B_block = tl.load(lora_ptr + base_offset[:, None] + k_offset*lora_k_stride + col_offsets[None,:] * lora_n_stride) shape (BLOCK_M, BLOCK_N) with mask over rows and columns.\n\nThus we can compute acc += a[:, k_idx][:, None] * B_block (broadcast across columns). Implementation in loop:\n\na_val = a[:, k_rel] # shape (BLOCK_M)\n# broadcast to shape (BLOCK_M, BLOCK_N)\na_val_exp = a_val[:, None] # shape (BLOCK_M,1)\nacc += a_val_exp * B_block\n\nThus we loops over K in steps of BLOCK_K (or 1). To improve performance, we can load a block of K values for each row (size BLOCK_K) and for each value loop over to multiply B subblocks sequentially. But could be slow. Instead we can vectorize by loading A block of shape (BLOCK_M, BLOCK_K) and B block per k using arithmetic and sum across K dimension via `tl.dot` but variable B across rows. However we could restructure B to make B shape (BLOCK_K, BLOCK_N) per row and then dot. Not possible.\n\nSimplify: Use 2 outer loops: over k offset loops, inner compute multiplication per k offset value. This yields O(K) operations (K may be 4096?), but maybe acceptable for demonstration.\n\nBut we can attempt to speed up by loading A block and B block for all K and N simultaneously: we can load B as a block of shape (BLOCK_M, BLOCK_K, BLOCK_N) using pointer offsets with broadcast across rows and columns and K within this block. But loading 3D high memory may exceed bandwidth.\n\nTrition's `tl.load` returns block shape determined by shape of offset arrays; you can produce 2D offset arrays of shape (BLOCK_M, BLOCK_K*N?), but not straightforward.\n\nLet's propose naive but correct approach: loop over each k (scalar) and accumulate.\n\nPseudo:\n\n# outer loop over k\nfor k_iter in range(0, K):\n k_idx = k_iter\n # mask for rows valid for this row (global)\n # load a values\n a_val = tl.load(input_ptr + global_row[:, None] * xm_stride + k_idx * xk_stride, mask=valid_row, other=0.0) # shape (BLOCK_M)\n # load B block (for all rows) using lora_idx\n base_offset = lora_idx * l0_stride\n b_ptr = lora_ptr + base_offset[:, None] + k_idx * lora_k_stride + col_offsets[None,:] * lora_n_stride\n b_val = tl.load(b_ptr, mask=valid_row[:, None] & col_mask[None,:], other=0.0) # shape (BLOCK_M, BLOCK_N)\n acc += a_val[:, None] * b_val\n\nBut this loops over K many iterations (e.g., hidden_dim=4096). That's huge overhead for GPU. Not practical.\n\nWe need to block K dimension similarly to typical GEMM approach: use BLOCK_K to load a subblock of K to improve memory access.\n\nWe can implement:\n\nfor k_base in range(0, K, BLOCK_K):\n cur_k = k_base + tl.arange(0, BLOCK_K) # shape (BLOCK_K)\n k_mask = cur_k < K\n # load A block: shape (BLOCK_M, BLOCK_K)\n a_block = tl.load(input_ptr + global_row[:, None] * xm_stride + cur_k[None, :] * xk_stride, mask=valid_row[:, None] & k_mask[None, :], other=0.0)\n # load B block: shape (BLOCK_M, BLOCK_K, BLOCK_N)\n # For each k in block we need B values across N for each row.\n # We can load B for each k individually within inner loop\n # For small BLOCK_K (e.g., 32), we can do inner loop across k offset values:\n for kk in range(0, BLOCK_K):\n actual_k = k_base + kk\n mask_k = actual_k < K\n a_val = a_block[:, kk] # shape (BLOCK_M)\n # compute B pointer offset for this k\n b_ptr = lora_ptr + base_offset[:, None] + (actual_k) * lora_k_stride + col_offsets[None,:] * lora_n_stride\n b_val = tl.load(b_ptr, mask=valid_row[:, None] & col_mask[None,:] & mask_k, other=0.0) # shape (BLOCK_M, BLOCK_N)\n acc += a_val[:, None] * b_val\n\nBut this adds inner loop over BLOCK_K per outer loop, ending up O(K*BLOCK_K) similar to O(K). Actually it's double loop: outer iterates K/BLOCK_K times, inner iterates BLOCK_K times => total K iterations, same as scalar loop. So no improvement.\n\nBetter: Use vectorized broadcasting: we could load B block for all ks and all ns at once using pointer arithmetic with shape (BLOCK_M * BLOCK_K, BLOCK_N). Then we could multiply a_block expanded across columns and sum across K dimension with `tl.dot`. But we need per-row different B slices per each k.\n\nWe could reshape B block into 2D with dimensions (BLOCK_M*BLOCK_K, BLOCK_N), and flatten A block accordingly to (BLOCK_M, BLOCK_K). But multiplication across K dimension not matrix multiplication but per-row elementwise product across K then sum for each row-col.\n\nActually we can compute:\n\nLet a_block shape (M,K). Let b_block shape (M,K,N). Equivalent to a linear operation: for each row i, output[i,n] = sum_k a[i,k] * b[i,k,n].\n\nWe can use batched matrix multiplication with batch dimension M: treat a[i,:] as (1,K) vector, b[i,:,:] as (K,N) matrix, and compute each individually. Triton has ability to compute per-row dot maybe via broadcast and sum.\n\nOne approach: do elementwise multiplication then sum across K dimension:\n\nprod = a_block[:, :, None] * b_block # shape (M,K,N)\nacc = tl.sum(prod, axis=1) # sum over K dimension\n\nBut we need to load b_block as (M,K,N). We can achieve by using pointer offsets for each row and each k and each n. That would be a bigger memory load but could be okay if K and N are not huge. For block sizes we can load b_block for each k block of shape (M, BLOCK_K, BLOCK_N). That's manageable.\n\nThus we can load b_block using vectorized loads: For each row i and each k offset within block, we need to load slice across N columns.\n\nWe can compute pointer offsets shape (BLOCK_M, BLOCK_K, BLOCK_N) as:\n\nrow_offsets = global_row[:, None, None] # shape (M,1,1)\nk_offsets = (k_base + tl.arange(0, BLOCK_K))[None,:,None] # shape (1,K,1)\ncol_offsets = col_offsets[None, None,:] # shape (1,1,N)\n\nThen pointer = lora_ptr + (row_offsets * ????) Actually pointer offset for lora includes l0_stride per row for lora_idx (different per row), plus k_offset * lora_k_stride, plus col_offset * lora_n_stride.\n\nThe per-row base offset (slice index) can be applied using broadcasting:\n\nbase_offset = lora_idx * l0_stride # shape (M,)\nbase_offset = base_offset[:, None, None] # shape (M,1,1)\n\nThus final pointer offset = base_offset + k_offsets * lora_k_stride + col_offsets * lora_n_stride\n\nNow load with mask: mask = valid_row[:,None,None] & k_mask[None,:,None] & col_mask[None,None,:]\n\nThus b_block shape (M,K,N) returns.\n\nThen compute product:\n\na_block_exp = a_block[:, :, None] # (M,K,1)\nprod = a_block_exp * b_block # (M,K,N)\n\nacc += tl.sum(prod, axis=1) # sum across K and accumulate across blocks.\n\nThus we have two outer loops over K blocks.\n\nImplementation steps:\n\n- Determine col_mask = col_offsets < N (slice_size)\n\n- Determine k_offsets per block (k_offset0 + tl.arange(0,BLOCK_K))\n\n- Determine k_mask = k_offsets < K\n\n- Compute a_block: load with shape (BLOCK_M, BLOCK_K) using mask valid_row & k_mask\n\n- Load b_block using shape (BLOCK_M, BLOCK_K, BLOCK_N). This is a static multi-dim load.\n\nBut can we have 3D loads with tl.load? I think Triton supports loads of arbitrary dimensional shape defined by the shape of offsets argument, which can be up to 2 dimensions? I'm not entirely sure. I think tl.load can accept any shape array of offsets up to e.g., 2 dimensions for block ptr maybe also using `tl.make_block_ptr`. There is a `tl.make_block_ptr` which allows shape and strides for block loads. It can load multi-dimensional data. Actually `tl.load` can take a `tl.make_block_ptr` pointer that contains shape and strides; you can then load a block of any dimension up to say 2? Let's recall: `tl.make_block_ptr` returns a pointer with shape and strides and you can use tl.load with pointer and shape? I think you use `tl.make_block_ptr` for 2D pointer (standard). I'm not sure about 3D.\n\nIf we can't get 3D loads easily, we may need to loop over k subblock within each block and load B subblock of shape (BLOCK_M, BLOCK_N) for each k offset, as we attempted earlier. That would be K/BLOCK_K * BLOCK_K loads, which equals K loads. That's inefficient.\n\nBut maybe we can restructure to use 2D load using `tl.make_block_ptr` where the b_ptr pointer uses block shape (BLOCK_K, BLOCK_N) and we can load for each row separately, but we need per-rowness. However, we can use `tl.view` and broadcast.\n\nBetter approach: we could compute out = tl.dot(A_block, B_block) where B_block is same across rows but we can treat LoRA weight as being same for all rows in block; this would be correct if lora_indices per block are same (maybe they are; could be we group rows by same lora slice). In practice LoRA indices probably point to rank indices per token; tokens of same batch may share same LoRA index (maybe not). But we could group tokens such that each block corresponds to a single lora index; then kernel would be correct.\n\nGiven the complexity, we may implement the simpler version assuming lora_indices are uniform across a block, and add a note that this kernel assumes that all rows in the block share the same LoRA slice. That would still compile and be functional if conditions hold. Since the request didn't require exact exact semantics, we just need to produce functional code.\n\nThus implement:\n\n- Load lora_idx for the first row in block: lora_idx = tl.load(lora_indices + global_row[0]) (maybe need to cast to int). Then broadcast across rows.\n\nThis yields B block of shape (BLOCK_K, BLOCK_N) (same for all rows). Use tl.dot for each row block.\n\nImplementation:\n\n# compute block row & column offsets as before\n\nglobal_row = start_loc + row_offsets\n\n# load lora_idx for each row? We assume uniform across block: choose first valid row\nfirst_valid = tl.where(valid_row)[0] # obtains first index? Actually we can compute with tl.max? Simpler: lora_idx = tl.load(lora_indices + global_row[0]) # just first row.\n\nBut if block contains invalid rows at top, we must handle. Simpler: load lora_idx for each row as array shape (BLOCK_M) and then compute for each row separate B block via tl.dot not possible.\n\nSimplify: ignore lora_indices completely and treat LoRA weight constant across rows (maybe it's a single weight). Or treat lora_indices as shift offset for B pointer per row but we still use tl.dot with broadcasting? We might use tl.dot(A_block, B_block) where B_block is shape (BLOCK_K, BLOCK_N) and we broadcast across rows automatically: each row i of A_block dot B_block yields result_i. That's fine. So we can implement that; not using lora_indices.\n\nBut signature includes lora_indices; we can load it for completeness but not use.\n\nThus:\n\n- Base pointer for LoRA weight: lora_ptr + offset = lora_indices[0] * l0_stride maybe.\n\nSimplify: use lora_indices[0] (first row). Or we could assume lora_indices is a pointer to the start of LoRA slice; maybe they are already offset for each row block.\n\nBetter: we could compute weight base pointer at location: lora_ptr + slice_offset*lora_n_stride? Not needed.\n\nSimplify: LoRA weight pointer is already pointing to the correct slice according to wrapper. Use lora_ptr directly as B matrix of shape (K,N). So we ignore lora_indices and l0_stride.\n\nThus kernel becomes a regular GEMM but optionally adds input slice.\n\nEven though that doesn't use all parameters, the signature includes them but they are unused. That's acceptable as requirement does not require using them, just function defined with these arguments.\n\nThus final kernel:\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n # Determine block coords\n pid0 = tl.program_id(0) # batch index (unused)\n pid_m = tl.program_id(1) # block row index\n pid_n = tl.program_id(2) # block column index\n\n # Compute row offsets (global token index)\n # For simplicity ignore batch dimension; treat batch index as offset into global rows\n # Assume inputs are flattened across batches.\n row_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n col_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # Compute masks for rows and cols\n row_mask = row_offset < N # N is total rows (token count)\n col_mask = col_offset < N # N also used for column bound? Actually column bound is slice size, we treat N as slice size too. We'll compute col_mask = col_offset < N (slice size).\n \n # Actually we have two uses for N: row bound and slice size. To avoid confusion, we assume N is row bound (token count). \n # But column size = slice_offset + BLOCK_N? no.\n\n # Let's rename: use N_rows = N (token count)\n # Use N_cols = K ??? No.\n\n # Instead treat column bound as col_offset < N (slice size) and row bound derived from seq_lens in a more generic way.\n\n # Collapse: using provided parameters we trust user passes appropriate values.\n\n # Load A block\n a_ptr = input_ptr + row_offset[:, None] * xm_stride + col_offset[None, :] * xk_stride\n a = tl.load(a_ptr, mask=row_mask[:, None] & col_mask[None, :], other=0.0)\n\n # Load B block (LoRA weights)\n b_ptr = lora_ptr + col_offset[:, None] * lora_k_stride + row_offset[None, :] * lora_n_stride\n b = tl.load(b_ptr, mask=col_mask[:, None] & row_mask[None, :], other=0.0)\n\n # Compute dot product: treat a as shape (BLOCK_M, BLOCK_K) and b as shape (BLOCK_K, BLOCK_N)\n # But our a and b shapes might be transposed. We'll treat:\n acc = tl.dot(a, b)\n\n # Optionally add inputs slice\n if ADD_INPUTS:\n # Load additional slice from inputs at same location (using slice_offset)\n add_ptr = input_ptr + row_offset[:, None] * xm_stride + (col_offset + slice_offset)[None, :] * xk_stride\n add = tl.load(add_ptr, mask=row_mask[:, None] & col_mask[None, :], other=0.0)\n acc += add\n\n # Cast result if needed\n if CAST_TYPE == 0:\n # Cast to float16\n acc = acc.to(tl.float16)\n else:\n # Keep float32\n acc = acc.to(tl.float32)\n\n # Store result\n out_ptr_off = out_ptr + row_offset[:, None] * cm_stride + (col_offset + slice_offset)[None, :] * cn_stride\n tl.store(out_ptr_off, acc, mask=row_mask[:, None] & col_mask[None, :])\n\nBut this is likely wrong due to shape mismatch (A and B shapes). Also pointer offsets for B may be wrong.\n\nBetter to restructure:\n\n- Input A shape (M,K) where K = hidden_dim (inner). Use row_offset for rows, k_offset for inner dimension.\n\n- LoRA B shape (K,N) where N = slice_size. So we need column offsets based on N (slice_size). Use col_offset for output columns.\n\nThus A load: row_offset * xm_stride + k_offset * xk_stride.\n\nWe need loops over k (inner). Use tl.dot.\n\nBut we need to use tl.dot(A_block, B_block) where A_block shape (BLOCK_M, BLOCK_K) and B_block shape (BLOCK_K, BLOCK_N). Overlap K dimension with BLOCK_K. So we need to iterate over K dimension using loops to accumulate.\n\nSimplify: Use built-in `tl.dot` within a loop over K dimension blocks. That's typical pattern:\n\nacc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\nfor k in range(0, K, BLOCK_K):\n cur_k = k + tl.arange(0, BLOCK_K)\n k_mask = cur_k < K\n a = tl.load(input_ptr + row_offset[:, None] * xm_stride + cur_k[None, :] * xk_stride, mask=row_mask[:, None] & k_mask[None, :], other=0.0)\n # Load B block:\n b = tl.load(lora_ptr + cur_k[:, None] * lora_k_stride + col_offset[None, :] * lora_n_stride, mask=k_mask[:, None] & col_mask[None, :], other=0.0)\n # Here b shape (BLOCK_K, BLOCK_N)\n # Use tl.dot\n acc += tl.dot(a, b)\n\nAfter loop, optionally add input.\n\nNow define mask for col: col_mask = col_offset < N (where N is slice_size). Might need separate column dimension argument; we need to reinterpret N param as slice_size.\n\nThus row bound token count we need to get from b_seq_start_loc and seq_lens but we can ignore that; assume flatten tokens and kernel only runs for valid rows using row_mask based on N = token_nums. However we also need column bound for slice size; but we can treat N (first param) as slice_size and row bound derived from b_seq_start_loc maybe. But row bound param not present.\n\nThus we need to change definition: we can treat N as slice_size, and we can use token_nums (flattened) as number of rows but not passed. Could use seq_len_tensor plus b_seq_start_loc to compute row bound per batch, but skipping that will make mask for rows may be missing.\n\nWe can reduce complexity by launching grid exactly matching token count rows (i.e., program_id(1) over token count), so we know all rows are valid, we need not mask rows.\n\nThus we could ignore row mask. For simplicity, we launch grid such that each block has full rows (i.e., token_num divisible by BLOCK_M). Or we compute row mask using row_offset < token_nums where token_nums is passed via input (maybe using N? but N is used for columns). Could compute token count from inputs.shape[0] but kernel cannot read shape. However we can pass token count via N if we used it for columns erroneously; but we need both token count and slice size. Without modifying signature we cannot pass both; we would need to reuse N to represent token count, and column mask will be limited by slice_size passed via maybe `slice_offset`? No.\n\nThus we need to decide which dimension to treat as token rows. The wrapper can pass token_nums to both N and K? Maybe they intentionally used N only for number of rows and K for inner dimension; output columns are slice size determined by slice_offset difference? Actually they may have output dimension equal to hidden_size (same as N maybe). Let's re-evaluate: The original snippet from request says \"The kernel performs a specialized form of matrix multiplication involving sparse generalized matrix-vector multiplication (SGMV). It operates on blocks of data defined by the parameters BLOCK_M, BLOCK_N, and BLOCK_K.\" In typical GEMM, they have A (M,K), B (K,N), C (M,N). So they'd pass M,N,K as shape parameters. Here they omitted M, maybe using token_nums via b_seq_start_loc and seq_lens. So that suggests that M is variable per batch, derived inside kernel using b_seq_start_loc and seq_len.\n\nThus N param is number of columns (output dimension). K param is inner dimension. So slice_offset is offset into output columns (like starting column). Indeed slice_offset required for offsetting output columns.\n\nTherefore column bound = slice_offset + N? Or N maybe total columns count? Actually no, N is number of columns to compute (slice size). So we will treat N as slice_size. Good.\n\nThus output column Bound is slice_offset + N.\n\nThus we need to compute col_mask = col_offset < N (slice size). We'll offset the pointer accordingly by slice_offset.\n\nRow bound: the kernel should compute rows based on b_seq_start_loc and seq_lens. We can use row_offset = ???. We have program IDs for batch and token offset within batch. Considering grid dims:\n\n- dimension 0: batch (pid0)\n- dimension 1: block row (pid1) within sequence (max_seq_length)\n- dimension 2: block column (pid2) (slice columns)\n\nThus row_offset = pid1 * BLOCK_M + tl.arange(0,BLOCK_M) // relative offset within a batch.\n\nNow we need to compute global row index = start location for batch + row_offset. b_seq_start_loc holds the starting index of each batch in flatten token space. seq_lens holds max sequence length for each batch.\n\nThus row_mask = row_offset < seq_len for this batch. We can compute seq_len = tl.load(seq_lens + pid0) (scalar). row_mask = row_offset < seq_len.\n\nNow compute absolute row index: start_loc = tl.load(b_seq_start_loc + pid0) (starting offset), then global_row = start_loc + row_offset.\n\nNow use global_row to index into input and output.\n\nThus we can have consistent masks and pointer offsets.\n\nThus generate kernel:\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(...):\n pid_batch = tl.program_id(0)\n pid_row = tl.program_id(1)\n pid_col = tl.program_id(2)\n\n # Load start location and seq length for this batch\n start = tl.load(b_seq_start_loc + pid_batch)\n seq_len = tl.load(seq_lens + pid_batch)\n\n # Compute offsets inside batch\n offs_row = pid_row * BLOCK_M + tl.arange(0, BLOCK_M)\n mask_row = offs_row < seq_len\n\n # Compute global row index\n row_idx = start + offs_row\n\n # Compute column offsets within slice\n offs_col = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n mask_col = offs_col < N # N = slice size\n\n # Compute mask for computation (both row and col must be valid)\n mask = mask_row[:, None] & mask_col[None, :]\n\n # Define accumulation\n acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\n # Loop over K dimension\n for k in range(0, K, BLOCK_K):\n cur_k = k + tl.arange(0, BLOCK_K)\n mask_k = cur_k < K\n\n # Load A block\n a_ptr = input_ptr + row_idx[:, None] * xm_stride + cur_k[None, :] * xk_stride\n a = tl.load(a_ptr, mask=mask_row[:, None] & mask_k[None, :], other=0.0)\n\n # Load LoRA weight block\n # Each token may have its own LoRA index, but we ignore it for simplicity.\n # Use lora_ptr directly, treat lora as (K, N) matrix.\n b_ptr = lora_ptr + cur_k[:, None] * lora_k_stride + offs_col[None, :] * lora_n_stride\n b = tl.load(b_ptr, mask=mask_k[:, None] & mask_col[None, :], other=0.0)\n\n # Compute partial product\n acc += tl.dot(a, b)\n\n # Optionally add input slice (residual)\n if ADD_INPUTS:\n add_ptr = input_ptr + row_idx[:, None] * xm_stride + (offs_col + slice_offset)[None, :] * xk_stride\n add = tl.load(add_ptr, mask=mask, other=0.0)\n acc += add\n\n # Cast result if needed\n if CAST_TYPE == 0:\n # cast to float16\n acc = acc.to(tl.float16)\n else:\n # keep float32\n acc = acc.to(tl.float32)\n\n # Store output\n out_ptr_offset = out_ptr + row_idx[:, None] * cm_stride + (offs_col + slice_offset)[None, :] * cn_stride\n tl.store(out_ptr_offset, acc, mask=mask)\n\n # End\n\nNow note: The pointer calculations: Since row_idx is a tl scalar vector (size BLOCK_M) and stride is integer. The pointer addition works.\n\nEdge Cases: In the inner loop we used the mask for K and N dimension based on global K and slice size N.\n\nOur loading of B uses lora_ptr, ignoring lora_indices and l0_stride; but signature includes them, maybe we should at least use lora_indices to compute a per-row offset. But we can comment that we ignore for simplicity. Given the instructions \"The kernel performs specialized form ... and uses lora_indices to access specific slices of LoRA weights\". They likely expect usage.\n\nTo satisfy requirement, we can incorporate lora_indices per row:\n\n- For each row i in block, we can compute lora_idx_i = tl.load(lora_indices + row_idx[i]) (mask row). This yields shape (BLOCK_M). For each row, we need to offset the B matrix accordingly: B_ptr = lora_ptr + lora_idx_i * l0_stride + cur_k * lora_k_stride + offs_col * lora_n_stride.\n\nGiven B is same across rows? Actually B shape (K,N) per slice; each row's index selects a particular \"slice\" of LoRA weight, maybe references a different B matrix. So B for each row may differ. Using tl.dot is not possible if B differs. We can handle by performing per-row dot manually.\n\nGiven proper usage, we need to implement per-row multiply. Let's attempt per-row dot manually.\n\nWe'll implement loop over k block and accumulate manually with per-row B.\n\nImplementation:\n\n- Compute lora_idx = tl.load(lora_indices + row_idx, mask=mask_row, other=0) # shape (BLOCK_M)\n\n- Compute base_offset = lora_idx * l0_stride # shape (BLOCK_M)\n\n- In each inner k loop, we need to load B_block for each row i: pointer offset = base_offset[i] + cur_k * lora_k_stride + offs_col * lora_n_stride\n\n- But tl.load expects pointer offsets as vector of shape (BLOCK_M, BLOCK_N). To compute that, we can broadcast base_offset[:, None] and cur_k and offs_col.\n\nCompute:\n\nb_ptr = lora_ptr + base_offset[:, None] + cur_k[None, :] * lora_k_stride + offs_col[None, :] * lora_n_stride\n\nBut cur_k[None,:] yields shape (1, BLOCK_K). We need shape (BLOCK_M, BLOCK_K) eventually? Wait we want B shape (BLOCK_M, BLOCK_N) for each k iteration, not per k. For each k value, we need to load a slice of B across columns for that k. For each row i, B_i for this k value is a vector of length BLOCK_N. So we can load B for each k individually inside a loop over k offset values. That would be O(K) load calls, which is inefficient but correct.\n\nBetter: load B block for the entire K block as shape (BLOCK_M, BLOCK_K, BLOCK_N). To achieve this we need to construct offsets of shape (BLOCK_M, BLOCK_K, BLOCK_N). However, `tl.load` can handle multi-dim offsets; I think it can as long as you use `tl.load(ptr + offset, mask=mask)`, where offset is a tiled pointer (a tensor of same shape as the output). In Triton you can load multi-dimensional blocks using `tl.load` with a pointer that is created by indexing with offsets of shape (block_m, block_k, block_n). Indeed you can pass offset arrays of shape (BLOCK_M, BLOCK_K, BLOCK_N). Example: offsets = (row_offset[:, None, None] * stride_m) + (k_offset[None, :, None] * stride_k) + (col_offset[None, None, :] * stride_n). Then tl.load loads a tensor of shape (BLOCK_M, BLOCK_K, BLOCK_N). That is possible.\n\nThus we can load b_block shape (BLOCK_M, BLOCK_K, BLOCK_N).\n\nImplementation:\n\n- row_offset = row_idx[:, None, None] # shape (M,1,1)\n- k_offset = cur_k[None, :, None] # shape (1,K,1)\n- col_offset = offs_col[None, None, :] # shape (1,1,N)\n\n- base_offset = lora_idx[:, None, None] * l0_stride # shape (M,1,1)\n\n- b_ptr_offset = base_offset + k_offset * lora_k_stride + col_offset * lora_n_stride\n\nThen:\n\nb = tl.load(lora_ptr + b_ptr_offset, mask=row_mask[:, None, None] & k_mask[None, :, None] & col_mask[None, None, :], other=0.0)\n\nNow a block shape (BLOCK_M, BLOCK_K) loaded as before.\n\nNow compute per-row dot: For each row i, we need to compute sum over k of a[i,k] * b[i,k,:].\n\nWe can broadcast a block to shape (BLOCK_M, BLOCK_K, 1) and multiply:\n\na_exp = a[:, :, None] # shape (M,K,1)\nprod = a_exp * b # shape (M,K,N)\nacc += tl.sum(prod, axis=1) # sum across K dimension\n\nBut tl.sum supports reduction over the specified axis for a tensor of static shape? I think `tl.sum` can sum over a dimension given axis integer? There is no axis argument; you can use `tl.sum(x, axis=?)`? Let's check. In Triton docs, `tl.sum(x, axis=0)` is allowed as per recent versions. I recall syntax similar to `tl.sum(x, axis=0)`. Yes you can do tl.sum(x, axis=0) to sum across axis dimension.\n\nThus we can do: acc += tl.sum(prod, axis=1) \n\nBut acc shape is (BLOCK_M, BLOCK_N), sum across axis 1 (K) yields (BLOCK_M, BLOCK_N) which matches.\n\nThus we can implement manual per-row multiplication using multi-dimensional loads.\n\nNow we have to ensure mask for b includes row and col and k.\n\nAlso we need to compute column mask based on N (slice size). We'll have N as slice_size.\n\nThus final kernel steps:\n\n- Determine row_idx via b_seq_start_loc and seq_len.\n\n- Compute mask_row = row_offset < seq_len.\n\n- Compute col_offset = pid_n * BLOCK_N + tl.arange(0, BLOCK_N). mask_col = col_offset < N.\n\n- Compute k offset loop as described.\n\nNow we need a per-row LoRA index:\n\n- lora_idx_raw = tl.load(lora_indices + row_idx, mask=mask_row, other=0). // shape (BLOCK_M)\n\n- Base offset per row: lora_idx_raw * l0_stride\n\nNow load B for each K block:\n\nb = tl.load(lora_ptr + base_offset[:, None, None] + cur_k[None,:,None] * lora_k_stride + col_offset[None,None,:] * lora_n_stride, mask=mask_row[:, None, None] & k_mask[None,:,None] & mask_col[None,None,:], other=0.0)\n\nNow compute a block as described: a = tl.load(input_ptr + row_idx[:, None] * xm_stride + cur_k[None,:] * xk_stride, mask=mask_row[:,None] & k_mask[None,:], other=0.0)\n\nNow compute:\n\nprod = a[:, :, None] * b # shape (M,K,N)\nacc += tl.sum(prod, axis=1)\n\nNow after loop:\n\nIf ADD_INPUTS:\n # add the original input slice \n add_ptr = input_ptr + row_idx[:, None] * xm_stride + (col_offset + slice_offset)[None,:] * xk_stride\n add = tl.load(add_ptr, mask=mask_row[:,None] & mask_col[None,:], other=0.0)\n acc += add\n\nNow cast if needed:\n\nif CAST_TYPE == 0:\n acc = acc.to(tl.float16)\nelse:\n acc = acc.to(tl.float32)\n\nNow store:\n\nout_ptr_off = out_ptr + row_idx[:, None] * cm_stride + (col_offset + slice_offset)[None,:] * cn_stride\ntl.store(out_ptr_off, acc, mask=mask_row[:,None] & mask_col[None,:])\n\nNow complete.\n\nPotential pitfalls: Use of `tl.make_block_ptr` might make pointer generation easier, but we used pointer arithmetic so fine.\n\nNow implement wrapper accordingly.\n\nWrapper steps:\n\n- Confirm shapes: inputs shape (token_nums, hidden_dim). lora_b_weights shape (max_rank, hidden_dim, slice_size). output_tensor shape (token_nums, slice_size). b_seq_start_loc shape (batches). seq_len_tensor shape (batches). lora_indices shape (token_nums). slice_offset is offset within the output columns where to write? Actually output_tensor is likely a slice of larger tensor; but we can ignore and assume slice_offset is 0 for straightforward use. We'll pass slice_offset to kernel to offset output columns indexing.\n\n- Validate continuity:\n\ninputs = inputs.contiguous()\n...\n\n- Determine hidden_dim = inputs.shape[1], slice_size = output_tensor.shape[1]\n\n- Validate that lora_b_weights.shape[1] == hidden_dim and lora_b_weights.shape[2] == slice_size.\n\n- Determine block sizes.\n\n- Determine max_seq_length: provided argument.\n\n- Determine token_nums: inputs.shape[0] (passed argument token_nums should match).\n\n- Compute grid: (batches, triton.cdiv(max_seq_length, BLOCK_M), triton.cdiv(slice_size, BLOCK_N))\n\n- Launch kernel.\n\nNow we need to import triton.cdiv from triton import cdiv? Actually Triton has function `triton.cdiv`. Indeed you can use `triton.cdiv`. Good.\n\nNow about the generation of kernel function with proper proper block constants: In kernel signature we have `BLOCK_M: tl.constexpr`, etc.\n\nImplementation details for mask creation: In kernel we need to use `tl.where` or boolean operations. For the row offset we used `mask_row = row_offset < seq_len`. However row_offset is vector of ints, seq_len is scalar loaded via tl.load -> scalar tl.int32? The comparison yields a tl.where? Actually `row_offset < seq_len` works elementwise returning a mask (1-bit). That's fine.\n\nBut `row_offset` is a tl variable that is a vector; seq_len is scalar; broadcasting works.\n\nNow building offsets:\n\nrow_idx = start + row_offset.\n\nWhen making pointer offsets for loads: `input_ptr + row_idx[:, None] * xm_stride + cur_k[None, :] * xk_stride`. However, Triton pointer addition expects integer offset (index into element) as pointer + offset. It can broadcast.\n\nNow we must ensure mask shape matches pointer shape. For A load we use mask = mask_row[:, None] & k_mask[None, :].\n\nNow for B load, we use a 3D shape: base_offset[:, None, None] + cur_k[None,:, None] * lora_k_stride + col_offset[None, None, :] * lora_n_stride . The mask shape is (BLOCK_M, BLOCK_K, BLOCK_N). Use combined mask: mask_row[:, None, None] & k_mask[None,:,None] & mask_col[None,None,:].\n\nNow for multiplication we need to multiply a shape (M,K) with b shape (M,K,N). We used a_exp = a[:, :, None] (M,K,1), multiply with b leads to (M,K,N). Then sum across axis=1 (K dimension) results shape (M,N). xt.\n\nNow in Triton `tl.dot` returns float32 maybe? But we used our own compute.\n\nNow ensure using correct dtype: a and b loaded as default dtype (input dtype). If inputs are float16, a is float16; b may be float16 or float32; multiplication yields float16 maybe but we store accumulate in float32. We need to cast a to float32 before multiplication? The product a_exp * b will be of same dtype? Might cause overflow. We can cast to float32 before multiplication: a_fp32 = a.to(tl.float32); b_fp32 = b.to(tl.float32); then multiplication yields float32; then sum accumulates into separate float32. We'll do that.\n\nAlgorithm:\n\na_fp32 = a.to(tl.float32)\nb_fp32 = b.to(tl.float32)\nprod = a_fp32[:, :, None] * b_fp32\nacc += tl.sum(prod, axis=1) # accumulates float32\n\nAlternatively we can compute using tl.dot with a_fp32 and b_fp32.\n\nPotential: tl.dot expects B shape (K,N) not (M,K,N). So we'll stick with explicit broadcast.\n\nNow for ADD_INPUTS part: We need to load the input slice of width N: that is similar to a load with k dimension? Actually we need columns from input with offset slice_offset maybe referencing columns within hidden dimension. Actually we want to add the original input values (maybe residual) of same shape as output (slice size). That would require loading from inputs with column offset slice_offset. So we load `add = tl.load(input_ptr + row_idx[:, None] * xm_stride + (col_offset + slice_offset)[None, :] * xk_stride, mask=mask_row[:, None] & mask_col[None, :], other=0.0)`. This is straightforward.\n\nNow add to acc (float32) but add is possibly float16; cast add_fp32 = add.to(tl.float32); acc += add_fp32.\n\nNow after finishing, cast based on CAST_TYPE:\n\nif CAST_TYPE == 0:\n acc = acc.to(tl.float16)\nelse:\n acc = acc.to(tl.float32)\n\nNow store.\n\nPotential issues: In case add_inputs is True but slice_offset may cause col_offset + slice_offset exceed hidden dim; but mask ensures no out-of-bounds.\n\nNow we need to compile this kernel for AMD. There's no use of CUDA specifics.\n\nNow wrapper:\n\ndef _sgmv_expand_slice(...)\n\nWe need to define inputs as torch.Tensor, etc.\n\nInside wrapper we may want to cast to appropriate dtype for lora_b_weights: they may be same dtype as inputs or maybe float16/float32.\n\nSet up type.\n\nNow we need to import triton.cdiv. Actually triton provides `cdiv` in `triton.language`. But from Python we can use `triton.cdiv`. Actually it's a function in `triton` library: `triton.cdiv`.\n\nWe'll import `import triton`.\n\nNow code skeleton:\n\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(...):\n ... # as above\n\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n token_nums: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n):\n # ensure contiguity\n inputs = inputs.contiguous()\n lora_b_weights = lora_b_weights.contiguous()\n output_tensor = output_tensor.contiguous()\n b_seq_start_loc = b_seq_start_loc.contiguous()\n seq_len_tensor = seq_len_tensor.contiguous()\n lora_indices_tensor = lora_indices_tensor.contiguous()\n \n # Validate dimensions\n assert inputs.dim() == 2, \"inputs must be 2D (token_num, hidden_dim)\"\n assert lora_b_weights.dim() == 3, \"lora_b_weights must be 3D (max_rank, hidden_dim, slice_size)\"\n assert output_tensor.shape == (token_nums, slice_size), \"output shape mismatch\"\n assert b_seq_start_loc.shape[0] == batches, \"b_seq_start_loc batch size mismatch\"\n assert seq_len_tensor.shape[0] == batches, \"seq_len_tensor batch size mismatch\"\n assert lora_indices_tensor.shape[0] == token_nums, \"lora_indices length mismatch\"\n \n hidden_dim = inputs.shape[1]\n max_rank = lora_b_weights.shape[0]\n assert lora_b_weights.shape[1] == hidden_dim, \"LoRA weight hidden_dim mismatch\"\n assert lora_b_weights.shape[2] == slice_size, \"LoRA weight slice size mismatch\"\n \n # Determine casting flag\n CAST_TYPE = 1 if inputs.dtype == torch.float32 else 0\n\n # Block sizes (tunable)\n BLOCK_M = 128\n BLOCK_N = 64\n BLOCK_K = 32\n\n EVEN_K = (hidden_dim % BLOCK_K) == 0\n\n # Compute grid dimensions\n grid = (\n batches,\n triton.cdiv(max_seq_length, BLOCK_M),\n triton.cdiv(slice_size, BLOCK_N),\n )\n\n # Launch kernel\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n slice_size, # N (output slice size)\n hidden_dim, # K (input hidden dimension)\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_K=BLOCK_K,\n EVEN_K=1 if EVEN_K else 0,\n ADD_INPUTS=1 if add_inputs else 0,\n CAST_TYPE=CAST_TYPE,\n )\n return output_tensor\n\nNow we need to ensure the kernel signature order matches call arguments. Let's list kernel signature order precisely:\n\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n)\n\nThus after out_ptr and N,K we have b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, slice_offset, then block size constants.\n\nOur call arguments match: input_ptr = inputs, lora_ptr = lora_b_weights, out_ptr = output_tensor, N = slice_size, K = hidden_dim, b_seq_start_loc = b_seq_start_loc, seq_lens = seq_len_tensor, lora_indices = lora_indices_tensor, xm_stride = inputs.stride(0), xk_stride = inputs.stride(1), l0_stride = lora_b_weights.stride(0), lora_k_stride = lora_b_weights.stride(1), lora_n_stride = lora_b_weights.stride(2), cm_stride = output_tensor.stride(0), cn_stride = output_tensor.stride(1), slice_offset = slice_offset.\n\nYes.\n\nNow compute block sizes: Use BLOCK_M=128 etc.\n\nNow need to compute mask indices correctly inside kernel.\n\nImplementation details inside kernel:\n\n- Program IDs: batch, row_block, col_block.\n\nWe'll refer to pid0 = tl.program_id(0) (batch), pid1 = tl.program_id(1) (row block within max_seq_length), pid2 = tl.program_id(2) (col block within slice columns).\n\nNow code inside kernel accordingly.\n\nLet's write code skeleton.\n\nPotential issues:\n\n- Computation of start location and seq length for each batch: start = tl.load(b_seq_start_loc + pid0) (this is offset within flattened token list for the start of batch). seq_len = tl.load(seq_lens + pid0).\n\n- Compute row_offset = pid1 * BLOCK_M + tl.arange(0, BLOCK_M) # shape [BLOCK_M]\n\n- mask_row = row_offset < seq_len\n\n- global_row = start + row_offset\n\nNow compute col_offset = pid2 * BLOCK_N + tl.arange(0, BLOCK_N)\n\nmask_col = col_offset < N\n\nNow compute before loops.\n\nNow loop over K:\n\nfor k in range(0, K, BLOCK_K):\n cur_k = k + tl.arange(0, BLOCK_K) # shape [BLOCK_K]\n mask_k = cur_k < K\n\n # Load A: inputs block (block rows x block K)\n a_ptr = input_ptr + global_row[:, None] * xm_stride + cur_k[None, :] * xk_stride\n a = tl.load(a_ptr, mask=mask_row[:, None] & mask_k[None, :], other=0.0)\n\n # Load lora_idx for each row:\n lora_idx = tl.load(lora_indices + global_row, mask=mask_row, other=0)\n\n base_offset = lora_idx * l0_stride # shape [BLOCK_M]\n\n # Load B block as 3D (M, K, N)\n # Define offsets:\n # row component base_offset[:, None, None]\n # k component cur_k[None, :, None] * lora_k_stride\n # col component col_offset[None, None, :] * lora_n_stride\n\n # Need to have shape after broadcasting:\n # filter mask: mask_row[:, None, None] & mask_k[None, :, None] & mask_col[None, None, :]\n\n b_ptr = lora_ptr + base_offset[:, None, None] + cur_k[None, :, None] * lora_k_stride + col_offset[None, None, :] * lora_n_stride\n\n b = tl.load(b_ptr, mask=mask_row[:, None, None] & mask_k[None, :, None] & mask_col[None, None, :], other=0.0)\n\n # Cast to float32 for accumulation\n a_f = a.to(tl.float32)\n b_f = b.to(tl.float32)\n\n # Compute product and sum over K\n # a_f shape (M, K) -> expand to (M, K, 1)\n # b_f shape (M, K, N)\n prod = a_f[:, :, None] * b_f # shape (M, K, N)\n # sum across K dimension -> axis=1\n acc += tl.sum(prod, axis=1)\n\nAfter loop, optional add inputs:\n\nif ADD_INPUTS:\n add_ptr = input_ptr + global_row[:, None] * xm_stride + (col_offset + slice_offset)[None, :] * xk_stride\n add = tl.load(add_ptr, mask=mask_row[:, None] & mask_col[None, :], other=0.0)\n acc += add.to(tl.float32)\n\nNow cast final:\n\nif CAST_TYPE == 0:\n out_val = acc.to(tl.float16)\nelse:\n out_val = acc\n\nNow store:\n\nout_ptr_off = out_ptr + global_row[:, None] * cm_stride + (col_offset + slice_offset)[None, :] * cn_stride\ntl.store(out_ptr_off, out_val, mask=mask_row[:, None] & mask_col[None, :])\n\nImportant: Use `tl.lang`? Actually we imported `triton.language as tl`.\n\nNow point about `tl.sum` axis parameter: I need to verify syntax. I think correct is `tl.sum(x, axis=1)`. In some versions you can also use `tl.sum(x, axis=0)`. According to Triton docs, `tl.sum(x, axis=0)` returns sum across dimension 0.\n\nThus we can use `tl.sum(prod, axis=1)`. Or we could reduce manually: we can do `tl.reduce(prod, axis=1, op='add')`? Not needed.\n\nNow ensure that `mask_row`, `mask_col`, `mask_k` have correct dtype (tl.int1). They should be booleans.\n\nNow final code ready.\n\nNow we need to produce output JSON with \"strategy\" explaining approach, and \"code\" with full code.\n\nWe'll write a concise strategy: \"Implemented a generic batched GEMM with per-token LoRA slice selection using provided strides, handling boundaries via seq lengths, supporting optional residual addition and casting. Uses block loops, loads per-row LoRA slice indices, accumulates in float32, casts to desired precision, AMD ROCm compatible.\"\n\nNow produce final JSON.\n\nWe must ensure no stray markdown formatting, just JSON object.\n\nPotential issues: The code may have syntax errors like missing parentheses or missing colon. Ensure proper imports indent.\n\nOne important thing: Use `tl.cdiv`? Not needed in kernel.\n\nCheck `tl.sum` axis parameter: It might be `axis=...` (parameter named `axis`). Let's verify with Triton 3.1 docs: `tl.sum(x, axis=0)`. Yes.\n\nNow check pointer arithmetic: addition of multiple terms works: input_ptr + global_row[:, None] * xm_stride + cur_k[None, :] * xk_stride.\n\nWill this work? It yields pointer of shape (M,K). That's permitted.\n\nNow for B pointer: base_offset[:, None, None] + cur_k[None, :, None] * lora_k_stride + col_offset[None, None, :] * lora_n_stride\n\nDimension of base_offset is (M), broadcast to (M,1,1). cur_k shape (K). The multiplication yields shape (1,K,1) broadcast to (M,K,N). All good.\n\nNow for loading B: tl.load(lora_ptr + offsets, mask=mask, other=0.0). The shape of offsets determines shape of loaded block.\n\nPotential needed dtype for `other`: specify same dtype as loaded. We can just use `0.0`.\n\nNow mask types: Must be boolean (tl.int1). Use `mask_row[:, None] & mask_k[None, :]` etc.\n\nNow final memory store: out_val shape (M,N). mask shape matches.\n\nNow about `tl.program_id(0)` etc. Must import correctly. Should be `tl.program_id(0)`, etc.\n\nNow confirm that we used all function arguments: input_ptr, lora_ptr, out_ptr, N, K, b_seq_start_loc, seq_lens, lora_indices, xm_stride, xk_stride, l0_stride, lora_k_stride, lora_n_stride, cm_stride, cn_stride, slice_offset. Indeed we used all of them. For l0_stride we used in base_offset = lora_idx * l0_stride. We used that.\n\nNow we used `BLOCK_M`, `BLOCK_N`, `BLOCK_K`, `EVEN_K` (unused, but okay). We can use EVEN_K to possibly unroll, but we can just ignore.\n\n`ADD_INPUTS` used.\n\n`CAST_TYPE` used.\n\nNow we need to ensure to call `tl.load` for `lora_indices` we used.\n\nPotential: load lora_indices per row each iteration of K loop is redundant; we can load once before loop. We'll move it before loop:\n\nlora_idx = tl.load(lora_indices + global_row, mask=mask_row, other=0) # shape (M)\nbase_offset = lora_idx * l0_stride\n\nThen inside loop reuse base_offset.\n\nThus we avoid reloading each iteration.\n\nNow inside kernel, we need to use `K` param (inner dimension). This param passed as argument.\n\nNow verifying we used `K` (hidden_dim) in loops. Yes.\n\nNow we need to compute the block number of columns: we used `N` param for slice_size. Good.\n\nNow compute `mask_col = col_offset < N`. Good.\n\nNow important: In our loop we compute `col_offset` used for B pointer and for output; it's consistent.\n\nNow add_inputs: we added input slice using col_offset + slice_offset; but the pointer uses `slice_offset` offset plus col_offset.\n\nOur kernel `slice_offset` is added to both compute column offset for output and input addition. Good.\n\nNow we need to handle case where `slice_offset` not 0 but output pointer col offset uses offset, fine.\n\nNow ensure we cast appropriate types: `acc` is float32. For final store we need to cast depending on CAST_TYPE.\n\nNow finalize code.\n\nPotential not imported `triton.cdiv`, we just use `triton.cdiv` directly.\n\nNow we produce final JSON.\n\nLet's produce final code.\n\nBefore produce final JSON, test mentally for syntax errors.\n\nBegin with imports:\n\nimport torch\nimport triton\nimport triton.language as tl\n\nDefine kernel:\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n pid_batch = tl.program_id(0)\n pid_row = tl.program_id(1)\n pid_col = tl.program_id(2)\n\n # Load per-batch start location and sequence length\n batch_start = tl.load(b_seq_start_loc + pid_batch)\n seq_len = tl.load(seq_lens + pid_batch)\n\n # Compute row offsets within the batch\n offs_m = pid_row * BLOCK_M + tl.arange(0, BLOCK_M)\n row_mask = offs_m < seq_len\n\n # Global row index (flattened token index)\n row_idx = batch_start + offs_m\n\n # Column offsets for the slice\n offs_n = pid_col * BLOCK_N + tl.arange(0, BLOCK_N)\n col_mask = offs_n < N\n\n # Combined mask for valid output elements\n out_mask = row_mask[:, None] & col_mask[None, :]\n\n # Load LoRA slice index for each row\n lora_idx = tl.load(lora_indices + row_idx, mask=row_mask, other=0)\n base_offset = lora_idx * l0_stride\n\n # Initialize accumulator in float32\n acc = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n\n # Loop over K dimension\n for k in range(0, K, BLOCK_K):\n cur_k = k + tl.arange(0, BLOCK_K)\n k_mask = cur_k < K\n\n # Load A block (inputs)\n a_ptr = input_ptr + row_idx[:, None] * xm_stride + cur_k[None, :] * xk_stride\n a = tl.load(a_ptr, mask=row_mask[:, None] & k_mask[None, :], other=0.0)\n\n # Load B block (LoRA weights) with per‑row slice selection\n # Offsets shape (M, K, N)\n b_ptr = lora_ptr + base_offset[:, None, None] + cur_k[None, :, None] * lora_k_stride + offs_n[None, None, :] * lora_n_stride\n b = tl.load(b_ptr, mask=row_mask[:, None, None] & k_mask[None, :, None] & col_mask[None, None, :], other=0.0)\n\n # Cast to float32 for accumulation\n a_f = a.to(tl.float32)\n b_f = b.to(tl.float32)\n\n # Multiply and reduce over the K dimension\n prod = a_f[:, :, None] * b_f # (M, K, N)\n acc += tl.sum(prod, axis=1) # sum over K -> (M, N)\n\n # Optional addition of the original input slice (residual)\n if ADD_INPUTS:\n add_ptr = input_ptr + row_idx[:, None] * xm_stride + (offs_n + slice_offset)[None, :] * xk_stride\n add = tl.load(add_ptr, mask=row_mask[:, None] & col_mask[None, :], other=0.0)\n acc += add.to(tl.float32)\n\n # Cast output according to requested precision\n if CAST_TYPE == 0:\n out_val = acc.to(tl.float16)\n else:\n out_val = acc.to(tl.float32)\n\n # Store the result\n out_ptr_off = out_ptr + row_idx[:, None] * cm_stride + (offs_n + slice_offset)[None, :] * cn_stride\n tl.store(out_ptr_off, out_val, mask=out_mask)\n\n\nNow wrapper:\n\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n token_nums: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n):\n # Ensure tensors are contiguous\n inputs = inputs.contiguous()\n lora_b_weights = lora_b_weights.contiguous()\n output_tensor = output_tensor.contiguous()\n b_seq_start_loc = b_seq_start_loc.contiguous()\n seq_len_tensor = seq_len_tensor.contiguous()\n lora_indices_tensor = lora_indices_tensor.contiguous()\n\n # Validate shapes\n assert inputs.ndim == 2, \"inputs must be 2-D (token_nums, hidden_dim)\"\n assert lora_b_weights.ndim == 3, \"lora_b_weights must be 3-D (max_rank, hidden_dim, slice_size)\"\n assert output_tensor.shape == (token_nums, slice_size), \"output_tensor shape mismatch\"\n assert b_seq_start_loc.numel() == batches, \"b_seq_start_loc batch size mismatch\"\n assert seq_len_tensor.numel() == batches, \"seq_len_tensor batch size mismatch\"\n assert lora_indices_tensor.numel() == token_nums, \"lora_indices length mismatch\"\n\n hidden_dim = inputs.shape[1]\n assert lora_b_weights.shape[1] == hidden_dim, \"LoRA hidden dimension mismatch\"\n assert lora_b_weights.shape[2] == slice_size, \"LoRA slice size mismatch\"\n\n # Choose casting flag based on input dtype\n CAST_TYPE = 1 if inputs.dtype == torch.float32 else 0\n\n # Tunable block sizes\n BLOCK_M = 128\n BLOCK_N = 64\n BLOCK_K = 32\n\n EVEN_K = (hidden_dim % BLOCK_K) == 0\n\n # Grid dimensions: (batch, row_blocks, col_blocks)\n grid = (\n batches,\n triton.cdiv(max_seq_length, BLOCK_M),\n triton.cdiv(slice_size, BLOCK_N),\n )\n\n # Launch kernel\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n slice_size, # N (output slice width)\n hidden_dim, # K (input hidden dimension)\n b_seq_start_loc", "input": "\nYou are an expert Python programmer specializing in NVIDIA Triton kernels, specifically targeting **AMD GPUs using the ROCm environment**.\nYour task is to generate a Python code snippet containing a Triton kernel based on the following request:\n\n**Target Platform:** AMD GPU (ROCm)\n\n**Request:**\nYou are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code consists of a Triton kernel, `_sgmv_expand_slice_kernel`, and a wrapper function, `_sgmv_expand_slice`. \n The kernel performs a specialized form of matrix multiplication involving sparse generalized matrix-vector multiplication (SGMV). \n It operates on blocks of data defined by the parameters BLOCK_M, BLOCK_N, and BLOCK_K. \n The kernel checks for boundary conditions and handles data casting between different precisions. \n It processes sequences of a batch in parallel and uses lora_indices to access specific slices of LoRA weights for computations. \n The wrapper function sets up the input parameters, ensuring contiguity and dimension correctness, and configures the Triton launch grid to execute the kernel. \n It validates tensor shapes and data types, adjusts weights if needed, and launches the kernel with a grid configured to cover the input data based on the maximum sequence length and batch size.\n \n\n**CRITICAL FUNCTION INFORMATION:**\nBased on analysis, the implementation requires these EXACT function signatures:\n* def _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n)\n* def _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n token_nums: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n)\n\n\n**Output Requirements:**\n1. **AMD Compatibility:** Generate code compatible with AMD GPUs and ROCm. **DO NOT use CUDA-specific features or functions (e.g., `tl.libdevice`).**\n2. **Complete Code:** Generate a single, complete, and syntactically correct Python code block.\n3. **Triton Kernel:** The core logic must be implemented within a Triton kernel function decorated with `@triton.jit`.\n4. **Imports:** ALWAYS include necessary imports at the beginning:\n ```python\n import torch\n import triton\n import triton.language as tl\n # import math # Only if standard math functions are truly needed outside the kernel\n ```\n Include other imports *only if absolutely necessary*.\n5. **Function Signature (CRITICAL):**\n * Define EACH function with EXACTLY the signature shown above.\n * DO NOT change parameter names, counts, or order.\n * Ensure all parameters in function calls match their function definitions.\n * **Type Hints:** Use PyTorch tensor type hints (e.g., `x: torch.Tensor`) for tensor arguments. **DO NOT use `tl.pointer`**. Use standard Python types (e.g., `int`, `float`) or `tl.constexpr` for others.\n * **`constexpr`:** Use `tl.constexpr` **ONLY** for arguments that *must* be known at compile time, typically block sizes (like `BLOCK_SIZE`, `BLOCK_M`) or flags that change the kernel's structure (like `IS_EVEN_K`). Simple numerical values like `eps` or `dropout_p` are usually *not* `constexpr`.\n6. **Data Types:** Be precise with data types inside the kernel (e.g., `tl.float16`, `tl.float32`, `tl.int32`). Ensure type compatibility. Assume input tensors might be `torch.float16` or `torch.float32` unless specified otherwise. Pay attention to potential type promotion/conversion needs (e.g., using `.to(tl.float32)` for accumulations).\n7. **Triton Operations:**\n * Use Triton language functions correctly (`tl.load`, `tl.store`, `tl.dot`, `tl.arange`, `tl.program_id`, `tl.where`, `tl.atomic_cas`, etc.).\n * **Pointers & Masks:** Be extremely careful when constructing pointers using offsets and strides. Ensure masks in `tl.load`/`tl.store` are correctly computed and match pointer dimensions. Avoid `ValueError: Mask argument cannot be block type...` or `ValueError: Unsupported ptr type...`.\n * **`tl.dot`:** Ensure inputs are 2D blocks and have compatible types (e.g., float16, bfloat16). Int32 is generally not supported directly as input.\n * **`tl.arange`:** Arguments `start` and `end` **must be `tl.constexpr`**.\n * **Math:** Use functions from `tl.math` where available (e.g., `tl.math.exp`, `tl.math.sqrt`). Check function existence; avoid assuming functions like `tanh` or `log1p` exist if they don't in `tl.math`.\n8. **Triton Version:** Assume Triton version 3.1.0 or later.\n\n**FINAL VERIFICATION:**\nBefore completing, verify:\n1. ALL functions defined in the code have EXACT signatures matching the required function signatures above.\n2. ALL function calls exactly match their definitions in terms of parameter counts and names.\n3. No functions are called without being defined.\n4. No parameters are missing from your implementations.\n\n**Generated AMD ROCm Compatible Triton Kernel Code:**\n\nHere is an example snippet of code: import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _flash_decoding_fwd_kernel(\n Q, KCache, VCache, block_tables, mid_o, mid_o_lse, kv_seq_len, q_len, batch_size, kv_group_num,\n x, sm_scale, stride_qt, stride_qh, stride_qd, stride_kcb, stride_kch, stride_kcsplit_x, stride_kcs,\n stride_kcd, stride_vcb, stride_vch, stride_vcs, stride_vcd, stride_bts, stride_btb, stride_mid_ot,\n stride_mid_oh, stride_mid_ob, stride_mid_od, stride_mid_o_lset, stride_mid_o_lseh, stride_mid_o_lseb,\n BLOCK_KV: tl.constexpr, BLOCK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,\n):\n cur_token_idx = tl.program_id(0)\n cur_seq_idx = cur_token_idx // q_len\n if cur_seq_idx >= batch_size:\n return\n cur_token_off = (cur_token_idx % q_len) - q_len + 1\n cur_head_idx = tl.program_id(1)\n block_start_kv = tl.program_id(2)\n\n tl.static_assert(BLOCK_KV == BLOCK_SIZE)\n cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off\n if block_start_kv * BLOCK_KV >= cur_kv_seq_len:\n return\n\n offsets_dmodel = tl.arange(0, HEAD_DIM)\n offsets_block = tl.arange(0, BLOCK_SIZE)\n\n block_table_ptr = block_tables + cur_seq_idx * stride_bts\n cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)\n cur_occupied_size = tl.where(\n (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE\n )\n tl.device_assert(cur_occupied_size >= 0)\n\n offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd\n q = tl.load(Q + offsets_q)\n cur_kv_head_idx = cur_head_idx // kv_group_num\n offset_kvcache = cur_block_id * stride_kcb + cur_kv_head_idx * stride_kch\n offsets_k = (\n offset_kvcache\n + (offsets_dmodel[None, :] // x) * stride_kcsplit_x\n + (offsets_dmodel[None, :] % x) * stride_kcd\n + offsets_block[:, None] * stride_kcs\n )\n k_cur_block = tl.load(KCache + offsets_k)\n V_block_ptr = tl.make_block_ptr(\n base=VCache + offset_kvcache,\n shape=(cur_occupied_size, HEAD_DIM),\n strides=(stride_vcs, stride_vcd),\n offsets=(0, 0),\n block_shape=(BLOCK_SIZE, HEAD_DIM),\n order=(0, 1),\n )\n v_cur_block = tl.load(V_block_ptr)\n acc = tl.zeros([HEAD_DIM], dtype=tl.float32)\n S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n\n S_ij += tl.sum(q[None, :] * k_cur_block, 1)\n S_ij *= sm_scale\n S_ij += tl.where(block_start_kv * BLOCK_KV + offsets_block < cur_kv_seq_len, 0, float(\"-inf\"))\n\n m = tl.max(S_ij, 0)\n S_ij -= m\n p_ij_hat = tl.exp(S_ij)\n l_i = tl.sum(p_ij_hat, 0)\n p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)\n acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)\n acc = acc / l_i\n\n offsets_mid_o = (\n cur_token_idx * stride_mid_ot\n + cur_head_idx * stride_mid_oh\n + block_start_kv * stride_mid_ob\n + offsets_dmodel * stride_mid_od\n )\n tl.store(mid_o + offsets_mid_o, acc)\n offsets_mid_o_lse = (\n cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb\n )\n tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))\n\n\n@triton.jit\ndef _alibi_flash_decoding_fwd_kernel(\n Q, KCache, VCache, block_tables, mid_o, mid_o_lse, kv_seq_len, q_len, batch_size, alibi_slopes,\n stride_qt, stride_qh, stride_qd, stride_cacheb, stride_cacheh, stride_cachebs, stride_cached,\n stride_bts, stride_btb, stride_mid_ot, stride_mid_oh, stride_mid_ob, stride_mid_od,\n stride_mid_o_lset, stride_mid_o_lseh, stride_mid_o_lseb, sm_scale, KV_GROUPS: tl.constexpr,\n BLOCK_KV: tl.constexpr, BLOCK_SIZE: tl.constexpr, HEAD_DIM: tl.constexpr,\n):\n cur_token_idx = tl.program_id(0)\n cur_seq_idx = cur_token_idx // q_len\n if cur_seq_idx >= batch_size:\n return\n cur_token_off = (cur_token_idx % q_len) - q_len + 1\n cur_head_idx = tl.program_id(1)\n block_start_kv = tl.program_id(2)\n\n tl.static_assert(BLOCK_KV == BLOCK_SIZE)\n cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off\n if block_start_kv * BLOCK_KV >= cur_kv_seq_len:\n return\n\n offsets_dmodel = tl.arange(0, HEAD_DIM)\n offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd\n q = tl.load(Q + offsets_q)\n block_table_ptr = block_tables + cur_seq_idx * stride_bts\n cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)\n cur_occupied_size = tl.where(\n (block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE\n )\n tl.device_assert(cur_occupied_size >= 0)\n\n cur_kv_head_idx = cur_head_idx // KV_GROUPS\n offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh\n K_block_ptr = tl.make_block_ptr(\n base=KCache + offset_kvcache,\n shape=(cur_occupied_size, HEAD_DIM),\n strides=(stride_cachebs, stride_cached),\n offsets=(0, 0),\n block_shape=(BLOCK_SIZE, HEAD_DIM),\n order=(0, 1),\n )\n V_block_ptr = tl.make_block_ptr(\n base=VCache + offset_kvcache,\n shape=(cur_occupied_size, HEAD_DIM),\n strides=(stride_cachebs, stride_cached),\n offsets=(0, 0),\n block_shape=(BLOCK_SIZE, HEAD_DIM),\n order=(0, 1),\n )\n k_cur_block = tl.load(K_block_ptr)\n v_cur_block = tl.load(V_block_ptr)\n acc = tl.zeros([HEAD_DIM], dtype=tl.float32)\n S_ij = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n\n alibi_slope = tl.load(alibi_slopes + cur_head_idx)\n position_k_offset = block_start_kv * BLOCK_KV + tl.arange(0, BLOCK_SIZE)\n\n S_ij += tl.sum(q[None, :] * k_cur_block, 1)\n S_ij *= sm_scale\n S_ij -= alibi_slope * (cur_kv_seq_len - 1 - position_k_offset)\n S_ij = tl.where(cur_kv_seq_len > position_k_offset, S_ij, float(\"-inf\"))\n\n m = tl.max(S_ij, 0)\n S_ij -= m\n p_ij_hat = tl.exp(S_ij)\n l_i = tl.sum(p_ij_hat, 0)\n p_ij_hat = p_ij_hat.to(v_cur_block.type.element_ty)\n acc += tl.sum(v_cur_block * p_ij_hat[:, None], 0)\n acc = acc / l_i\n\n offsets_mid_o = (\n cur_token_idx * stride_mid_ot\n + cur_head_idx * stride_mid_oh\n + block_start_kv * stride_mid_ob\n + offsets_dmodel * stride_mid_od\n )\n tl.store(mid_o + offsets_mid_o, acc)\n offsets_mid_o_lse = (\n cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb\n )\n tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l_i))\n\n\n@triton.jit\ndef _flash_decoding_fwd_reduce_kernel(\n mid_o, mid_o_lse, O, kv_seq_len, q_len, batch_size, stride_mid_ot, stride_mid_oh,\n stride_mid_ob, stride_mid_od, stride_o_lset, stride_o_lseh, stride_o_lseb,\n stride_ot, stride_oh, stride_od, BLOCK_KV: tl.constexpr, HEAD_DIM: tl.constexpr,\n):\n cur_token_idx = tl.program_id(0)\n cur_seq_idx = cur_token_idx // q_len\n if cur_seq_idx >= batch_size:\n return\n cur_head_idx = tl.program_id(1)\n\n cur_token_off = (cur_token_idx % q_len) - q_len + 1\n cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx) + cur_token_off\n offsets_dmodel = tl.arange(0, HEAD_DIM)\n\n kv_split_num = (cur_kv_seq_len + BLOCK_KV - 1) // BLOCK_KV\n m_i = float(\"-inf\")\n l_i = 0.0\n acc = tl.zeros([HEAD_DIM], dtype=tl.float32)\n\n offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel\n offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh\n for block_i in range(0, kv_split_num, 1):\n mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)\n lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)\n m_ij = tl.maximum(m_i, lse)\n scale = tl.exp(m_i - m_ij)\n acc = acc * scale\n lse -= m_ij\n exp_logic = tl.exp(lse)\n acc += exp_logic * mid_o_block\n l_i = scale * l_i + exp_logic\n m_i = m_ij\n\n acc = acc / l_i\n offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel\n tl.store(O + offsets_O, acc.to(O.type.element_ty))\n return\n\n\ndef flash_decoding_attention(\n q: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor, kv_seq_len: torch.Tensor,\n block_tables: torch.Tensor, block_size: int, max_seq_len_in_batch: int = None, output: torch.Tensor = None,\n mid_output: torch.Tensor = None, mid_output_lse: torch.Tensor = None, alibi_slopes: torch.Tensor = None,\n sm_scale: int = None, kv_group_num: int = 1, q_len: int = 1, use_new_kcache_layout: bool = False,\n):\n q = q.squeeze() if q.dim() == 4 else q\n assert q.dim() == 3, f\"Incompatible q dim: {q.dim()}\"\n n_tokens, num_heads, head_dim = q.shape\n assert n_tokens % q_len == 0, \"Invalid q_len\"\n bsz = n_tokens // q_len\n\n assert head_dim in {32, 64, 128, 256}\n assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (\n f\"Got incompatible batch size (number of seqs):\\n\"\n f\" KV seq lengths bsz {kv_seq_len.size(0)}, Block tables bsz {block_tables.size(0)}, \"\n f\"batch size {bsz}\"\n )\n assert k_cache.size(-2) == v_cache.size(-2) == block_size, (\n f\"Got incompatible block size on kv caches:\\n\"\n f\" assigned block_size {block_size}, k_cache block_size {k_cache.size(-2)}, \"\n f\"v_cache block_size {v_cache.size(-2)}\"\n )\n\n assert block_size in {16, 32, 64, 128}\n BLOCK_KV = block_size\n\n sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale\n max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch\n kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV\n\n if mid_output is None:\n mid_output = torch.empty(\n (bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device\n )\n if mid_output_lse is None:\n mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)\n if output is None:\n output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)\n\n assert (\n mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num\n ), \"Incompatible kv split number of intermediate output tensors\"\n assert (\n mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens\n ), f\"Incompatible first dimension of output tensors\"\n\n grid = lambda META: (\n triton.next_power_of_2(bsz * q_len),\n num_heads,\n triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), META[\"BLOCK_KV\"]),\n )\n\n if alibi_slopes is not None:\n assert (\n not use_new_kcache_layout\n ), \"Alibi Slopes will be supported with new kcache layout later when the whole triton flow is ready\"\n\n _alibi_flash_decoding_fwd_kernel[grid](\n q,\n k_cache,\n v_cache,\n block_tables,\n mid_output,\n mid_output_lse,\n kv_seq_len,\n q_len,\n bsz,\n alibi_slopes,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n block_tables.stride(0),\n block_tables.stride(1),\n mid_output.stride(0),\n mid_output.stride(1),\n mid_output.stride(2),\n mid_output.stride(3),\n mid_output_lse.stride(0),\n mid_output_lse.stride(1),\n mid_output_lse.stride(2),\n sm_scale,\n KV_GROUPS=kv_group_num,\n BLOCK_KV=block_size,\n BLOCK_SIZE=block_size,\n HEAD_DIM=head_dim,\n )\n else:\n x = head_dim\n kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)\n if use_new_kcache_layout:\n assert (\n k_cache.dim() == 5\n and k_cache.shape[1] == v_cache.shape[1]\n and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]\n ), f\"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}\"\n x = k_cache.size(-1)\n kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]\n\n _flash_decoding_fwd_kernel[grid](\n q,\n k_cache,\n v_cache,\n block_tables,\n mid_output,\n mid_output_lse,\n kv_seq_len,\n q_len,\n bsz,\n kv_group_num,\n x,\n sm_scale,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n kcsplit_x_stride,\n kcs_stride,\n kcd_stride,\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n block_tables.stride(0),\n block_tables.stride(1),\n mid_output.stride(0),\n mid_output.stride(1),\n mid_output.stride(2),\n mid_output.stride(3),\n mid_output_lse.stride(0),\n mid_output_lse.stride(1),\n mid_output_lse.stride(2),\n BLOCK_KV=block_size,\n BLOCK_SIZE=block_size,\n HEAD_DIM=head_dim,\n )\n\n grid = (triton.next_power_of_2(bsz * q_len), num_heads)\n _flash_decoding_fwd_reduce_kernel[grid](\n mid_output,\n mid_output_lse,\n output,\n kv_seq_len,\n q_len,\n bsz,\n mid_output.stride(0),\n mid_output.stride(1),\n mid_output.stride(2),\n mid_output.stride(3),\n mid_output_lse.stride(0),\n mid_output_lse.stride(1),\n mid_output_lse.stride(2),\n output.stride(0),\n head_dim,\n 1,\n BLOCK_KV=block_size,\n HEAD_DIM=head_dim,\n )\n\n return output\n\nThink before writing the optimization and no more explanation is required after the thinking. \nYou should not suggest changes to the name of the function and parameter names, counts, or order.\nOutput your answer in json format, with the format as follows: {\"strategy\": \"\", \"code\": \"\"}. Please strictly output in JSON format.\nGenerate the strategy that used to correct and optimized code in the \"strategy\" field.\"\nGenerate the correct and optimized code without explanation, which we can run directly in the \"code\" field.\nThink before writing the optimization and no more explanation is required after the thinking. \nYou should not suggest changes to the name of the function and parameter names, counts, or order.\nOutput your answer in json format, with the format as follows: {\"strategy\": \"\", \"code\": \"\"}. Please strictly output in JSON format.\nGenerate the strategy that used to correct and optimized code in the \"strategy\" field.\"\nGenerate the correct and optimized code without explanation, which we can run directly in the \"code\" field."} | |