from typing import Optional, Tuple import torch from sgl_kernel.utils import _get_cache_buf def awq_dequantize( qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor ) -> torch.ByteTensor: return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros) def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.int8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype): return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, ) def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None): return torch.ops.sgl_kernel.fp8_scaled_mm.default( mat_a, mat_b, scales_a, scales_b, out_dtype, bias, ) def _bmm_fp8_internal( workspace_buffer: torch.Tensor, A: torch.Tensor, B: torch.Tensor, D: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, ) -> None: cublas_handle = torch.cuda.current_blas_handle() torch.ops.sgl_kernel.bmm_fp8.default( A, B, D, A_scale, B_scale, workspace_buffer, cublas_handle, ) def bmm_fp8( A: torch.Tensor, B: torch.Tensor, A_scale: torch.Tensor, B_scale: torch.Tensor, dtype: torch.dtype, out: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out is None: out = torch.empty( (A.shape[0], A.shape[1], B.shape[2]), device=A.device, dtype=dtype, ) workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device) _bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale) return out def dsv3_fused_a_gemm( mat_a: torch.Tensor, mat_b: torch.Tensor, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is None: output = torch.empty( (mat_a.shape[0], mat_b.shape[1]), device=mat_a.device, dtype=mat_a.dtype, ) torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b) return output def sgl_per_token_group_quant_8bit( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, group_size: int, eps: float, fp8_min: float, fp8_max: float, scale_ue8m0: bool = False, fuse_silu_and_mul: bool = False, masked_m: Optional[torch.Tensor] = None, enable_v2: Optional[bool] = None, ) -> None: if enable_v2 is None: from sglang.srt.utils import get_bool_env_var enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2") if enable_v2: return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0, fuse_silu_and_mul, masked_m, ) assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul" assert masked_m is None, "only v2 support masked_m" torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default( input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0 ) # For legacy usage sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit def sgl_per_tensor_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, is_static: bool, ) -> None: torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default( input, output_q, output_s, is_static ) def sgl_per_token_quant_fp8( input: torch.Tensor, output_q: torch.Tensor, output_s: torch.Tensor, ) -> None: torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s) def cutlass_scaled_fp4_mm( a: torch.Tensor, b: torch.Tensor, block_scale_a: torch.Tensor, block_scale_b: torch.Tensor, alpha: torch.Tensor, out_dtype: torch.dtype, ) -> torch.Tensor: from sglang.jit_kernel.nvfp4 import ( cutlass_scaled_fp4_mm as jit_cutlass_scaled_fp4_mm, ) return jit_cutlass_scaled_fp4_mm( a, b, block_scale_a, block_scale_b, alpha, out_dtype, ) def scaled_fp4_quant( input: torch.Tensor, input_global_scale: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantize input tensor to FP4 and return quantized tensor and scale. This function quantizes the last dimension of the given tensor `input`. For every 16 consecutive elements, a single dynamically computed scaling factor is shared. This scaling factor is quantized using the `input_global_scale` and is stored in a swizzled layout (see https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x). Args: input: The input tensor to be quantized to FP4 input_global_scale: A scalar scaling factor for the entire tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every two values are packed into a uint8 and float8_e4m3 scaling factors in a sizzled layout. """ from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as jit_scaled_fp4_quant return jit_scaled_fp4_quant(input, input_global_scale) def qserve_w4a8_per_chn_gemm( in_feats: torch.Tensor, kernel: torch.Tensor, wscales: torch.Tensor, ascales: torch.Tensor, w_szs: torch.Tensor, a_ssums: torch.Tensor, out_feats: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out_feats is None: # NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now out_feats = torch.empty( (in_feats.shape[0], kernel.shape[0]), device=in_feats.device, dtype=torch.float16, ) torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default( in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats ) return out_feats def qserve_w4a8_per_group_gemm( in_feats: torch.Tensor, kernel: torch.Tensor, zeros: torch.Tensor, scales_i8: torch.Tensor, wscales: torch.Tensor, ascales: torch.Tensor, out_feats: Optional[torch.Tensor] = None, ) -> torch.Tensor: if out_feats is None: # NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now out_feats = torch.empty( (in_feats.shape[0], kernel.shape[0]), device=in_feats.device, dtype=torch.float16, ) torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default( in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats ) return out_feats def dsv3_router_gemm( hidden_states: torch.Tensor, router_weights: torch.Tensor, out_dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: output = torch.empty( hidden_states.shape[0], router_weights.shape[0], device=hidden_states.device, dtype=out_dtype, ) torch.ops.sgl_kernel.dsv3_router_gemm( output, hidden_states, router_weights, ) return output def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape): output_tensor = torch.empty( output_tensor_shape, device=input_tensor.device, dtype=input_tensor.dtype, ) torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor) return output_tensor def scaled_fp4_grouped_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, mask: torch.Tensor, ): """ Quantize input tensor to FP4 and return quantized tensor and scale, for grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). Args: input: The input tensor to be quantized to FP4, with shape (l, m, k) l is number of groups, m is number of tokens per group, k is number of features. input_global_scale: A scalar scaling factor for the entire tensor, with shape (l,). Outputs: output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into an uint8. output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) but the physical layout is (l, rm, rk, 32, 4, 4). Note: For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are required by the NVIDIA Blackwell MMA operations. """ from sglang.jit_kernel.nvfp4 import ( scaled_fp4_grouped_quant as jit_scaled_fp4_grouped_quant, ) return jit_scaled_fp4_grouped_quant(input_tensor, input_global_scale, mask) def silu_and_mul_scaled_fp4_grouped_quant( input_tensor: torch.Tensor, input_global_scale: torch.Tensor, mask: torch.Tensor, ): """ Quantize input tensor to FP4 and return quantized tensor and scale, for grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer). Args: input: The input tensor to be quantized to FP4, with shape (l, m, k * 2) l is number of groups, m is number of tokens per group, k is number of features. input_global_scale: A scalar scaling factor for the entire tensor, with shape (l,). mask: The mask tensor, with shape (l,) Outputs: output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into an uint8. output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l) but the physical layout is (l, rm, rk, 32, 4, 4). Note: For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128. `4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are required by the NVIDIA Blackwell MMA operations. """ from sglang.jit_kernel.nvfp4 import ( silu_and_mul_scaled_fp4_grouped_quant as jit_silu_and_mul_scaled_fp4_grouped_quant, ) return jit_silu_and_mul_scaled_fp4_grouped_quant( input_tensor, input_global_scale, mask, ) # GPTQ kernels def gptq_gemm( a: torch.Tensor, b_q_weight: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_g_idx: torch.Tensor, use_shuffle: bool, bit: int, ) -> torch.Tensor: return torch.ops.sgl_kernel.gptq_gemm( a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit ) def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None: torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)