CloverLM / vllm_plugin /quartet2_quant.py
mansaripo's picture
Upload folder using huggingface_hub
954e44f verified
raw
history blame
3.99 kB
import torch
import torch.nn.functional as F
from torch.nn import Parameter
from vllm.model_executor.layers.quantization import (
register_quantization_config,
)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.parameter import ModelWeightParameter
@register_quantization_config("quartet2")
class QuartetIIConfig(QuantizationConfig):
def get_name(self) -> str:
return "quartet2"
def get_supported_act_dtypes(self) -> list:
return [torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 100 # Blackwell (SM 10.0)
@staticmethod
def get_config_filenames() -> list[str]:
return []
@classmethod
def from_config(cls, config: dict) -> "QuartetIIConfig":
return cls()
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> QuantizeMethodBase | None:
if isinstance(layer, LinearBase):
return QuartetIILinearMethod(self)
return None
class QuartetIILinearMethod(LinearMethodBase):
def __init__(self, config: QuartetIIConfig):
self.config = config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: list[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight = ModelWeightParameter(
data=torch.empty(
output_size_per_partition,
input_size_per_partition,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=extra_weight_attrs.get("weight_loader"),
)
layer.register_parameter("weight", weight)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
from scipy.linalg import hadamard as scipy_hadamard
device = layer.weight.device
had_np = scipy_hadamard(128) * 128 ** -0.5
layer.had = torch.tensor(
had_np, dtype=torch.bfloat16, device=device, requires_grad=False,
)
layer.scratch_amax = torch.empty(
(), dtype=torch.uint32, device=device,
)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
from quartet2.quant import quant_fp4, NVFP4QuantMode
from quartet2.linear import abs_max, _fp4_mm
weight = layer.weight
orig_shape = x.shape
flat_x = x.reshape(-1, x.shape[-1])
# Quartet II requires rows to be multiples of 128; pad if needed.
num_rows = flat_x.shape[0]
remainder = num_rows % 128
if remainder != 0:
pad_rows = 128 - remainder
flat_x = F.pad(flat_x, (0, 0, 0, pad_rows))
else:
pad_rows = 0
input_amax = abs_max(flat_x)
weight_amax = abs_max(weight)
mode = NVFP4QuantMode.FOUR_SIX
scale_override = 1.0
input_fp4 = quant_fp4(
flat_x, amax=input_amax,
scale_override=scale_override, mode=mode,
)
weight_fp4 = quant_fp4(
weight, amax=weight_amax,
scale_override=scale_override, mode=mode,
)
alpha = input_fp4.tensor_scale * weight_fp4.tensor_scale
output = _fp4_mm(
input_fp4.fp4, weight_fp4.fp4,
input_fp4.micro_scales, weight_fp4.micro_scales,
alpha,
)
if pad_rows > 0:
output = output[:num_rows]
output = output.reshape(*orig_shape[:-1], output.shape[-1])
if bias is not None:
output = output + bias
return output