| import os |
| from typing import Dict, Optional, Union |
|
|
| import safetensors |
| import torch |
| from diffusers.utils import _get_model_file, logging |
| from safetensors import safe_open |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CustomAdapterMixin: |
| def init_custom_adapter(self, *args, **kwargs): |
| self._init_custom_adapter(*args, **kwargs) |
|
|
| def _init_custom_adapter(self, *args, **kwargs): |
| raise NotImplementedError |
|
|
| def load_custom_adapter( |
| self, |
| pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], |
| weight_name: str, |
| subfolder: Optional[str] = None, |
| **kwargs, |
| ): |
| |
| cache_dir = kwargs.pop("cache_dir", None) |
| force_download = kwargs.pop("force_download", False) |
| resume_download = kwargs.pop("resume_download", False) |
| proxies = kwargs.pop("proxies", None) |
| local_files_only = kwargs.pop("local_files_only", None) |
| token = kwargs.pop("token", None) |
| revision = kwargs.pop("revision", None) |
|
|
| user_agent = { |
| "file_type": "attn_procs_weights", |
| "framework": "pytorch", |
| } |
|
|
| if not isinstance(pretrained_model_name_or_path_or_dict, dict): |
| model_file = _get_model_file( |
| pretrained_model_name_or_path_or_dict, |
| weights_name=weight_name, |
| cache_dir=cache_dir, |
| force_download=force_download, |
| proxies=proxies, |
| local_files_only=local_files_only, |
| token=token, |
| revision=revision, |
| subfolder=subfolder, |
| user_agent=user_agent, |
| ) |
| if weight_name.endswith(".safetensors"): |
| state_dict = {} |
| with safe_open(model_file, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| state_dict[key] = f.get_tensor(key) |
| else: |
| state_dict = torch.load(model_file, map_location="cpu") |
| else: |
| state_dict = pretrained_model_name_or_path_or_dict |
|
|
| self._load_custom_adapter(state_dict) |
|
|
| def _load_custom_adapter(self, state_dict): |
| raise NotImplementedError |
|
|
| def save_custom_adapter( |
| self, |
| save_directory: Union[str, os.PathLike], |
| weight_name: str, |
| safe_serialization: bool = False, |
| **kwargs, |
| ): |
| if os.path.isfile(save_directory): |
| logger.error( |
| f"Provided path ({save_directory}) should be a directory, not a file" |
| ) |
| return |
|
|
| if safe_serialization: |
|
|
| def save_function(weights, filename): |
| return safetensors.torch.save_file( |
| weights, filename, metadata={"format": "pt"} |
| ) |
|
|
| else: |
| save_function = torch.save |
|
|
| |
| state_dict = self._save_custom_adapter(**kwargs) |
| save_function(state_dict, os.path.join(save_directory, weight_name)) |
| logger.info( |
| f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" |
| ) |
|
|
| def _save_custom_adapter(self): |
| raise NotImplementedError |
|
|