server / source /compressed_tensors /linear /compressed_linear.py
Harmony18090's picture
Add source batch 2/11
76f9669 verified
raw
history blame
4.05 kB
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings
from typing import Dict, Tuple
import torch
from compressed_tensors.compressors.base import BaseCompressor
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStatus,
initialize_module_for_quantization,
)
from compressed_tensors.utils import register_offload_parameter
from compressed_tensors.utils.offload import get_execution_device
from torch import Tensor
from torch.nn import Parameter
from torch.nn.functional import linear
from torch.nn.modules import Linear
class CompressedLinear(Linear):
"""
Wrapper module for running a compressed forward pass of a quantized Linear module.
The wrapped layer will decompressed on each forward call.
"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
warnings.warn(
"CompressedLinear should not be initialized directly. "
"Use the from_linear method instead.",
UserWarning,
)
@classmethod
@torch.no_grad()
def from_linear(
cls,
module: Linear,
quantization_scheme: QuantizationScheme,
quantization_format: str,
):
"""
:param module: dense linear module to replace
:param quantization_scheme: quantization config for the module to wrap
:param quantization_format: compression format module is stored as
:return: CompressedLinear module wrapping the input module
"""
module.__class__ = CompressedLinear
module.compressor = BaseCompressor.load_from_registry(quantization_format)
init_device = get_execution_device(module)
# this will initialize all the scales and zero points
initialize_module_for_quantization(
module, quantization_scheme, force_zero_point=False
)
# get the shape and dtype of compressed parameters
compression_params: Dict[str, Tuple] = module.compressor.compression_param_info(
module.weight.shape, quantization_scheme.weights
)
# no need for this once quantization is initialized, will be replaced
# with the compressed parameter
delattr(module, "weight")
# populate compressed weights and quantization parameters
for name, (shape, dtype) in compression_params.items():
param = Parameter(
torch.empty(shape, device=init_device, dtype=dtype), requires_grad=False
)
register_offload_parameter(module, name, param)
# mark module as compressed
module.quantization_status = QuantizationStatus.COMPRESSED
# handles case where forward is wrapped in new_forward by accelerate hooks
if hasattr(module, "_old_forward"):
module._old_forward = CompressedLinear.forward.__get__(
module, CompressedLinear
)
return module
def forward(self, input: Tensor) -> Tensor:
"""
Decompresses the weight, then runs the wrapped forward pass
"""
if self.quantization_status == QuantizationStatus.COMPRESSED:
weight_data = self.compressor.decompress_module(self)
param = Parameter(weight_data, requires_grad=False)
register_offload_parameter(self, "weight", param)
self.quantization_status = QuantizationStatus.FROZEN
return linear(input, self.weight, self.bias)