leideng/QCFuse / srt /layers /quantization /marlin_utils_fp8.py
leideng's picture
download
raw
12.5 kB
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import Optional
import torch
from sglang.srt.layers.quantization.marlin_utils import (
USE_FP32_REDUCE_DEFAULT,
marlin_make_workspace,
marlin_permute_bias,
marlin_permute_scales,
should_use_atomic_add_reduce,
)
from sglang.srt.layers.quantization.utils import get_scalar_types
from sglang.srt.utils import is_cuda
_is_cuda = is_cuda()
if _is_cuda:
from sgl_kernel import gptq_marlin_gemm, gptq_marlin_repack
ScalarType, scalar_types = get_scalar_types()
logger = logging.getLogger(__name__)
def fp8_fused_exponent_bias_into_scales(scales):
fp8_exponent = 4
if scales.dtype == torch.half:
target_exponent = 5
elif scales.dtype == torch.bfloat16:
target_exponent = 8
# exponent_bias_fp16 = 2 ** 4 - 2 ** 3 = 8
# exponent_bias_bf16 = 2 ** 7 - 2 ** 3 = 120
exponent_bias = 2 ** (target_exponent - 1) - 2 ** (fp8_exponent - 1)
s = torch.ones_like(scales) * 2
s = s**exponent_bias
return scales * s
def apply_fp8_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
workspace: torch.Tensor,
size_n: int,
size_k: int,
bias: Optional[torch.Tensor],
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
) -> torch.Tensor:
# For GPUs that lack FP8 hardware support, we can leverage the
# Marlin kernel for fast weight-only FP8 quantization
reshaped_x = input.reshape(-1, input.shape[-1])
out_shape = input.shape[:-1] + (size_n,)
use_atomic_add = should_use_atomic_add_reduce(
m=reshaped_x.size(0), n=size_n, k=size_k, device=input.device, dtype=input.dtype
)
output = gptq_marlin_gemm(
a=reshaped_x,
c=None,
b_q_weight=weight,
b_bias=bias,
b_scales=weight_scale,
global_scale=None,
b_zeros=None,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.float8_e4m3fn,
size_m=reshaped_x.size(0),
size_n=size_n,
size_k=size_k,
use_atomic_add=use_atomic_add,
use_fp32_reduce=use_fp32_reduce,
)
return output.reshape(out_shape)
def prepare_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
weight_block_size = getattr(layer, "weight_block_size", None)
if size_k_first:
assert layer.weight.shape == (part_size_k, part_size_n)
else:
assert layer.weight.shape == (part_size_n, part_size_k)
device = layer.weight.device
# WORKSPACE
layer.workspace = marlin_make_workspace(device)
# WEIGHT
# Repack weights to marlin format
perm = torch.empty(0, dtype=torch.int, device=device)
qweight = pack_fp8_to_int32(layer.weight, size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = gptq_marlin_repack(
b_q_weight=qweight,
perm=perm,
size_k=part_size_k,
size_n=part_size_n,
num_bits=8,
)
layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False)
# WEIGHT SCALES
# Permute scales
if "weight_scale" in dir(layer):
scales = layer.weight_scale.to(layer.orig_dtype)
elif "weight_scale_inv" in dir(layer):
scales = layer.weight_scale_inv.to(layer.orig_dtype)
del layer.weight_scale_inv
group_size = -1 if weight_block_size is None else weight_block_size[1]
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if weight_block_size is None:
if scales.nelement() == 1:
# tensor-wise quantization -> channel-wise quantization
# (1, 1) =>(repeat)=> (1, size_n)
scales = scales.view(1, 1).repeat_interleave(part_size_n, 1)
elif scales.nelement() > 1 and scales.nelement() != part_size_n:
assert part_size_n % scales.nelement() == 0
s_size = scales.nelement()
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (1, s_size) =>(repeat)=> (1, size_n)
scales = scales.view(1, s_size)
scales = scales.repeat_interleave(part_size_n // s_size, 1)
else:
# channel-wise quantization
# (1, size_n)
scales = scales.view(1, part_size_n)
else:
# block-wise quantization -> group-wise quantization
# (size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.T.contiguous()
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 1)
# size_n may not divisible by block_size[0]
scales = scales[:, :part_size_n]
marlin_scales = marlin_permute_scales(
s=scales, size_k=part_size_k, size_n=part_size_n, group_size=group_size
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False)
if hasattr(layer, "bias") and layer.bias is not None:
assert layer.bias.shape == (part_size_n,)
bias = marlin_permute_bias(layer.bias)
layer.bias = torch.nn.Parameter(bias, requires_grad=False)
def prepare_moe_fp8_layer_for_marlin(
layer: torch.nn.Module, size_k_first: bool = True
) -> None:
logger.warning_once(
"Your GPU does not have native support for FP8 computation but "
"FP8 quantization is being used. Weight-only FP8 compression will "
"be used leveraging the Marlin kernel. This may degrade "
"performance for compute-heavy workloads."
)
e = layer.num_experts
k = layer.hidden_size
n = layer.intermediate_size_per_partition
weight_block_size = getattr(layer, "weight_block_size", None)
# WORKSPACE
device = layer.w13_weight.device
layer.workspace = marlin_make_workspace(device, 4)
perm = torch.empty(0, dtype=torch.int, device=device)
# WEIGHT
# Repack weights to marlin format
for name in ["w13_weight", "w2_weight"]:
weight = getattr(layer, name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
if size_k_first:
assert weight.shape == (e, size_k, size_n)
else:
assert weight.shape == (e, size_n, size_k)
for i in range(e):
qweight = pack_fp8_to_int32(weight[i], size_k_first)
if not size_k_first:
qweight = qweight.T.contiguous()
marlin_qweight = gptq_marlin_repack(
b_q_weight=qweight, perm=perm, size_k=size_k, size_n=size_n, num_bits=8
)
tensor_list.append(marlin_qweight)
weight = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
weight = torch.nn.Parameter(weight, requires_grad=False)
setattr(layer, name, weight)
# WEIGHT SCALES
# Permute scales
group_size = -1 if weight_block_size is None else weight_block_size[1]
for name in ["w13", "w2"]:
if name + "_weight_scale" in dir(layer):
new_name = name + "_weight_scale"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
elif name + "_weight_scale_inv" in dir(layer):
new_name = name + "_weight_scale_inv"
scales = getattr(layer, new_name).to(layer.orig_dtype)
delattr(layer, new_name)
tensor_list = []
if "w13" in name:
size_n, size_k = n * 2, k
else:
size_n, size_k = k, n
# marlin kernel only support channel-wise and group-wise quantization
# we need to convert the scales
if weight_block_size is None:
if scales.nelement() == e:
# tensor-wise quantization -> channel-wise quantization
# (e, 1, 1) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, 1).repeat_interleave(size_n, 2)
elif scales.nelement() > e and scales.nelement() != e * size_n:
assert (e * size_n) % scales.nelement() == 0
s_size = scales.nelement() // e
# tensor-wise quantization (for gate-up proj)
# -> channel-wise quantization
# (e, 1, s_size) =>(repeat)=> (e, 1, size_n)
scales = scales.view(e, 1, s_size)
scales = scales.repeat_interleave(size_n // s_size, 2)
else:
# channel-wise quantization
# (e, 1, size_n)
scales = scales.view(e, 1, size_n)
else:
# block-wise quantization -> group-wise quantization
# (e, size_k // block_size[1], ceil(size_n / block_size[0]))
# =>(repeat)=> (e, size_k // block_size[1], size_n)
if not size_k_first:
scales = scales.permute(0, 2, 1)
block_n = weight_block_size[0]
scales = scales.repeat_interleave(block_n, 2)
# size_n may not divisible by block_size[0]
scales = scales[..., :size_n].contiguous()
for i in range(e):
marlin_scales = marlin_permute_scales(
s=scales[i], size_k=size_k, size_n=size_n, group_size=group_size
)
tensor_list.append(marlin_scales)
scales = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
scales = fp8_fused_exponent_bias_into_scales(scales)
scales = torch.nn.Parameter(scales, requires_grad=False)
setattr(layer, name + "_weight_scale", scales)
# BIAS
# Permute bias
for name in ["w13_bias", "w2_bias"]:
if not hasattr(layer, name):
continue
bias = getattr(layer, name).to(layer.orig_dtype)
tensor_list = []
for i in range(e):
expert_bias = bias[i]
tensor_list.append(marlin_permute_bias(expert_bias))
bias = torch.cat([x.unsqueeze(0) for x in tensor_list], 0)
bias = torch.nn.Parameter(bias, requires_grad=False)
setattr(layer, name, bias)
def pack_fp8_to_int32(
fp8_tensor: torch.Tensor, size_k_first: bool = True
) -> torch.Tensor:
"""
Repack FP8 weights to gptq format (packed int32 elements)
"""
assert fp8_tensor.dtype == torch.float8_e4m3fn
assert fp8_tensor.ndim == 2
fp8_tensor = fp8_tensor.T if size_k_first else fp8_tensor
fp8_tensor = fp8_tensor.contiguous()
# fp8_tensor is contiguous and have shape (N, K) now
# with `.view(torch.int32)`, it become (N, K // 4)
int32_tensor = fp8_tensor.view(torch.int32)
return int32_tensor.T.contiguous() if size_k_first else int32_tensor
def marlin_quant_fp8_torch(weight, group_size):
size_n, size_k = weight.shape
device = weight.device
if group_size != -1:
scales = weight.view(size_n, -1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(group_size, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
else:
scales = weight.view(size_n, 1, group_size).abs().max(-1)[0] / 448
repeated_scales = scales.repeat_interleave(size_k, 1)
fp8_weight = (weight / repeated_scales).to(torch.float8_e4m3fn)
weight_ref = fp8_weight.to(weight.dtype) * repeated_scales
packed_weight = pack_fp8_to_int32(fp8_weight, False).T.contiguous()
marlin_qweight = gptq_marlin_repack(
b_q_weight=packed_weight,
perm=torch.empty(0, dtype=torch.int, device=device),
size_k=size_k,
size_n=size_n,
num_bits=8,
)
marlin_scales = marlin_permute_scales(
s=scales.T, size_k=size_k, size_n=size_n, group_size=group_size
)
marlin_scales = fp8_fused_exponent_bias_into_scales(marlin_scales)
return weight_ref.T, marlin_qweight, marlin_scales

Xet Storage Details

Size:
12.5 kB
·
Xet hash:
609f0e23e1cd1ed45daefca23978b35e57908aaa034358a770881726e14546d5

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.