| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import warnings |
| | from typing import Dict, Tuple |
| |
|
| | import torch |
| | from compressed_tensors.compressors.base import BaseCompressor |
| | from compressed_tensors.quantization import ( |
| | QuantizationScheme, |
| | QuantizationStatus, |
| | initialize_module_for_quantization, |
| | ) |
| | from compressed_tensors.utils import register_offload_parameter |
| | from compressed_tensors.utils.offload import get_execution_device |
| | from torch import Tensor |
| | from torch.nn import Parameter |
| | from torch.nn.functional import linear |
| | from torch.nn.modules import Linear |
| |
|
| |
|
| | class CompressedLinear(Linear): |
| | """ |
| | Wrapper module for running a compressed forward pass of a quantized Linear module. |
| | The wrapped layer will decompressed on each forward call. |
| | |
| | """ |
| |
|
| | def __init__(self, *args, **kwargs) -> None: |
| | super().__init__(*args, **kwargs) |
| | warnings.warn( |
| | "CompressedLinear should not be initialized directly. " |
| | "Use the from_linear method instead.", |
| | UserWarning, |
| | ) |
| |
|
| | @classmethod |
| | @torch.no_grad() |
| | def from_linear( |
| | cls, |
| | module: Linear, |
| | quantization_scheme: QuantizationScheme, |
| | quantization_format: str, |
| | ): |
| | """ |
| | :param module: dense linear module to replace |
| | :param quantization_scheme: quantization config for the module to wrap |
| | :param quantization_format: compression format module is stored as |
| | :return: CompressedLinear module wrapping the input module |
| | """ |
| | module.__class__ = CompressedLinear |
| | module.compressor = BaseCompressor.load_from_registry(quantization_format) |
| | init_device = get_execution_device(module) |
| |
|
| | |
| | initialize_module_for_quantization( |
| | module, quantization_scheme, force_zero_point=False |
| | ) |
| |
|
| | |
| | compression_params: Dict[str, Tuple] = module.compressor.compression_param_info( |
| | module.weight.shape, quantization_scheme.weights |
| | ) |
| |
|
| | |
| | |
| | delattr(module, "weight") |
| |
|
| | |
| | for name, (shape, dtype) in compression_params.items(): |
| | param = Parameter( |
| | torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False |
| | ) |
| | register_offload_parameter(module, name, param) |
| |
|
| | |
| | module.quantization_status = QuantizationStatus.COMPRESSED |
| |
|
| | |
| | if hasattr(module, "_old_forward"): |
| | module._old_forward = CompressedLinear.forward.__get__( |
| | module, CompressedLinear |
| | ) |
| |
|
| | return module |
| |
|
| | def forward(self, input: Tensor) -> Tensor: |
| | """ |
| | Decompresses the weight, then runs the wrapped forward pass |
| | """ |
| | if self.quantization_status == QuantizationStatus.COMPRESSED: |
| | weight_data = self.compressor.decompress_module(self) |
| | param = Parameter(weight_data, requires_grad=False) |
| | register_offload_parameter(self, "weight", param) |
| |
|
| | self.quantization_status = QuantizationStatus.FROZEN |
| |
|
| | return linear(input, self.weight, self.bias) |
| |
|