import os from typing import Callable, Dict, List, Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from diffusers.utils import ( USE_PEFT_BACKEND, deprecate, get_submodule_by_name, is_bitsandbytes_available, is_gguf_available, is_peft_available, is_peft_version, is_torch_version, is_transformers_available, is_transformers_version, logging, ) from diffusers.loaders.lora_base import ( LoraBaseMixin, _fetch_state_dict, _pack_dict_with_prefix ) _LOW_CPU_MEM_USAGE_DEFAULT_LORA = False if is_torch_version(">=", "1.9.0"): if ( is_peft_available() and is_peft_version(">=", "0.13.1") and is_transformers_available() and is_transformers_version(">", "4.45.2") ): _LOW_CPU_MEM_USAGE_DEFAULT_LORA = True logger = logging.get_logger(__name__) TRANSFORMER_NAME = "transformer" class MeissonicLoraLoaderMixin(LoraBaseMixin): r""" Load LoRA layers into [`Transformer2DModel`]. Specific to [`MeissonicPipeline`]. """ _lora_loadable_modules = ["transformer"] transformer_name = TRANSFORMER_NAME @classmethod @validate_hf_hub_args def lora_state_dict( cls, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], return_alphas: bool = False, **kwargs, ): r""" Return state dict for lora weights and the network alphas. We support loading A1111 formatted LoRA checkpoints in a limited capacity. This function is experimental and might change in the future. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved with [`ModelMixin.save_pretrained`]. - A [torch state dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict). cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. local_files_only (`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. subfolder (`str`, *optional*, defaults to `""`): The subfolder location of a model file within a larger model repository on the Hub or locally. return_lora_metadata (`bool`, *optional*, defaults to False): When enabled, additionally return the LoRA adapter metadata, typically found in the state dict. """ # Load the main state dict first which has the LoRA layers for either of # transformer and text encoder or both. cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) return_lora_metadata = kwargs.pop("return_lora_metadata", False) allow_pickle = False if use_safetensors is None: use_safetensors = True allow_pickle = True user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"} state_dict, metadata = _fetch_state_dict( pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict, weight_name=weight_name, use_safetensors=use_safetensors, local_files_only=local_files_only, cache_dir=cache_dir, force_download=force_download, proxies=proxies, token=token, revision=revision, subfolder=subfolder, user_agent=user_agent, allow_pickle=allow_pickle, ) is_dora_scale_present = any("dora_scale" in k for k in state_dict) if is_dora_scale_present: warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new." logger.warning(warn_msg) state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k} out = (state_dict, metadata) if return_lora_metadata else state_dict return out def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], adapter_name: Optional[str] = None, hotswap: bool = False, **kwargs, ): """ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded. See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state dict is loaded into `self.transformer`. Parameters: pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. kwargs (`dict`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`]. """ if not USE_PEFT_BACKEND: raise ValueError("PEFT backend is required for this method.") low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA) if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # if a dict is passed, copy it instead of modifying it inplace if isinstance(pretrained_model_name_or_path_or_dict, dict): pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy() # First, ensure that the checkpoint is a compatible one and can be successfully loaded. kwargs["return_lora_metadata"] = True state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs) is_correct_format = all("lora" in key for key in state_dict.keys()) if not is_correct_format: raise ValueError("Invalid LoRA checkpoint.") self.load_lora_into_transformer( state_dict, transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer, adapter_name=adapter_name, metadata=metadata, _pipeline=self, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @classmethod def load_lora_into_transformer( cls, state_dict, transformer, adapter_name=None, _pipeline=None, low_cpu_mem_usage=False, hotswap: bool = False, metadata=None, ): """ This will load the LoRA layers specified in `state_dict` into `transformer`. Parameters: state_dict (`dict`): A standard state dict containing the lora layer parameters. The keys can either be indexed directly into the unet or prefixed with an additional `unet` which can be used to distinguish between text encoder lora layers. transformer (`SD3Transformer2DModel`): The Transformer model to load the LoRA layers into. adapter_name (`str`, *optional*): Adapter name to be used for referencing the loaded adapter model. If not specified, it will use `default_{i}` where i is the total number of adapters being loaded. low_cpu_mem_usage (`bool`, *optional*): Speed up model loading by only loading the pretrained LoRA weights and not initializing the random weights. hotswap (`bool`, *optional*): See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`]. metadata (`dict`): Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived from the state dict. """ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"): raise ValueError( "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`." ) # Load the layers corresponding to transformer. logger.info(f"Loading {cls.transformer_name}.") transformer.load_lora_adapter( state_dict, network_alphas=None, adapter_name=adapter_name, metadata=metadata, _pipeline=_pipeline, low_cpu_mem_usage=low_cpu_mem_usage, hotswap=hotswap, ) @classmethod # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights def save_lora_weights( cls, save_directory: Union[str, os.PathLike], transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, is_main_process: bool = True, weight_name: str = None, save_function: Callable = None, safe_serialization: bool = True, transformer_lora_adapter_metadata: Optional[dict] = None, ): r""" Save the LoRA parameters corresponding to the transformer. Arguments: save_directory (`str` or `os.PathLike`): Directory to save LoRA parameters to. Will be created if it doesn't exist. transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `transformer`. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful during distributed training and you need to call this function on all processes. In this case, set `is_main_process=True` only on the main process to avoid race conditions. save_function (`Callable`): The function to use to save the state dictionary. Useful during distributed training when you need to replace `torch.save` with another method. Can be configured with the environment variable `DIFFUSERS_SAVE_MODE`. safe_serialization (`bool`, *optional*, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`. transformer_lora_adapter_metadata: LoRA adapter metadata associated with the transformer to be serialized with the state dict. """ state_dict = {} lora_adapter_metadata = {} if not transformer_lora_layers: raise ValueError("You must pass `transformer_lora_layers`.") state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name)) if transformer_lora_adapter_metadata is not None: lora_adapter_metadata.update( _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name) ) # Save the model cls.write_lora_layers( state_dict=state_dict, save_directory=save_directory, is_main_process=is_main_process, weight_name=weight_name, save_function=save_function, safe_serialization=safe_serialization, lora_adapter_metadata=lora_adapter_metadata, )