Lekr0's picture
Add files using upload-large-folder tool
d02d576 verified
from typing import Optional, Tuple
import torch
from sgl_kernel.utils import _get_cache_buf
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8.default(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_8bit(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
) -> None:
if enable_v2 is None:
from sglang.srt.utils import get_bool_env_var
enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2")
if enable_v2:
return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul"
assert masked_m is None, "only v2 support masked_m"
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
# For legacy usage
sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit
sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
input, output_q, output_s, is_static
)
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
def cutlass_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
from sglang.jit_kernel.nvfp4 import (
cutlass_scaled_fp4_mm as jit_cutlass_scaled_fp4_mm,
)
return jit_cutlass_scaled_fp4_mm(
a,
b,
block_scale_a,
block_scale_b,
alpha,
out_dtype,
)
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as jit_scaled_fp4_quant
return jit_scaled_fp4_quant(input, input_global_scale)
def qserve_w4a8_per_chn_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
w_szs: torch.Tensor,
a_ssums: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
)
return out_feats
def qserve_w4a8_per_group_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
zeros: torch.Tensor,
scales_i8: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
from sglang.jit_kernel.nvfp4 import (
scaled_fp4_grouped_quant as jit_scaled_fp4_grouped_quant,
)
return jit_scaled_fp4_grouped_quant(input_tensor, input_global_scale, mask)
def silu_and_mul_scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
from sglang.jit_kernel.nvfp4 import (
silu_and_mul_scaled_fp4_grouped_quant as jit_silu_and_mul_scaled_fp4_grouped_quant,
)
return jit_silu_and_mul_scaled_fp4_grouped_quant(
input_tensor,
input_global_scale,
mask,
)
# GPTQ kernels
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)