diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1c45d3a183bdadbececfad8ed2865e48d46569 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..595c2aeed7ce70e3de6137e5432e35b03f8be40e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/aqlm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9faea7b2b781bf0d087fb2715bfb8c515791f81b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b36eba5290baa5edd2844002a01b296e65dd0585 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_marlin.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d160ef1796bd1cd86e0b4270609941fa30362f4 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/awq_triton.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f6cff88f00c7e85e8e7e22baf4cb30cbfeaaa01 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/base_config.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3eadaa158e70cb2def8f4f7d906aeb5d9db6b41 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/bitsandbytes.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9a1480c6a7381317457c9707c3ee0015e3391e7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/deepspeedfp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3edbf30104b9b595b9ced05986c42dffb9d9bdf0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/experts_int8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed3463e2124bc382013f954986375751039e0f9c Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fbgemm_fp8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d532899dbe4f2fd4e15675eb2c6058cb6098d6b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/fp8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65712af20c96314ae90d46931dcb87f6169d1ad2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gguf.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..811abff5ea43d51d28159f85d2417d547a0b5468 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fe1cf37e781ef561b0600d8e4b0de19406244cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9af60ce65649ebee97aaf08a7162974bdec0ea84 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/gptq_marlin_24.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/hqq_marlin.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/hqq_marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cdefaf96f6e5be3f4aa5bdbb53aee27872d28bf6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/hqq_marlin.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c21a583121d29d58c5ed092fc242235e53e8158 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/ipex_quant.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..826df110fe2ca01fe5d7a9b3a50950ea08796a95 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/kv_cache.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..515792f4c479d8f381156f0f04b071d8dff623e3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/marlin.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3c6fed30b890462e9d0e04b5f0ffa99ecc065f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/modelopt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/moe_wna16.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/moe_wna16.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25e428eeee47a11973facf70fb31c2e220cb46ba Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/moe_wna16.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ccb368023c92426c5a6bd7814978c9de7aa2b3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/neuron_quant.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef333a9121f4cabb7f9dec22e5a2db03b412a453 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/qqq.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfd2d39805f6d9647ca71490549ee5232a82cac2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/schema.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4987954d3af4c1900d146c9985dd76166f75ea66 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/__pycache__/tpu_int8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3190cba5e47ebdfd21dcbfd9108627d4e85a279e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py new file mode 100644 index 0000000000000000000000000000000000000000..c06befaf3b5ad877a9949741916a829138ed90ef --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py @@ -0,0 +1,89 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple + +import torch + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.scalar_type import ScalarType + + +@dataclass +class MPLinearLayerConfig: + full_weight_shape: Tuple[int, int] # [in, out] + partition_weight_shape: Tuple[int, int] + weight_type: ScalarType + act_type: torch.dtype + group_size: int + zero_points: bool + has_g_idx: bool + + +class MPLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, + c: MPLinearLayerConfig, + w_q_param_name: str, + w_s_param_name: str, + w_zp_param_name: Optional[str] = None, + w_gidx_param_name: Optional[str] = None) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + if c.zero_points: + assert w_zp_param_name is not None + if c.has_g_idx: + assert w_gidx_param_name is not None + self.w_zp_name = w_zp_param_name + self.w_gidx_name = w_gidx_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _transform_param(self, layer: torch.nn.Module, name: Optional[str], + fn: Callable) -> None: + if name is not None and getattr(layer, name, None) is not None: + + old_param = getattr(layer, name) + new_param = fn(old_param) + # replace the parameter with torch.nn.Parameter for TorchDynamo + # compatibility + replace_parameter( + layer, name, + torch.nn.Parameter(new_param.data, requires_grad=False)) + + def _get_weight_params( + self, layer: torch.nn.Module) -> Tuple[ + torch.Tensor, # w_q + torch.Tensor, # w_s + Optional[torch.Tensor], # w_zp, + Optional[torch.Tensor] # w_gidx + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.w_zp_name or "", None), + getattr(layer, self.w_gidx_name or "", None), + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bcfdb1677716656d1de72159bddeba967cee5aed --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py @@ -0,0 +1,76 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Type + +import vllm.envs as envs +from vllm.model_executor.layers.quantization.kernels.mixed_precision.exllama import ( # noqa: E501 + ExllamaLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.machete import ( # noqa: E501 + MacheteLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.marlin import ( # noqa: E501 + MarlinLinearKernel) +from vllm.model_executor.layers.quantization.kernels.mixed_precision.MPLinearKernel import ( # noqa: E501 + MPLinearKernel, MPLinearLayerConfig) +from vllm.platforms import current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [ + MacheteLinearKernel, + MarlinLinearKernel, + ExllamaLinearKernel, +] + + +def choose_mp_linear_kernel( + config: MPLinearLayerConfig, + compute_capability: Optional[int] = None) -> Type[MPLinearKernel]: + """ + Choose an MPLinearKernel that can implement the given config for the given + compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (MPLinearLayerConfig): Description of the linear layer to be + implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the compute + capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[MPLinearKernel]: Chosen kernel. + """ + if compute_capability is None: + if current_platform is None: + raise ValueError("Cannot determine compute capability") + _cc = current_platform.get_device_capability() + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS: + if kernel.__name__ in envs.VLLM_DISABLED_KERNELS: + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + if kernel.get_min_capability() > compute_capability: + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel.get_min_capability()}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "WNA16 linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/MPLinearKernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/MPLinearKernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1260e660720f7a07dec20956a8da145064bc196 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/MPLinearKernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c80945af01eab7e247d7d4ae678da379e5a1606 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/exllama.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/exllama.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7938e119095d0406066f49b85b186b2eed7e46a0 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/exllama.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/machete.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/machete.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..322795e31fa9a2668336aae8ef8b4147432bfc25 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/machete.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/marlin.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/marlin.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8345419583a1f54c19d05539d061c3801349970f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/__pycache__/marlin.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py new file mode 100644 index 0000000000000000000000000000000000000000..2706fbb539ab4e7d6c54526b5f51f930b57841a5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) +from vllm.scalar_type import scalar_types + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class ExllamaLinearKernel(MPLinearKernel): + SUPPORTED_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] + # In theory supports `scalar_types.uint2b2, scalar_types.uint3b4` too but + # currently untested so not added to the list + + @classmethod + def get_min_capability(cls) -> int: + return 60 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Exllama, "\ + "when the input features are partitioned across "\ + "devices" + + if c.partition_weight_shape[1] % (32 // c.weight_type.size_bits) != 0: + return False, "Output features must be a multiple of the pack " \ + "factor (32 / num_bits) so that we can correctly " \ + "pack the zero points" + + if c.act_type != torch.float16: + return False, "Exllama only supports float16 activations" + + if c.weight_type not in cls.SUPPORTED_QUANT_TYPES: + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Exllama, supported types are: "\ + f"{cls.SUPPORTED_QUANT_TYPES}" + + if c.full_weight_shape[0] % c.group_size != 0: + return False, f"Group size ({c.group_size}) does not evenly divide"\ + " the number of input features "\ + f"({c.full_weight_shape[0]})" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + # For Exllama, we need to set a zero-point tensor if there is not one + if not c.zero_points: + self.w_zp_name = "qzeros" + device = getattr(layer, self.w_q_name).device + groups = c.partition_weight_shape[0] // c.group_size + out_features = c.partition_weight_shape[1] + + if c.weight_type.has_bias(): + # if the type has a bias we have to create a zeros tensor that + # contains the bias values repeated for each group (-1 due to + # a bug in the original GPTQ checkpoint format leading to + # exllama kernel adding 1 to the zero points during inference) + # Documentation of the bug can be found here: + # https://garden.danieldk.eu/GPTQ-Checkpoint-Format + zeros = torch.full((groups, out_features), + c.weight_type.bias - 1, + dtype=torch.int32, + device=device) + else: + raise NotImplementedError( + "A 0 zero-point is not supported by Exllama due to " + "a bug in the original GPTQ checkpoint format leading to " + "exllama kernel adding 1 to the zero points during " + "inference") + zeros = pack_quantized_values_into_int32(zeros, + c.weight_type, + packed_dim=1) + setattr(layer, self.w_zp_name, + torch.nn.Parameter(zeros, requires_grad=False)) + + if c.has_g_idx: + + def transform_w_g_idx(x): + # Exllama wants the permutation array instead of the group + # indices + return torch.argsort(x).to(torch.int) + + self._transform_param(layer, self.w_gidx_name, transform_w_g_idx) + else: + self.w_gidx_name = "g_idx" + empty_g_idx = torch.nn.Parameter(torch.empty((0, ), + dtype=torch.int, + device=device), + requires_grad=False) + setattr(layer, self.w_gidx_name, empty_g_idx) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + assert self.w_gidx_name is not None + g_idx = getattr(layer, self.w_gidx_name) + + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x_cont = x.data.contiguous() + ops.gptq_shuffle(x_cont, g_idx, c.weight_type.size_bits) + return x_cont + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x.to(dtype=c.act_type) + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer) + + assert w_zp is not None, "Zero points are required by Exllama" + assert w_g_idx is not None, "Group index is required by Exllama" + output = ops.gptq_gemm(x_2d, w_q, w_zp, w_s, w_g_idx, True, + c.weight_type.size_bits) + + if bias is not None: + output.add_(bias) + return output.reshape(out_shape) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0586f6e30d6a02f95a36196b88f96d16a1a15e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 + +from functools import partial +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.machete_utils import ( + MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape, + query_machete_supported_quant_types) +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + pack_quantized_values_into_int32, unpack_quantized_values_into_int32) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MacheteLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 90 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.has_g_idx and\ + c.partition_weight_shape[0] != c.full_weight_shape[0]: + return False, "Act reordering currently not supported by Machete, "\ + "when the input features are partitioned across "\ + "devices" + + if c.zero_points: + return False, "Zero points currently not supported by "\ + " Compressed Tensors + Machete. (Kernel supports it"\ + " but CompressedTensorsWNA16 does not so support has"\ + " not been added to MacheteWNA16Kernel yet" + + if c.weight_type not in query_machete_supported_quant_types( + c.zero_points): + return False, f"Quant type ({c.weight_type}) not supported by "\ + "Machete, supported types are: "\ + f"{query_machete_supported_quant_types(c.zero_points)}" + + if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Machete, supported group sizes are: "\ + f"{MACHETE_SUPPORTED_GROUP_SIZES}" + + return check_machete_supports_shape(c.partition_weight_shape[0], + c.partition_weight_shape[1]) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module): + c = self.config + + if c.has_g_idx: + assert self.w_gidx_name is not None + perm = torch.argsort(getattr(layer, self.w_gidx_name))\ + .to(torch.int) + + self.act_perm = lambda x: x[:, perm] + # use `ops.permute_cols` if possible + if c.act_type in [torch.float16, torch.bfloat16] \ + and c.partition_weight_shape[0] % 8 == 0: + self.act_perm = partial(ops.permute_cols, perm=perm) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + if c.has_g_idx: + x_unpacked = unpack_quantized_values_into_int32(x.data, + c.weight_type, + packed_dim=0) + x_perm = x_unpacked[perm, :] + x.data = pack_quantized_values_into_int32(x_perm, + c.weight_type, + packed_dim=0) + x.data = ops.machete_prepack_B(x.data.t().contiguous().t(), + a_type=c.act_type, + b_type=c.weight_type, + group_scales_type=c.act_type) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = x.data.contiguous() + return x + + # Repack weights and scales for Machete + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, _, _ = self._get_weight_params(layer) + + x_2d = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (c.partition_weight_shape[1], ) + + if c.has_g_idx: + x_2d = self.act_perm(x_2d) + + output = ops.machete_mm(a=x_2d, + b_q=w_q, + b_type=c.weight_type, + b_group_zeros=None, + b_group_scales=w_s, + b_group_size=c.group_size) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py new file mode 100644 index 0000000000000000000000000000000000000000..e21801cf6a7857ae700c30f2f0d15484993e2044 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py @@ -0,0 +1,135 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils.marlin_utils import ( + MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear, + check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx, + marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx, + query_marlin_supported_quant_types) +from vllm.model_executor.parameter import (BasevLLMParameter, + permute_param_layout_) + +from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig + + +class MarlinLinearKernel(MPLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def can_implement(cls, + c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if c.zero_points: + return False, "Zero points currently not supported by "\ + " MarlinLinearKernel. Will be added when AWQMarlin "\ + "is migrated over to using MPLinearKernel backend" + + quant_types = query_marlin_supported_quant_types(c.zero_points) + if c.weight_type not in quant_types: + return False, f"Quant type ({c.weight_type}) not supported by"\ + f" Marlin, supported types are: {quant_types}" + + if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + return False, f"Group size ({c.group_size}) not supported by "\ + "Marlin, supported group sizes are: "\ + f"{MARLIN_SUPPORTED_GROUP_SIZES}" + + return check_marlin_supports_shape( + c.partition_weight_shape[1], # out_features + c.partition_weight_shape[0], # in_features + c.full_weight_shape[0], # in_features + c.group_size) + + # note assumes that + # `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0} + # `weight_scale` is: {input_dim = 0, output_dim = 1} + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + device = getattr(layer, self.w_q_name).device + c = self.config + + row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0]) + self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel) + + # Allocate marlin workspace. + self.workspace = marlin_make_workspace(c.partition_weight_shape[1], + device) + + # Default names since marlin requires empty parameters for these, + # TODO: remove this requirement from marlin (allow optional tensors) + if self.w_gidx_name is None: + self.w_gidx_name = "g_idx" + if self.w_zp_name is None: + self.w_zp_name = "w_zp" + + if c.has_g_idx: + g_idx, g_idx_sort_indices = marlin_sort_g_idx( + getattr(layer, self.w_gidx_name)) + self._transform_param(layer, self.w_gidx_name, lambda _: g_idx) + layer.g_idx_sort_indices = g_idx_sort_indices + else: + setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device)) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + if c.zero_points: + pass + # TODO (lucas): add the following when AWQMarlin is migrated over to + # using MPLinearKernel backend + # self._transform_param(layer, self.w_zp_name, lambda x: \ + # marlin_zero_points( + # x, + # size_k=c.partition_weight_shape[0], + # size_n=c.partition_weight_shape[1], + # num_bits=c.weight_type.size_bits)) + else: + setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device)) + + def transform_w_q(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0) + x.data = ops.gptq_marlin_repack(x.data.contiguous(), + perm=layer.g_idx_sort_indices, + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + num_bits=c.weight_type.size_bits) + return x + + def transform_w_s(x): + assert isinstance(x, BasevLLMParameter) + permute_param_layout_(x, input_dim=0, output_dim=1) + x.data = marlin_permute_scales(x.data.contiguous(), + size_k=c.partition_weight_shape[0], + size_n=c.partition_weight_shape[1], + group_size=c.group_size) + return x + + self._transform_param(layer, self.w_q_name, transform_w_q) + self._transform_param(layer, self.w_s_name, transform_w_s) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + c = self.config + w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer) + + # `process_weights_after_loading` will ensure w_zp and w_gidx are not + # None for marlin + return apply_gptq_marlin_linear( + input=x, + weight=w_q, + weight_scale=w_s, + weight_zp=w_zp, # type: ignore + g_idx=w_gidx, # type: ignore + g_idx_sort_indices=layer.g_idx_sort_indices, + workspace=self.workspace, + wtype=c.weight_type, + input_size_per_partition=c.partition_weight_shape[0], + output_size_per_partition=c.partition_weight_shape[1], + is_k_full=self.is_k_full, + bias=bias) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py new file mode 100644 index 0000000000000000000000000000000000000000..91e7654053f9d1ad1b8a5539f41a4d30a36a891e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + + +@dataclass +class ScaledMMLinearLayerConfig: + is_channelwise: bool + is_static_input_scheme: bool + input_symmetric: bool + + +class ScaledMMLinearKernel(ABC): + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + raise NotImplementedError + + @classmethod + @abstractmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + raise NotImplementedError + + def __init__(self, c: ScaledMMLinearLayerConfig, w_q_param_name: str, + w_s_param_name: str, i_s_param_name: str, + i_zp_param_name: str, azp_adj_param_name: str) -> None: + assert self.can_implement(c) + self.config = c + self.w_q_name = w_q_param_name + self.w_s_name = w_s_param_name + self.i_s_name = i_s_param_name + self.i_zp_name = i_zp_param_name + self.azp_adj_name = azp_adj_param_name + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + raise NotImplementedError + + def _get_weight_params( + self, layer: torch.nn.Module) -> Tuple[ + torch.Tensor, # weight + torch.Tensor, # weight_scale + Optional[torch.Tensor], # input_scale, + Optional[torch.Tensor], # input_zp + Optional[torch.Tensor], # azp_adj + ]: + return ( + getattr(layer, self.w_q_name), + getattr(layer, self.w_s_name), + getattr(layer, self.i_s_name), + getattr(layer, self.i_zp_name), + getattr(layer, self.azp_adj_name), + ) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a5967995ac88d8e18d69adff5910fe671492a1a2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py @@ -0,0 +1,84 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os +from typing import Dict, List, Optional, Type + +from vllm.model_executor.layers.quantization.kernels.scaled_mm.cutlass import ( + CutlassScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.ScaledMMLinearKernel import ( # noqa: E501 + ScaledMMLinearKernel, ScaledMMLinearLayerConfig) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.triton import ( + TritonScaledMMLinearKernel) +from vllm.model_executor.layers.quantization.kernels.scaled_mm.xla import ( + XLAScaledMMLinearKernel) +from vllm.platforms import PlatformEnum, current_platform + +# in priority/performance order (when available) +_POSSIBLE_KERNELS: Dict[PlatformEnum, List[Type[ScaledMMLinearKernel]]] = { + PlatformEnum.CPU: [CutlassScaledMMLinearKernel], + PlatformEnum.CUDA: [CutlassScaledMMLinearKernel], + PlatformEnum.ROCM: [TritonScaledMMLinearKernel], + PlatformEnum.TPU: [XLAScaledMMLinearKernel], +} + + +def choose_scaled_mm_linear_kernel( + config: ScaledMMLinearLayerConfig, + compute_capability: Optional[int] = None +) -> Type[ScaledMMLinearKernel]: + """ + Choose an ScalledMMLinearKernel that can implement the given config for the + given compute capability. Attempts to choose the best kernel in terms of + performance. + + Args: + config (ScaledMMLinearLayerConfig): Description of the linear layer + to be implemented. + compute_capability (Optional[int], optional): The compute capability of + the target device, if None uses `current_platform` to get the + compute capability. Defaults to None. + + Raises: + ValueError: If no kernel can implement the given config. + + Returns: + Type[ScaledMMLinearKernel]: Chosen kernel. + """ + + if compute_capability is None: + _cc = current_platform.get_device_capability() + if _cc is not None: + compute_capability = _cc[0] * 10 + _cc[1] + + failure_reasons = [] + for kernel in _POSSIBLE_KERNELS[current_platform._enum]: + if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\ + .split(","): + failure_reasons.append( + f' {kernel.__name__} disabled by environment variable') + continue + + # If the current platform uses compute_capability, + # make sure the kernel supports the compute cability. + if compute_capability is not None: + kernel_min_capability = kernel.get_min_capability() + if (kernel_min_capability is not None + and kernel_min_capability > compute_capability): + failure_reasons.append( + f"{kernel.__name__} requires capability " + f"{kernel_min_capability}, current compute capability " + f"is {compute_capability}") + continue + + can_implement, failure_reason = kernel.can_implement(config) + if can_implement: + return kernel + else: + failure_reasons.append( + f' {kernel.__name__} cannot implement due to: {failure_reason}' + ) + + raise ValueError( + "Failed to find a kernel that can implement the "\ + "ScaledMM linear layer. Reasons: \n" + + '\n'.join(failure_reasons)) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/ScaledMMLinearKernel.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/ScaledMMLinearKernel.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1093e997dc9b286468344ed4a05b59f85dc44d91 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/ScaledMMLinearKernel.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7391729f5f48c6c44d04cc31ef18cfb93e4a2800 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/cutlass.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/cutlass.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b90f63bd97e6feeae15f6b65972570fed1528c41 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/cutlass.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/triton.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/triton.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1be369045d9774a6e5cd61e18bd5833f4e86dc9b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/triton.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/xla.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/xla.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d93212ba713b51dc957964f9f4531e36bf5795cc Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/__pycache__/xla.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf21a05c46d9ef614ec8b25eeed75038e579fb5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if (not current_platform.is_cuda() and not current_platform.is_cpu()): + return False, "CutlassScaledMM requires running on CUDA or CPU." + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # Cutlass kernels need transposed weight. + weight = getattr(layer, self.w_q_name) + replace_parameter( + layer, self.w_q_name, + torch.nn.Parameter(weight.t().data, requires_grad=False)) + + # WEIGHT SCALE + # Cutlass kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # INPUT SCALE + if self.config.is_static_input_scheme: + input_scale = getattr(layer, self.i_s_name) + + if self.config.input_symmetric: + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(input_scale.max(), requires_grad=False)) + setattr(layer, self.i_zp_name, None) + else: + input_zero_point = getattr(layer, self.i_zp_name) + + # reconstruct the ranges + int8_traits = torch.iinfo(torch.int8) + azps = input_zero_point.to(dtype=torch.int32) + range_max = (input_scale * (int8_traits.max - azps)).max() + range_min = (input_scale * (int8_traits.min - azps)).min() + + scale = (range_max - range_min) / (int8_traits.max - + int8_traits.min) + replace_parameter( + layer, self.i_s_name, + torch.nn.Parameter(scale, requires_grad=False)) + + # AZP loaded as int8 but used as int32 + azp = (int8_traits.min - + range_min / scale).to(dtype=torch.int32) + replace_parameter(layer, self.i_zp_name, + torch.nn.Parameter(azp, requires_grad=False)) + + else: + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + + # azp_adj is the AZP adjustment term, used to account for weights. + # It does not depend on scales or azp, so it is the same for + # static and dynamic quantization. + # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md + # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md + if not self.config.input_symmetric: + weight = getattr(layer, self.w_q_name) + azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32) + if self.config.is_static_input_scheme: + # cutlass_w8a8 requires azp to be folded into azp_adj + # in the per-tensor case + azp_adj = getattr(layer, self.i_zp_name) * azp_adj + setattr(layer, self.azp_adj_name, + torch.nn.Parameter(azp_adj, requires_grad=False)) + else: + setattr(layer, self.azp_adj_name, None) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, i_s, i_zp, azp_adj = self._get_weight_params(layer) + + # ops.scaled_int8_quant supports both dynamic and static quant: + # * dynamic, i_s is None and x_s computed from x. + # * static, i_s is scalar and x_s is i_s. + symmetric = azp_adj is None + x_q, x_s, x_zp = ops.scaled_int8_quant(x, + i_s, + i_zp, + symmetric=symmetric) + + if x_zp is not None: + # Currently, static is always per-tensor and dynamic is per-token + static = i_zp is not None + azp = None if static else x_zp + return ops.cutlass_scaled_mm_azp(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + azp_adj=azp_adj, + azp=azp, + bias=bias) + return ops.cutlass_scaled_mm(x_q, + w_q, + scale_a=x_s, + scale_b=w_s, + out_dtype=x.dtype, + bias=bias) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py new file mode 100644 index 0000000000000000000000000000000000000000..5da5df8efaeb0e55e48871fa105ae437968a4d8d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional, Tuple + +import torch + +from vllm.platforms import current_platform + +from .cutlass import CutlassScaledMMLinearKernel +from .ScaledMMLinearKernel import ScaledMMLinearLayerConfig + + +class TritonScaledMMLinearKernel(CutlassScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + return 75 + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + if current_platform.is_cpu(): + return ( + False, + "TritonScaledMMLinearKernel requires Triton which is not " + + "currently supported on CPU.") + if not c.input_symmetric: + return (False, + "TritonScaledMMLinearKernel only supports symmetric " + + "quantization.") + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + super().process_weights_after_loading(layer) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return super().apply_weights(layer, x, bias) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py new file mode 100644 index 0000000000000000000000000000000000000000..0bf090d7fab3ca13d7d5109e7461633a0b736629 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py @@ -0,0 +1,103 @@ +# SPDX-License-Identifier: Apache-2.0 + +import warnings +from typing import Optional, Tuple + +import torch +from functorch.experimental.control_flow import cond # noqa: F401 + +from vllm.model_executor.layers.quantization.utils import replace_parameter +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + convert_to_channelwise) +from vllm.platforms import current_platform + +from .ScaledMMLinearKernel import (ScaledMMLinearKernel, + ScaledMMLinearLayerConfig) + + +class XLAScaledMMLinearKernel(ScaledMMLinearKernel): + + @classmethod + def get_min_capability(cls) -> int: + raise NotImplementedError( + "TPU platform does have a concept of compute capability, " + "this method should not be called.") + + @classmethod + def can_implement( + cls, c: ScaledMMLinearLayerConfig) -> Tuple[bool, Optional[str]]: + + if not current_platform.is_tpu(): + return False, "ScaledMMXLA requires running on TPU." + + if c.is_static_input_scheme: + return False, "ScaledMMXLA requires dynamic activation scales." + + if not c.input_symmetric: + return False, "ScaledMMXLA requires symmetric activation scales." + + if not c.is_channelwise: + return False, "ScaledMMXLA requires channelwise weight scales" + + return True, None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # WEIGHT + # [out, in] (different than cutlass_scaled_mm) + weight = getattr(layer, self.w_q_name) + replace_parameter(layer, self.w_q_name, + torch.nn.Parameter(weight.data, requires_grad=False)) + + # WEIGHT SCALE + # XLA kernels support only per-tensor and per-channel. + # If we have a fused module (QKV, MLP) with per tensor scales (thus N + # scales being passed to the kernel), convert to the per-channel case. + is_fused_module = len(layer.logical_widths) > 1 + weight_scale = getattr(layer, self.w_s_name) + if is_fused_module and not self.config.is_channelwise: + weight_scale = convert_to_channelwise(weight_scale, + layer.logical_widths) + + # [out_channel,] (different than cutlass_scaled_mm) + weight_scale = weight_scale.squeeze(-1) + replace_parameter( + layer, self.w_s_name, + torch.nn.Parameter(weight_scale.data, requires_grad=False)) + + # Only support symmetric dynamic activation quantization. + setattr(layer, self.i_s_name, None) + setattr(layer, self.i_zp_name, None) + setattr(layer, self.azp_adj_name, None) + + # Filter warning for cond usage in apply_weights. It is okay + # to specialize the graph since bias is not dynamic. + warnings.filterwarnings( + "ignore", + message= + "Pred is a Python constant. When used with torch.cond, it specializes on one of the branches." # noqa: E501 + ) + + def no_add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + + def add_bias(self, x: torch.Tensor, bias: Optional[torch.Tensor]): + return x + bias + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + w_q, w_s, _, _, _ = self._get_weight_params(layer) + + import torch_xla.experimental.xla_quantized_matmul # noqa: F401 + out = torch.ops.xla.quantized_matmul(x, + w_q, + w_s, + zero_point=None, + block_size=-1, + int4_weight=False, + quantize_activation=True) + + # Explicitly capture control flow to make dynamo happy. + # https://pytorch.org/docs/main/generated/exportdb/index.html#cond-branch-class-method # noqa: E501 + return cond(bias is None, self.no_add_bias, self.add_bias, [out, bias]) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6bd9a0b0877ab25c3416a9a3de6b7cdd61fcb9a7 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa0aabb7c1a646dd03592907997de167b842fb1d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark_moe.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark_moe.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00e68c8eaf35247cf04a651e850bc8be49243591 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/quark_moe.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68aadd2f7895890492a7f8628bf960bc7c00947f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark.py new file mode 100644 index 0000000000000000000000000000000000000000..ba123565a0ecc70396946da3ec4267081e11ea1e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark.py @@ -0,0 +1,389 @@ +# SPDX-License-Identifier: Apache-2.0 + +import fnmatch +import re +from typing import Any, Dict, List, Optional, cast + +import torch + +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) +from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.quantization.quark.quark_moe import ( # noqa: E501 + QuarkMoEMethod) +from vllm.model_executor.layers.quantization.quark.schemes import ( + QuarkScheme, QuarkW8A8Fp8, QuarkW8A8Int8) +from vllm.model_executor.layers.quantization.quark.utils import ( + deep_compare, should_ignore_layer) +from vllm.platforms import current_platform + +__all__ = ["QuarkLinearMethod"] + + +class QuarkConfig(QuantizationConfig): + + def __init__(self, + quant_config: Dict[str, Any], + kv_cache_group: Optional[List[str]] = None, + kv_cache_config: Optional[Dict[str, Any]] = None, + pack_method: str = "reorder"): + if kv_cache_group is None: + kv_cache_group = [] + self.quant_config = quant_config + self.kv_cache_group = kv_cache_group + self.kv_cache_config = kv_cache_config + self.pack_method = pack_method + + def get_linear_method(self) -> "QuarkLinearMethod": + return QuarkLinearMethod(self) + + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + def get_name(self) -> str: + return "quark" + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["QuantizeMethodBase"]: + from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + exclude_layers = cast(List[str], self.quant_config.get("exclude")) + if should_ignore_layer(prefix, + ignore=exclude_layers, + fused_mapping=self.packed_modules_mapping): + return UnquantizedLinearMethod() + if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme + return QuarkLinearMethod(self) + if isinstance(layer, Attention): + return QuarkKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return QuarkMoEMethod.get_moe_method(self, + module=layer, + layer_name=prefix) + return None + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "QuarkConfig": + export_config = config.get("export") + if export_config is None: + raise ValueError("The export key should be included in " + "the configurations of Quark quantized model") + kv_cache_group = cast(List[str], export_config.get("kv_cache_group")) + pack_method = cast(str, export_config.get("pack_method")) + + # In the export model of quark, the quantization configuration + # of kv_cache is stored in layer_quant_config. First, it is + # judged whether kv_cache_group exists, and then it is judged + # whether layer_quant_config has a quantization configuration + # that matches kv_cache. + if len(kv_cache_group) == 0: + kv_cache_config = None + else: + kv_cache_set = set(kv_cache_group) + layer_quant_config = cast(Dict[str, Any], + config.get("layer_quant_config")) + layer_quant_names = list(layer_quant_config.keys()) + layer_quant_set = set(layer_quant_names) + + if not kv_cache_set.issubset(layer_quant_set): + raise ValueError("The Quark quantized model has the " + "kv_cache_group parameter setting, " + "but no kv_cache quantization settings " + "were found in the quantization " + "configuration.") + + q_configs = [ + cast(Dict[str, Any], layer_quant_config.get(name)) + for name in kv_cache_group + ] + if not all( + deep_compare(q_config, q_configs[0]) + for q_config in q_configs): + raise ValueError( + "The quantization method used for kv_cache should " + "be the same, but the quantization method for the " + "kv_cache layer in the config is different.") + kv_cache_config = q_configs[0].get("output_tensors") + if kv_cache_config is None: + raise ValueError( + "The kv_cache quantization configuration is empty.") + + # Since we have already set kv_cache quantization configurations, + # we will remove the quantization configuration for the + # output_tensors corresponding to the kv_cache layer. + for q_config in q_configs: + q_config["output_tensors"] = None + + return cls(quant_config=config, + kv_cache_group=kv_cache_group, + kv_cache_config=kv_cache_config, + pack_method=pack_method) + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [] + + def _check_scheme_supported(self, + min_capability: int, + error: bool = True) -> bool: + capability_tuple = current_platform.get_device_capability() + + if capability_tuple is not None: + capability = capability_tuple.to_int() + supported = capability >= min_capability + if error and not supported: + raise RuntimeError( + "Quantization scheme is not supported for ", + f"the current GPU. Min capability: {min_capability}. ", + f"Current capability: {capability}.") + return supported + else: + return False + + def _is_fp8_w8a8(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + # Confirm weight scheme is supported + is_fp8_dtype = (weight_quant.get("dtype") == "fp8_e4m3" + and input_quant.get("dtype") == "fp8_e4m3") + is_static_weight = not weight_quant.get("is_dynamic") + is_per_tensor_or_channel_weight = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"]) + + if not (is_fp8_dtype and is_static_weight + and is_per_tensor_or_channel_weight): + return False + + # Dynamic quantization is always supported if weights supported. + if input_quant.get("is_dynamic"): + return True + + # Confirm activation scheme is supported. + is_per_tensor_activation = (input_quant.get("qscheme") == "per_tensor") + return is_per_tensor_activation + + def _is_static_tensor_w8a8(self, weight_quant: Optional[Dict[str, Any]], + input_quant: Optional[Dict[str, Any]]) -> bool: + # Confirm weights and input quantized. + if weight_quant is None or input_quant is None: + return False + + is_int8_dtype = (weight_quant.get("dtype") == "int8" + and input_quant.get("dtype") == "int8") + + is_tensor = (weight_quant.get("qscheme") + in ["per_tensor", "per_channel"] + and input_quant.get("qscheme") == "per_tensor") + + is_static = (not weight_quant.get("is_dynamic") + and not input_quant.get("is_dynamic")) + + is_weight_symmetric = (weight_quant.get("symmetric") is True) + + # Both symmetric and asymmetric input quantization supported. + # Only symmetric weight quantization supported. + return is_int8_dtype and is_tensor and is_weight_symmetric and is_static + + def _find_matched_config(self, layer_name: str, + module: torch.nn.Module) -> Dict[str, Any]: + + proj_name = layer_name.split(".")[-1] + if proj_name in self.packed_modules_mapping: + shard_proj_names = self.packed_modules_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + shard_configs = [ + self._find_matched_config(shard_name, module) + for shard_name in shard_names + ] + if not all( + deep_compare(q_config, shard_configs[0]) + for q_config in shard_configs): + raise ValueError( + f"Found a different quantization configuration for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + return shard_configs[0] + else: + layer_quant_config = cast( + Dict[str, Any], self.quant_config.get("layer_quant_config")) + for name_pattern in layer_quant_config: + if fnmatch.fnmatch(layer_name, name_pattern): + return layer_quant_config[name_pattern] + + layer_type = cast(str, type(module)) + layer_type_quant_config = cast( + Dict[str, Any], + self.quant_config.get("layer_type_quant_config")) + if layer_type in layer_type_quant_config: + return layer_type_quant_config[layer_type] + + global_quant_config = cast( + Dict[str, Any], self.quant_config.get("global_quant_config")) + return global_quant_config + + def _get_scheme_from_config(self, config: Dict[str, Any]) -> "QuarkScheme": + if config.get("output_tensors") or config.get("bias"): + raise NotImplementedError( + "Currently, Quark models with output_tensors " + "and bias quantized are not supported") + weight_config = cast(Dict[str, Any], config.get("weight")) + input_config = cast(Dict[str, Any], config.get("input_tensors")) + + if self._is_fp8_w8a8(weight_config, input_config): + is_fp8_w8a8_supported = self._check_scheme_supported( + QuarkW8A8Fp8.get_min_capability(), error=False) + if is_fp8_w8a8_supported: + weight_qscheme = cast(str, weight_config.get("qscheme")) + input_static = (input_config is not None and + not cast(bool, input_config.get("is_dynamic"))) + return QuarkW8A8Fp8(qscheme=weight_qscheme, + is_static_input_scheme=input_static) + elif self._is_static_tensor_w8a8(weight_config, input_config): + weight_qscheme = cast(str, weight_config.get("qscheme")) + return QuarkW8A8Int8(qscheme=weight_qscheme, + is_static_input_scheme=True, + input_symmetric=input_config.get("symmetric")) + + raise NotImplementedError("No quark compatible scheme was found. " + f"Weight config: {weight_config}, " + f"Input config: {input_config}") + + def get_scheme(self, layer: torch.nn.Module, + layer_name: str) -> "QuarkScheme": + + layer_quant_config = self._find_matched_config(layer_name, layer) + + # Find the quant_scheme + scheme = self._get_scheme_from_config(layer_quant_config) + # Raise error if device does not support the scheme + # (e.g. fp8 needs ada lovelace) + self._check_scheme_supported(scheme.get_min_capability()) + + return scheme + + def get_cache_scale(self, name: str) -> Optional[str]: + """ + Check whether the param name matches the format for k/v cache scales + in quark. If this is the case, return its equivalent param name + expected by vLLM + + :param name: param name + :return: matching param name for KV cache scale in vLLM + """ + if self.kv_cache_group is None or len(self.kv_cache_group) == 0: + return None + + kv_proj_names = [ + re.split(r"[*.]", kv_cache)[-1] for kv_cache in self.kv_cache_group + ] + if name.endswith(".output_scale"): + if len(kv_proj_names) == 1 and kv_proj_names[0] in name: + kv_output_scale_name = "." + kv_proj_names[0] + ".output_scale" + return name.replace(kv_output_scale_name, ".attn.k_scale") + + elif len(kv_proj_names) == 2: + for kv_proj_name in kv_proj_names: + if kv_proj_name in name and kv_proj_name == "k_proj": + return name.replace(".k_proj.output_scale", + ".attn.k_scale") + elif kv_proj_name in name and kv_proj_name == "v_proj": + return name.replace(".v_proj.output_scale", + ".attn.v_scale") + + # If no matches, return None + return None + + +class QuarkLinearMethod(LinearMethodBase): + + def __init__(self, quantization_config: QuarkConfig): + self.quantization_config = quantization_config + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + layer.scheme.process_weights_after_loading(layer) + + def create_weights(self, layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, + output_size: int, params_dtype: torch.dtype, + **extra_weight_attrs): + """ + Use the CompressedTensorsScheme associated with each layer to create + the necessary parameters for the layer. See LinearMethodBase for param + details + """ + weight_loader = extra_weight_attrs.get("weight_loader") + layer.scheme.create_weights( + layer=layer, + input_size=input_size, + input_size_per_partition=input_size_per_partition, + output_partition_sizes=output_partition_sizes, + output_size=output_size, + params_dtype=params_dtype, + weight_loader=weight_loader) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None): + """ + Use the output of create_weights and the CompressedTensorsScheme + associated with the layer to apply the forward pass with the + layer input. See LinearMethodBase for param details + + """ + scheme = layer.scheme + if scheme is None: + raise ValueError("A scheme must be defined for each layer") + return scheme.apply_weights(layer, x, bias=bias) + + +class QuarkKVCacheMethod(BaseKVCacheMethod): + """ + Supports loading kv-cache scaling factors from quark checkpoints. + """ + + def __init__(self, quant_config: QuarkConfig): + self.validate_kv_cache_config(quant_config.kv_cache_config) + super().__init__(quant_config) + + @staticmethod + def validate_kv_cache_config(kv_cache_config: Optional[Dict[str, Any]]): + """ + Validator for the kv cache configuration. Useful for controlling the + kv cache quantization schemes, that are being supported in vLLM + :param kv_cache_config: the quark kv cache scheme + """ + if kv_cache_config is None: + return + + dtype = kv_cache_config.get("dtype") + if dtype != "fp8_e4m3": + raise NotImplementedError( + "Currently supported kv cache quantization is " + f"dtype=fp8_e4m3, however received {dtype}") + + qscheme = kv_cache_config.get("qscheme") + if qscheme != "per_tensor": + raise NotImplementedError( + "Only support per-tensor scaling factor " + "for quark KV cache. " + f"Expected qscheme: per_tensor, found qscheme: {qscheme}") diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark_moe.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark_moe.py new file mode 100644 index 0000000000000000000000000000000000000000..98743b15e4b25a5b5e6d78a8e9c31fb5bbba84c2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/quark_moe.py @@ -0,0 +1,229 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Dict, Optional + +import torch + +import vllm.model_executor.layers.fused_moe # noqa +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +__all__ = ["QuarkMoEMethod", "QuarkW8A8Fp8MoEMethod"] + + +class QuarkMoEMethod(FusedMoEMethodBase): + + @staticmethod + def get_moe_method( + quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821 + module: torch.nn.Module, + layer_name: str) -> "QuarkMoEMethod": + layer_quant_config = quant_config._find_matched_config( + layer_name, module) + + if (layer_quant_config.get("output_tensors") + or layer_quant_config.get("bias")): + raise NotImplementedError("Currently, Quark models with " + "output_tensors and bias " + "quantized are not supported") + weight_config = layer_quant_config.get("weight") + input_config = layer_quant_config.get("input_tensors") + + if quant_config._is_fp8_w8a8(weight_config, input_config): + return QuarkW8A8Fp8MoEMethod(weight_config, input_config) + else: + raise RuntimeError("Unsupported FusedMoe scheme") + + +class QuarkW8A8Fp8MoEMethod(QuarkMoEMethod): + + def __init__(self, weight_config: Dict[str, Any], input_config: Dict[str, + Any]): + self.weight_quant = weight_config + self.input_quant = input_config + + weight_qscheme = self.weight_quant.get("qscheme") + input_qscheme = self.input_quant.get("qscheme") + if not (weight_qscheme == "per_tensor" + and input_qscheme == "per_tensor"): + raise ValueError( + "For FP8 Fused MoE layers, only per-tensor scales" + "for weights and activations are supported. Found " + f"{weight_qscheme}, {input_qscheme}") # noqa E501 + + self.static_input_scales = not self.input_quant.get("is_dynamic") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size_per_partition: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + params_dtype = torch.float8_e4m3fn + + # WEIGHTS + w13_weight = torch.nn.Parameter(torch.empty( + num_experts, + 2 * intermediate_size_per_partition, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty( + num_experts, + hidden_size, + intermediate_size_per_partition, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + # WEIGHT_SCALES + # Allocate 2 scales for w1 and w3 respectively. + # They will be combined to a single scale after weight loading. + w13_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + 2, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_weight_scale) + + w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts, + dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_weight_scale) + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) + + # INPUT_SCALES + if self.static_input_scales: + w13_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w13_input_scale", w13_input_scale) + set_weight_attrs(w13_input_scale, extra_weight_attrs) + + w2_input_scale = torch.nn.Parameter(torch.ones( + num_experts, dtype=torch.float32), + requires_grad=False) + layer.register_parameter("w2_input_scale", w2_input_scale) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: + layer.w13_input_scale = None + layer.w2_input_scale = None + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # Fp8 moe kernels require a single activation scale. + # We take the max of all the scales in case they differ. + if self.static_input_scales: + if (layer.w13_input_scale is None or layer.w2_input_scale is None): + raise ValueError( + "QuantConfig has static quantization, but found " + "activation scales are None.") + if (not all_close_1d(layer.w13_input_scale) + or not all_close_1d(layer.w2_input_scale)): + logger.warning_once( + "Found input_scales that are not equal for " + "fp8 MoE layer. Using the maximum across experts " + "for each layer. ") + layer.w13_input_scale = torch.nn.Parameter( + layer.w13_input_scale.max(), requires_grad=False) + layer.w2_input_scale = torch.nn.Parameter( + layer.w2_input_scale.max(), requires_grad=False) + + # If rocm, normalize the weights and scales to e4m3fnuz + if current_platform.is_rocm(): + # Normalize the weights and scales + w13_weight, w13_weight_scale, w13_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w13_weight, layer.w13_weight_scale, + layer.w13_input_scale) + w2_weight, w2_weight_scale, w2_input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + layer.w2_weight, layer.w2_weight_scale, + layer.w2_input_scale) + # Reset the parameter + layer.w13_weight = torch.nn.Parameter(w13_weight, + requires_grad=False) + layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, + requires_grad=False) + if w13_input_scale is not None: + layer.w13_input_scale = torch.nn.Parameter(w13_input_scale, + requires_grad=False) + layer.w2_weight = torch.nn.Parameter(w2_weight, + requires_grad=False) + layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, + requires_grad=False) + if w2_input_scale is not None: + layer.w2_input_scale = torch.nn.Parameter(w2_input_scale, + requires_grad=False) + + # Fp8 moe kernel needs single weight scale for w13 per expert. + # We take the max then dequant and requant each expert. + assert layer.w13_weight_scale is not None + shard_size = layer.intermediate_size_per_partition + max_w13_scales = layer.w13_weight_scale.max(dim=1).values + for expert_id in range(layer.num_experts): + start = 0 + for shard_id in range(2): + dq_weight = per_tensor_dequantize( + layer.w13_weight[expert_id][start:start + shard_size, :], + layer.w13_weight_scale[expert_id][shard_id]) + layer.w13_weight[expert_id][ + start:start + shard_size, :], _ = ops.scaled_fp8_quant( + dq_weight, max_w13_scales[expert_id]) + start += shard_size + + layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, + requires_grad=False) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool, + use_grouped_topk: bool = False, + topk_group: Optional[int] = None, + num_expert_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + scoring_func: str = "softmax", + e_score_correction_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from vllm.model_executor.layers.fused_moe import fused_experts + + topk_weights, topk_ids = FusedMoE.select_experts( + hidden_states=x, + router_logits=router_logits, + use_grouped_topk=use_grouped_topk, + top_k=top_k, + renormalize=renormalize, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + e_score_correction_bias=e_score_correction_bias) + + return fused_experts(x, + layer.w13_weight, + layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=True, + use_fp8_w8a8=True, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + a1_scale=layer.w13_input_scale, + a2_scale=layer.w2_input_scale) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9069b5a0d515d78eb5d3f68b0fb162f5292db8ec --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__init__.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .quark_scheme import QuarkScheme +from .quark_w8a8_fp8 import QuarkW8A8Fp8 +from .quark_w8a8_int8 import QuarkW8A8Int8 + +__all__ = ["QuarkScheme", "QuarkW8A8Fp8", "QuarkW8A8Int8"] diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c4bbfd22594a298ddf0a89cbc184656bd749bdf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_scheme.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_scheme.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93512499dedf279a03bbd14f8332cab44bb61f00 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_scheme.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_fp8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_fp8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff09ae9d9da0d5ac184571ef6e56c8871c834a78 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_fp8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_int8.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_int8.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e348c427b1de0039437a7bb99ccba3a73bd2fd5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/__pycache__/quark_w8a8_int8.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..40c8ea86d3c385417f7810c774b5ebe85baf3a02 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py @@ -0,0 +1,54 @@ +# SPDX-License-Identifier: Apache-2.0 + +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +__all__ = ["QuarkScheme"] + + +class QuarkScheme(ABC): + """ + Abstract class used to describe the weight creation and forward pass + of different quantization schemes supported by Quark. + """ + + @classmethod + @abstractmethod + def get_min_capability(cls) -> int: + """ + Get minimum device capability. + """ + raise NotImplementedError + + @abstractmethod + def create_weights(self, *args, **kwargs): + """ + Weight creation for the particular scheme. Inputs to this function + + """ + raise NotImplementedError + + @abstractmethod + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]): + """ + Run the forward pass for the particular scheme. This is where + scheme-specific dequant/quant steps/kernels should be applied. + + :param layer: torch.nn.Module with the registered weights and + other parameters relevant to the particular scheme. + :param x: input to the layer + :param bias: bias parameter + + """ + raise NotImplementedError + + @abstractmethod + def process_weights_after_loading(self, layer: torch.nn.Module): + """ + Called after weight loading is complete for any cleanup that + needs to occur. + """ + raise NotImplementedError diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..c885e98a4d66e2c7b24579751da081c2f304b66d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py @@ -0,0 +1,142 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional + +import torch +from torch.nn import Parameter + +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + apply_fp8_linear, cutlass_fp8_supported, normalize_e4m3fn_to_e4m3fnuz, + requantize_with_max_scale) +from vllm.model_executor.parameter import (ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) +from vllm.platforms import current_platform + +__all__ = ["QuarkW8A8Fp8"] + + +class QuarkW8A8Fp8(QuarkScheme): + + def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool]): + self.qscheme = qscheme + self.is_static_input_scheme = is_static_input_scheme + self.cutlass_fp8_supported = cutlass_fp8_supported() + + @classmethod + def get_min_capability(cls) -> int: + # lovelace and up + return 89 + + def process_weights_after_loading(self, layer) -> None: + # If per tensor, when we have a fused module (e.g. QKV) with per + # tensor scales (thus N scales being passed to the kernel), + # requantize so we can always run per tensor + if self.qscheme == "per_tensor": + max_w_scale, weight = requantize_with_max_scale( + weight=layer.weight, + weight_scale=layer.weight_scale, + logical_widths=layer.logical_widths, + ) + + if current_platform.is_rocm(): + weight, max_w_scale, input_scale = normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=max_w_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + + layer.weight = Parameter(weight.t(), requires_grad=False) + layer.weight_scale = Parameter(max_w_scale, requires_grad=False) + + # If channelwise, scales are already lined up, so just transpose. + elif self.qscheme == "per_channel": + weight = layer.weight + + if current_platform.is_rocm(): + weight, weight_scale, input_scale = \ + normalize_e4m3fn_to_e4m3fnuz( + weight=weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale) + if input_scale is not None: + layer.input_scale = Parameter(input_scale, + requires_grad=False) + else: + weight_scale = layer.weight_scale.data + + layer.weight = Parameter(weight.t(), requires_grad=False) + # required by torch.compile to be torch.nn.Parameter + layer.weight_scale = Parameter(weight_scale, requires_grad=False) + + else: + raise ValueError(f"Unknown quantization scheme {self.qscheme}") + + # INPUT SCALE + if self.is_static_input_scheme: + layer.input_scale = Parameter(layer.input_scale.max(), + requires_grad=False) + else: + layer.input_scale = None + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + output_size_per_partition = sum(output_partition_sizes) + layer.logical_widths = output_partition_sizes + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + output_size_per_partition, + input_size_per_partition, + dtype=torch.float8_e4m3fn), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + # TODO: update create_xxx_parameter functions to return + # the newly added parameters + if self.qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + + # min requirement for fp8 kernels + weight_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + input_scale[:] = torch.finfo(torch.float32).min + layer.register_parameter("input_scale", input_scale) + + def apply_weights(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + + return apply_fp8_linear( + input=x, + weight=layer.weight, + weight_scale=layer.weight_scale, + input_scale=layer.input_scale, + bias=bias, + cutlass_fp8_supported=self.cutlass_fp8_supported, + use_per_token_if_dynamic=True) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf34b098938c1282e4ab597db849f03a4ee610f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py @@ -0,0 +1,107 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Callable, List, Optional, Set + +import torch + +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.kernels.scaled_mm import ( + ScaledMMLinearLayerConfig, choose_scaled_mm_linear_kernel) +from vllm.model_executor.layers.quantization.quark.schemes import QuarkScheme +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + ModelWeightParameter, + PerTensorScaleParameter) + +logger = init_logger(__name__) + + +class QuarkW8A8Int8(QuarkScheme): + _kernel_backends_being_used: Set[str] = set() + + def __init__(self, qscheme: str, is_static_input_scheme: Optional[bool], + input_symmetric: Optional[bool]): + self.qscheme = qscheme + self.is_static_input_scheme = is_static_input_scheme + self.input_symmetric = input_symmetric + + @classmethod + def get_min_capability(cls) -> int: + # turing and up + return 75 + + def create_weights(self, layer: torch.nn.Module, + output_partition_sizes: List[int], + input_size_per_partition: int, + params_dtype: torch.dtype, weight_loader: Callable, + **kwargs): + self.logical_widths = output_partition_sizes + + scaled_mm_linear_kernel_config = ScaledMMLinearLayerConfig( + is_channelwise=(self.qscheme == "per_channel"), + is_static_input_scheme=(self.is_static_input_scheme is True), + input_symmetric=(self.input_symmetric is True)) + + kernel_type = choose_scaled_mm_linear_kernel( + scaled_mm_linear_kernel_config) + + if kernel_type.__name__ not in self._kernel_backends_being_used: + logger.info("Using %s for QuarkW8A8Int8", kernel_type.__name__) + self._kernel_backends_being_used.add(kernel_type.__name__) + + # WEIGHT + weight = ModelWeightParameter(data=torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=torch.int8), + input_dim=1, + output_dim=0, + weight_loader=weight_loader) + + layer.register_parameter("weight", weight) + + # WEIGHT SCALE + if self.qscheme == "per_channel": + weight_scale = ChannelQuantScaleParameter( + data=torch.empty((sum(output_partition_sizes), 1), + dtype=torch.float32), + output_dim=0, + weight_loader=weight_loader) + else: + assert self.qscheme == "per_tensor" + weight_scale = PerTensorScaleParameter(data=torch.empty( + len(output_partition_sizes), dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("weight_scale", weight_scale) + + # INPUT SCALE + if self.is_static_input_scheme: + input_scale = BasevLLMParameter(data=torch.empty( + 1, dtype=torch.float32), + weight_loader=weight_loader) + layer.register_parameter("input_scale", input_scale) + + if not self.input_symmetric: + # Note: quark stores the zp using the same dtype + # as the weights + # AZP loaded as int8 but used as int32 + input_zero_point = BasevLLMParameter( + data=torch.empty(1, dtype=torch.int8), + weight_loader=weight_loader) + layer.register_parameter("input_zero_point", input_zero_point) + + self.kernel = kernel_type(c=scaled_mm_linear_kernel_config, + w_q_param_name="weight", + w_s_param_name="weight_scale", + i_s_param_name="input_scale", + i_zp_param_name="input_zero_point", + azp_adj_param_name="azp_adj") + + # Checkpoints are serialized in quark format, which is + # different from the format the kernel may want. Handle repacking here. + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + self.kernel.process_weights_after_loading(layer) + + def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return self.kernel.apply_weights(layer, x, bias) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..17e0df021085a9eff8334419e438dfeafe151b5e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/quark/utils.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +import re +from types import MappingProxyType +from typing import Any, Iterable, List, Mapping, Optional + + +def deep_compare(dict1: Any, dict2: Any) -> bool: + if type(dict1) is not type(dict2): + return False + if isinstance(dict1, dict): + if dict1.keys() != dict2.keys(): + return False + return all(deep_compare(dict1[k], dict2[k]) for k in dict1) + elif isinstance(dict1, list): + return set(dict1) == set(dict2) + else: + return dict1 == dict2 + + +def should_ignore_layer( + layer_name: Optional[str], + ignore: Iterable[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: + if layer_name is None: + return False + + # layer_name = model.layers.0.self_attn.qkv_proj + # proj_name = qkv_proj + proj_name = layer_name.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_proj_names = fused_mapping[proj_name] + + # Convert fused_name --> [shard_names] + shard_names = [ + layer_name.replace(proj_name, shard_proj_name) + for shard_proj_name in shard_proj_names + ] + + # Layer should be ignored if shards are ignored. + should_ignore_layer = None + for shard_name in shard_names: + should_ignore_shard = check_equal_or_regex_match( + layer_name=shard_name, targets=ignore) + + # If shard_idx=0, set layer ignore to match shard. + if should_ignore_layer is None: + should_ignore_layer = should_ignore_shard + + # If shard_idx=1+ confirm scheme matches prior shards. + elif should_ignore_shard != should_ignore_layer: + raise ValueError(f"Found a different quantization schemes for " + f"{shard_proj_names} in {layer_name}. vLLM " + "requires all to use the same scheme.") + + # Unfused layers like down_proj and o_proj will match + # the safetensors checkpoint already. + else: + should_ignore_layer = check_equal_or_regex_match(layer_name=layer_name, + targets=ignore) + + assert should_ignore_layer is not None + return should_ignore_layer + + +def check_equal_or_regex_match(layer_name: str, + targets: Iterable[str]) -> bool: + """ + Checks whether a layer_name is exactly equal or a regex match for + if target starts with 're:' to any target in list. + """ + for target in targets: + if _is_equal_or_regex_match(layer_name, target): + return True + return False + + +def _is_equal_or_regex_match(value: str, + target: str, + check_contains: bool = False) -> bool: + """ + Checks whether a value is exactly equal or a regex match for target + if target starts with 're:'. If check_contains is set to True, + additionally checks if the target string is contained within the value. + """ + + if target.startswith("re:"): + pattern = target[3:] + if re.match(pattern, value): + return True + elif check_contains: + if target.lower() in value.lower(): + return True + elif target == value: + return True + return False diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__init__.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f7ee4728851408d2fe796b3910e04d270ba0faeb --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__init__.py @@ -0,0 +1,5 @@ +# SPDX-License-Identifier: Apache-2.0 + +from .layer_utils import replace_parameter, update_tensor_inplace + +__all__ = ['update_tensor_inplace', 'replace_parameter'] diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b72644cf0f4553b57a14175e6c8fe2f58623db Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/layer_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8155f182df5c3ddcc743d8f3c97715285d40510 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/__pycache__/marlin_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9895537c219ab6875546ae91f63f695d0ec392f1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/fp8_utils.py @@ -0,0 +1,533 @@ +# SPDX-License-Identifier: Apache-2.0 + +# Adapted from https://github.com/sgl-project/sglang/pull/2575 +import functools +import json +import os +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import triton +import triton.language as tl + +from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.utils.quant_utils import ( + _normalize_quant_group_shape, scaled_dequantize) +from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( + CUTLASS_BLOCK_FP8_SUPPORTED, CUTLASS_FP8_SUPPORTED, apply_fp8_linear) +from vllm.platforms import current_platform + +logger = init_logger(__name__) + +current_platform_fp8_dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else + torch.float8_e4m3fn) + + +def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool: + if isinstance(x, torch.Tensor): + x = x.dtype + return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz + + +def apply_w8a8_block_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + block_size: List[int], + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, +) -> torch.Tensor: + assert input_scale is None + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[0]] + + shape_supported_by_cutlass = (weight.shape[0] % 128 == 0 + and weight.shape[1] % 128 == 0) + if current_platform.is_rocm(): + scale_a_shape = ((input_2d.shape[-1] // block_size[1], ) + + input_2d.shape[:-1])[::-1] + scale_b_shape = (weight_scale.view(-1, 1) + if weight_scale.dim() <= 1 else weight_scale.T).shape + ar, ac = scale_a_shape + br, bc = scale_b_shape + if (ac > 1 or bc > 1 or ar not in (1, input_2d.shape[0]) + or br not in (1, weight.shape[0])): + shape_supported_by_cutlass = False + if cutlass_block_fp8_supported and shape_supported_by_cutlass: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=True) + output = ops.cutlass_scaled_mm(q_input, + weight.T, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale.T) + else: + q_input, x_scale = per_token_group_quant_fp8(input_2d, + block_size[1], + column_major_scales=False) + output = w8a8_block_fp8_matmul(q_input, + weight, + x_scale, + weight_scale, + block_size, + output_dtype=input.dtype) + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +# Unify the interface between `apply_w8a8_block_fp8_linear` and +# `apply_fp8_linear` +# NOTE(lucas): this is quite messy, we should think through this more formally +def apply_fp8_linear_generic( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_group_shape: Tuple[int, int], + weight_group_shape: Tuple[int, int], + input_scale: Optional[torch.Tensor] = None, # static scale if one + cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, + cutlass_block_fp8_supported: bool = CUTLASS_BLOCK_FP8_SUPPORTED, +) -> torch.Tensor: + # View input as 2D matrix for fp8 methods + input = input.view(-1, input.shape[-1]) + + weight_group_shape = _normalize_quant_group_shape(\ + weight, weight_group_shape) + input_group_shape = _normalize_quant_group_shape(input, input_group_shape) + + def is_dim_blocked(dim, shape, group_shape): + return group_shape < shape[dim] and group_shape > 1 + + if is_dim_blocked(0, weight.shape, weight_group_shape[0])\ + and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\ + input_group_shape == (1, weight_group_shape[1]): + return apply_w8a8_block_fp8_linear( + input, + weight, + list(weight_group_shape), + weight_scale, + cutlass_block_fp8_supported=cutlass_block_fp8_supported) + else: + # Despite having linear in the it doesn't conform to + # `torch.nn.functional.linear` which is defined as `input @ weight.T` + # so we explicitly transpose the weight matrix here + return apply_fp8_linear(input, weight.T, weight_scale.T, + cutlass_fp8_supported=cutlass_fp8_supported, + use_per_token_if_dynamic=\ + (input_group_shape == (1, input.shape[1]))) + + +def input_to_float8( + x: torch.Tensor, + dtype: Optional[torch.dtype] = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function quantizes input values to float8 values " + "with tensor-wise quantization.""" + if dtype is None: + dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else torch.float8_e4m3fn) + finfo = torch.finfo(dtype) + min_val, max_val = x.aminmax() + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max) + return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal() + + +def block_quant_to_tensor_quant( + x_q_block: torch.Tensor, + x_s: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """This function converts block-wise quantization to tensor-wise + quantization. The inputs are block-wise quantization tensor `x_q_block`, + block-wise quantization scale and the block size. + The outputs are tensor-wise quantization tensor and tensor-wise + quantization scale. Note only float8 is supported for now. + """ + x_dq_block = scaled_dequantize(x_q_block, x_s) + x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype) + return x_q_tensor, scale + + +@triton.jit +def _per_token_group_quant_fp8( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + y_s_ptr += g_id + + cols = tl.arange(0, BLOCK) # N <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +@triton.jit +def _per_token_group_quant_fp8_colmajor( + # Pointers to inputs and output + y_ptr, + y_q_ptr, + y_s_ptr, + group_size, + # Num columns of y + y_num_columns, + # Stride from one column to the next of y_s + y_s_col_stride, + # Avoid to divide zero + eps, + # Information for float8 + fp8_min, + fp8_max, + # Meta-parameters + BLOCK: tl.constexpr, +): + """A Triton-accelerated function to perform per-token-group + quantization on a tensor. + This function converts the tensor values into float8 values. + """ + # Map the program id to the row of X and Y it should compute. + g_id = tl.program_id(0) + y_ptr += g_id * group_size + y_q_ptr += g_id * group_size + + # Convert g_id the flattened block coordinate to 2D so we can index + # into the output y_scales matrix + blocks_per_row = y_num_columns // group_size + scale_col = g_id % blocks_per_row + scale_row = g_id // blocks_per_row + y_s_ptr += scale_col * y_s_col_stride + scale_row + + cols = tl.arange(0, BLOCK) # group_size <= BLOCK + mask = cols < group_size + + y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32) + # Quant + _absmax = tl.maximum(tl.max(tl.abs(y)), eps) + y_s = _absmax / fp8_max + y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty) + + tl.store(y_q_ptr + cols, y_q, mask=mask) + tl.store(y_s_ptr, y_s) + + +def per_token_group_quant_fp8( + x: torch.Tensor, + group_size: int, + eps: float = 1e-10, + dtype: Optional[torch.dtype] = None, + column_major_scales: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Function to perform per-token-group quantization on an input tensor `x`. + It converts the tensor values into signed float8 values and returns the + quantized tensor along with the scaling factor used for quantization. + Args: + x: The input tensor with ndim >= 2. + group_size: The group size used for quantization. + eps: The minimum to avoid dividing zero. + dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` + is supported for now. + Returns: + Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the + scaling factor for quantization. + """ + if dtype is None: + dtype = (torch.float8_e4m3fnuz + if current_platform.is_rocm() else torch.float8_e4m3fn) + assert (x.shape[-1] % group_size == 0), ( + f"the last dimension of `x` {x.shape[-1]} must be divisible " + f"by `group_size` {group_size}") + assert x.is_contiguous(), "`x` must be contiguous" + + finfo = torch.finfo(dtype) + fp8_min = finfo.min + fp8_max = finfo.max + + x_q = torch.empty_like(x, device=x.device, dtype=dtype) + M = x.numel() // group_size + N = group_size + if column_major_scales: + shape = (x.shape[-1] // group_size, ) + x.shape[:-1] + x_s = torch.empty(shape, device=x.device, + dtype=torch.float32).permute(-1, -2) + else: + shape = x.shape[:-1] + (x.shape[-1] // group_size, ) + x_s = torch.empty(shape, device=x.device, dtype=torch.float32) + + BLOCK = triton.next_power_of_2(N) + # heuristics for number of warps + num_warps = min(max(BLOCK // 256, 1), 8) + num_stages = 1 + if column_major_scales: + _per_token_group_quant_fp8_colmajor[(M, )]( + x, + x_q, + x_s, + group_size, + x.shape[1], + x_s.stride(1), + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + else: + _per_token_group_quant_fp8[(M, )]( + x, + x_q, + x_s, + group_size, + eps, + fp8_min=fp8_min, + fp8_max=fp8_max, + BLOCK=BLOCK, + num_warps=num_warps, + num_stages=num_stages, + ) + + return x_q, x_s + + +@triton.jit +def _w8a8_block_fp8_matmul( + # Pointers to inputs and output + A, + B, + C, + As, + Bs, + # Shape for matmul + M, + N, + K, + # Block size for block-wise quantization + group_n, + group_k, + # Stride for inputs and output + stride_am, + stride_ak, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_As_m, + stride_As_k, + stride_Bs_k, + stride_Bs_n, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, +): + """Triton-accelerated function used to perform linear operations (dot + product) on input tensors `A` and `B` with block-wise quantization, and + store the result in output tensor `C`. + """ + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) + b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn) + + As_ptrs = As + offs_am * stride_As_m + offs_bsn = offs_bn // group_n + Bs_ptrs = Bs + offs_bsn * stride_Bs_n + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load(a_ptrs, + mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + + k_start = k * BLOCK_SIZE_K + offs_ks = k_start // group_k + a_s = tl.load(As_ptrs + offs_ks * stride_As_k) + b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k) + + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if C.dtype.element_ty == tl.bfloat16: + c = accumulator.to(tl.bfloat16) + elif C.dtype.element_ty == tl.float16: + c = accumulator.to(tl.float16) + else: + c = accumulator.to(tl.float32) + + offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :] + c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + +@functools.lru_cache +def get_w8a8_block_fp8_configs(N: int, K: int, block_n: int, + block_k: int) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the w8a8 block fp8 kernel. + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the w8a8 block fp8 kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + device_name = current_platform.get_device_name().replace(" ", "_") + json_file_name = f"N={N},K={K},device_name={device_name},dtype=fp8_w8a8,block_shape=[{block_n},{block_k}].json" # noqa: E501 + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info( + "Using configuration from %s for W8A8 Block FP8 kernel.", + config_file_path, + ) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.warning( + "Using default W8A8 Block FP8 kernel config. Performance might " + "be sub-optimal! Config file not found at %s", + config_file_path, + ) + return None + + +def w8a8_block_fp8_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + block_size: List[int], + output_dtype: torch.dtype = torch.float16, +) -> torch.Tensor: + """This function performs matrix multiplication with block-wise + quantization. + It takes two input tensors `A` and `B` with scales `As` and `Bs`. + The output is returned in the specified `output_dtype`. + Args: + A: The input tensor, e.g., activation. + B: The input tensor, e.g., weight. + As: The per-token-group quantization scale for `A`. + Bs: The per-block quantization scale for `B`. + block_size: The block size for per-block quantization. It should + be 2-dim, e.g., [128, 128]. + output_dytpe: The dtype of the returned tensor. + Returns: + torch.Tensor: The result of matmul. + """ + assert len(block_size) == 2 + block_n, block_k = block_size[0], block_size[1] + + assert A.shape[-1] == B.shape[-1] + assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous() + assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1] + M = A.numel() // A.shape[-1] + + assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2 + N, K = B.shape + assert triton.cdiv(N, block_n) == Bs.shape[0] + assert triton.cdiv(K, block_k) == Bs.shape[1] + + C_shape = A.shape[:-1] + (N, ) + C = A.new_empty(C_shape, dtype=output_dtype) + + configs = get_w8a8_block_fp8_configs(N, K, block_size[0], block_size[1]) + if configs: + # Get the optimal config if there is one + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Default config + # Block-wise quant: BLOCK_SIZE_N must be divisible by block_size[0] + # BLOCK_SIZE_K must be divisible by block_size[1] + config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": block_size[0], + "BLOCK_SIZE_K": block_size[1], + "GROUP_SIZE_M": 32, + "num_warps": 4, + "num_stages": 2, + } + + def grid(META): + return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * + triton.cdiv(N, META["BLOCK_SIZE_N"]), ) + + _w8a8_block_fp8_matmul[grid]( + A, + B, + C, + As, + Bs, + M, + N, + K, + block_n, + block_k, + A.stride(-2), + A.stride(-1), + B.stride(1), + B.stride(0), + C.stride(-2), + C.stride(-1), + As.stride(-2), + As.stride(-1), + Bs.stride(1), + Bs.stride(0), + **config, + ) + + return C diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/layer_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/layer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5acae7ca3b84f2047608a7688a3271a37d81331f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/layer_utils.py @@ -0,0 +1,39 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Union + +import torch + + +def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor): + assert dst.dtype == src.dtype, "Tensors must have the same dtype" + + # update tensor shape and stride + dst.as_strided_(src.shape, src.stride()) + + # If not the same underlying storage move tensor data + if dst.data_ptr() != src.data_ptr(): + dst.copy_(src) + del src + + +# Newly generated tensors need to replace existing tensors that are +# already registered as parameters by vLLM (and won't be freed) +def replace_parameter(mod: torch.nn.Module, name: str, + new: Union[torch.Tensor, torch.nn.Parameter]) -> None: + + old = getattr(mod, name) + if type(old) is type(new) and old.dtype == new.dtype and \ + old.untyped_storage().nbytes() == new.untyped_storage().nbytes(): + # If we can just update in-place to avoid re-registering + # can be faster if the underlying storage is the same + update_tensor_inplace(old, new) + else: + # Fallback re-register parameter, convert to Parameter if necessary + # this not only ensures we don't register a tensor as a parameter, but + # also ensures that all parameter subclasses get re-registered as + # parameters for `torch.compile` compatibility + if not isinstance(new, torch.nn.Parameter): + new = torch.nn.Parameter(new, requires_grad=False) + mod.register_parameter(name, + torch.nn.Parameter(new, requires_grad=False)) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/machete_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/machete_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb7d49ed6f1ca046d58bf6e5c333130963a713a0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/machete_utils.py @@ -0,0 +1,32 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +import torch + +from vllm.scalar_type import ScalarType, scalar_types + +MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128] +MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128] + + +def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]: + if zero_points: + return [scalar_types.uint4, scalar_types.uint8] + else: + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]: + return [torch.float16, torch.bfloat16] + + +def check_machete_supports_shape(in_features: int, out_featrues: int) \ + -> Tuple[bool, Optional[str]]: + if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0: + return False, "Input features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}" + if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0: + return False, "Output features size must be divisible by "\ + f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}" + return True, None diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3beba30832441deca56e61056e6c285ca37ec76b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils.py @@ -0,0 +1,352 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple + +import numpy +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform +from vllm.scalar_type import ScalarType, scalar_types + +from .quant_utils import pack_cols, unpack_cols + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + +# In case there is a performance issue with Marlin, the variable below can be +# changed to False, which allows Marlin to perform global reductions in fp16 +# precision (instead of fp32), and therefore, save on some memory movements. +USE_FP32_REDUCE_DEFAULT = True + + +# For binary size and compile time, we don't support the same types for with and +# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ. +# TODO: we may want to move this into the C++ so its closer to the actual impl +def query_marlin_supported_quant_types(has_zp: bool, + device_capability: Optional[int] = None + ): + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + if device_capability < 80: + return [] + + if has_zp: + # AWQ style, unsigned + runtime zero-point + return [scalar_types.uint4, scalar_types.uint8] + else: + # GPTQ style, unsigned + symmetric bias + # TODO: once fp8_marlin is merged into "gptq_marlin" we should be able + # to add `scalar_types.float8_e4m3fn` here + return [scalar_types.uint4b8, scalar_types.uint8b128] + + +def _check_marlin_supported( + quant_type: ScalarType, + group_size: Optional[int], + has_zp: bool, + device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]: + + if device_capability is None: + capability_tuple = current_platform.get_device_capability() + device_capability = (-1 if capability_tuple is None else + capability_tuple.to_int()) + + supported_types = query_marlin_supported_quant_types( + has_zp, device_capability) + + if quant_type not in supported_types: + return (False, f"Marlin does not support weight_bits = {quant_type}. " + f"Only types = {supported_types} " + f"are supported (for group_size = {group_size}, " + f"device_capability = {device_capability}, zp = {has_zp}).") + if (group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES): + return (False, f"Marlin does not support group_size = {group_size}. " + f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + + return True, None + + +def check_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False, + device_capability: Optional[int] = None) -> bool: + cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, + device_capability) + return cond + + +def verify_marlin_supported(quant_type: ScalarType, + group_size: int, + has_zp: bool = False) -> None: + cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp) + if not cond: + assert err_msg is not None + raise ValueError(err_msg) + + +def verify_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) -> None: + + # Validate output_size_per_partition + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError(f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + # Validate input_size_per_partition + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError(f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. " + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}." + "Consider reducing tensor_parallel_size or running " + "with --quantization gptq.") + + +def check_marlin_supports_shape(output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, group_size: int) \ + -> Tuple[bool, Optional[str]]: + try: + verify_marlin_supports_shape(output_size_per_partition, + input_size_per_partition, input_size, + group_size) + except ValueError as e: + return False, e.__str__() + return True, None + + +def marlin_make_workspace(output_size_per_partition: int, + device: torch.device) -> torch.Tensor: + max_workspace_size = (output_size_per_partition // + GPTQ_MARLIN_MIN_THREAD_N) * GPTQ_MARLIN_MAX_PARALLEL + + return torch.zeros(max_workspace_size, + dtype=torch.int, + device=device, + requires_grad=False) + + +def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool: + return (not act_order) or (act_order and not is_row_parallel) + + +def marlin_repeat_scales_on_all_ranks(act_order: bool, group_size: int, + is_row_parallel: bool) -> bool: + # Need to repeat scales on every rank if act_ordering or + # channelwise and RowParallelLinear + is_channelwise = group_size == -1 + return act_order or (is_channelwise and is_row_parallel) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_make_empty_zp(device: torch.device) -> torch.Tensor: + return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device), + requires_grad=False) + + +def marlin_sort_g_idx( + g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + g_idx_sort_indices = torch.argsort(g_idx).to(torch.int) + return g_idx[g_idx_sort_indices], g_idx_sort_indices + + +def get_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_moe_permute_scales( + s: torch.Tensor, + size_k: int, + size_n: int, + group_size: int, +): + num_experts = s.shape[0] + output = torch.empty( + (num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype, + ) + + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size) + return output + + +def marlin_zero_points(zp: torch.Tensor, size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + # Permute zero-points in a similar way to scales, but do not use the + # "single" permutation, since zero-points are applied on every MMA + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + + return zp + + +def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int) -> torch.Tensor: + # AWQ zero-points are quantized and packed on the column dim. + # In addition, the values are permuted based on dequantizer. + # Here we undo both of these, and then apply marlin permutation + # and pack it back. + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo interleaving (use argsort(..) to get inverse perm) + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits) + return marlin_zp + + +def moe_awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int, + size_n: int, num_bits: int): + num_experts = q_zp_packed.shape[0] + output = torch.empty( + (num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]), + device=q_zp_packed.device, + dtype=q_zp_packed.dtype, + ) + for e in range(num_experts): + output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, + num_bits) + return output + + +def apply_gptq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + wtype: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + is_k_full: bool, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + wtype, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=is_k_full, + has_zp=False, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def apply_awq_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + weight_zp: torch.Tensor, + g_idx: torch.Tensor, + g_idx_sort_indices: torch.Tensor, + workspace: torch.Tensor, + quant_type: ScalarType, + output_size_per_partition: int, + input_size_per_partition: int, + bias: Optional[torch.Tensor] = None, + use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT) -> torch.Tensor: + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (output_size_per_partition, ) + + output = ops.gptq_marlin_gemm(reshaped_x, + weight, + weight_scale, + weight_zp, + g_idx, + g_idx_sort_indices, + workspace, + quant_type, + size_m=reshaped_x.shape[0], + size_n=output_size_per_partition, + size_k=input_size_per_partition, + is_k_full=True, + has_zp=True, + use_fp32_reduce=use_fp32_reduce, + is_zp_float=False) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py new file mode 100644 index 0000000000000000000000000000000000000000..6120a8e66aef45227b37f5c6c5031800d832baa1 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py @@ -0,0 +1,110 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import torch + +import vllm._custom_ops as ops +from vllm.logger import init_logger +from vllm.platforms import current_platform + +from .marlin_utils import marlin_make_workspace, marlin_permute_scales + +logger = init_logger(__name__) + + +def is_fp8_marlin_supported(): + return current_platform.has_device_capability(80) + + +def apply_fp8_marlin_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + workspace: torch.Tensor, + size_n: int, + size_k: int, + bias: Optional[torch.Tensor], +) -> torch.Tensor: + # For GPUs that lack FP8 hardware support, we can leverage the + # Marlin kernel for fast weight-only FP8 quantization + + reshaped_x = input.reshape(-1, input.shape[-1]) + out_shape = input.shape[:-1] + (size_n, ) + + output = ops.fp8_marlin_gemm( + a=reshaped_x, + b_q_weight=weight, + b_scales=weight_scale, + workspace=workspace, + num_bits=8, + size_m=reshaped_x.shape[0], + size_n=size_n, + size_k=size_k, + ) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) + + +def prepare_fp8_layer_for_marlin(layer: torch.nn.Module, + strategy: str = "tensor") -> None: + logger.warning_once( + "Your GPU does not have native support for FP8 computation but " + "FP8 quantization is being used. Weight-only FP8 compression will " + "be used leveraging the Marlin kernel. This may degrade " + "performance for compute-heavy workloads.") + + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + + device = layer.weight.device + + # WORKSPACE + layer.workspace = marlin_make_workspace(part_size_n, device) + + # WEIGHT + # Repack weights to marlin format + marlin_qweight = ops.gptq_marlin_repack(b_q_weight=pack_fp8_to_int32( + layer.weight), + perm=torch.empty(0, + dtype=torch.int, + device=device), + size_k=part_size_k, + size_n=part_size_n, + num_bits=8) + layer.weight = torch.nn.Parameter(marlin_qweight, requires_grad=False) + + # WEIGHT SCALES + scales = layer.weight_scale.to(layer.orig_dtype) + # Permute scales + marlin_scales = marlin_permute_scales(s=scales, + size_k=part_size_k, + size_n=part_size_n, + group_size=-1) + layer.weight_scale = torch.nn.Parameter(marlin_scales, requires_grad=False) + + +def pack_fp8_to_int32(fp8_tensor: torch.Tensor) -> torch.Tensor: + """ + Repack FP8 weights to gptq format (packed int32 elements) + """ + assert fp8_tensor.dtype == torch.float8_e4m3fn + assert fp8_tensor.shape[0] % 4 == 0 + + # Reshape to prepare for packing + reshaped = fp8_tensor.reshape(-1, 4, *fp8_tensor.shape[1:]) + + # Convert fp8 to uint8 (byte) representation + byte_tensor = reshaped.view(torch.uint8) + + # Pack 4 uint8 values into one int32 + packed = (byte_tensor[:, 0].to(torch.int32) | + (byte_tensor[:, 1].to(torch.int32) << 8) | + (byte_tensor[:, 2].to(torch.int32) << 16) | + (byte_tensor[:, 3].to(torch.int32) << 24)) + + return packed.view(fp8_tensor.shape[0] // 4, + *fp8_tensor.shape[1:]).contiguous() diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py new file mode 100644 index 0000000000000000000000000000000000000000..fb557a31393caf90669aac94572b75858e810617 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test.py @@ -0,0 +1,164 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +from typing import List, Optional + +import numpy as np +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils import (GPTQ_MARLIN_TILE, marlin_permute_scales, + marlin_zero_points) +from .quant_utils import (get_pack_factor, gptq_quantize_weights, + quantize_weights, sort_weights) + + +class MarlinWorkspace: + + def __init__(self, out_features, min_thread_n, max_parallel): + assert (out_features % min_thread_n == 0), ( + "out_features = {} is undivisible by min_thread_n = {}".format( + out_features, min_thread_n)) + + max_workspace_size = ((out_features // min_thread_n) * max_parallel) + + self.scratch = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda") + + +def marlin_permute_weights(q_w, size_k, size_n, perm, tile=GPTQ_MARLIN_TILE): + assert q_w.shape == (size_k, size_n) + assert size_k % tile == 0, f"size_k = {size_k}, tile = {tile}" + assert size_n % tile == 0, f"size_k = {size_n}, tile = {tile}" + + # Permute weights to 16x64 marlin tiles + q_w = q_w.reshape((size_k // tile, tile, size_n // tile, tile)) + q_w = q_w.permute((0, 2, 1, 3)) + q_w = q_w.reshape((size_k // tile, size_n * tile)) + + q_w = q_w.reshape((-1, perm.numel()))[:, perm].reshape(q_w.shape) + + return q_w + + +def marlin_weights(q_w, size_k, size_n, num_bits, perm): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(np.uint32) + + q_packed = np.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=np.uint32) + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(np.int32)).to(orig_device) + + return q_packed + + +def get_weight_perm(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = np.array(perm_list) + + if num_bits == 4: + interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = np.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_quantize(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, size_n = w.shape + num_bits = quant_type.size_bits + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Quantize (and apply act_order if provided) + w_ref, q_w, s, g_idx, rand_perm = gptq_quantize_weights( + w, quant_type, group_size, act_order, test_perm) + + # For act_order, sort the "weights" and "g_idx" so that group ids are + # increasing + sort_indices = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + q_w, g_idx, sort_indices = sort_weights(q_w, g_idx) + + # Reformat to marlin + weight_perm = get_weight_perm(num_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, num_bits, weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list + + +def awq_marlin_quantize(w: torch.Tensor, quant_type: ScalarType, + group_size: int): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Detect num groups + assert size_k % group_size == 0 + num_groups = size_k // group_size + + # Quantize with zp + w_ref, q_w, s, zp = quantize_weights(w, + quant_type, + group_size, + zero_points=True) + + # Reformat to marlin + weight_perm = get_weight_perm(quant_type.size_bits) + marlin_q_w = marlin_weights(q_w, size_k, size_n, quant_type.size_bits, + weight_perm) + marlin_s = marlin_permute_scales(s, size_k, size_n, group_size) + marlin_zp = marlin_zero_points(zp, num_groups, size_n, + quant_type.size_bits) + + # Create result + res_list = [w_ref, marlin_q_w, marlin_s, marlin_zp] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py new file mode 100644 index 0000000000000000000000000000000000000000..3654268e27af31ecf8b28f0f98d1126db33a7d42 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Utility functions used for tests and benchmarks""" + +import random +from typing import List + +import numpy +import torch + +from vllm.scalar_type import ScalarType + +from .marlin_utils_test import marlin_weights +from .quant_utils import gptq_quantize_weights + + +# This is PyTorch implementation of main part of reorder_meta() +# function, from tools/util/include/cutlass/util/host_reorder.h file +# of CUTLASS source tree. Furthermore, CUTLASS template for sparse +# GEMM decides upon layout of this matrix, and at the moment for the +# sparse GEMM executed on tensor cores, this is layout described by +# ColumnMajorInterleaved<2> data structure, in +# include/cutlass/layout/matrix.h of CUTLASS source tree. The +# reordering of meta matrix into meta_reordered matrix calculated +# according to these segments of CUTLASS code is re-implemented here. +# Note that this calculation produces offsets for scattering metadata +# matrix elements into reordered metadata matrix elements (or, +# equivalently, for gathering reordered metadata matrix element back +# into metadata matrix elements). +def _calculate_meta_reordering_scatter_offsets(m, meta_ncols, meta_dtype, + device): + dst_rows = torch.arange(0, m, device=device)[:, None].repeat(1, meta_ncols) + dst_cols = torch.arange(0, meta_ncols, device=device).repeat(m, 1) + + # Reorder the rows, then swizzle the 2x2 blocks. + group_x = 64 + group_y = 32 if meta_dtype.itemsize == 2 else 16 + + dst_rows = (dst_rows // group_x * group_x + (dst_rows % 2) * 2 + + (dst_rows % 8) // 4 + ((dst_rows % group_y) % 4) // 2 * 32 + + ((dst_rows % group_x) // 8) * 4) + + topright = ((dst_rows % 2 == 0) & (dst_cols % 2 == 1)).to(torch.int8) + bottomleft = ((dst_rows % 2 == 1) & (dst_cols % 2 == 0)).to(torch.int8) + dst_rows += topright - bottomleft + dst_cols -= topright - bottomleft + + # Assumed that meta tensor is to be stored in CUTLASS + # InterleavedColumnMajor layout, and reverse engineered + # corresponding code to store values into this tensor. + interleave = 2 + cols_maj = dst_cols // interleave + cols_min = dst_cols % interleave + return (cols_maj * m * interleave + dst_rows * interleave + + cols_min).view(-1) + + +# This function converts dense matrix into sparse semi-structured +# representation, producing "compressed" matrix, in the layout used by +# CUTLASS backend, and corresponding metadata matrix. +def sparse_semi_structured_from_dense_cutlass(dense): + if dense.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional dense tensor, got {dense.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = dense.shape + device = dense.device + + meta_dtype = torch.int8 + if dense.dtype == torch.int8: + meta_dtype = torch.int32 + elif dense.dtype in [torch.half, torch.bfloat16, torch.float, torch.int32]: + meta_dtype = torch.int16 + else: + raise RuntimeError(f"Invalid datatype {dense.dtype} of dense matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + if quadbits_per_meta_elem not in (4, 8): + raise RuntimeError( + "Invalid number of elements per meta element calculated") + + if meta_dtype == torch.int32: + if m % 16 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 16") + else: + if m % 32 != 0: + raise RuntimeError( + f"Number of rows of dense matrix {m} must be divisible by 32") + if k % (4 * quadbits_per_meta_elem) != 0: + raise RuntimeError( + f"Number of columns of dense matrix {k} must be divisible by {4 * quadbits_per_meta_elem}" # noqa: E501 + ) + + if dense.dtype != torch.float: + ksparse = 4 + dense_4 = dense.view(-1, k // ksparse, ksparse) + m0, m1, m2, m3 = (dense_4 != 0).unbind(-1) + else: + ksparse = 2 + dense_2 = dense.view(-1, k // ksparse, ksparse) + m0, m2 = m1, m3 = (dense_2 != 0).unbind(-1) + meta_ncols = k // (ksparse * quadbits_per_meta_elem) + + # Encoding quadruples of True/False values as follows: + # [True, True, False, False] -> 0b0100 + # [True, False, True, False] -> 0b1000 + # [False, True, True, False] -> 0b1001 + # [True, False, False, True ] -> 0b1100 + # [False, True, False, True ] -> 0b1101 + # [False, False, True, True ] -> 0b1110 + # Thus, lower two bits in the encoding are index of the True value + # at the lowest index in the quadruple, and the higher two bits in + # the encoding are index of the other True value in the quadruple. + # In case there are less than two True values, than False value or + # values at some index or indices are considered True for the + # encoding. In case there are more than two True values, then the + # excess True value(s) at some indices are considered False for + # the encoding. The exact encodings used for these cases are as + # follows: + # [False, False, False, False] -> 0b1110 + # [False, False, False, True ] -> 0b1110 + # [False, False, True, False] -> 0b1110 + # [False, True, False, False] -> 0b1001 + # [False, True, True, True ] -> 0b1101 + # [True, False, False, False] -> 0b1000 + # [True, False, True, True ] -> 0b1100 + # [True, True, False, True ] -> 0b0100 + # [True, True, True, False] -> 0b0100 + # [True, True, True, True ] -> 0b0100 + # These particular encodings are chosen, with the help of Espresso + # logic minimizer software, for the purpose of minimization of + # corresponding Boolean functions, that translate non-zero flags + # into encoding bits. Note also possible choices for the first + # and last of these encodings were limited only to (0b0100, + # 0b1110), in order to produce valid encodings for 1:2 sparsity + # case. + + expr0 = m0 & m1 + expr1 = ~m0 & m1 + expr2 = ~m0 & ~m1 + bit0 = expr1 + bit1 = expr2 + bit2 = expr0 | expr2 | m3 + bit3 = expr1 | ~m1 + idxs0 = bit0 | (bit1.to(torch.int64) << 1) + idxs1 = bit2 | (bit3.to(torch.int64) << 1) + + if dense.dtype != torch.float: + sparse0 = dense_4.gather( + -1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] + sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) + sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) + else: + sparse = dense_2.gather(-1, + idxs0.unsqueeze(-1) // 2).view( + m, + k // 2) # type: ignore[possibly-undefined] + + meta_4 = idxs0 | (idxs1 << 2) + meta_n = meta_4.view( + (-1, meta_ncols, quadbits_per_meta_elem)).to(meta_dtype) + + if quadbits_per_meta_elem == 4: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12)) + elif quadbits_per_meta_elem == 8: + meta = (meta_n[:, :, 0] + | (meta_n[:, :, 1] << 4) + | (meta_n[:, :, 2] << 8) + | (meta_n[:, :, 3] << 12) + | (meta_n[:, :, 4] << 16) + | (meta_n[:, :, 5] << 20) + | (meta_n[:, :, 6] << 24) + | (meta_n[:, :, 7] << 28)) + + # Reorder meta tensor elements. + meta_reordered = meta.new_empty( + (m * meta_ncols, )) # type: ignore[possibly-undefined] + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) + + return (sparse, meta_reordered.view(m, meta_ncols)) + + +# This function performs reverse of the function above - it +# reconstructs dense matrix from a pair of "compressed" matrix, given +# in the layout used by CUTLASS backend, and accompanying metadata +# matrix. +def sparse_semi_structured_to_dense_cutlass(sparse, meta_reordered): + if sparse.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional sparse tensor, got {sparse.dim()}-dimensional tensor" # noqa: E501 + ) + + m, k = sparse.shape + device = sparse.device + + if meta_reordered.dim() != 2: + raise RuntimeError( + f"Expected 2-dimensional meta tensor, got {meta_reordered.dim()}-dimensional tensor" # noqa: E501 + ) + if meta_reordered.device != device: + raise RuntimeError( + f"Expected meta matrix to be on {device} device, got matrix on {meta_reordered.device} device" # noqa: E501 + ) + + meta_dtype = meta_reordered.dtype + if meta_dtype not in (torch.int16, torch.int32): + raise RuntimeError(f"Invalid datatype {meta_dtype} of meta matrix") + quadbits_per_meta_elem = meta_dtype.itemsize * 8 // 4 + + ksparse = 4 if sparse.dtype != torch.float else 2 + + meta_nrows, meta_ncols = meta_reordered.shape + if meta_nrows != m: + raise RuntimeError( + f"Number of rows of meta matrix {meta_nrows} must be equal to number of columns of spase matrix {m}" # noqa: E501 + ) + if meta_ncols * ksparse * quadbits_per_meta_elem != 2 * k: + raise RuntimeError( + f"Number of columns of sparse matrix {k} different from the {meta_ncols * ksparse * quadbits_per_meta_elem // 2}, " # noqa: E501 + "expected according to the number of columns of meta matrix") + + # Undo meta tensor elements reordering. + meta_offsets = _calculate_meta_reordering_scatter_offsets( + m, meta_ncols, meta_dtype, device) + meta = torch.gather(meta_reordered.view(-1), 0, + meta_offsets).view(m, meta_ncols) + + # Unpack sparse tensor back to original dense tensor, using + # information provided by meta tensor. Note that torch.float + # datatype is handled pretty much the same as + # torch.half/torch.bfloat16, as metadata for a pair of torch.float + # value is encoded as if underlying 8 bytes contain four + # torch.half/torch.bfloat16 values, where either first two or last + # two are zeros. + meta_2 = torch.empty( + (m, meta_ncols, 2 * quadbits_per_meta_elem), + dtype=meta_dtype, + device=device, + ) + if quadbits_per_meta_elem == 4: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + elif quadbits_per_meta_elem == 8: + meta_2[:, :, 0] = meta & 0b11 + meta_2[:, :, 1] = (meta >> 2) & 0b11 + meta_2[:, :, 2] = (meta >> 4) & 0b11 + meta_2[:, :, 3] = (meta >> 6) & 0b11 + meta_2[:, :, 4] = (meta >> 8) & 0b11 + meta_2[:, :, 5] = (meta >> 10) & 0b11 + meta_2[:, :, 6] = (meta >> 12) & 0b11 + meta_2[:, :, 7] = (meta >> 14) & 0b11 + meta_2[:, :, 8] = (meta >> 16) & 0b11 + meta_2[:, :, 9] = (meta >> 18) & 0b11 + meta_2[:, :, 10] = (meta >> 20) & 0b11 + meta_2[:, :, 11] = (meta >> 22) & 0b11 + meta_2[:, :, 12] = (meta >> 24) & 0b11 + meta_2[:, :, 13] = (meta >> 26) & 0b11 + meta_2[:, :, 14] = (meta >> 28) & 0b11 + meta_2[:, :, 15] = (meta >> 30) & 0b11 + + dense_offsets = meta_2.view(-1) + ( + torch.arange(0, 2 * m * k // ksparse, device=device) * 4).view( + -1, 1).repeat(1, 2).view(-1) + + dense = torch.zeros((m * 2 * k, ), dtype=sparse.dtype, device=device) + if sparse.dtype != torch.float: + # dense.scatter_(0, dense_offsets, sparse.view(-1)) + dense.scatter_(0, dense_offsets, sparse.reshape(-1)) + else: + dense.view(torch.half).scatter_(0, dense_offsets, + sparse.view(torch.half).view(-1)) + + return dense.view(m, 2 * k) + + +def mask_creator(tensor): + """ + Class for creating N:M sparsity masks. + Masks will be created using the N:M ratio, where for every block of + M weights, N will be pruned based on ranked weight value. Each mask + will correspond to the given tensor. + + :param N: The number of weights in a group to keep + :param M: The size of a weight group + """ + N = 2 + M = 4 + + mask = None + # for i, tensor in enumerate(tensors): + if tensor.numel() % M != 0: + raise ValueError( + f"Tensor of size {tensor.shape} can't be evenly divided into " + f"{M} groups") + + num_groups = tensor.numel() // M + + # N:M sparsity for linear layers + tensor_temp = tensor.detach().abs().reshape(num_groups, M) + index = torch.argsort(tensor_temp, dim=1)[:, :int(M - N)] + + w_b = torch.ones(tensor_temp.shape, device=tensor_temp.device) + mask = w_b.scatter_(dim=1, index=index, value=0).reshape(tensor.shape) + + return mask + + +def inject_24(w, size_k, size_n): + assert w.shape == (size_k, size_n) + + mask = mask_creator(w.t()).t().cuda().bool() + + return (mask * w).contiguous(), mask.contiguous() + + +def check_24(w, num_rows_to_sample=50, _verbose=False): + BLOCK_SIZE = 4 + MAX_NON_ZEROS = 2 + + w = w.t().contiguous() + + print("check_24: w.shape = {}".format(w.shape)) + + num_rows, num_cols = w.shape + sampled_row_idxs = random.choices(range(num_rows), k=num_rows_to_sample) + if _verbose: + print(f"Sampled row idxs = {sampled_row_idxs}") + + total_segments = 0 + non_24_segments = 0 + for i in sampled_row_idxs: + for j in range(0, num_cols - BLOCK_SIZE, BLOCK_SIZE): + total_segments += 1 + block = w[i, j:j + BLOCK_SIZE] + num_nonzero = torch.count_nonzero(block) + if num_nonzero > MAX_NON_ZEROS: + print("i = {} j = {} block = {}".format(i, j, block)) + non_24_segments += 1 + + print(f"{non_24_segments} / {total_segments} do not have 2:4 structure.") + + +def compress_quantized_24_weight(q_24, size_k, size_n, wtype: ScalarType): + assert q_24.shape == (size_k, size_n) + + # Remove bias to normalize over 0 + q_24_no_zp = q_24 - wtype.bias + + # Compress + q_24_no_zp = q_24_no_zp.t().contiguous() + q_24_no_zp_comp, meta = sparse_semi_structured_from_dense_cutlass( + q_24_no_zp) + q_24_no_zp_comp = q_24_no_zp_comp.t().contiguous() + + # Restore bias + q_24_comp = q_24_no_zp_comp + wtype.bias + + # Resize meta to its actual shape (without moving any data) + meta = meta.resize_(meta.shape[1] // 2, meta.shape[0] * 2) + + return q_24_comp, meta + + +def get_scale_perms_24(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i * 8 + j for j in [0, 4, 1, 5, 2, 6, 3, 7]]) + scale_perm_single: List[int] = [] + for i in range(8): + scale_perm_single.extend([8 * i + j for j in [0, 1, 2, 3, 4, 5, 6, 7]]) + return scale_perm, scale_perm_single + + +def get_weight_perm_24(num_bits: int): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + col_o = col // 2 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col_o * 256 + 8 * (col % 2) + + 4 * block) + for j in range(4): + perm_list.extend([p + 1 * j for p in perm1]) + perm = numpy.array(perm_list) + + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError("num_bits must be 4 or 8, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_permute_scales_24(s: torch.Tensor, size_k: int, size_n: int, + group_size: int) -> torch.Tensor: + + scale_perm, scale_perm_single = get_scale_perms_24() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +def marlin_24_quantize( + w: torch.Tensor, + quant_type: ScalarType, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + # Inject 2:4 sparsity + w_24, mask_24 = inject_24(w, size_k, size_n) + + # Quantize + w_24_ref, q_w_24, s, g_idx, rand_perm = gptq_quantize_weights( + w_24, quant_type, group_size, act_order=False) + + # Compress quantized weight + q_w_24_comp, meta = compress_quantized_24_weight(q_w_24, size_k, size_n, + quant_type) + size_k_comp = size_k // 2 + + # Reformat to marlin + weight_perm = get_weight_perm_24(quant_type.size_bits) + marlin_24_q_w_comp = marlin_weights(q_w_24_comp, size_k_comp, size_n, + quant_type.size_bits, weight_perm) + marlin_24_s = marlin_permute_scales_24(s, size_k, size_n, group_size) + + # Create result + res_list = [w_24_ref, marlin_24_q_w_comp, meta, marlin_24_s] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py new file mode 100644 index 0000000000000000000000000000000000000000..176b2947ab09e73a87217a167a6cc00a32d940b3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List + +import numpy +import torch + +from .marlin_utils_test import marlin_permute_weights +from .quant_utils import get_pack_factor, qqq_quantize_weights + + +def marlin_qqq_weights(q_w, size_k, size_n, num_bits, perm, group_size): + # Permute + q_w = marlin_permute_weights(q_w, size_k, size_n, perm) + + # Pack + pack_factor = get_pack_factor(num_bits) + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_packed = numpy.zeros((q_w.shape[0], q_w.shape[1] // pack_factor), + dtype=numpy.uint32) + if group_size == size_k: + for i in range(pack_factor): + q_packed |= (q_w[:, i::pack_factor] & 0xF) << num_bits * i + else: + for i in range(pack_factor): + q_packed |= q_w[:, i::pack_factor] << num_bits * i + + q_packed = torch.from_numpy(q_packed.astype(numpy.int32)).to(orig_device) + + return q_packed + + +def get_qqq_scale_perms(): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + +# NOTE(HandH1998): QQQ employs different perms for per-group and per-channel weight quantization. # noqa: E501 +def get_qqq_weight_perm(num_bits: int, quant_type: str): + perm_list: List[int] = [] + for i in range(32): + perm1: List[int] = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 4 * (i % 4), + 4 * (i % 4) + 1, + 4 * (i % 4) + 2, + 4 * (i % 4) + 3, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm_list.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm_list) + + assert quant_type in ["per-channel", + "per-group"], "not supported quantization type" + if num_bits == 4: + if quant_type == "per-channel": + interleave = numpy.array([4, 0, 5, 1, 6, 2, 7, 3]) + else: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + else: + raise Exception("num_bits must be 4, got {}".format(num_bits)) + + perm = perm.reshape((-1, len(interleave)))[:, interleave].ravel() + perm = torch.from_numpy(perm) + return perm + + +def marlin_qqq_permute_scales(s_group, s_channel, size_k, size_n, group_size): + scale_perm, scale_perm_single = get_qqq_scale_perms() + if group_size < size_k and group_size != -1: + s_group = s_group.reshape((-1, len(scale_perm)))[:, scale_perm] + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_group = s_group.reshape((-1, size_n)).contiguous() + else: + s_channel = s_channel.reshape( + (-1, len(scale_perm_single)))[:, scale_perm_single] + s_channel = s_channel.reshape((-1, size_n)).contiguous() + + return s_group, s_channel + + +def marlin_qqq_quantize( + w: torch.Tensor, + num_bits: int, + group_size: int, +): + size_k, size_n = w.shape + + # Normalize group_size + if group_size == -1: + group_size = size_k + assert group_size <= size_k + quant_type = "per-channel" if group_size == size_k else "per-group" + + # Quantize + w_ref, q_w, s_group, s_channel = qqq_quantize_weights( + w, num_bits, group_size) + + # Reformat to marlin_qqq + weight_perm = get_qqq_weight_perm(num_bits, quant_type) + marlin_qqq_q_w = marlin_qqq_weights(q_w, size_k, size_n, num_bits, + weight_perm, group_size) + marlin_qqq_s_group, marlin_qqq_s_channel = marlin_qqq_permute_scales( + s_group, s_channel, size_k, size_n, group_size) + + # Create result + res_list = [ + w_ref, marlin_qqq_q_w, marlin_qqq_s_group, marlin_qqq_s_channel + ] + for i in range(len(res_list)): + res_list[i] = res_list[i].to(w.device) + + return res_list diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/quant_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/quant_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c7ce3a42c81f99ac8ec5d7cc645806a88e6c0b54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -0,0 +1,571 @@ +# SPDX-License-Identifier: Apache-2.0 +"""This file is used for /tests and /benchmarks""" +from types import MappingProxyType +from typing import List, Mapping, Optional, Tuple + +import numpy +import torch + +from vllm.model_executor.layers.quantization.qqq import ( + MARLIN_QQQ_SUPPORTED_NUM_BITS) +from vllm.scalar_type import ScalarType, scalar_types + +SUPPORTED_GPTQ_QUANT_TYPES = [scalar_types.uint4b8, scalar_types.uint8b128] +SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] + + +# Normalize the group_shape to the full extent for any dims that are -1 +def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int, + int]): + # -1 means full extent + return (group_shape[0] if group_shape[0] > 0 else x.shape[-2], + group_shape[1] if group_shape[1] > 0 else x.shape[-1]) + + +# Useful when treating N-dimensional group scaling as extended numpy-style +# broadcasting in numpy simply stretches dimensions with an extent of 1 to match +# the target shape by repeating the data along that dimension (broadcasting) +# , we extend these semantics to say if the extent of a dimension in the +# source shape is not 1 and does not match the target shape we repeat each +# element along that dimension src_shape[dim] // target_shape[dim] times +# example if we have: +# a = [[1, 2], and target_shape = (2, 4) +# [3, 4]] +# then we would expand a to: +# a = [[1, 1, 2, 2], +# [3, 3, 4, 4]] +# NOTE this function this function does not explicitly broadcast dimensions +# with an extent of 1, since this can be done implicitly by pytorch +def group_broadcast(t, shape): + for i, s in enumerate(shape): + if t.shape[i] != s and t.shape[i] != 1: + assert s % t.shape[i] == 0 + t = t.unsqueeze(i + 1)\ + .expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\ + .flatten(i, i + 1) + return t + + +# Quantize assuming once scale per group of elements with shape group_shape, +# example group shapes: +# * (-1, -1) for per-tensor quantization +# * (1, -1) for per-row quantization +# * (-1, 1) for per-column quantization +# * (128, 128) for 128x128 deepseek style block quantization +# * (1, 128) for deepseek style activation quantization +# (i.e. per-token-per-group) +def scaled_quantize( + x: torch.Tensor, + group_shape: Tuple[int, int], + quant_dtype: torch.dtype, +) -> Tuple[torch.Tensor, torch.Tensor]: + group_shape = _normalize_quant_group_shape(x, group_shape) + assert quant_dtype.is_floating_point, \ + "currently `scaled_quantize` only supports floating point dtypes " \ + "but could be extended to support other dtypes" + + finfo = torch.finfo(quant_dtype) + + # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N) + assert x.ndim == 2 + assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0 + blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1] + x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1]) + + # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N) + x_blkd_permd = x_blkd.permute(0, 2, 1, 3) + # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) + x_blkd_permd = x_blkd_permd.flatten(start_dim=2) + + # Compute scales + min_val, max_val = x_blkd_permd.aminmax(dim=-1) + amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12) + scale = finfo.max / amax + + # Apply scale and convert form: + # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N) + x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\ + .clamp(min=finfo.min, max=finfo.max)\ + .reshape(blk_m, blk_n, group_shape[0], group_shape[1])\ + .permute(0, 2, 1, 3)\ + .reshape(x.shape) + + return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal() + + +# inverses `scaled_quantize` +def scaled_dequantize( + x_q: torch.Tensor, + x_s: torch.Tensor, + group_shape: Optional[Tuple[int, int]] = None, + out_dtype: torch.dtype = torch.float32, +) -> Tuple[torch.Tensor, torch.Tensor]: + if group_shape is not None: + group_shape = _normalize_quant_group_shape(x_q, group_shape) + + if x_s.ndim == 0: # scalar + x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor + if x_s.ndim == 1: + if group_shape is None: + raise AssertionError( + "if x_s is 1D tensor, group_shape must be provided otherwise " + "its ambiguous which dimension to broadcast x_s to") + # unsqueeze the scales for the dimension where we want to broadcast + # across the full extent + if group_shape[0] == x_q.shape[-2]: + x_s = x_s.unsqueeze(-2) + elif group_shape[1] == x_q.shape[-1]: + x_s = x_s.unsqueeze(-1) + else: + raise AssertionError( + "if x_s is a vector we should be broadcasting it to the full " + "extent of one of the dimensions") + + if group_shape is not None: + assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1] + assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0] + x_s = group_broadcast(x_s.to(torch.float32), x_q.shape) + return (x_q.to(torch.float32) * x_s).to(out_dtype) + + +def pack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + assert w_q_perm.shape[-1] % pack_factor == 0 + new_shape_perm[-1] //= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i + + return res.permute(inv_perm) + + +def unpack_quantized_values_into_int32(w_q: torch.Tensor, + wtype: ScalarType, + packed_dim: int = 0): + # move dim to pack to the end + perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim) + inv_perm = tuple(perm.index(i) for i in range(len(perm))) + w_q_perm = w_q.permute(perm) + + pack_factor = 32 // wtype.size_bits + mask = (1 << wtype.size_bits) - 1 + + new_shape_perm = list(w_q_perm.shape) + new_shape_perm[-1] *= pack_factor + + res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device) + for i in range(pack_factor): + res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask + + return res.permute(inv_perm) + + +def is_layer_skipped( + prefix: str, + ignored_layers: List[str], + fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) +) -> bool: + # prefix: model.layers.0.self_attn.q_proj + # proj_name: q_proj + proj_name = prefix.split(".")[-1] + + # Fused layers like gate_up_proj or qkv_proj will not be fused + # in the safetensors checkpoint. So, we convert the name + # from the fused version to unfused + check to make sure that + # each shard of the fused layer has the same scheme. + if proj_name in fused_mapping: + shard_prefixes = [ + prefix.replace(proj_name, shard_proj_name) + for shard_proj_name in fused_mapping[proj_name] + ] + + is_skipped = None + for shard_prefix in shard_prefixes: + is_shard_skipped = shard_prefix in ignored_layers + + if is_skipped is None: + is_skipped = is_shard_skipped + elif is_shard_skipped != is_skipped: + raise ValueError( + f"Detected some but not all shards of {prefix} " + "are quantized. All shards of fused layers " + "to have the same precision.") + else: + is_skipped = prefix in ignored_layers + + assert is_skipped is not None + return is_skipped + + +def get_pack_factor(num_bits): + assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}" + return 32 // num_bits + + +def permute_rows(q_w: torch.Tensor, + w_ref: torch.Tensor, + group_size: int, + test_perm: Optional[torch.Tensor] = None): + assert q_w.shape == w_ref.shape + + orig_device = q_w.device + k_size, _ = q_w.shape + + g_idx = torch.zeros((k_size, ), dtype=torch.int32) + for i in range(k_size): + g_idx[i] = i // group_size + + # Simulate act_order by doing a random permutation on K + rand_perm = test_perm if test_perm is not None else torch.randperm(k_size) + + g_idx = g_idx[rand_perm].contiguous() + q_w = q_w[rand_perm, :].contiguous() + w_ref = w_ref[rand_perm, :].contiguous() + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + rand_perm.to(device=orig_device), + ) + + +def quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: Optional[int], + zero_points: bool = False, + ref_zero_points_after_scales: bool = False): + assert quant_type.is_integer(), \ + "Floating point quantization may work but has not been tested" + assert not zero_points or group_size is not None, \ + "to have group zero points, group_size must be provided "\ + "(-1 group_size is channelwise)" + + orig_device = w.device + orig_type = w.dtype + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + + if group_size == -1: + group_size = size_k + + # Reshape to [groupsize, -1] + if group_size is not None and group_size < size_k: + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + # Compute scale for each group + max_val = torch.max(w, 0, keepdim=True).values + min_val = torch.min(w, 0, keepdim=True).values + + max_q_val = quant_type.max() + min_q_val = quant_type.min() + + w_s = torch.Tensor([1.0]).to(w.device) # unscaled case + maybe_w_zp = None + if group_size is not None: + if zero_points: + assert not quant_type.is_signed() and quant_type.max() > 0 + w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max() + maybe_w_zp = torch.round(torch.abs(min_val / w_s)) \ + .clamp(min_q_val, max_q_val).int() + else: + # If the bias is such that there are no possible negative/positive + # values, set the max value to inf to avoid divide by 0 + w_s = torch.max( + abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)), + abs(min_val / (min_q_val if min_q_val != 0 else torch.inf))) + + # Quantize + w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0) + w_q = torch.clamp(w_q, min_q_val, max_q_val) + + # Compute ref (dequantized) + # For some kernels (namely Machete) the zero-points are applied after the + # scales are applied, for this case computing the reference in similar way + # allows us to use tighter error tolerances in our unit tests. + if ref_zero_points_after_scales and maybe_w_zp is not None: + w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s + else: + w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s + + if quant_type.has_bias(): + w_q += quant_type.bias + + # Restore original shapes + if group_size is not None and group_size < size_k: + + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + w_q = reshape_w(w_q) + w_ref = reshape_w(w_ref) + w_s = w_s.reshape((-1, size_n)).contiguous() + + if maybe_w_zp is not None: + maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous() + maybe_w_zp = maybe_w_zp.to(device=orig_device) + + return ( + w_ref.to(device=orig_device), + w_q.to(device=orig_device), + w_s if group_size is not None else None, + maybe_w_zp, + ) + + +def gptq_quantize_weights(w: torch.Tensor, + quant_type: ScalarType, + group_size: int, + act_order: bool, + test_perm: Optional[torch.Tensor] = None): + size_k, _ = w.shape + + assert w.is_floating_point(), "w must be float" + assert quant_type in SUPPORTED_GPTQ_QUANT_TYPES, \ + f"Unsupported gptq type = {quant_type}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + w_ref, w_q, w_s, _ = quantize_weights(w, quant_type, group_size) + + # Apply act_order + g_idx = torch.empty(0, dtype=torch.int, device=w.device) + rand_perm = torch.empty(0, dtype=torch.int, device=w.device) + if act_order: + assert ( + group_size < size_k + ), "For act_order, groupsize = {} must be less than size_k = {}".format( + group_size, size_k) + + w_ref, w_q, g_idx, rand_perm = permute_rows(w_q, w_ref, group_size, + test_perm) + + return w_ref, w_q, w_s, g_idx, rand_perm + + +# QQQ employs different quant schemes for per-group and +# per-channel quantization. +def qqq_quantize_weights(w: torch.Tensor, num_bits: int, group_size: int): + orig_device = w.device + size_k, size_n = w.shape + + assert w.is_floating_point(), "w must be float" + assert num_bits in MARLIN_QQQ_SUPPORTED_NUM_BITS, \ + f"Unsupported num_bits = {num_bits}" + assert group_size in SUPPORTED_GROUP_SIZES + [ + size_k + ], f"Unsupported groupsize = {group_size}" + + if group_size == -1: + group_size = size_k + assert group_size <= size_k + + if group_size < size_k: + # Reshape to [groupsize, -1] + w = w.reshape((-1, group_size, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((group_size, -1)) + + max_q_val = 2**num_bits - 1 + half_q_val = (max_q_val + 1) // 2 + + # Compute scale for each group + s_group = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_group *= 2 / max_q_val # 2 => symmetric + + # Quantize + q_w = torch.round(w / s_group).int() + q_w += half_q_val + q_w = torch.clamp(q_w, 0, max_q_val) + # Compute ref (dequantized) + w_ref = (q_w - half_q_val).half() * s_group + + # Restore original shapes + def reshape_w(w): + w = w.reshape((group_size, -1, size_n)) + w = w.permute(1, 0, 2) + w = w.reshape((size_k, size_n)).contiguous() + return w + + q_w = reshape_w(q_w) + w_ref = reshape_w(w_ref) + + # Compute int8 quantization scale for each channel + s_channel = torch.max(torch.abs(w_ref), 0, keepdim=True)[0] + s_channel /= 127.0 + t_int8 = (w_ref / s_channel).round().clamp(-128, 127).to(torch.int8) + w_ref = t_int8.half() * s_channel + s_channel = s_channel.reshape(1, -1).to(dtype=torch.float) + + # Fuse scales + s_group = (s_group.reshape(-1, size_n).contiguous() / + s_channel).to(dtype=torch.half) + else: + max_q_val = 2**(num_bits - 1) - 1 + + # Compute scale for each channel + s_channel = torch.max(torch.abs(w), 0, keepdim=True)[0] + s_channel /= max_q_val + + # Quantize + q_w = torch.round(w / s_channel).int() + q_w = torch.clamp(q_w, -max_q_val, max_q_val) + # Compute ref (dequantized) + w_ref = q_w.half() * s_channel + + s_group = torch.tensor([], dtype=torch.half) + # div 2 ** (8 - self.bits)) to offset right shift in unpacking + s_channel /= (2**(8 - num_bits)) + s_channel = s_channel.reshape(-1, size_n).contiguous().to(torch.float) + + return ( + w_ref.to(device=orig_device), + q_w.to(device=orig_device), + s_group.to(device=orig_device), + s_channel.to(device=orig_device), + ) + + +def sort_weights(q_w: torch.Tensor, g_idx: torch.Tensor): + orig_device = q_w.device + + sort_indices = torch.argsort(g_idx).to( + dtype=torch.int32) # Sort based on g_idx + + g_idx = g_idx[sort_indices].contiguous() + q_w = q_w[sort_indices, :].contiguous() + + return ( + q_w.to(device=orig_device), + g_idx.to(device=orig_device), + sort_indices.to(device=orig_device), + ) + + +def pack_rows( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_k % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k // pack_factor, size_n), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[i::pack_factor, :] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + return q_res + + +def pack_cols( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + + orig_device = q_w.device + + q_w = q_w.cpu().numpy().astype(numpy.uint32) + + q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32) + + for i in range(pack_factor): + q_res |= q_w[:, i::pack_factor] << num_bits * i + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def unpack_cols( + packed_q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + pack_factor = get_pack_factor(num_bits) + assert size_n % pack_factor == 0 + assert packed_q_w.shape == ( + size_k, size_n // pack_factor + ), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format( + packed_q_w.shape, size_k, size_n, pack_factor) + + orig_device = packed_q_w.device + + packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) + q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32) + + mask = (1 << num_bits) - 1 + for i in range(pack_factor): + vals = packed_q_w_cpu & mask + packed_q_w_cpu >>= num_bits + q_res[:, i::pack_factor] = vals + + q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device) + q_res = q_res.contiguous() + + return q_res + + +def gptq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + return pack_rows(q_w, num_bits, size_k, size_n) + + +def awq_pack( + q_w: torch.Tensor, + num_bits: int, + size_k: int, + size_n: int, +): + assert q_w.shape == (size_k, size_n) + + # Interleave column dim (for the dequantize code) and pack it to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise Exception("num_bits must be 4 or 8, got {}".format(num_bits)) + + q_w = q_w.reshape((-1, len(interleave)))[:, interleave].ravel() + q_w = q_w.reshape((-1, size_n)).contiguous() + + return pack_cols(q_w, num_bits, size_k, size_n) diff --git a/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/w8a8_utils.py b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/w8a8_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dedeb0c296bd45f032b8b4be50203392d0418b4c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/vllm/model_executor/layers/quantization/utils/w8a8_utils.py @@ -0,0 +1,241 @@ +# SPDX-License-Identifier: Apache-2.0 + +from typing import List, Optional, Tuple, Union + +import torch + +from vllm import _custom_ops as ops +from vllm.platforms import current_platform + +# Input scaling factors are no longer optional in _scaled_mm starting +# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale +TORCH_DEVICE_IDENTITY = torch.ones(1, dtype=torch.float32) + + +def sparse_cutlass_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_sparse_scaled_mm_supported(capability) + + +def cutlass_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_fp8(capability) + + +def cutlass_block_fp8_supported() -> bool: + if not current_platform.is_cuda(): + return False + + capability_tuple = current_platform.get_device_capability() + capability = -1 if capability_tuple is None else capability_tuple.to_int() + + return ops.cutlass_scaled_mm_supports_block_fp8(capability) + + +CUTLASS_FP8_SUPPORTED = cutlass_fp8_supported() +CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported() + + +def per_tensor_dequantize( + tensor: torch.Tensor, inv_scale: Union[float, + torch.Tensor]) -> torch.Tensor: + fake_qweight = tensor.to(torch.float16) + dq_weight = fake_qweight * inv_scale + return dq_weight + + +def all_close_1d(x: torch.Tensor) -> bool: + assert len(x.shape) == 1 + return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0])) + + +def convert_to_channelwise( + weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Create channelwise buffer + weight_scale_channel = torch.empty((sum(logical_widths), 1), + dtype=torch.float32, + device=weight_scale.device) + + # Expand each scale to match the size of each logical matrix. + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_scale_channel[start:end, :] = weight_scale[idx] + start = end + + return weight_scale_channel + + +def requantize_with_max_scale( + weight: torch.Tensor, weight_scale: torch.Tensor, + logical_widths: List[int]) -> Tuple[torch.Tensor, torch.Tensor]: + # Max scale to be used for requanitzation. + max_w_scale = weight_scale.max() + + # QKV / MLP is fused in the on disk checkpoint if any of the + # weight scales are still set to the default since we initialize + # N weight scales for N shards but we only load 1 weight scale + # from disk in this case. Skip requantization in this case (since) + # we already are quantized with the single scale. + # * Sample Model: nm-testing/Phi-3-mini-128k-instruct-FP8 + unfused_module_in_checkpoint = (weight_scale[-1] + > torch.finfo(torch.float8_e4m3fn).min) + + # If unfused checkpoint, need requanize with the single scale. + if unfused_module_in_checkpoint: + start = 0 + for idx, logical_width in enumerate(logical_widths): + end = start + logical_width + weight_dq = per_tensor_dequantize(weight[start:end, :], + weight_scale[idx]) + weight[start:end, :], _ = ops.scaled_fp8_quant( + weight_dq, max_w_scale) + start = end + + return max_w_scale, weight + + +def apply_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None, + input_scale_ub: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + cutlass_fp8_supported: bool = CUTLASS_FP8_SUPPORTED, + use_per_token_if_dynamic: bool = False, +) -> torch.Tensor: + # ops.scaled_fp8_quant supports both dynamic and static quant. + # If dynamic, layer.input_scale is None and x_scale computed from x. + # If static, layer.input_scale is scalar and x_scale is input_scale. + + # View input as 2D matrix for fp8 methods + input_2d = input.view(-1, input.shape[-1]) + output_shape = [*input.shape[:-1], weight.shape[1]] + + # cutlass_scaled_mm supports per tensor/channel W and per tensor/token A + if cutlass_fp8_supported: + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + scale_ub=input_scale_ub, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + # Fused GEMM_DQ + output = ops.cutlass_scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + return output.view(*output_shape) + + # torch.scaled_mm supports per tensor weights + activations only + # so fallback to naive if per channel or per token + else: + # Note: we pad the input because torch._scaled_mm is more performant + # for matrices with batch dimension > 16. + # This could change in the future. + qinput, x_scale = ops.scaled_fp8_quant( + input_2d, + input_scale, + num_token_padding=17, + use_per_token_if_dynamic=use_per_token_if_dynamic) + + per_tensor_weights = (weight_scale.numel() == 1) + per_tensor_activations = (x_scale.numel() == 1) + + if per_tensor_weights and per_tensor_activations: + # Fused GEMM_DQ + output = torch._scaled_mm(qinput, + weight, + out_dtype=input.dtype, + scale_a=x_scale, + scale_b=weight_scale, + bias=bias) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + + return torch.narrow(output, 0, 0, + input_2d.shape[0]).view(*output_shape) + + else: + # Fallback for channelwise case, where we use unfused DQ + # due to limitations with scaled_mm + + # Symmetric quantized GEMM by definition computes the following: + # C = (s_x * X) (s_w * W) + bias + # This is equivalent to dequantizing the weights and activations + # before applying a GEMM. + # + # In order to compute quantized operands, a quantized kernel + # will rewrite the above like so: + # C = s_w * s_x * (X * W) + bias + # + # For the scaled_mm fallback case, we break this down, since it + # does not support s_w being a vector. + + # Making sure the dummy tensor is on the same device as the weight + global TORCH_DEVICE_IDENTITY + if TORCH_DEVICE_IDENTITY.device != weight.device: + TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device) + + # GEMM + # This computes C = (X * W). + # Output in fp32 to allow subsequent ops to happen in-place + output = torch._scaled_mm(qinput, + weight, + scale_a=TORCH_DEVICE_IDENTITY, + scale_b=TORCH_DEVICE_IDENTITY, + out_dtype=torch.float32) + # A fix for discrepancy in scaled_mm which returns tuple + # for torch < 2.5 and a single value in torch >= 2.5 + if type(output) is tuple and len(output) == 2: + output = output[0] + # Unpad (undo num_token_padding) + output = torch.narrow(output, 0, 0, input_2d.shape[0]) + x_scale = torch.narrow(x_scale, 0, 0, input_2d.shape[0]) + + # DQ + # C = sw * sx * (X * W) + bias + output = output * x_scale * weight_scale.t() + if bias is not None: + output = output + bias + return output.to(dtype=input.dtype).view(*output_shape) + + +def normalize_e4m3fn_to_e4m3fnuz( + weight: torch.Tensor, + weight_scale: torch.Tensor, + input_scale: Optional[torch.Tensor] = None +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + assert weight.dtype == torch.float8_e4m3fn + # The bits pattern 10000000(-128) represents zero in e4m3fn + # but NaN in e4m3fnuz. So here we set it to 0. + # https://onnx.ai/onnx/technical/float8.html + weight_as_int8 = weight.view(torch.int8) + ROCM_FP8_NAN_AS_INT = -128 + weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0 + weight = weight_as_int8.view(torch.float8_e4m3fnuz) + + # For the same bits representation, e4m3fnuz value is half of + # the e4m3fn value, so we should double the scaling factor to + # get the same dequantized value. + # https://onnx.ai/onnx/technical/float8.html + weight_scale = weight_scale * 2.0 + if input_scale is not None: + input_scale = input_scale * 2.0 + return weight, weight_scale, input_scale