Spaces:
Sleeping
Sleeping
| import os | |
| from typing import Callable, Dict, List, Optional, Union | |
| import torch | |
| from huggingface_hub.utils import validate_hf_hub_args | |
| from diffusers.utils import ( | |
| USE_PEFT_BACKEND, | |
| deprecate, | |
| get_submodule_by_name, | |
| is_bitsandbytes_available, | |
| is_gguf_available, | |
| is_peft_available, | |
| is_peft_version, | |
| is_torch_version, | |
| is_transformers_available, | |
| is_transformers_version, | |
| logging, | |
| ) | |
| from diffusers.loaders.lora_base import ( | |
| LoraBaseMixin, | |
| _fetch_state_dict, | |
| _pack_dict_with_prefix | |
| ) | |
| _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False | |
| if is_torch_version(">=", "1.9.0"): | |
| if ( | |
| is_peft_available() | |
| and is_peft_version(">=", "0.13.1") | |
| and is_transformers_available() | |
| and is_transformers_version(">", "4.45.2") | |
| ): | |
| _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True | |
| logger = logging.get_logger(__name__) | |
| TRANSFORMER_NAME = "transformer" | |
| class MeissonicLoraLoaderMixin(LoraBaseMixin): | |
| r""" | |
| Load LoRA layers into [`Transformer2DModel`]. Specific to [`MeissonicPipeline`]. | |
| """ | |
| _lora_loadable_modules = ["transformer"] | |
| transformer_name = TRANSFORMER_NAME | |
| def lora_state_dict( | |
| cls, | |
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
| return_alphas: bool = False, | |
| **kwargs, | |
| ): | |
| r""" | |
| Return state dict for lora weights and the network alphas. | |
| <Tip warning={true}> | |
| We support loading A1111 formatted LoRA checkpoints in a limited capacity. | |
| This function is experimental and might change in the future. | |
| </Tip> | |
| Parameters: | |
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
| Can be either: | |
| - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
| the Hub. | |
| - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved | |
| with [`ModelMixin.save_pretrained`]. | |
| - A [torch state | |
| dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). | |
| cache_dir (`Union[str, os.PathLike]`, *optional*): | |
| Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
| is not used. | |
| force_download (`bool`, *optional*, defaults to `False`): | |
| Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
| cached versions if they exist. | |
| proxies (`Dict[str, str]`, *optional*): | |
| A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
| 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
| local_files_only (`bool`, *optional*, defaults to `False`): | |
| Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
| won't be downloaded from the Hub. | |
| token (`str` or *bool*, *optional*): | |
| The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
| `diffusers-cli login` (stored in `~/.huggingface`) is used. | |
| revision (`str`, *optional*, defaults to `"main"`): | |
| The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
| allowed by Git. | |
| subfolder (`str`, *optional*, defaults to `""`): | |
| The subfolder location of a model file within a larger model repository on the Hub or locally. | |
| return_lora_metadata (`bool`, *optional*, defaults to False): | |
| When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. | |
| """ | |
| # Load the main state dict first which has the LoRA layers for either of | |
| # transformer and text encoder or both. | |
| cache_dir = kwargs.pop("cache_dir", None) | |
| force_download = kwargs.pop("force_download", False) | |
| proxies = kwargs.pop("proxies", None) | |
| local_files_only = kwargs.pop("local_files_only", None) | |
| token = kwargs.pop("token", None) | |
| revision = kwargs.pop("revision", None) | |
| subfolder = kwargs.pop("subfolder", None) | |
| weight_name = kwargs.pop("weight_name", None) | |
| use_safetensors = kwargs.pop("use_safetensors", None) | |
| return_lora_metadata = kwargs.pop("return_lora_metadata", False) | |
| allow_pickle = False | |
| if use_safetensors is None: | |
| use_safetensors = True | |
| allow_pickle = True | |
| user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} | |
| state_dict, metadata = _fetch_state_dict( | |
| pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, | |
| weight_name=weight_name, | |
| use_safetensors=use_safetensors, | |
| local_files_only=local_files_only, | |
| cache_dir=cache_dir, | |
| force_download=force_download, | |
| proxies=proxies, | |
| token=token, | |
| revision=revision, | |
| subfolder=subfolder, | |
| user_agent=user_agent, | |
| allow_pickle=allow_pickle, | |
| ) | |
| is_dora_scale_present = any("dora_scale" in k for k in state_dict) | |
| if is_dora_scale_present: | |
| warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." | |
| logger.warning(warn_msg) | |
| state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} | |
| out = (state_dict, metadata) if return_lora_metadata else state_dict | |
| return out | |
| def load_lora_weights( | |
| self, | |
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], | |
| adapter_name: Optional[str] = None, | |
| hotswap: bool = False, | |
| **kwargs, | |
| ): | |
| """ | |
| Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and | |
| `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See | |
| [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state | |
| dict is loaded into `self.transformer`. | |
| Parameters: | |
| pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
| adapter_name (`str`, *optional*): | |
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
| `default_{i}` where i is the total number of adapters being loaded. | |
| low_cpu_mem_usage (`bool`, *optional*): | |
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
| weights. | |
| hotswap (`bool`, *optional*): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. | |
| kwargs (`dict`, *optional*): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. | |
| """ | |
| if not USE_PEFT_BACKEND: | |
| raise ValueError("PEFT backend is required for this method.") | |
| low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) | |
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
| raise ValueError( | |
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
| ) | |
| # if a dict is passed, copy it instead of modifying it inplace | |
| if isinstance(pretrained_model_name_or_path_or_dict, dict): | |
| pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() | |
| # First, ensure that the checkpoint is a compatible one and can be successfully loaded. | |
| kwargs["return_lora_metadata"] = True | |
| state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) | |
| is_correct_format = all("lora" in key for key in state_dict.keys()) | |
| if not is_correct_format: | |
| raise ValueError("Invalid LoRA checkpoint.") | |
| self.load_lora_into_transformer( | |
| state_dict, | |
| transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, | |
| adapter_name=adapter_name, | |
| metadata=metadata, | |
| _pipeline=self, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| hotswap=hotswap, | |
| ) | |
| def load_lora_into_transformer( | |
| cls, | |
| state_dict, | |
| transformer, | |
| adapter_name=None, | |
| _pipeline=None, | |
| low_cpu_mem_usage=False, | |
| hotswap: bool = False, | |
| metadata=None, | |
| ): | |
| """ | |
| This will load the LoRA layers specified in `state_dict` into `transformer`. | |
| Parameters: | |
| state_dict (`dict`): | |
| A standard state dict containing the lora layer parameters. The keys can either be indexed directly | |
| into the unet or prefixed with an additional `unet` which can be used to distinguish between text | |
| encoder lora layers. | |
| transformer (`SD3Transformer2DModel`): | |
| The Transformer model to load the LoRA layers into. | |
| adapter_name (`str`, *optional*): | |
| Adapter name to be used for referencing the loaded adapter model. If not specified, it will use | |
| `default_{i}` where i is the total number of adapters being loaded. | |
| low_cpu_mem_usage (`bool`, *optional*): | |
| Speed up model loading by only loading the pretrained LoRA weights and not initializing the random | |
| weights. | |
| hotswap (`bool`, *optional*): | |
| See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. | |
| metadata (`dict`): | |
| Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived | |
| from the state dict. | |
| """ | |
| if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): | |
| raise ValueError( | |
| "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." | |
| ) | |
| # Load the layers corresponding to transformer. | |
| logger.info(f"Loading {cls.transformer_name}.") | |
| transformer.load_lora_adapter( | |
| state_dict, | |
| network_alphas=None, | |
| adapter_name=adapter_name, | |
| metadata=metadata, | |
| _pipeline=_pipeline, | |
| low_cpu_mem_usage=low_cpu_mem_usage, | |
| hotswap=hotswap, | |
| ) | |
| # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights | |
| def save_lora_weights( | |
| cls, | |
| save_directory: Union[str, os.PathLike], | |
| transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, | |
| is_main_process: bool = True, | |
| weight_name: str = None, | |
| save_function: Callable = None, | |
| safe_serialization: bool = True, | |
| transformer_lora_adapter_metadata: Optional[dict] = None, | |
| ): | |
| r""" | |
| Save the LoRA parameters corresponding to the transformer. | |
| Arguments: | |
| save_directory (`str` or `os.PathLike`): | |
| Directory to save LoRA parameters to. Will be created if it doesn't exist. | |
| transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): | |
| State dict of the LoRA layers corresponding to the `transformer`. | |
| is_main_process (`bool`, *optional*, defaults to `True`): | |
| Whether the process calling this is the main process or not. Useful during distributed training and you | |
| need to call this function on all processes. In this case, set `is_main_process=True` only on the main | |
| process to avoid race conditions. | |
| save_function (`Callable`): | |
| The function to use to save the state dictionary. Useful during distributed training when you need to | |
| replace `torch.save` with another method. Can be configured with the environment variable | |
| `DIFFUSERS_SAVE_MODE`. | |
| safe_serialization (`bool`, *optional*, defaults to `True`): | |
| Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. | |
| transformer_lora_adapter_metadata: | |
| LoRA adapter metadata associated with the transformer to be serialized with the state dict. | |
| """ | |
| state_dict = {} | |
| lora_adapter_metadata = {} | |
| if not transformer_lora_layers: | |
| raise ValueError("You must pass `transformer_lora_layers`.") | |
| state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) | |
| if transformer_lora_adapter_metadata is not None: | |
| lora_adapter_metadata.update( | |
| _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) | |
| ) | |
| # Save the model | |
| cls.write_lora_layers( | |
| state_dict=state_dict, | |
| save_directory=save_directory, | |
| is_main_process=is_main_process, | |
| weight_name=weight_name, | |
| save_function=save_function, | |
| safe_serialization=safe_serialization, | |
| lora_adapter_metadata=lora_adapter_metadata, | |
| ) | |