|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING, Optional |
|
|
|
|
|
from ..utils.logging import tqdm |
|
|
from .base import HfQuantizer |
|
|
from .quantizers_utils import get_module_from_name |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from ..modeling_utils import PreTrainedModel |
|
|
|
|
|
from ..utils import is_accelerate_available, is_flute_available, is_hadamard_available, is_torch_available, logging |
|
|
from ..utils.quantization_config import QuantizationConfigMixin |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class HiggsHfQuantizer(HfQuantizer): |
|
|
""" |
|
|
Quantizer of the HIGGS method. Enables the loading of prequantized models and in-flight quantization of full-precision models. |
|
|
""" |
|
|
|
|
|
requires_calibration = False |
|
|
requires_parameters_quantization = True |
|
|
required_packages = ["flute-kernel", "fast_hadamard_transform"] |
|
|
|
|
|
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): |
|
|
super().__init__(quantization_config, **kwargs) |
|
|
self.quantization_config = quantization_config |
|
|
|
|
|
def validate_environment(self, device_map, **kwargs): |
|
|
if not torch.cuda.is_available(): |
|
|
raise NotImplementedError("HIGGS quantization is only supported on GPU. Please use a different quantizer.") |
|
|
|
|
|
if not is_accelerate_available(): |
|
|
raise ImportError("Using `higgs` quantization requires Accelerate: `pip install accelerate`") |
|
|
|
|
|
if not is_flute_available(): |
|
|
raise ImportError("Using `higgs` quantization requires FLUTE: `pip install flute-kernel>=0.3.0`") |
|
|
|
|
|
if not is_hadamard_available(): |
|
|
raise ImportError( |
|
|
"Using `higgs` quantization requires fast_hadamard_transform: `pip install fast_hadamard_transform`" |
|
|
) |
|
|
|
|
|
if device_map is None: |
|
|
raise ValueError( |
|
|
"You are attempting to load a HIGGS model without setting device_map." |
|
|
" Please set device_map comprised of 'cuda' devices." |
|
|
) |
|
|
elif isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()): |
|
|
raise ValueError( |
|
|
"You are attempting to load a HIGGS model with a device_map that contains a CPU or disk device." |
|
|
" This is not supported. Please remove the CPU or disk device from the device_map." |
|
|
) |
|
|
|
|
|
def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype": |
|
|
if dtype is None: |
|
|
logger.info("`dtype` is None. Setting `dtype=torch.float16` for FLUTE compatibility.") |
|
|
dtype = torch.float16 |
|
|
elif dtype != torch.float16 and dtype != torch.bfloat16: |
|
|
raise ValueError( |
|
|
f"Invalid `dtype` {dtype}. HIGGS quantization only supports `dtype=torch.float16` or `dtype=torch.bfloat16`." |
|
|
) |
|
|
|
|
|
return dtype |
|
|
|
|
|
def create_quantized_param( |
|
|
self, |
|
|
model: "PreTrainedModel", |
|
|
param_value: "torch.Tensor", |
|
|
param_name: str, |
|
|
target_device: "torch.device", |
|
|
**kwargs, |
|
|
): |
|
|
from ..integrations import quantize_with_higgs |
|
|
|
|
|
flute_dict = quantize_with_higgs( |
|
|
param_value.to(target_device), |
|
|
self.quantization_config.bits, |
|
|
self.quantization_config.p, |
|
|
self.quantization_config.group_size, |
|
|
self.quantization_config.hadamard_size, |
|
|
) |
|
|
del param_value |
|
|
|
|
|
module, _ = get_module_from_name(model, param_name) |
|
|
module_name = ".".join(param_name.split(".")[:-1]) |
|
|
for key, value in flute_dict.items(): |
|
|
if key in module._parameters: |
|
|
module._parameters[key] = torch.nn.Parameter(value, requires_grad=False) |
|
|
elif key in module._buffers: |
|
|
module._buffers[key] = torch.nn.Buffer(value) |
|
|
elif key == "tune_metadata": |
|
|
module.tune_metadata = value |
|
|
self.quantization_config.tune_metadata[module_name] = value.to_dict() |
|
|
else: |
|
|
raise ValueError(f"Unexpected key {key} in module {module}") |
|
|
|
|
|
def _process_model_before_weight_loading( |
|
|
self, |
|
|
model: "PreTrainedModel", |
|
|
keep_in_fp32_modules: Optional[list[str]] = None, |
|
|
**kwargs, |
|
|
): |
|
|
from ..integrations import replace_with_higgs_linear |
|
|
|
|
|
self.modules_to_not_convert = self.get_modules_to_not_convert( |
|
|
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules |
|
|
) |
|
|
|
|
|
replace_with_higgs_linear( |
|
|
model, |
|
|
quantization_config=self.quantization_config, |
|
|
modules_to_not_convert=self.modules_to_not_convert, |
|
|
) |
|
|
model.config.quantization_config = self.quantization_config |
|
|
|
|
|
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs): |
|
|
from flute.tune import TuneMetaData, maybe_tune_and_repack |
|
|
from flute.utils import make_workspace_streamk |
|
|
|
|
|
from ..integrations import HiggsLinear |
|
|
|
|
|
flute_workspaces = {} |
|
|
flute_modules = {name: module for name, module in model.named_modules() if isinstance(module, HiggsLinear)} |
|
|
for name, module in tqdm(flute_modules.items(), desc="Repacking HIGGS modules", leave=False): |
|
|
|
|
|
|
|
|
if module.weight.device not in flute_workspaces: |
|
|
flute_workspaces[module.weight.device] = make_workspace_streamk(device=module.weight.device) |
|
|
module.workspace = flute_workspaces[module.weight.device] |
|
|
|
|
|
|
|
|
|
|
|
module.tune_metadata = TuneMetaData.from_dict(self.quantization_config.tune_metadata[name]) |
|
|
module.weight.data, module.tune_metadata = maybe_tune_and_repack( |
|
|
weight=module.weight.data, |
|
|
scales=module.scales.data, |
|
|
metadata=module.tune_metadata, |
|
|
) |
|
|
self.quantization_config.tune_metadata[name] = module.tune_metadata.to_dict() |
|
|
|
|
|
def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]: |
|
|
from ..integrations import HiggsLinear |
|
|
|
|
|
higgs_names = {name for name, module in model.named_modules() if isinstance(module, HiggsLinear)} |
|
|
|
|
|
def should_update(key: str) -> bool: |
|
|
if key.endswith(".weight") or key.endswith(".bias"): |
|
|
return False |
|
|
full_key = f"{prefix}.{key}" |
|
|
return any(name in key or name in full_key for name in higgs_names) |
|
|
|
|
|
return [key for key in missing_keys if not should_update(key)] |
|
|
|
|
|
@property |
|
|
def is_trainable(self) -> bool: |
|
|
return False |
|
|
|
|
|
def is_serializable(self, safe_serialization=None): |
|
|
return True |
|
|
|
|
|
def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool: |
|
|
from ..integrations import HiggsLinear |
|
|
|
|
|
module, tensor_name = get_module_from_name(model, param_name) |
|
|
if isinstance(module, HiggsLinear) and tensor_name == "weight": |
|
|
|
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
def _dequantize(self, model): |
|
|
from ..integrations import dequantize_higgs |
|
|
|
|
|
model = dequantize_higgs(model) |
|
|
return model |
|
|
|