| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import logging |
| |
|
| | import torch |
| | from compressed_tensors.quantization.lifecycle.forward import quantize |
| | from compressed_tensors.quantization.quant_config import QuantizationStatus |
| | from torch.nn import Module |
| |
|
| |
|
| | __all__ = [ |
| | "compress_quantized_weights", |
| | ] |
| |
|
| |
|
| | _LOGGER = logging.getLogger(__name__) |
| |
|
| |
|
| | def compress_quantized_weights(module: Module): |
| | """ |
| | Quantizes the module weight representation to use fewer bits in memory |
| | |
| | apply to full model with `model.apply(compress_quantized_weights)` |
| | |
| | :param module: module to compress to quantized representation |
| | """ |
| | scheme = getattr(module, "quantization_scheme", None) |
| | if not scheme or not scheme.weights: |
| | |
| | return |
| |
|
| | status = getattr(module, "quantization_status", None) |
| | if status is QuantizationStatus.COMPRESSED: |
| | |
| | return |
| |
|
| | weight = getattr(module, "weight", None) |
| | scale = getattr(module, "weight_scale", None) |
| | zero_point = getattr(module, "weight_zero_point", None) |
| | g_idx = getattr(module, "weight_g_idx", None) |
| |
|
| | if weight is None or scale is None: |
| | |
| |
|
| | |
| | module.quantization_status = QuantizationStatus.COMPRESSED |
| | return |
| |
|
| | module.weight.requires_grad = False |
| | module.weight.data = quantize( |
| | x=weight, |
| | scale=scale, |
| | zero_point=zero_point, |
| | g_idx=g_idx, |
| | args=scheme.weights, |
| | dtype=torch.int8, |
| | ) |
| |
|
| | module.quantization_status = QuantizationStatus.COMPRESSED |
| |
|