|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Adapted from |
|
|
https://github.com/huggingface/transformers/blob/3a8eb74668e9c2cc563b2f5c62fac174797063e0/src/transformers/quantizers/quantizer_torchao.py |
|
|
""" |
|
|
|
|
|
import importlib |
|
|
import types |
|
|
from fnmatch import fnmatch |
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Union |
|
|
|
|
|
from packaging import version |
|
|
|
|
|
from ...utils import ( |
|
|
get_module_from_name, |
|
|
is_torch_available, |
|
|
is_torch_version, |
|
|
is_torchao_available, |
|
|
is_torchao_version, |
|
|
logging, |
|
|
) |
|
|
from ..base import DiffusersQuantizer |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ...models.modeling_utils import ModelMixin |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
if is_torch_version(">=", "2.5"): |
|
|
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( |
|
|
|
|
|
|
|
|
|
|
|
torch.int8, |
|
|
torch.float8_e4m3fn, |
|
|
torch.float8_e5m2, |
|
|
torch.uint1, |
|
|
torch.uint2, |
|
|
torch.uint3, |
|
|
torch.uint4, |
|
|
torch.uint5, |
|
|
torch.uint6, |
|
|
torch.uint7, |
|
|
) |
|
|
else: |
|
|
SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION = ( |
|
|
torch.int8, |
|
|
torch.float8_e4m3fn, |
|
|
torch.float8_e5m2, |
|
|
) |
|
|
|
|
|
if is_torchao_available(): |
|
|
from torchao.quantization import quantize_ |
|
|
|
|
|
|
|
|
def _update_torch_safe_globals(): |
|
|
safe_globals = [ |
|
|
(torch.uint1, "torch.uint1"), |
|
|
(torch.uint2, "torch.uint2"), |
|
|
(torch.uint3, "torch.uint3"), |
|
|
(torch.uint4, "torch.uint4"), |
|
|
(torch.uint5, "torch.uint5"), |
|
|
(torch.uint6, "torch.uint6"), |
|
|
(torch.uint7, "torch.uint7"), |
|
|
] |
|
|
try: |
|
|
from torchao.dtypes import NF4Tensor |
|
|
from torchao.dtypes.floatx.float8_layout import Float8AQTTensorImpl |
|
|
from torchao.dtypes.uintx.uint4_layout import UInt4Tensor |
|
|
from torchao.dtypes.uintx.uintx_layout import UintxAQTTensorImpl, UintxTensor |
|
|
|
|
|
safe_globals.extend([UintxTensor, UInt4Tensor, UintxAQTTensorImpl, Float8AQTTensorImpl, NF4Tensor]) |
|
|
|
|
|
except (ImportError, ModuleNotFoundError) as e: |
|
|
logger.warning( |
|
|
"Unable to import `torchao` Tensor objects. This may affect loading checkpoints serialized with `torchao`" |
|
|
) |
|
|
logger.debug(e) |
|
|
|
|
|
finally: |
|
|
torch.serialization.add_safe_globals(safe_globals=safe_globals) |
|
|
|
|
|
|
|
|
if ( |
|
|
is_torch_available() |
|
|
and is_torch_version(">=", "2.6.0") |
|
|
and is_torchao_available() |
|
|
and is_torchao_version(">=", "0.7.0") |
|
|
): |
|
|
_update_torch_safe_globals() |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def _quantization_type(weight): |
|
|
from torchao.dtypes import AffineQuantizedTensor |
|
|
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor |
|
|
|
|
|
if isinstance(weight, AffineQuantizedTensor): |
|
|
return f"{weight.__class__.__name__}({weight._quantization_type()})" |
|
|
|
|
|
if isinstance(weight, LinearActivationQuantizedTensor): |
|
|
return f"{weight.__class__.__name__}(activation={weight.input_quant_func}, weight={_quantization_type(weight.original_weight_tensor)})" |
|
|
|
|
|
|
|
|
def _linear_extra_repr(self): |
|
|
weight = _quantization_type(self.weight) |
|
|
if weight is None: |
|
|
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight=None" |
|
|
else: |
|
|
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}" |
|
|
|
|
|
|
|
|
class TorchAoHfQuantizer(DiffusersQuantizer): |
|
|
r""" |
|
|
Diffusers Quantizer for TorchAO: https://github.com/pytorch/ao/. |
|
|
""" |
|
|
|
|
|
requires_calibration = False |
|
|
required_packages = ["torchao"] |
|
|
|
|
|
def __init__(self, quantization_config, **kwargs): |
|
|
super().__init__(quantization_config, **kwargs) |
|
|
|
|
|
def validate_environment(self, *args, **kwargs): |
|
|
if not is_torchao_available(): |
|
|
raise ImportError( |
|
|
"Loading a TorchAO quantized model requires the torchao library. Please install with `pip install torchao`" |
|
|
) |
|
|
torchao_version = version.parse(importlib.metadata.version("torch")) |
|
|
if torchao_version < version.parse("0.7.0"): |
|
|
raise RuntimeError( |
|
|
f"The minimum required version of `torchao` is 0.7.0, but the current version is {torchao_version}. Please upgrade with `pip install -U torchao`." |
|
|
) |
|
|
|
|
|
self.offload = False |
|
|
|
|
|
device_map = kwargs.get("device_map", None) |
|
|
if isinstance(device_map, dict): |
|
|
if "cpu" in device_map.values() or "disk" in device_map.values(): |
|
|
if self.pre_quantized: |
|
|
raise ValueError( |
|
|
"You are attempting to perform cpu/disk offload with a pre-quantized torchao model " |
|
|
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." |
|
|
) |
|
|
else: |
|
|
self.offload = True |
|
|
|
|
|
if self.pre_quantized: |
|
|
weights_only = kwargs.get("weights_only", None) |
|
|
if weights_only: |
|
|
torch_version = version.parse(importlib.metadata.version("torch")) |
|
|
if torch_version < version.parse("2.5.0"): |
|
|
|
|
|
raise RuntimeError( |
|
|
f"In order to use TorchAO pre-quantized model, you need to have torch>=2.5.0. However, the current version is {torch_version}." |
|
|
) |
|
|
|
|
|
def update_torch_dtype(self, torch_dtype): |
|
|
quant_type = self.quantization_config.quant_type |
|
|
|
|
|
if quant_type.startswith("int") or quant_type.startswith("uint"): |
|
|
if torch_dtype is not None and torch_dtype != torch.bfloat16: |
|
|
logger.warning( |
|
|
f"You are trying to set torch_dtype to {torch_dtype} for int4/int8/uintx quantization, but " |
|
|
f"only bfloat16 is supported right now. Please set `torch_dtype=torch.bfloat16`." |
|
|
) |
|
|
|
|
|
if torch_dtype is None: |
|
|
|
|
|
logger.warning( |
|
|
"Overriding `torch_dtype` with `torch_dtype=torch.bfloat16` due to requirements of `torchao` " |
|
|
"to enable model loading in different precisions. Pass your own `torch_dtype` to specify the " |
|
|
"dtype of the remaining non-linear layers, or pass torch_dtype=torch.bfloat16, to remove this warning." |
|
|
) |
|
|
torch_dtype = torch.bfloat16 |
|
|
|
|
|
return torch_dtype |
|
|
|
|
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": |
|
|
quant_type = self.quantization_config.quant_type |
|
|
|
|
|
if quant_type.startswith("int8") or quant_type.startswith("int4"): |
|
|
|
|
|
return torch.int8 |
|
|
elif quant_type == "uintx_weight_only": |
|
|
return self.quantization_config.quant_type_kwargs.get("dtype", torch.uint8) |
|
|
elif quant_type.startswith("uint"): |
|
|
return { |
|
|
1: torch.uint1, |
|
|
2: torch.uint2, |
|
|
3: torch.uint3, |
|
|
4: torch.uint4, |
|
|
5: torch.uint5, |
|
|
6: torch.uint6, |
|
|
7: torch.uint7, |
|
|
}[int(quant_type[4])] |
|
|
elif quant_type.startswith("float") or quant_type.startswith("fp"): |
|
|
return torch.bfloat16 |
|
|
|
|
|
if isinstance(target_dtype, SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION): |
|
|
return target_dtype |
|
|
|
|
|
|
|
|
|
|
|
possible_device_maps = ["auto", "balanced", "balanced_low_0", "sequential"] |
|
|
raise ValueError( |
|
|
f"You have set `device_map` as one of {possible_device_maps} on a TorchAO quantized model but a suitable target dtype " |
|
|
f"could not be inferred. The supported target_dtypes are: {SUPPORTED_TORCH_DTYPES_FOR_QUANTIZATION}. If you think the " |
|
|
f"dtype you are using should be supported, please open an issue at https://github.com/huggingface/diffusers/issues." |
|
|
) |
|
|
|
|
|
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: |
|
|
max_memory = {key: val * 0.9 for key, val in max_memory.items()} |
|
|
return max_memory |
|
|
|
|
|
def check_if_quantized_param( |
|
|
self, |
|
|
model: "ModelMixin", |
|
|
param_value: "torch.Tensor", |
|
|
param_name: str, |
|
|
state_dict: Dict[str, Any], |
|
|
**kwargs, |
|
|
) -> bool: |
|
|
param_device = kwargs.pop("param_device", None) |
|
|
|
|
|
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert): |
|
|
return False |
|
|
elif param_device == "cpu" and self.offload: |
|
|
|
|
|
return False |
|
|
else: |
|
|
|
|
|
module, tensor_name = get_module_from_name(model, param_name) |
|
|
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight") |
|
|
|
|
|
def create_quantized_param( |
|
|
self, |
|
|
model: "ModelMixin", |
|
|
param_value: "torch.Tensor", |
|
|
param_name: str, |
|
|
target_device: "torch.device", |
|
|
state_dict: Dict[str, Any], |
|
|
unexpected_keys: List[str], |
|
|
**kwargs, |
|
|
): |
|
|
r""" |
|
|
Each nn.Linear layer that needs to be quantized is processed here. First, we set the value the weight tensor, |
|
|
then we move it to the target device. Finally, we quantize the module. |
|
|
""" |
|
|
module, tensor_name = get_module_from_name(model, param_name) |
|
|
|
|
|
if self.pre_quantized: |
|
|
|
|
|
|
|
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) |
|
|
if isinstance(module, nn.Linear): |
|
|
module.extra_repr = types.MethodType(_linear_extra_repr, module) |
|
|
else: |
|
|
|
|
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) |
|
|
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) |
|
|
|
|
|
def get_cuda_warm_up_factor(self): |
|
|
""" |
|
|
This factor is used in caching_allocator_warmup to determine how many bytes to pre-allocate for CUDA warmup. |
|
|
- A factor of 2 means we pre-allocate the full memory footprint of the model. |
|
|
- A factor of 4 means we pre-allocate half of that, and so on |
|
|
|
|
|
However, when using TorchAO, calculating memory usage with param.numel() * param.element_size() doesn't give |
|
|
the correct size for quantized weights (like int4 or int8) That's because TorchAO internally represents |
|
|
quantized tensors using subtensors and metadata, and the reported element_size() still corresponds to the |
|
|
torch_dtype not the actual bit-width of the quantized data. |
|
|
|
|
|
To correct for this: |
|
|
- Use a division factor of 8 for int4 weights |
|
|
- Use a division factor of 4 for int8 weights |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
map_to_target_dtype = {"int4_*": 8, "int8_*": 4, "uint*": 8, "float8*": 4} |
|
|
quant_type = self.quantization_config.quant_type |
|
|
for pattern, target_dtype in map_to_target_dtype.items(): |
|
|
if fnmatch(quant_type, pattern): |
|
|
return target_dtype |
|
|
raise ValueError(f"Unsupported quant_type: {quant_type!r}") |
|
|
|
|
|
def _process_model_before_weight_loading( |
|
|
self, |
|
|
model: "ModelMixin", |
|
|
device_map, |
|
|
keep_in_fp32_modules: List[str] = [], |
|
|
**kwargs, |
|
|
): |
|
|
self.modules_to_not_convert = self.quantization_config.modules_to_not_convert |
|
|
|
|
|
if not isinstance(self.modules_to_not_convert, list): |
|
|
self.modules_to_not_convert = [self.modules_to_not_convert] |
|
|
|
|
|
self.modules_to_not_convert.extend(keep_in_fp32_modules) |
|
|
|
|
|
|
|
|
if isinstance(device_map, dict) and len(device_map.keys()) > 1: |
|
|
keys_on_cpu = [key for key, value in device_map.items() if value in ["disk", "cpu"]] |
|
|
self.modules_to_not_convert.extend(keys_on_cpu) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None] |
|
|
|
|
|
model.config.quantization_config = self.quantization_config |
|
|
|
|
|
def _process_model_after_weight_loading(self, model: "ModelMixin"): |
|
|
return model |
|
|
|
|
|
def is_serializable(self, safe_serialization=None): |
|
|
|
|
|
if safe_serialization: |
|
|
logger.warning( |
|
|
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False." |
|
|
) |
|
|
return False |
|
|
|
|
|
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse( |
|
|
"0.25.0" |
|
|
) |
|
|
|
|
|
if not _is_torchao_serializable: |
|
|
logger.warning("torchao quantized model is only serializable after huggingface_hub >= 0.25.0 ") |
|
|
|
|
|
if self.offload and self.quantization_config.modules_to_not_convert is None: |
|
|
logger.warning( |
|
|
"The model contains offloaded modules and these modules are not quantized. We don't recommend saving the model as we won't be able to reload them." |
|
|
"If you want to specify modules to not quantize, please specify modules_to_not_convert in the quantization_config." |
|
|
) |
|
|
return False |
|
|
|
|
|
return _is_torchao_serializable |
|
|
|
|
|
@property |
|
|
def is_trainable(self): |
|
|
return self.quantization_config.quant_type.startswith("int8") |
|
|
|
|
|
@property |
|
|
def is_compileable(self) -> bool: |
|
|
return True |
|
|
|