File size: 6,783 Bytes
44e6efe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from typing import TYPE_CHECKING, Any, Dict, List, Union
from ...utils import (
get_module_from_name,
is_accelerate_available,
is_nvidia_modelopt_available,
is_torch_available,
logging,
)
from ..base import DiffusersQuantizer
if TYPE_CHECKING:
from ...models.modeling_utils import ModelMixin
if is_torch_available():
import torch
import torch.nn as nn
if is_accelerate_available():
from accelerate.utils import set_module_tensor_to_device
logger = logging.get_logger(__name__)
class NVIDIAModelOptQuantizer(DiffusersQuantizer):
r"""
Diffusers Quantizer for TensorRT Model Optimizer
"""
use_keep_in_fp32_modules = True
requires_calibration = False
required_packages = ["nvidia_modelopt"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
if not is_nvidia_modelopt_available():
raise ImportError(
"Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)"
)
self.offload = False
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
if self.pre_quantized:
raise ValueError(
"You are attempting to perform cpu/disk offload with a pre-quantized modelopt model "
"This is not supported yet. Please remove the CPU or disk device from the `device_map` argument."
)
else:
self.offload = True
def check_if_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.quantization.utils import is_quantized
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
return True
elif is_quantized(module) and "weight" in tensor_name:
return True
return False
def create_quantized_param(
self,
model: "ModelMixin",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
*args,
**kwargs,
):
"""
Create the quantized parameter by calling .calibrate() after setting it to the module.
"""
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.quantization as mtq
dtype = kwargs.get("dtype", torch.float32)
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
else:
set_module_tensor_to_device(model, param_name, target_device, param_value, dtype)
mtq.calibrate(
module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop
)
mtq.compress(module)
module.weight.requires_grad = False
def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
max_memory = {key: val * 0.90 for key, val in max_memory.items()}
return max_memory
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
if self.quantization_config.quant_type == "FP8":
target_dtype = torch.float8_e4m3fn
return target_dtype
def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype":
if torch_dtype is None:
logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.")
torch_dtype = torch.float32
return torch_dtype
def get_conv_param_names(self, model: "ModelMixin") -> List[str]:
"""
Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and
ConvTranspose1d/2d/3d.
"""
conv_types = (
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
nn.ConvTranspose1d,
nn.ConvTranspose2d,
nn.ConvTranspose3d,
)
conv_param_names = []
for name, module in model.named_modules():
if isinstance(module, conv_types):
for param_name, _ in module.named_parameters(recurse=False):
conv_param_names.append(f"{name}.{param_name}")
return conv_param_names
def _process_model_before_weight_loading(
self,
model: "ModelMixin",
device_map,
keep_in_fp32_modules: List[str] = [],
**kwargs,
):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
import modelopt.torch.opt as mto
if self.pre_quantized:
return
modules_to_not_convert = self.quantization_config.modules_to_not_convert
if modules_to_not_convert is None:
modules_to_not_convert = []
if isinstance(modules_to_not_convert, str):
modules_to_not_convert = [modules_to_not_convert]
modules_to_not_convert.extend(keep_in_fp32_modules)
if self.quantization_config.disable_conv_quantization:
modules_to_not_convert.extend(self.get_conv_param_names(model))
for module in modules_to_not_convert:
self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False}
self.quantization_config.modules_to_not_convert = modules_to_not_convert
mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)])
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model, **kwargs):
# ModelOpt imports diffusers internally. This is here to prevent circular imports
from modelopt.torch.opt import ModeloptStateManager
if self.pre_quantized:
return model
for _, m in model.named_modules():
if hasattr(m, ModeloptStateManager._state_key) and m is not model:
ModeloptStateManager.remove_state(m)
return model
@property
def is_trainable(self):
return True
@property
def is_serializable(self):
self.quantization_config.check_model_patching(operation="saving")
return True
|