| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import importlib |
| | import os |
| | import re |
| | import warnings |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional, Union |
| |
|
| | import torch |
| | from huggingface_hub import ( |
| | model_info, |
| | ) |
| | from packaging import version |
| |
|
| | from ..utils import ( |
| | SAFETENSORS_WEIGHTS_NAME, |
| | WEIGHTS_NAME, |
| | get_class_from_dynamic_module, |
| | is_peft_available, |
| | is_transformers_available, |
| | logging, |
| | ) |
| | from ..utils.torch_utils import is_compiled_module |
| |
|
| |
|
| | if is_transformers_available(): |
| | import transformers |
| | from transformers import PreTrainedModel |
| | from transformers.utils import FLAX_WEIGHTS_NAME as TRANSFORMERS_FLAX_WEIGHTS_NAME |
| | from transformers.utils import SAFE_WEIGHTS_NAME as TRANSFORMERS_SAFE_WEIGHTS_NAME |
| | from transformers.utils import WEIGHTS_NAME as TRANSFORMERS_WEIGHTS_NAME |
| | from huggingface_hub.utils import validate_hf_hub_args |
| |
|
| | from ..utils import FLAX_WEIGHTS_NAME, ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME |
| |
|
| |
|
| | INDEX_FILE = "diffusion_pytorch_model.bin" |
| | CUSTOM_PIPELINE_FILE_NAME = "pipeline.py" |
| | DUMMY_MODULES_FOLDER = "diffusers.utils" |
| | TRANSFORMERS_DUMMY_MODULES_FOLDER = "transformers.utils" |
| | CONNECTED_PIPES_KEYS = ["prior"] |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| | LOADABLE_CLASSES = { |
| | "diffusers": { |
| | "ModelMixin": ["save_pretrained", "from_pretrained"], |
| | "SchedulerMixin": ["save_pretrained", "from_pretrained"], |
| | "DiffusionPipeline": ["save_pretrained", "from_pretrained"], |
| | "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], |
| | }, |
| | "transformers": { |
| | "PreTrainedTokenizer": ["save_pretrained", "from_pretrained"], |
| | "PreTrainedTokenizerFast": ["save_pretrained", "from_pretrained"], |
| | "PreTrainedModel": ["save_pretrained", "from_pretrained"], |
| | "FeatureExtractionMixin": ["save_pretrained", "from_pretrained"], |
| | "ProcessorMixin": ["save_pretrained", "from_pretrained"], |
| | "ImageProcessingMixin": ["save_pretrained", "from_pretrained"], |
| | }, |
| | "onnxruntime.training": { |
| | "ORTModule": ["save_pretrained", "from_pretrained"], |
| | }, |
| | } |
| |
|
| | ALL_IMPORTABLE_CLASSES = {} |
| | for library in LOADABLE_CLASSES: |
| | ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) |
| |
|
| |
|
| | def is_safetensors_compatible(filenames, variant=None, passed_components=None) -> bool: |
| | """ |
| | Checking for safetensors compatibility: |
| | - By default, all models are saved with the default pytorch serialization, so we use the list of default pytorch |
| | files to know which safetensors files are needed. |
| | - The model is safetensors compatible only if there is a matching safetensors file for every default pytorch file. |
| | |
| | Converting default pytorch serialized filenames to safetensors serialized filenames: |
| | - For models from the diffusers library, just replace the ".bin" extension with ".safetensors" |
| | - For models from the transformers library, the filename changes from "pytorch_model" to "model", and the ".bin" |
| | extension is replaced with ".safetensors" |
| | """ |
| | pt_filenames = [] |
| |
|
| | sf_filenames = set() |
| |
|
| | passed_components = passed_components or [] |
| |
|
| | for filename in filenames: |
| | _, extension = os.path.splitext(filename) |
| |
|
| | if len(filename.split("/")) == 2 and filename.split("/")[0] in passed_components: |
| | continue |
| |
|
| | if extension == ".bin": |
| | pt_filenames.append(os.path.normpath(filename)) |
| | elif extension == ".safetensors": |
| | sf_filenames.add(os.path.normpath(filename)) |
| |
|
| | for filename in pt_filenames: |
| | |
| | path, filename = os.path.split(filename) |
| | filename, extension = os.path.splitext(filename) |
| |
|
| | if filename.startswith("pytorch_model"): |
| | filename = filename.replace("pytorch_model", "model") |
| | else: |
| | filename = filename |
| |
|
| | expected_sf_filename = os.path.normpath(os.path.join(path, filename)) |
| | expected_sf_filename = f"{expected_sf_filename}.safetensors" |
| | if expected_sf_filename not in sf_filenames: |
| | logger.warning(f"{expected_sf_filename} not found") |
| | return False |
| |
|
| | return True |
| |
|
| |
|
| | def variant_compatible_siblings(filenames, variant=None) -> Union[List[os.PathLike], str]: |
| | weight_names = [ |
| | WEIGHTS_NAME, |
| | SAFETENSORS_WEIGHTS_NAME, |
| | FLAX_WEIGHTS_NAME, |
| | ONNX_WEIGHTS_NAME, |
| | ONNX_EXTERNAL_WEIGHTS_NAME, |
| | ] |
| |
|
| | if is_transformers_available(): |
| | weight_names += [TRANSFORMERS_WEIGHTS_NAME, TRANSFORMERS_SAFE_WEIGHTS_NAME, TRANSFORMERS_FLAX_WEIGHTS_NAME] |
| |
|
| | |
| | weight_prefixes = [w.split(".")[0] for w in weight_names] |
| | |
| | weight_suffixs = [w.split(".")[-1] for w in weight_names] |
| | |
| | transformers_index_format = r"\d{5}-of-\d{5}" |
| |
|
| | if variant is not None: |
| | |
| | variant_file_re = re.compile( |
| | rf"({'|'.join(weight_prefixes)})\.({variant}|{variant}-{transformers_index_format})\.({'|'.join(weight_suffixs)})$" |
| | ) |
| | |
| | variant_index_re = re.compile( |
| | rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.{variant}\.json$" |
| | ) |
| |
|
| | |
| | non_variant_file_re = re.compile( |
| | rf"({'|'.join(weight_prefixes)})(-{transformers_index_format})?\.({'|'.join(weight_suffixs)})$" |
| | ) |
| | |
| | non_variant_index_re = re.compile(rf"({'|'.join(weight_prefixes)})\.({'|'.join(weight_suffixs)})\.index\.json") |
| |
|
| | if variant is not None: |
| | variant_weights = {f for f in filenames if variant_file_re.match(f.split("/")[-1]) is not None} |
| | variant_indexes = {f for f in filenames if variant_index_re.match(f.split("/")[-1]) is not None} |
| | variant_filenames = variant_weights | variant_indexes |
| | else: |
| | variant_filenames = set() |
| |
|
| | non_variant_weights = {f for f in filenames if non_variant_file_re.match(f.split("/")[-1]) is not None} |
| | non_variant_indexes = {f for f in filenames if non_variant_index_re.match(f.split("/")[-1]) is not None} |
| | non_variant_filenames = non_variant_weights | non_variant_indexes |
| |
|
| | |
| | usable_filenames = set(variant_filenames) |
| |
|
| | def convert_to_variant(filename): |
| | if "index" in filename: |
| | variant_filename = filename.replace("index", f"index.{variant}") |
| | elif re.compile(f"^(.*?){transformers_index_format}").match(filename) is not None: |
| | variant_filename = f"{filename.split('-')[0]}.{variant}-{'-'.join(filename.split('-')[1:])}" |
| | else: |
| | variant_filename = f"{filename.split('.')[0]}.{variant}.{filename.split('.')[1]}" |
| | return variant_filename |
| |
|
| | for f in non_variant_filenames: |
| | variant_filename = convert_to_variant(f) |
| | if variant_filename not in usable_filenames: |
| | usable_filenames.add(f) |
| |
|
| | return usable_filenames, variant_filenames |
| |
|
| |
|
| | @validate_hf_hub_args |
| | def warn_deprecated_model_variant(pretrained_model_name_or_path, token, variant, revision, model_filenames): |
| | info = model_info( |
| | pretrained_model_name_or_path, |
| | token=token, |
| | revision=None, |
| | ) |
| | filenames = {sibling.rfilename for sibling in info.siblings} |
| | comp_model_filenames, _ = variant_compatible_siblings(filenames, variant=revision) |
| | comp_model_filenames = [".".join(f.split(".")[:1] + f.split(".")[2:]) for f in comp_model_filenames] |
| |
|
| | if set(model_filenames).issubset(set(comp_model_filenames)): |
| | warnings.warn( |
| | f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'` even though you can load it via `variant=`{revision}`. Loading model variants via `revision='{revision}'` is deprecated and will be removed in diffusers v1. Please use `variant='{revision}'` instead.", |
| | FutureWarning, |
| | ) |
| | else: |
| | warnings.warn( |
| | f"You are loading the variant {revision} from {pretrained_model_name_or_path} via `revision='{revision}'`. This behavior is deprecated and will be removed in diffusers v1. One should use `variant='{revision}'` instead. However, it appears that {pretrained_model_name_or_path} currently does not have the required variant filenames in the 'main' branch. \n The Diffusers team and community would be very grateful if you could open an issue: https://github.com/huggingface/diffusers/issues/new with the title '{pretrained_model_name_or_path} is missing {revision} files' so that the correct variant file can be added.", |
| | FutureWarning, |
| | ) |
| |
|
| |
|
| | def _unwrap_model(model): |
| | """Unwraps a model.""" |
| | if is_compiled_module(model): |
| | model = model._orig_mod |
| |
|
| | if is_peft_available(): |
| | from peft import PeftModel |
| |
|
| | if isinstance(model, PeftModel): |
| | model = model.base_model.model |
| |
|
| | return model |
| |
|
| |
|
| | def maybe_raise_or_warn( |
| | library_name, library, class_name, importable_classes, passed_class_obj, name, is_pipeline_module |
| | ): |
| | """Simple helper method to raise or warn in case incorrect module has been passed""" |
| | if not is_pipeline_module: |
| | library = importlib.import_module(library_name) |
| | class_obj = getattr(library, class_name) |
| | class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
| |
|
| | expected_class_obj = None |
| | for class_name, class_candidate in class_candidates.items(): |
| | if class_candidate is not None and issubclass(class_obj, class_candidate): |
| | expected_class_obj = class_candidate |
| |
|
| | |
| | |
| | sub_model = passed_class_obj[name] |
| | unwrapped_sub_model = _unwrap_model(sub_model) |
| | model_cls = unwrapped_sub_model.__class__ |
| |
|
| | if not issubclass(model_cls, expected_class_obj): |
| | raise ValueError( |
| | f"{passed_class_obj[name]} is of type: {model_cls}, but should be" f" {expected_class_obj}" |
| | ) |
| | else: |
| | logger.warning( |
| | f"You have passed a non-standard module {passed_class_obj[name]}. We cannot verify whether it" |
| | " has the correct type" |
| | ) |
| |
|
| |
|
| | def get_class_obj_and_candidates( |
| | library_name, class_name, importable_classes, pipelines, is_pipeline_module, component_name=None, cache_dir=None |
| | ): |
| | """Simple helper method to retrieve class object of module as well as potential parent class objects""" |
| | component_folder = os.path.join(cache_dir, component_name) |
| |
|
| | if is_pipeline_module: |
| | pipeline_module = getattr(pipelines, library_name) |
| |
|
| | class_obj = getattr(pipeline_module, class_name) |
| | class_candidates = {c: class_obj for c in importable_classes.keys()} |
| | elif os.path.isfile(os.path.join(component_folder, library_name + ".py")): |
| | |
| | class_obj = get_class_from_dynamic_module( |
| | component_folder, module_file=library_name + ".py", class_name=class_name |
| | ) |
| | class_candidates = {c: class_obj for c in importable_classes.keys()} |
| | else: |
| | |
| | library = importlib.import_module(library_name) |
| |
|
| | class_obj = getattr(library, class_name) |
| | class_candidates = {c: getattr(library, c, None) for c in importable_classes.keys()} |
| |
|
| | return class_obj, class_candidates |
| |
|
| |
|
| | def _get_pipeline_class( |
| | class_obj, |
| | config=None, |
| | load_connected_pipeline=False, |
| | custom_pipeline=None, |
| | repo_id=None, |
| | hub_revision=None, |
| | class_name=None, |
| | cache_dir=None, |
| | revision=None, |
| | ): |
| | if custom_pipeline is not None: |
| | if custom_pipeline.endswith(".py"): |
| | path = Path(custom_pipeline) |
| | |
| | file_name = path.name |
| | custom_pipeline = path.parent.absolute() |
| | elif repo_id is not None: |
| | file_name = f"{custom_pipeline}.py" |
| | custom_pipeline = repo_id |
| | else: |
| | file_name = CUSTOM_PIPELINE_FILE_NAME |
| |
|
| | if repo_id is not None and hub_revision is not None: |
| | |
| | |
| | revision = hub_revision |
| |
|
| | return get_class_from_dynamic_module( |
| | custom_pipeline, |
| | module_file=file_name, |
| | class_name=class_name, |
| | cache_dir=cache_dir, |
| | revision=revision, |
| | ) |
| |
|
| | if class_obj.__name__ != "DiffusionPipeline": |
| | return class_obj |
| |
|
| | diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) |
| | class_name = class_name or config["_class_name"] |
| | if not class_name: |
| | raise ValueError( |
| | "The class name could not be found in the configuration file. Please make sure to pass the correct `class_name`." |
| | ) |
| |
|
| | class_name = class_name[4:] if class_name.startswith("Flax") else class_name |
| |
|
| | pipeline_cls = getattr(diffusers_module, class_name) |
| |
|
| | if load_connected_pipeline: |
| | from .auto_pipeline import _get_connected_pipeline |
| |
|
| | connected_pipeline_cls = _get_connected_pipeline(pipeline_cls) |
| | if connected_pipeline_cls is not None: |
| | logger.info( |
| | f"Loading connected pipeline {connected_pipeline_cls.__name__} instead of {pipeline_cls.__name__} as specified via `load_connected_pipeline=True`" |
| | ) |
| | else: |
| | logger.info(f"{pipeline_cls.__name__} has no connected pipeline class. Loading {pipeline_cls.__name__}.") |
| |
|
| | pipeline_cls = connected_pipeline_cls or pipeline_cls |
| |
|
| | return pipeline_cls |
| |
|
| |
|
| | def load_sub_model( |
| | library_name: str, |
| | class_name: str, |
| | importable_classes: List[Any], |
| | pipelines: Any, |
| | is_pipeline_module: bool, |
| | pipeline_class: Any, |
| | torch_dtype: torch.dtype, |
| | provider: Any, |
| | sess_options: Any, |
| | device_map: Optional[Union[Dict[str, torch.device], str]], |
| | max_memory: Optional[Dict[Union[int, str], Union[int, str]]], |
| | offload_folder: Optional[Union[str, os.PathLike]], |
| | offload_state_dict: bool, |
| | model_variants: Dict[str, str], |
| | name: str, |
| | from_flax: bool, |
| | variant: str, |
| | low_cpu_mem_usage: bool, |
| | cached_folder: Union[str, os.PathLike], |
| | ): |
| | """Helper method to load the module `name` from `library_name` and `class_name`""" |
| | |
| | class_obj, class_candidates = get_class_obj_and_candidates( |
| | library_name, |
| | class_name, |
| | importable_classes, |
| | pipelines, |
| | is_pipeline_module, |
| | component_name=name, |
| | cache_dir=cached_folder, |
| | ) |
| |
|
| | load_method_name = None |
| | |
| | for class_name, class_candidate in class_candidates.items(): |
| | if class_candidate is not None and issubclass(class_obj, class_candidate): |
| | load_method_name = importable_classes[class_name][1] |
| |
|
| | |
| | if load_method_name is None: |
| | none_module = class_obj.__module__ |
| | is_dummy_path = none_module.startswith(DUMMY_MODULES_FOLDER) or none_module.startswith( |
| | TRANSFORMERS_DUMMY_MODULES_FOLDER |
| | ) |
| | if is_dummy_path and "dummy" in none_module: |
| | |
| | class_obj() |
| |
|
| | raise ValueError( |
| | f"The component {class_obj} of {pipeline_class} cannot be loaded as it does not seem to have" |
| | f" any of the loading methods defined in {ALL_IMPORTABLE_CLASSES}." |
| | ) |
| |
|
| | load_method = getattr(class_obj, load_method_name) |
| |
|
| | |
| | diffusers_module = importlib.import_module(__name__.split(".")[0]) |
| | loading_kwargs = {} |
| | if issubclass(class_obj, torch.nn.Module): |
| | loading_kwargs["torch_dtype"] = torch_dtype |
| | if issubclass(class_obj, diffusers_module.OnnxRuntimeModel): |
| | loading_kwargs["provider"] = provider |
| | loading_kwargs["sess_options"] = sess_options |
| |
|
| | is_diffusers_model = issubclass(class_obj, diffusers_module.ModelMixin) |
| |
|
| | if is_transformers_available(): |
| | transformers_version = version.parse(version.parse(transformers.__version__).base_version) |
| | else: |
| | transformers_version = "N/A" |
| |
|
| | is_transformers_model = ( |
| | is_transformers_available() |
| | and issubclass(class_obj, PreTrainedModel) |
| | and transformers_version >= version.parse("4.20.0") |
| | ) |
| |
|
| | |
| | |
| | |
| | if is_diffusers_model or is_transformers_model: |
| | loading_kwargs["device_map"] = device_map |
| | loading_kwargs["max_memory"] = max_memory |
| | loading_kwargs["offload_folder"] = offload_folder |
| | loading_kwargs["offload_state_dict"] = offload_state_dict |
| | loading_kwargs["variant"] = model_variants.pop(name, None) |
| |
|
| | if from_flax: |
| | loading_kwargs["from_flax"] = True |
| |
|
| | |
| | |
| | if ( |
| | is_transformers_model |
| | and loading_kwargs["variant"] is not None |
| | and transformers_version < version.parse("4.27.0") |
| | ): |
| | raise ImportError( |
| | f"When passing `variant='{variant}'`, please make sure to upgrade your `transformers` version to at least 4.27.0.dev0" |
| | ) |
| | elif is_transformers_model and loading_kwargs["variant"] is None: |
| | loading_kwargs.pop("variant") |
| |
|
| | |
| | if not (from_flax and is_transformers_model): |
| | loading_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage |
| | else: |
| | loading_kwargs["low_cpu_mem_usage"] = False |
| |
|
| | |
| | if os.path.isdir(os.path.join(cached_folder, name)): |
| | loaded_sub_model = load_method(os.path.join(cached_folder, name), **loading_kwargs) |
| | else: |
| | |
| | loaded_sub_model = load_method(cached_folder, **loading_kwargs) |
| |
|
| | return loaded_sub_model |
| |
|
| |
|
| | def _fetch_class_library_tuple(module): |
| | |
| | diffusers_module = importlib.import_module(__name__.split(".")[0]) |
| | pipelines = getattr(diffusers_module, "pipelines") |
| |
|
| | |
| | not_compiled_module = _unwrap_model(module) |
| | library = not_compiled_module.__module__.split(".")[0] |
| |
|
| | |
| | module_path_items = not_compiled_module.__module__.split(".") |
| | pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None |
| |
|
| | path = not_compiled_module.__module__.split(".") |
| | is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir) |
| |
|
| | |
| | |
| | |
| | if is_pipeline_module: |
| | library = pipeline_dir |
| | elif library not in LOADABLE_CLASSES: |
| | library = not_compiled_module.__module__ |
| |
|
| | |
| | class_name = not_compiled_module.__class__.__name__ |
| |
|
| | return (library, class_name) |
| |
|