Spaces:
Configuration error
Configuration error
| # coding=utf-8 | |
| # Copyright 2024 The HuggingFace Inc. team. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import importlib | |
| import os | |
| import re | |
| import warnings | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Optional, Union | |
| import torch | |
| from huggingface_hub import model_info | |
| from huggingface_hub.utils import validate_hf_hub_args | |
| from packaging import version | |
| from .. import __version__ | |
| from ..utils import ( | |
| FLAX_WEIGHTS_NAME, | |
| ONNX_EXTERNAL_WEIGHTS_NAME, | |
| ONNX_WEIGHTS_NAME, | |
| SAFETENSORS_WEIGHTS_NAME, | |
| WEIGHTS_NAME, | |
| get_class_from_dynamic_module, | |
| is_accelerate_available, | |
| is_peft_available, | |
| is_transformers_available, | |
| logging, | |
| ) | |
| from ..utils.torch_utils import is_compiled_module | |
| if is_transformers_available(): | |
| import transformers | |
| from transformers import PreTrainedModel | |
| from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME | |
| from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME | |
| from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME | |
| if is_accelerate_available(): | |
| import accelerate | |
| from accelerate import dispatch_model | |
| from accelerate.hooks import remove_hook_from_module | |
| from accelerate.utils import compute_module_sizes, get_max_memory | |
| INDEX_FILE = "diffusion_pytorch_model.bin" | |
| CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" | |
| DUMMY_MODULES_FOLDER = "diffusers.utils" | |
| TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" | |
| CONNECTED_PIPES_KEYS = ["prior"] | |
| logger = logging.get_logger(__name__) | |
| LOADABLE_CLASSES = { | |
| "diffusers": { | |
| "ModelMixin": ["save_pretrained", "from_pretrained"], | |
| "SchedulerMixin": ["save_pretrained", "from_pretrained"], | |
| "DiffusionPipeline": ["save_pretrained", "from_pretrained"], | |
| "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], | |
| }, | |
| "transformers": { | |
| "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], | |
| "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], | |
| "PreTrainedModel": ["save_pretrained", "from_pretrained"], | |
| "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], | |
| "ProcessorMixin": ["save_pretrained", "from_pretrained"], | |
| "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], | |
| }, | |
| "onnxruntime.training": { | |
| "ORTModule": ["save_pretrained", "from_pretrained"], | |
| }, | |
| } | |
| ALL_IMPORTABLE_CLASSES = {} | |
| for library in LOADABLE_CLASSES: | |
| ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) | |
| def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: | |
| """ | |
| Checking for safetensors compatibility: | |
| - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch | |
| files to know which safetensors files are needed. | |
| - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. | |
| Converting default pytorch serialized filenames to safetensors serialized filenames: | |
| - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" | |
| - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" | |
| extension is replaced with ".safetensors" | |
| """ | |
| pt_filenames = [] | |
| sf_filenames = set() | |
| passed_components = passed_components or [] | |
| for filename in filenames: | |
| _, extension = os.path.splitext(filename) | |
| if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: | |
| continue | |
| if extension == ".bin": | |
| pt_filenames.append(os.path.normpath(filename)) | |
| elif extension == ".safetensors": | |
| sf_filenames.add(os.path.normpath(filename)) | |
| for filename in pt_filenames: | |
| # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' | |
| path, filename = os.path.split(filename) | |
| filename, extension = os.path.splitext(filename) | |
| if filename.startswith("pytorch_model"): | |
| filename = filename.replace("pytorch_model", "model") | |
| else: | |
| filename = filename | |
| expected_sf_filename = os.path.normpath(os.path.join(path, filename)) | |
| expected_sf_filename = f"{expected_sf_filename}.safetensors" | |
| if expected_sf_filename not in sf_filenames: | |
| logger.warning(f"{expected_sf_filename} not found") | |
| return False | |
| return True | |
| def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: | |
| weight_names = [ | |
| WEIGHTS_NAME, | |
| SAFETENSORS_WEIGHTS_NAME, | |
| FLAX_WEIGHTS_NAME, | |
| ONNX_WEIGHTS_NAME, | |
| ONNX_EXTERNAL_WEIGHTS_NAME, | |
| ] | |
| if is_transformers_available(): | |
| weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] | |
| # model_pytorch, diffusion_model_pytorch, ... | |
| weight_prefixes = [w.split(".")[0] for w in weight_names] | |
| # .bin, .safetensors, ... | |
| weight_suffixs = [w.split(".")[-1] for w in weight_names] | |
| # -00001-of-00002 | |
| transformers_index_format = r"\d{5}-of-\d{5}" | |
| if variant is not None: | |
| # `diffusion_pytorch_model.fp16.bin` as well as `model.fp16-00001-of-00002.safetensors` | |
| variant_file_re = re.compile( | |
| rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" | |
| ) | |
| # `text_encoder/pytorch_model.bin.index.fp16.json` | |
| variant_index_re = re.compile( | |
| rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" | |
| ) | |
| # `diffusion_pytorch_model.bin` as well as `model-00001-of-00002.safetensors` | |
| non_variant_file_re = re.compile( | |
| rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" | |
| ) | |
| # `text_encoder/pytorch_model.bin.index.json` | |
| non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") | |
| if variant is not None: | |
| variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} | |
| variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} | |
| variant_filenames = variant_weights | variant_indexes | |
| else: | |
| variant_filenames = set() | |
| non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} | |
| non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} | |
| non_variant_filenames = non_variant_weights | non_variant_indexes | |
| # all variant filenames will be used by default | |
| usable_filenames = set(variant_filenames) | |
| def convert_to_variant(filename): | |
| if "index" in filename: | |
| variant_filename = filename.replace("index", f"index.{variant}") | |
| elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: | |
| variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" | |
| else: | |
| variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" | |
| return variant_filename | |
| for f in non_variant_filenames: | |
| variant_filename = convert_to_variant(f) | |
| if variant_filename not in usable_filenames: | |
| usable_filenames.add(f) | |
| return usable_filenames, variant_filenames | |
| def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): | |
| info = model_info( | |
| pretrained_model_name_or_path, | |
| token=token, | |
| revision=None, | |
| ) | |
| filenames = {sibling.rfilename for sibling in info.siblings} | |
| comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) | |
| comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] | |
| if set(model_filenames).issubset(set(comp_model_filenames)): | |
| warnings.warn( | |
| f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", | |
| FutureWarning, | |
| ) | |
| else: | |
| warnings.warn( | |
| f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", | |
| FutureWarning, | |
| ) | |
| def _unwrap_model(model): | |
| """Unwraps a model.""" | |
| if is_compiled_module(model): | |
| model = model._orig_mod | |
| if is_peft_available(): | |
| from peft import PeftModel | |
| if isinstance(model, PeftModel): | |
| model = model.base_model.model | |
| return model | |
| def maybe_raise_or_warn( | |
| library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module | |
| ): | |
| """Simple helper method to raise or warn in case incorrect module has been passed""" | |
| if not is_pipeline_module: | |
| library = importlib.import_module(library_name) | |
| class_obj = getattr(library, class_name) | |
| class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} | |
| expected_class_obj = None | |
| for class_name, class_candidate in class_candidates.items(): | |
| if class_candidate is not None and issubclass(class_obj, class_candidate): | |
| expected_class_obj = class_candidate | |
| # Dynamo wraps the original model in a private class. | |
| # I didn't find a public API to get the original class. | |
| sub_model = passed_class_obj[name] | |
| unwrapped_sub_model = _unwrap_model(sub_model) | |
| model_cls = unwrapped_sub_model.__class__ | |
| if not issubclass(model_cls, expected_class_obj): | |
| raise ValueError( | |
| f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" | |
| ) | |
| else: | |
| logger.warning( | |
| f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" | |
| " has the correct type" | |
| ) | |
| def get_class_obj_and_candidates( | |
| library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None | |
| ): | |
| """Simple helper method to retrieve class object of module as well as potential parent class objects""" | |
| component_folder = os.path.join(cache_dir, component_name) | |
| if is_pipeline_module: | |
| pipeline_module = getattr(pipelines, library_name) | |
| class_obj = getattr(pipeline_module, class_name) | |
| class_candidates = {c: class_obj for c in importable_classes.keys()} | |
| elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): | |
| # load custom component | |
| class_obj = get_class_from_dynamic_module( | |
| component_folder, module_file=library_name + ".py", class_name=class_name | |
| ) | |
| class_candidates = {c: class_obj for c in importable_classes.keys()} | |
| else: | |
| # else we just import it from the library. | |
| library = importlib.import_module(library_name) | |
| class_obj = getattr(library, class_name) | |
| class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} | |
| return class_obj, class_candidates | |
| def _get_custom_pipeline_class( | |
| custom_pipeline, | |
| repo_id=None, | |
| hub_revision=None, | |
| class_name=None, | |
| cache_dir=None, | |
| revision=None, | |
| ): | |
| if custom_pipeline.endswith(".py"): | |
| path = Path(custom_pipeline) | |
| # decompose into folder & file | |
| file_name = path.name | |
| custom_pipeline = path.parent.absolute() | |
| elif repo_id is not None: | |
| file_name = f"{custom_pipeline}.py" | |
| custom_pipeline = repo_id | |
| else: | |
| file_name = CUSTOM_PIPELINE_FILE_NAME | |
| if repo_id is not None and hub_revision is not None: | |
| # if we load the pipeline code from the Hub | |
| # make sure to overwrite the `revision` | |
| revision = hub_revision | |
| return get_class_from_dynamic_module( | |
| custom_pipeline, | |
| module_file=file_name, | |
| class_name=class_name, | |
| cache_dir=cache_dir, | |
| revision=revision, | |
| ) | |
| def _get_pipeline_class( | |
| class_obj, | |
| config=None, | |
| load_connected_pipeline=False, | |
| custom_pipeline=None, | |
| repo_id=None, | |
| hub_revision=None, | |
| class_name=None, | |
| cache_dir=None, | |
| revision=None, | |
| ): | |
| if custom_pipeline is not None: | |
| return _get_custom_pipeline_class( | |
| custom_pipeline, | |
| repo_id=repo_id, | |
| hub_revision=hub_revision, | |
| class_name=class_name, | |
| cache_dir=cache_dir, | |
| revision=revision, | |
| ) | |
| if class_obj.__name__ != "DiffusionPipeline": | |
| return class_obj | |
| diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) | |
| class_name = class_name or config["_class_name"] | |
| if not class_name: | |
| raise ValueError( | |
| "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." | |
| ) | |
| class_name = class_name[4:] if class_name.startswith("Flax") else class_name | |
| pipeline_cls = getattr(diffusers_module, class_name) | |
| if load_connected_pipeline: | |
| from .auto_pipeline import _get_connected_pipeline | |
| connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) | |
| if connected_pipeline_cls is not None: | |
| logger.info( | |
| f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" | |
| ) | |
| else: | |
| logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") | |
| pipeline_cls = connected_pipeline_cls or pipeline_cls | |
| return pipeline_cls | |
| def _load_empty_model( | |
| library_name: str, | |
| class_name: str, | |
| importable_classes: List[Any], | |
| pipelines: Any, | |
| is_pipeline_module: bool, | |
| name: str, | |
| torch_dtype: Union[str, torch.dtype], | |
| cached_folder: Union[str, os.PathLike], | |
| **kwargs, | |
| ): | |
| # retrieve class objects. | |
| class_obj, _ = get_class_obj_and_candidates( | |
| library_name, | |
| class_name, | |
| importable_classes, | |
| pipelines, | |
| is_pipeline_module, | |
| component_name=name, | |
| cache_dir=cached_folder, | |
| ) | |
| if is_transformers_available(): | |
| transformers_version = version.parse(version.parse(transformers.__version__).base_version) | |
| else: | |
| transformers_version = "N/A" | |
| # Determine library. | |
| is_transformers_model = ( | |
| is_transformers_available() | |
| and issubclass(class_obj, PreTrainedModel) | |
| and transformers_version >= version.parse("4.20.0") | |
| ) | |
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
| is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) | |
| model = None | |
| config_path = cached_folder | |
| user_agent = { | |
| "diffusers": __version__, | |
| "file_type": "model", | |
| "framework": "pytorch", | |
| } | |
| if is_diffusers_model: | |
| # Load config and then the model on meta. | |
| config, unused_kwargs, commit_hash = class_obj.load_config( | |
| os.path.join(config_path, name), | |
| cache_dir=cached_folder, | |
| return_unused_kwargs=True, | |
| return_commit_hash=True, | |
| force_download=kwargs.pop("force_download", False), | |
| resume_download=kwargs.pop("resume_download", None), | |
| proxies=kwargs.pop("proxies", None), | |
| local_files_only=kwargs.pop("local_files_only", False), | |
| token=kwargs.pop("token", None), | |
| revision=kwargs.pop("revision", None), | |
| subfolder=kwargs.pop("subfolder", None), | |
| user_agent=user_agent, | |
| ) | |
| with accelerate.init_empty_weights(): | |
| model = class_obj.from_config(config, **unused_kwargs) | |
| elif is_transformers_model: | |
| config_class = getattr(class_obj, "config_class", None) | |
| if config_class is None: | |
| raise ValueError("`config_class` cannot be None. Please double-check the model.") | |
| config = config_class.from_pretrained( | |
| cached_folder, | |
| subfolder=name, | |
| force_download=kwargs.pop("force_download", False), | |
| resume_download=kwargs.pop("resume_download", None), | |
| proxies=kwargs.pop("proxies", None), | |
| local_files_only=kwargs.pop("local_files_only", False), | |
| token=kwargs.pop("token", None), | |
| revision=kwargs.pop("revision", None), | |
| user_agent=user_agent, | |
| ) | |
| with accelerate.init_empty_weights(): | |
| model = class_obj(config) | |
| if model is not None: | |
| model = model.to(dtype=torch_dtype) | |
| return model | |
| def _assign_components_to_devices( | |
| module_sizes: Dict[str, float], device_memory: Dict[str, float], device_mapping_strategy: str = "balanced" | |
| ): | |
| device_ids = list(device_memory.keys()) | |
| device_cycle = device_ids + device_ids[::-1] | |
| device_memory = device_memory.copy() | |
| device_id_component_mapping = {} | |
| current_device_index = 0 | |
| for component in module_sizes: | |
| device_id = device_cycle[current_device_index % len(device_cycle)] | |
| component_memory = module_sizes[component] | |
| curr_device_memory = device_memory[device_id] | |
| # If the GPU doesn't fit the current component offload to the CPU. | |
| if component_memory > curr_device_memory: | |
| device_id_component_mapping["cpu"] = [component] | |
| else: | |
| if device_id not in device_id_component_mapping: | |
| device_id_component_mapping[device_id] = [component] | |
| else: | |
| device_id_component_mapping[device_id].append(component) | |
| # Update the device memory. | |
| device_memory[device_id] -= component_memory | |
| current_device_index += 1 | |
| return device_id_component_mapping | |
| def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dict, library, max_memory, **kwargs): | |
| # To avoid circular import problem. | |
| from diffusers import pipelines | |
| torch_dtype = kwargs.get("torch_dtype", torch.float32) | |
| # Load each module in the pipeline on a meta device so that we can derive the device map. | |
| init_empty_modules = {} | |
| for name, (library_name, class_name) in init_dict.items(): | |
| if class_name.startswith("Flax"): | |
| raise ValueError("Flax pipelines are not supported with `device_map`.") | |
| # Define all importable classes | |
| is_pipeline_module = hasattr(pipelines, library_name) | |
| importable_classes = ALL_IMPORTABLE_CLASSES | |
| loaded_sub_model = None | |
| # Use passed sub model or load class_name from library_name | |
| if name in passed_class_obj: | |
| # if the model is in a pipeline module, then we load it from the pipeline | |
| # check that passed_class_obj has correct parent class | |
| maybe_raise_or_warn( | |
| library_name, | |
| library, | |
| class_name, | |
| importable_classes, | |
| passed_class_obj, | |
| name, | |
| is_pipeline_module, | |
| ) | |
| with accelerate.init_empty_weights(): | |
| loaded_sub_model = passed_class_obj[name] | |
| else: | |
| loaded_sub_model = _load_empty_model( | |
| library_name=library_name, | |
| class_name=class_name, | |
| importable_classes=importable_classes, | |
| pipelines=pipelines, | |
| is_pipeline_module=is_pipeline_module, | |
| pipeline_class=pipeline_class, | |
| name=name, | |
| torch_dtype=torch_dtype, | |
| cached_folder=kwargs.get("cached_folder", None), | |
| force_download=kwargs.get("force_download", None), | |
| resume_download=kwargs.get("resume_download", None), | |
| proxies=kwargs.get("proxies", None), | |
| local_files_only=kwargs.get("local_files_only", None), | |
| token=kwargs.get("token", None), | |
| revision=kwargs.get("revision", None), | |
| ) | |
| if loaded_sub_model is not None: | |
| init_empty_modules[name] = loaded_sub_model | |
| # determine device map | |
| # Obtain a sorted dictionary for mapping the model-level components | |
| # to their sizes. | |
| module_sizes = { | |
| module_name: compute_module_sizes(module, dtype=torch_dtype)[""] | |
| for module_name, module in init_empty_modules.items() | |
| if isinstance(module, torch.nn.Module) | |
| } | |
| module_sizes = dict(sorted(module_sizes.items(), key=lambda item: item[1], reverse=True)) | |
| # Obtain maximum memory available per device (GPUs only). | |
| max_memory = get_max_memory(max_memory) | |
| max_memory = dict(sorted(max_memory.items(), key=lambda item: item[1], reverse=True)) | |
| max_memory = {k: v for k, v in max_memory.items() if k != "cpu"} | |
| # Obtain a dictionary mapping the model-level components to the available | |
| # devices based on the maximum memory and the model sizes. | |
| final_device_map = None | |
| if len(max_memory) > 0: | |
| device_id_component_mapping = _assign_components_to_devices( | |
| module_sizes, max_memory, device_mapping_strategy=device_map | |
| ) | |
| # Obtain the final device map, e.g., `{"unet": 0, "text_encoder": 1, "vae": 1, ...}` | |
| final_device_map = {} | |
| for device_id, components in device_id_component_mapping.items(): | |
| for component in components: | |
| final_device_map[component] = device_id | |
| return final_device_map | |
| def load_sub_model( | |
| library_name: str, | |
| class_name: str, | |
| importable_classes: List[Any], | |
| pipelines: Any, | |
| is_pipeline_module: bool, | |
| pipeline_class: Any, | |
| torch_dtype: torch.dtype, | |
| provider: Any, | |
| sess_options: Any, | |
| device_map: Optional[Union[Dict[str, torch.device], str]], | |
| max_memory: Optional[Dict[Union[int, str], Union[int, str]]], | |
| offload_folder: Optional[Union[str, os.PathLike]], | |
| offload_state_dict: bool, | |
| model_variants: Dict[str, str], | |
| name: str, | |
| from_flax: bool, | |
| variant: str, | |
| low_cpu_mem_usage: bool, | |
| cached_folder: Union[str, os.PathLike], | |
| ): | |
| """Helper method to load the module `name` from `library_name` and `class_name`""" | |
| # retrieve class candidates | |
| class_obj, class_candidates = get_class_obj_and_candidates( | |
| library_name, | |
| class_name, | |
| importable_classes, | |
| pipelines, | |
| is_pipeline_module, | |
| component_name=name, | |
| cache_dir=cached_folder, | |
| ) | |
| load_method_name = None | |
| # retrieve load method name | |
| for class_name, class_candidate in class_candidates.items(): | |
| if class_candidate is not None and issubclass(class_obj, class_candidate): | |
| load_method_name = importable_classes[class_name][1] | |
| # if load method name is None, then we have a dummy module -> raise Error | |
| if load_method_name is None: | |
| none_module = class_obj.__module__ | |
| is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( | |
| TRANSFORMERS_DUMMY_MODULES_FOLDER | |
| ) | |
| if is_dummy_path and "dummy" in none_module: | |
| # call class_obj for nice error message of missing requirements | |
| class_obj() | |
| raise ValueError( | |
| f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" | |
| f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." | |
| ) | |
| load_method = getattr(class_obj, load_method_name) | |
| # add kwargs to loading method | |
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
| loading_kwargs = {} | |
| if issubclass(class_obj, torch.nn.Module): | |
| loading_kwargs["torch_dtype"] = torch_dtype | |
| if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): | |
| loading_kwargs["provider"] = provider | |
| loading_kwargs["sess_options"] = sess_options | |
| is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) | |
| if is_transformers_available(): | |
| transformers_version = version.parse(version.parse(transformers.__version__).base_version) | |
| else: | |
| transformers_version = "N/A" | |
| is_transformers_model = ( | |
| is_transformers_available() | |
| and issubclass(class_obj, PreTrainedModel) | |
| and transformers_version >= version.parse("4.20.0") | |
| ) | |
| # When loading a transformers model, if the device_map is None, the weights will be initialized as opposed to diffusers. | |
| # To make default loading faster we set the `low_cpu_mem_usage=low_cpu_mem_usage` flag which is `True` by default. | |
| # This makes sure that the weights won't be initialized which significantly speeds up loading. | |
| if is_diffusers_model or is_transformers_model: | |
| loading_kwargs["device_map"] = device_map | |
| loading_kwargs["max_memory"] = max_memory | |
| loading_kwargs["offload_folder"] = offload_folder | |
| loading_kwargs["offload_state_dict"] = offload_state_dict | |
| loading_kwargs["variant"] = model_variants.pop(name, None) | |
| if from_flax: | |
| loading_kwargs["from_flax"] = True | |
| # the following can be deleted once the minimum required `transformers` version | |
| # is higher than 4.27 | |
| if ( | |
| is_transformers_model | |
| and loading_kwargs["variant"] is not None | |
| and transformers_version < version.parse("4.27.0") | |
| ): | |
| raise ImportError( | |
| f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" | |
| ) | |
| elif is_transformers_model and loading_kwargs["variant"] is None: | |
| loading_kwargs.pop("variant") | |
| # if `from_flax` and model is transformer model, can currently not load with `low_cpu_mem_usage` | |
| if not (from_flax and is_transformers_model): | |
| loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage | |
| else: | |
| loading_kwargs["low_cpu_mem_usage"] = False | |
| # check if the module is in a subdirectory | |
| if os.path.isdir(os.path.join(cached_folder, name)): | |
| loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) | |
| else: | |
| # else load from the root directory | |
| loaded_sub_model = load_method(cached_folder, **loading_kwargs) | |
| if isinstance(loaded_sub_model, torch.nn.Module) and isinstance(device_map, dict): | |
| # remove hooks | |
| remove_hook_from_module(loaded_sub_model, recurse=True) | |
| needs_offloading_to_cpu = device_map[""] == "cpu" | |
| if needs_offloading_to_cpu: | |
| dispatch_model( | |
| loaded_sub_model, | |
| state_dict=loaded_sub_model.state_dict(), | |
| device_map=device_map, | |
| force_hooks=True, | |
| main_device=0, | |
| ) | |
| else: | |
| dispatch_model(loaded_sub_model, device_map=device_map, force_hooks=True) | |
| return loaded_sub_model | |
| def _fetch_class_library_tuple(module): | |
| # import it here to avoid circular import | |
| diffusers_module = importlib.import_module(__name__.split(".")[0]) | |
| pipelines = getattr(diffusers_module, "pipelines") | |
| # register the config from the original module, not the dynamo compiled one | |
| not_compiled_module = _unwrap_model(module) | |
| library = not_compiled_module.__module__.split(".")[0] | |
| # check if the module is a pipeline module | |
| module_path_items = not_compiled_module.__module__.split(".") | |
| pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None | |
| path = not_compiled_module.__module__.split(".") | |
| is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) | |
| # if library is not in LOADABLE_CLASSES, then it is a custom module. | |
| # Or if it's a pipeline module, then the module is inside the pipeline | |
| # folder so we set the library to module name. | |
| if is_pipeline_module: | |
| library = pipeline_dir | |
| elif library not in LOADABLE_CLASSES: | |
| library = not_compiled_module.__module__ | |
| # retrieve class_name | |
| class_name = not_compiled_module.__class__.__name__ | |
| return (library, class_name) | |