File size: 11,196 Bytes
d02d576 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 | from typing import Optional, Tuple
import torch
from sgl_kernel.utils import _get_cache_buf
def awq_dequantize(
qweight: torch.Tensor, scales: torch.Tensor, qzeros: torch.Tensor
) -> torch.ByteTensor:
return torch.ops.sgl_kernel.awq_dequantize.default(qweight, scales, qzeros)
def int8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.int8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def fp8_blockwise_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype):
return torch.ops.sgl_kernel.fp8_blockwise_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
)
def fp8_scaled_mm(mat_a, mat_b, scales_a, scales_b, out_dtype, bias=None):
return torch.ops.sgl_kernel.fp8_scaled_mm.default(
mat_a,
mat_b,
scales_a,
scales_b,
out_dtype,
bias,
)
def _bmm_fp8_internal(
workspace_buffer: torch.Tensor,
A: torch.Tensor,
B: torch.Tensor,
D: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
) -> None:
cublas_handle = torch.cuda.current_blas_handle()
torch.ops.sgl_kernel.bmm_fp8.default(
A,
B,
D,
A_scale,
B_scale,
workspace_buffer,
cublas_handle,
)
def bmm_fp8(
A: torch.Tensor,
B: torch.Tensor,
A_scale: torch.Tensor,
B_scale: torch.Tensor,
dtype: torch.dtype,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out is None:
out = torch.empty(
(A.shape[0], A.shape[1], B.shape[2]),
device=A.device,
dtype=dtype,
)
workspace_buffer = _get_cache_buf("bmm_fp8_workspace", 32 * 1024 * 1024, A.device)
_bmm_fp8_internal(workspace_buffer, A, B, out, A_scale, B_scale)
return out
def dsv3_fused_a_gemm(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if output is None:
output = torch.empty(
(mat_a.shape[0], mat_b.shape[1]),
device=mat_a.device,
dtype=mat_a.dtype,
)
torch.ops.sgl_kernel.dsv3_fused_a_gemm.default(output, mat_a, mat_b)
return output
def sgl_per_token_group_quant_8bit(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
group_size: int,
eps: float,
fp8_min: float,
fp8_max: float,
scale_ue8m0: bool = False,
fuse_silu_and_mul: bool = False,
masked_m: Optional[torch.Tensor] = None,
enable_v2: Optional[bool] = None,
) -> None:
if enable_v2 is None:
from sglang.srt.utils import get_bool_env_var
enable_v2 = get_bool_env_var("SGLANG_PER_TOKEN_GROUP_QUANT_8BIT_V2")
if enable_v2:
return torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit_v2.default(
input,
output_q,
output_s,
group_size,
eps,
fp8_min,
fp8_max,
scale_ue8m0,
fuse_silu_and_mul,
masked_m,
)
assert not fuse_silu_and_mul, "only v2 support fuse_silu_and_mul"
assert masked_m is None, "only v2 support masked_m"
torch.ops.sgl_kernel.sgl_per_token_group_quant_8bit.default(
input, output_q, output_s, group_size, eps, fp8_min, fp8_max, scale_ue8m0
)
# For legacy usage
sgl_per_token_group_quant_fp8 = sgl_per_token_group_quant_8bit
sgl_per_token_group_quant_int8 = sgl_per_token_group_quant_8bit
def sgl_per_tensor_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
is_static: bool,
) -> None:
torch.ops.sgl_kernel.sgl_per_tensor_quant_fp8.default(
input, output_q, output_s, is_static
)
def sgl_per_token_quant_fp8(
input: torch.Tensor,
output_q: torch.Tensor,
output_s: torch.Tensor,
) -> None:
torch.ops.sgl_kernel.sgl_per_token_quant_fp8.default(input, output_q, output_s)
def cutlass_scaled_fp4_mm(
a: torch.Tensor,
b: torch.Tensor,
block_scale_a: torch.Tensor,
block_scale_b: torch.Tensor,
alpha: torch.Tensor,
out_dtype: torch.dtype,
) -> torch.Tensor:
from sglang.jit_kernel.nvfp4 import (
cutlass_scaled_fp4_mm as jit_cutlass_scaled_fp4_mm,
)
return jit_cutlass_scaled_fp4_mm(
a,
b,
block_scale_a,
block_scale_b,
alpha,
out_dtype,
)
def scaled_fp4_quant(
input: torch.Tensor, input_global_scale: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantize input tensor to FP4 and return quantized tensor and scale.
This function quantizes the last dimension of the given tensor `input`. For
every 16 consecutive elements, a single dynamically computed scaling factor
is shared. This scaling factor is quantized using the `input_global_scale`
and is stored in a swizzled layout (see
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
Args:
input: The input tensor to be quantized to FP4
input_global_scale: A scalar scaling factor for the entire tensor.
Returns:
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
two values are packed into a uint8 and float8_e4m3 scaling factors
in a sizzled layout.
"""
from sglang.jit_kernel.nvfp4 import scaled_fp4_quant as jit_scaled_fp4_quant
return jit_scaled_fp4_quant(input, input_global_scale)
def qserve_w4a8_per_chn_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
w_szs: torch.Tensor,
a_ssums: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_chn_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_chn_gemm.default(
in_feats, kernel, wscales, ascales, w_szs, a_ssums, out_feats
)
return out_feats
def qserve_w4a8_per_group_gemm(
in_feats: torch.Tensor,
kernel: torch.Tensor,
zeros: torch.Tensor,
scales_i8: torch.Tensor,
wscales: torch.Tensor,
ascales: torch.Tensor,
out_feats: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if out_feats is None:
# NOTE(HandH1998): qserve_w4a8_per_group_gemm only supports out dtype=torch.float16 now
out_feats = torch.empty(
(in_feats.shape[0], kernel.shape[0]),
device=in_feats.device,
dtype=torch.float16,
)
torch.ops.sgl_kernel.qserve_w4a8_per_group_gemm.default(
in_feats, kernel, zeros, scales_i8, wscales, ascales, out_feats
)
return out_feats
def dsv3_router_gemm(
hidden_states: torch.Tensor,
router_weights: torch.Tensor,
out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
output = torch.empty(
hidden_states.shape[0],
router_weights.shape[0],
device=hidden_states.device,
dtype=out_dtype,
)
torch.ops.sgl_kernel.dsv3_router_gemm(
output,
hidden_states,
router_weights,
)
return output
def shuffle_rows(input_tensor, dst2src_map, output_tensor_shape):
output_tensor = torch.empty(
output_tensor_shape,
device=input_tensor.device,
dtype=input_tensor.dtype,
)
torch.ops.sgl_kernel.shuffle_rows.default(input_tensor, dst2src_map, output_tensor)
return output_tensor
def scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
from sglang.jit_kernel.nvfp4 import (
scaled_fp4_grouped_quant as jit_scaled_fp4_grouped_quant,
)
return jit_scaled_fp4_grouped_quant(input_tensor, input_global_scale, mask)
def silu_and_mul_scaled_fp4_grouped_quant(
input_tensor: torch.Tensor,
input_global_scale: torch.Tensor,
mask: torch.Tensor,
):
"""
Quantize input tensor to FP4 and return quantized tensor and scale, for
grouped gemm inputs (e.g., grouped_gemm_nt_masked for flashinfer).
Args:
input: The input tensor to be quantized to FP4, with shape (l, m, k * 2)
l is number of groups, m is number of tokens per group, k is number of features.
input_global_scale: A scalar scaling factor for the entire tensor, with
shape (l,).
mask: The mask tensor, with shape (l,)
Outputs:
output: The quantized tensor in FP4, with shape (m, k // 2, l) but the physical
layout is (l, m, k // 2). `// 2` is because two fp4 values are packed into
an uint8.
output_scales: The blockscale tensor in FP8-E4M3, with shape (32, 4, rm, 4, rk, l)
but the physical layout is (l, rm, rk, 32, 4, 4).
Note:
For the shape of output_scales, `32 * 4 * rm` is a padded m to nearest multiple of 128.
`4 * rk` is a padded `k // 16` to nearest multiple of 4. These layout constants are
required by the NVIDIA Blackwell MMA operations.
"""
from sglang.jit_kernel.nvfp4 import (
silu_and_mul_scaled_fp4_grouped_quant as jit_silu_and_mul_scaled_fp4_grouped_quant,
)
return jit_silu_and_mul_scaled_fp4_grouped_quant(
input_tensor,
input_global_scale,
mask,
)
# GPTQ kernels
def gptq_gemm(
a: torch.Tensor,
b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor,
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_shuffle: bool,
bit: int,
) -> torch.Tensor:
return torch.ops.sgl_kernel.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_shuffle, bit
)
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, bit: int) -> None:
torch.torch.ops.sgl_kernel.gptq_shuffle(q_weight, q_perm, bit)
|