|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import os |
|
|
from pathlib import Path |
|
|
from typing import Any, List, Optional, Union |
|
|
|
|
|
from huggingface_hub import ModelHubMixin, snapshot_download |
|
|
|
|
|
from optimum.quanto import freeze, qtype, quantization_map, quantize, requantize, Optimizer |
|
|
from optimum.quanto.models import is_diffusers_available |
|
|
|
|
|
from diffusers.models.model_loading_utils import load_state_dict |
|
|
from diffusers.models.modeling_utils import ModelMixin |
|
|
from diffusers.utils import ( |
|
|
CONFIG_NAME, |
|
|
SAFE_WEIGHTS_INDEX_NAME, |
|
|
SAFETENSORS_WEIGHTS_NAME, |
|
|
_get_checkpoint_shard_files, |
|
|
is_accelerate_available, |
|
|
) |
|
|
from optimum.quanto.models.shared_dict import ShardedStateDict |
|
|
|
|
|
|
|
|
class QuantizedDiffusersModel(ModelHubMixin): |
|
|
"""Base class for quantized diffusers models.""" |
|
|
BASE_NAME = "quanto" |
|
|
base_class = None |
|
|
|
|
|
def __init__(self, model: ModelMixin): |
|
|
if not isinstance(model, ModelMixin) or len(quantization_map(model)) == 0: |
|
|
raise ValueError("The source model must be a quantized diffusers model.") |
|
|
self._wrapped = model |
|
|
|
|
|
def __getattr__(self, name: str) -> Any: |
|
|
"""If an attribute is not found in this class, look in the wrapped module.""" |
|
|
try: |
|
|
return super().__getattr__(name) |
|
|
except AttributeError: |
|
|
wrapped = self.__dict__["_wrapped"] |
|
|
return getattr(wrapped, name) |
|
|
|
|
|
def forward(self, *args, **kwargs): |
|
|
return self._wrapped.forward(*args, **kwargs) |
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
return self._wrapped.forward(*args, **kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def _qmap_name(): |
|
|
return f"{QuantizedDiffusersModel.BASE_NAME}_qmap.json" |
|
|
|
|
|
@classmethod |
|
|
def quantize( |
|
|
cls, |
|
|
model: ModelMixin, |
|
|
weights: Optional[Union[str, qtype]] = None, |
|
|
activations: Optional[Union[str, qtype]] = None, |
|
|
optimizer: Optional[Optimizer] = None, |
|
|
include: Optional[Union[str, List[str]]] = None, |
|
|
exclude: Optional[Union[str, List[str]]] = None, |
|
|
): |
|
|
"""Quantize the specified model.""" |
|
|
if not isinstance(model, ModelMixin): |
|
|
raise ValueError("The source model must be a diffusers model.") |
|
|
|
|
|
quantize( |
|
|
model, weights=weights, activations=activations, optimizer=optimizer, include=include, exclude=exclude |
|
|
) |
|
|
freeze(model) |
|
|
return cls(model) |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs): |
|
|
if cls.base_class is None: |
|
|
raise ValueError("The `base_class` attribute needs to be configured.") |
|
|
|
|
|
if not is_accelerate_available(): |
|
|
raise ValueError("Reloading a quantized diffusers model requires the accelerate library.") |
|
|
from accelerate import init_empty_weights |
|
|
|
|
|
if os.path.isdir(pretrained_model_name_or_path): |
|
|
working_dir = pretrained_model_name_or_path |
|
|
else: |
|
|
working_dir = snapshot_download(pretrained_model_name_or_path, **kwargs) |
|
|
|
|
|
|
|
|
qmap_path = os.path.join(working_dir, cls._qmap_name()) |
|
|
if not os.path.exists(qmap_path): |
|
|
raise ValueError( |
|
|
f"No quantization map found in {pretrained_model_name_or_path}: is this a quantized model ?" |
|
|
) |
|
|
|
|
|
|
|
|
model_config_path = os.path.join(working_dir, CONFIG_NAME) |
|
|
if not os.path.exists(model_config_path): |
|
|
raise ValueError(f"{CONFIG_NAME} not found in {pretrained_model_name_or_path}.") |
|
|
|
|
|
with open(qmap_path, "r", encoding="utf-8") as f: |
|
|
qmap = json.load(f) |
|
|
|
|
|
with open(model_config_path, "r", encoding="utf-8") as f: |
|
|
original_model_cls_name = json.load(f)["_class_name"] |
|
|
configured_cls_name = cls.base_class.__name__ |
|
|
if configured_cls_name != original_model_cls_name: |
|
|
raise ValueError( |
|
|
f"Configured base class ({configured_cls_name}) differs from what was derived from the provided configuration ({original_model_cls_name})." |
|
|
) |
|
|
|
|
|
|
|
|
config = cls.base_class.load_config(pretrained_model_name_or_path, **kwargs) |
|
|
with init_empty_weights(): |
|
|
model = cls.base_class.from_config(config) |
|
|
|
|
|
|
|
|
checkpoint_file = os.path.join(working_dir, SAFE_WEIGHTS_INDEX_NAME) |
|
|
if os.path.exists(checkpoint_file): |
|
|
|
|
|
_, sharded_metadata = _get_checkpoint_shard_files(working_dir, checkpoint_file) |
|
|
|
|
|
state_dict = ShardedStateDict(working_dir, sharded_metadata["weight_map"]) |
|
|
else: |
|
|
|
|
|
checkpoint_file = os.path.join(working_dir, SAFETENSORS_WEIGHTS_NAME) |
|
|
if not os.path.exists(checkpoint_file): |
|
|
raise ValueError(f"No safetensor weights found in {pretrained_model_name_or_path}.") |
|
|
|
|
|
state_dict = load_state_dict(checkpoint_file) |
|
|
|
|
|
|
|
|
requantize(model, state_dict=state_dict, quantization_map=qmap) |
|
|
model.eval() |
|
|
return cls(model) |
|
|
|
|
|
def _save_pretrained(self, save_directory: Path) -> None: |
|
|
self._wrapped.save_pretrained(save_directory) |
|
|
|
|
|
qmap_name = os.path.join(save_directory, self._qmap_name()) |
|
|
qmap = quantization_map(self._wrapped) |
|
|
with open(qmap_name, "w", encoding="utf8") as f: |
|
|
json.dump(qmap, f, indent=4) |
|
|
|
|
|
|
|
|
|
|
|
from diffusers.models.transformers.transformer_flux2 import Flux2Transformer2DModel |
|
|
|
|
|
|
|
|
class QuantizedFlux2Transformer2DModel(QuantizedDiffusersModel): |
|
|
"""Quantized FLUX.2 Transformer model.""" |
|
|
base_class = Flux2Transformer2DModel |
|
|
|