| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from __future__ import annotations |
|
|
| import copy |
| import time |
| from collections import OrderedDict |
| from itertools import combinations |
| from typing import Any |
|
|
| import torch |
|
|
| from ..hooks import ModelHook |
| from ..utils import ( |
| is_accelerate_available, |
| logging, |
| ) |
| from ..utils.torch_utils import get_device |
|
|
|
|
| if is_accelerate_available(): |
| from accelerate.hooks import add_hook_to_module, remove_hook_from_module |
| from accelerate.state import PartialState |
| from accelerate.utils import send_to_device |
| from accelerate.utils.memory import clear_device_cache |
| from accelerate.utils.modeling import convert_file_size_to_int |
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class CustomOffloadHook(ModelHook): |
| """ |
| A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are |
| on the given device. Optionally offloads other models to the CPU before the forward pass is called. |
| |
| Args: |
| execution_device(`str`, `int` or `torch.device`, *optional*): |
| The device on which the model should be executed. Will default to the MPS device if it's available, then |
| GPU 0 if there is a GPU, and finally to the CPU. |
| """ |
|
|
| no_grad = False |
|
|
| def __init__( |
| self, |
| execution_device: str | int | torch.device | None = None, |
| other_hooks: list["UserCustomOffloadHook"] | None = None, |
| offload_strategy: "AutoOffloadStrategy" | None = None, |
| ): |
| self.execution_device = execution_device if execution_device is not None else PartialState().default_device |
| self.other_hooks = other_hooks |
| self.offload_strategy = offload_strategy |
| self.model_id = None |
|
|
| def set_strategy(self, offload_strategy: "AutoOffloadStrategy"): |
| self.offload_strategy = offload_strategy |
|
|
| def add_other_hook(self, hook: "UserCustomOffloadHook"): |
| """ |
| Add a hook to the list of hooks to consider for offloading. |
| """ |
| if self.other_hooks is None: |
| self.other_hooks = [] |
| self.other_hooks.append(hook) |
|
|
| def init_hook(self, module): |
| return module.to("cpu") |
|
|
| def pre_forward(self, module, *args, **kwargs): |
| if module.device != self.execution_device: |
| if self.other_hooks is not None: |
| hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device] |
| |
| start_time = time.perf_counter() |
| if self.offload_strategy is not None: |
| hooks_to_offload = self.offload_strategy( |
| hooks=hooks_to_offload, |
| model_id=self.model_id, |
| model=module, |
| execution_device=self.execution_device, |
| ) |
| end_time = time.perf_counter() |
| logger.info( |
| f" time taken to apply offload strategy for {self.model_id}: {(end_time - start_time):.2f} seconds" |
| ) |
|
|
| for hook in hooks_to_offload: |
| logger.info( |
| f"moving {self.model_id} to {self.execution_device}, offloading {hook.model_id} to cpu" |
| ) |
| hook.offload() |
|
|
| if hooks_to_offload: |
| clear_device_cache() |
| module.to(self.execution_device) |
| return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device) |
|
|
|
|
| class UserCustomOffloadHook: |
| """ |
| A simple hook grouping a model and a `CustomOffloadHook`, which provides easy APIs for to call the init method of |
| the hook or remove it entirely. |
| """ |
|
|
| def __init__(self, model_id, model, hook): |
| self.model_id = model_id |
| self.model = model |
| self.hook = hook |
|
|
| def offload(self): |
| self.hook.init_hook(self.model) |
|
|
| def attach(self): |
| add_hook_to_module(self.model, self.hook) |
| self.hook.model_id = self.model_id |
|
|
| def remove(self): |
| remove_hook_from_module(self.model) |
| self.hook.model_id = None |
|
|
| def add_other_hook(self, hook: "UserCustomOffloadHook"): |
| self.hook.add_other_hook(hook) |
|
|
|
|
| def custom_offload_with_hook( |
| model_id: str, |
| model: torch.nn.Module, |
| execution_device: str | int | torch.device = None, |
| offload_strategy: "AutoOffloadStrategy" | None = None, |
| ): |
| hook = CustomOffloadHook(execution_device=execution_device, offload_strategy=offload_strategy) |
| user_hook = UserCustomOffloadHook(model_id=model_id, model=model, hook=hook) |
| user_hook.attach() |
| return user_hook |
|
|
|
|
| |
| class AutoOffloadStrategy: |
| """ |
| Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on |
| the available memory on the device. |
| """ |
|
|
| |
| |
| def __init__(self, memory_reserve_margin="3GB"): |
| self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) |
|
|
| def __call__(self, hooks, model_id, model, execution_device): |
| if len(hooks) == 0: |
| return [] |
|
|
| try: |
| current_module_size = model.get_memory_footprint() |
| except AttributeError: |
| raise AttributeError(f"Do not know how to compute memory footprint of `{model.__class__.__name__}.") |
|
|
| device_type = execution_device.type |
| device_module = getattr(torch, device_type, torch.cuda) |
| try: |
| mem_on_device = device_module.mem_get_info(execution_device.index)[0] |
| except AttributeError: |
| raise AttributeError(f"Do not know how to obtain obtain memory info for {str(device_module)}.") |
|
|
| mem_on_device = mem_on_device - self.memory_reserve_margin |
| if current_module_size < mem_on_device: |
| return [] |
|
|
| min_memory_offload = current_module_size - mem_on_device |
| logger.info(f" search for models to offload in order to free up {min_memory_offload / 1024**3:.2f} GB memory") |
|
|
| |
| module_sizes = dict( |
| sorted( |
| {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(), |
| key=lambda x: x[1], |
| reverse=True, |
| ) |
| ) |
|
|
| |
| def search_best_candidate(module_sizes, min_memory_offload): |
| """ |
| search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a |
| minimum memory offload size. the combination of models should add up to the smallest modulesize that is |
| larger than `min_memory_offload` |
| """ |
| model_ids = list(module_sizes.keys()) |
| best_candidate = None |
| best_size = float("inf") |
| for r in range(1, len(model_ids) + 1): |
| for candidate_model_ids in combinations(model_ids, r): |
| candidate_size = sum( |
| module_sizes[candidate_model_id] for candidate_model_id in candidate_model_ids |
| ) |
| if candidate_size < min_memory_offload: |
| continue |
| else: |
| if best_candidate is None or candidate_size < best_size: |
| best_candidate = candidate_model_ids |
| best_size = candidate_size |
|
|
| return best_candidate |
|
|
| best_offload_model_ids = search_best_candidate(module_sizes, min_memory_offload) |
|
|
| if best_offload_model_ids is None: |
| |
| logger.warning("no combination of models to offload to cpu is found, offloading all models") |
| hooks_to_offload = hooks |
| else: |
| hooks_to_offload = [hook for hook in hooks if hook.model_id in best_offload_model_ids] |
|
|
| return hooks_to_offload |
|
|
|
|
| |
| |
| def summarize_dict_by_value_and_parts(d: dict[str, Any]) -> dict[str, Any]: |
| """Summarizes a dictionary by finding common prefixes that share the same value. |
| |
| For a dictionary with dot-separated keys like: { |
| 'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], |
| 'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], |
| 'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], |
| } |
| |
| Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { |
| 'down_blocks': [0.6], 'up_blocks': [0.3] |
| } |
| """ |
| |
| value_to_keys = {} |
| for key, value in d.items(): |
| value_tuple = tuple(value) if isinstance(value, list) else value |
| if value_tuple not in value_to_keys: |
| value_to_keys[value_tuple] = [] |
| value_to_keys[value_tuple].append(key) |
|
|
| def find_common_prefix(keys: list[str]) -> str: |
| """Find the shortest common prefix among a list of dot-separated keys.""" |
| if not keys: |
| return "" |
| if len(keys) == 1: |
| return keys[0] |
|
|
| |
| key_parts = [k.split(".") for k in keys] |
|
|
| |
| common_length = 0 |
| for parts in zip(*key_parts): |
| if len(set(parts)) == 1: |
| common_length += 1 |
| else: |
| break |
|
|
| if common_length == 0: |
| return "" |
|
|
| |
| return ".".join(key_parts[0][:common_length]) |
|
|
| |
| summary = {} |
| for value_tuple, keys in value_to_keys.items(): |
| prefix = find_common_prefix(keys) |
| if prefix: |
| |
| value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple |
| summary[prefix] = value |
| else: |
| summary[""] = value |
|
|
| return summary |
|
|
|
|
| class ComponentsManager: |
| """ |
| A central registry and management system for model components across multiple pipelines. |
| |
| [`ComponentsManager`] provides a unified way to register, track, and reuse model components (like UNet, VAE, text |
| encoders, etc.) across different modular pipelines. It includes features for duplicate detection, memory |
| management, and component organization. |
| |
| > [!WARNING] > This is an experimental feature and is likely to change in the future. |
| |
| Example: |
| ```python |
| from diffusers import ComponentsManager |
| |
| # Create a components manager |
| cm = ComponentsManager() |
| |
| # Add components |
| cm.add("unet", unet_model, collection="sdxl") |
| cm.add("vae", vae_model, collection="sdxl") |
| |
| # Enable auto offloading |
| cm.enable_auto_cpu_offload() |
| |
| # Retrieve components |
| unet = cm.get_one(name="unet", collection="sdxl") |
| ``` |
| """ |
|
|
| _available_info_fields = [ |
| "model_id", |
| "added_time", |
| "collection", |
| "class_name", |
| "size_gb", |
| "adapters", |
| "has_hook", |
| "execution_device", |
| "ip_adapter", |
| "quantization", |
| ] |
|
|
| def __init__(self): |
| self.components = OrderedDict() |
| |
| self.added_time = OrderedDict() |
| self.collections = OrderedDict() |
| self.model_hooks = None |
| self._auto_offload_enabled = False |
|
|
| def _lookup_ids( |
| self, |
| name: str | None = None, |
| collection: str | None = None, |
| load_id: str | None = None, |
| components: OrderedDict | None = None, |
| ): |
| """ |
| Lookup component_ids by name, collection, or load_id. Does not support pattern matching. Returns a set of |
| component_ids |
| """ |
| if components is None: |
| components = self.components |
|
|
| if name: |
| ids_by_name = set() |
| for component_id, component in components.items(): |
| comp_name = self._id_to_name(component_id) |
| if comp_name == name: |
| ids_by_name.add(component_id) |
| else: |
| ids_by_name = set(components.keys()) |
| if collection and collection not in self.collections: |
| return set() |
| elif collection and collection in self.collections: |
| ids_by_collection = set() |
| for component_id, component in components.items(): |
| if component_id in self.collections[collection]: |
| ids_by_collection.add(component_id) |
| else: |
| ids_by_collection = set(components.keys()) |
| if load_id: |
| ids_by_load_id = set() |
| for name, component in components.items(): |
| if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id == load_id: |
| ids_by_load_id.add(name) |
| else: |
| ids_by_load_id = set(components.keys()) |
|
|
| ids = ids_by_name.intersection(ids_by_collection).intersection(ids_by_load_id) |
| return ids |
|
|
| @staticmethod |
| def _id_to_name(component_id: str): |
| return "_".join(component_id.split("_")[:-1]) |
|
|
| def add(self, name: str, component: Any, collection: str | None = None): |
| """ |
| Add a component to the ComponentsManager. |
| |
| Args: |
| name (str): The name of the component |
| component (Any): The component to add |
| collection (str | None): The collection to add the component to |
| |
| Returns: |
| str: The unique component ID, which is generated as "{name}_{id(component)}" where |
| id(component) is Python's built-in unique identifier for the object |
| """ |
| component_id = f"{name}_{id(component)}" |
| is_new_component = True |
|
|
| |
| for comp_id, comp in self.components.items(): |
| if comp == component: |
| comp_name = self._id_to_name(comp_id) |
| if comp_name == name: |
| logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'") |
| component_id = comp_id |
| is_new_component = False |
| break |
| else: |
| logger.warning( |
| f"ComponentsManager: adding component '{name}' as '{component_id}', but it is duplicate of '{comp_id}'" |
| f"To remove a duplicate, call `components_manager.remove('<component_id>')`." |
| ) |
|
|
| |
| if hasattr(component, "_diffusers_load_id") and component._diffusers_load_id != "null": |
| components_with_same_load_id = self._lookup_ids(load_id=component._diffusers_load_id) |
| components_with_same_load_id = [id for id in components_with_same_load_id if id != component_id] |
|
|
| if components_with_same_load_id: |
| existing = ", ".join(components_with_same_load_id) |
| logger.warning( |
| f"ComponentsManager: adding component '{component_id}', but it has duplicate load_id '{component._diffusers_load_id}' with existing components: {existing}. " |
| f"To remove a duplicate, call `components_manager.remove('<component_id>')`." |
| ) |
|
|
| |
| self.components[component_id] = component |
| if is_new_component: |
| self.added_time[component_id] = time.time() |
|
|
| if collection: |
| if collection not in self.collections: |
| self.collections[collection] = set() |
| if component_id not in self.collections[collection]: |
| comp_ids_in_collection = self._lookup_ids(name=name, collection=collection) |
| for comp_id in comp_ids_in_collection: |
| logger.warning( |
| f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" |
| ) |
| |
| self.remove_from_collection(comp_id, collection) |
|
|
| self.collections[collection].add(component_id) |
| logger.info( |
| f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" |
| ) |
| else: |
| logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") |
|
|
| if self._auto_offload_enabled and is_new_component: |
| self.enable_auto_cpu_offload(self._auto_offload_device) |
|
|
| return component_id |
|
|
| def remove_from_collection(self, component_id: str, collection: str): |
| """ |
| Remove a component from a collection. |
| """ |
| if collection not in self.collections: |
| logger.warning(f"Collection '{collection}' not found in ComponentsManager") |
| return |
| if component_id not in self.collections[collection]: |
| logger.warning(f"Component '{component_id}' not found in collection '{collection}'") |
| return |
| |
| self.collections[collection].remove(component_id) |
| |
| comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps] |
| if not comp_colls: |
| logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager") |
| self.remove(component_id) |
|
|
| def remove(self, component_id: str = None): |
| """ |
| Remove a component from the ComponentsManager. |
| |
| Args: |
| component_id (str): The ID of the component to remove |
| """ |
| if component_id not in self.components: |
| logger.warning(f"Component '{component_id}' not found in ComponentsManager") |
| return |
|
|
| component = self.components.pop(component_id) |
| self.added_time.pop(component_id) |
|
|
| for collection in self.collections: |
| if component_id in self.collections[collection]: |
| self.collections[collection].remove(component_id) |
|
|
| if self._auto_offload_enabled: |
| self.enable_auto_cpu_offload(self._auto_offload_device) |
| else: |
| if isinstance(component, torch.nn.Module): |
| component.to("cpu") |
| del component |
| import gc |
|
|
| gc.collect() |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| if torch.xpu.is_available(): |
| torch.xpu.empty_cache() |
|
|
| |
| def search_components( |
| self, |
| names: str | None = None, |
| collection: str | None = None, |
| load_id: str | None = None, |
| return_dict_with_names: bool = True, |
| ): |
| """ |
| Search components by name with simple pattern matching. Optionally filter by collection or load_id. |
| |
| Args: |
| names: Component name(s) or pattern(s) |
| Patterns: |
| - "unet" : match any component with base name "unet" (e.g., unet_123abc) |
| - "!unet" : everything except components with base name "unet" |
| - "unet*" : anything with base name starting with "unet" |
| - "!unet*" : anything with base name NOT starting with "unet" |
| - "*unet*" : anything with base name containing "unet" |
| - "!*unet*" : anything with base name NOT containing "unet" |
| - "refiner|vae|unet" : anything with base name exactly matching "refiner", "vae", or "unet" |
| - "!refiner|vae|unet" : anything with base name NOT exactly matching "refiner", "vae", or "unet" |
| - "unet*|vae*" : anything with base name starting with "unet" OR starting with "vae" |
| collection: Optional collection to filter by |
| load_id: Optional load_id to filter by |
| return_dict_with_names: |
| If True, returns a dictionary with component names as keys, throw an error if |
| multiple components with the same name are found If False, returns a dictionary |
| with component IDs as keys |
| |
| Returns: |
| Dictionary mapping component names to components if return_dict_with_names=True, or a dictionary mapping |
| component IDs to components if return_dict_with_names=False |
| """ |
|
|
| |
| selected_ids = self._lookup_ids(collection=collection, load_id=load_id) |
| components = {k: self.components[k] for k in selected_ids} |
|
|
| def get_return_dict(components, return_dict_with_names): |
| """ |
| Create a dictionary mapping component names to components if return_dict_with_names=True, or a dictionary |
| mapping component IDs to components if return_dict_with_names=False, throw an error if duplicate component |
| names are found when return_dict_with_names=True |
| """ |
| if return_dict_with_names: |
| dict_to_return = {} |
| for comp_id, comp in components.items(): |
| comp_name = self._id_to_name(comp_id) |
| if comp_name in dict_to_return: |
| raise ValueError( |
| f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" |
| ) |
| dict_to_return[comp_name] = comp |
| return dict_to_return |
| else: |
| return components |
|
|
| |
| if names is None: |
| return get_return_dict(components, return_dict_with_names) |
|
|
| |
| elif not isinstance(names, str): |
| raise ValueError(f"Invalid type for `names: {type(names)}, only support string") |
|
|
| |
| base_names = {comp_id: self._id_to_name(comp_id) for comp_id in components.keys()} |
|
|
| |
| def matches_pattern(component_id, pattern, exact_match=False): |
| """ |
| Helper function to check if a component matches a pattern based on its base name. |
| |
| Args: |
| component_id: The component ID to check |
| pattern: The pattern to match against |
| exact_match: If True, only exact matches to base_name are considered |
| """ |
| base_name = base_names[component_id] |
|
|
| |
| if exact_match: |
| return pattern == base_name |
|
|
| |
| elif pattern.endswith("*"): |
| prefix = pattern[:-1] |
| return base_name.startswith(prefix) |
|
|
| |
| elif pattern.startswith("*"): |
| search = pattern[1:-1] if pattern.endswith("*") else pattern[1:] |
| return search in base_name |
|
|
| |
| else: |
| return pattern == base_name |
|
|
| |
| is_not_pattern = names.startswith("!") |
| if is_not_pattern: |
| names = names[1:] |
|
|
| |
| if "|" in names: |
| terms = names.split("|") |
| matches = {} |
|
|
| for comp_id, comp in components.items(): |
| |
| exact_match = all(not (term.startswith("*") or term.endswith("*")) for term in terms) |
|
|
| |
| should_include = any(matches_pattern(comp_id, term, exact_match) for term in terms) |
|
|
| |
| if is_not_pattern: |
| should_include = not should_include |
|
|
| if should_include: |
| matches[comp_id] = comp |
|
|
| log_msg = "NOT " if is_not_pattern else "" |
| match_type = "exactly matching" if exact_match else "matching any of patterns" |
| logger.info(f"Getting components {log_msg}{match_type} {terms}: {list(matches.keys())}") |
|
|
| |
| elif any(names == base_name for base_name in base_names.values()): |
| |
| matches = { |
| comp_id: comp |
| for comp_id, comp in components.items() |
| if (base_names[comp_id] == names) != is_not_pattern |
| } |
|
|
| if is_not_pattern: |
| logger.info(f"Getting all components except those with base name '{names}': {list(matches.keys())}") |
| else: |
| logger.info(f"Getting components with base name '{names}': {list(matches.keys())}") |
|
|
| |
| elif names.endswith("*"): |
| prefix = names[:-1] |
| matches = { |
| comp_id: comp |
| for comp_id, comp in components.items() |
| if base_names[comp_id].startswith(prefix) != is_not_pattern |
| } |
| if is_not_pattern: |
| logger.info(f"Getting components NOT starting with '{prefix}': {list(matches.keys())}") |
| else: |
| logger.info(f"Getting components starting with '{prefix}': {list(matches.keys())}") |
|
|
| |
| elif names.startswith("*"): |
| search = names[1:-1] if names.endswith("*") else names[1:] |
| matches = { |
| comp_id: comp |
| for comp_id, comp in components.items() |
| if (search in base_names[comp_id]) != is_not_pattern |
| } |
| if is_not_pattern: |
| logger.info(f"Getting components NOT containing '{search}': {list(matches.keys())}") |
| else: |
| logger.info(f"Getting components containing '{search}': {list(matches.keys())}") |
|
|
| |
| elif any(names in base_name for base_name in base_names.values()): |
| matches = { |
| comp_id: comp |
| for comp_id, comp in components.items() |
| if (names in base_names[comp_id]) != is_not_pattern |
| } |
| if is_not_pattern: |
| logger.info(f"Getting components NOT containing '{names}': {list(matches.keys())}") |
| else: |
| logger.info(f"Getting components containing '{names}': {list(matches.keys())}") |
|
|
| else: |
| raise ValueError(f"Component or pattern '{names}' not found in ComponentsManager") |
|
|
| if not matches: |
| raise ValueError(f"No components found matching pattern '{names}'") |
|
|
| return get_return_dict(matches, return_dict_with_names) |
|
|
| def enable_auto_cpu_offload(self, device: str | int | torch.device = None, memory_reserve_margin="3GB"): |
| """ |
| Enable automatic CPU offloading for all components. |
| |
| The algorithm works as follows: |
| 1. All models start on CPU by default |
| 2. When a model's forward pass is called, it's moved to the execution device |
| 3. If there's insufficient memory, other models on the device are moved back to CPU |
| 4. The system tries to offload the smallest combination of models that frees enough memory |
| 5. Models stay on the execution device until another model needs memory and forces them off |
| |
| Args: |
| device (str | int | torch.device): The execution device where models are moved for forward passes |
| memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of |
| memory to keep free on the device to avoid running out of memory during model |
| execution (e.g., for intermediate activations, gradients, etc.) |
| """ |
| if not is_accelerate_available(): |
| raise ImportError("Make sure to install accelerate to use auto_cpu_offload") |
|
|
| if device is None: |
| device = get_device() |
| if not isinstance(device, torch.device): |
| device = torch.device(device) |
|
|
| device_type = device.type |
| device_module = getattr(torch, device_type, torch.cuda) |
| if not hasattr(device_module, "mem_get_info"): |
| raise NotImplementedError( |
| f"`enable_auto_cpu_offload() relies on the `mem_get_info()` method. It's not implemented for {str(device.type)}." |
| ) |
|
|
| if device.index is None: |
| device = torch.device(f"{device.type}:{0}") |
|
|
| for name, component in self.components.items(): |
| if isinstance(component, torch.nn.Module) and hasattr(component, "_hf_hook"): |
| remove_hook_from_module(component, recurse=True) |
|
|
| self.disable_auto_cpu_offload() |
| offload_strategy = AutoOffloadStrategy(memory_reserve_margin=memory_reserve_margin) |
|
|
| all_hooks = [] |
| for name, component in self.components.items(): |
| if isinstance(component, torch.nn.Module): |
| hook = custom_offload_with_hook(name, component, device, offload_strategy=offload_strategy) |
| all_hooks.append(hook) |
|
|
| for hook in all_hooks: |
| other_hooks = [h for h in all_hooks if h is not hook] |
| for other_hook in other_hooks: |
| if other_hook.hook.execution_device == hook.hook.execution_device: |
| hook.add_other_hook(other_hook) |
|
|
| self.model_hooks = all_hooks |
| self._auto_offload_enabled = True |
| self._auto_offload_device = device |
|
|
| def disable_auto_cpu_offload(self): |
| """ |
| Disable automatic CPU offloading for all components. |
| """ |
| if self.model_hooks is None: |
| self._auto_offload_enabled = False |
| return |
|
|
| for hook in self.model_hooks: |
| hook.offload() |
| hook.remove() |
| if self.model_hooks: |
| clear_device_cache() |
| self.model_hooks = None |
| self._auto_offload_enabled = False |
|
|
| def get_model_info( |
| self, |
| component_id: str, |
| fields: str | list[str] | None = None, |
| ) -> dict[str, Any] | None: |
| """Get comprehensive information about a component. |
| |
| Args: |
| component_id (str): Name of the component to get info for |
| fields (str | list[str] | None): |
| Field(s) to return. Can be a string for single field or list of fields. If None, uses the |
| available_info_fields setting. |
| |
| Returns: |
| Dictionary containing requested component metadata. If fields is specified, returns only those fields. |
| Otherwise, returns all fields. |
| """ |
| if component_id not in self.components: |
| raise ValueError(f"Component '{component_id}' not found in ComponentsManager") |
|
|
| component = self.components[component_id] |
|
|
| |
| if fields is not None: |
| if isinstance(fields, str): |
| fields = [fields] |
| for field in fields: |
| if field not in self._available_info_fields: |
| raise ValueError(f"Field '{field}' not found in available_info_fields") |
|
|
| |
| info = { |
| "model_id": component_id, |
| "added_time": self.added_time[component_id], |
| "collection": ", ".join([coll for coll, comps in self.collections.items() if component_id in comps]) |
| or None, |
| } |
|
|
| |
| if isinstance(component, torch.nn.Module): |
| |
| has_hook = hasattr(component, "_hf_hook") |
| execution_device = None |
| if has_hook and hasattr(component._hf_hook, "execution_device"): |
| execution_device = component._hf_hook.execution_device |
|
|
| info.update( |
| { |
| "class_name": component.__class__.__name__, |
| "size_gb": component.get_memory_footprint() / (1024**3), |
| "adapters": None, |
| "has_hook": has_hook, |
| "execution_device": execution_device, |
| } |
| ) |
|
|
| |
| if hasattr(component, "peft_config"): |
| info["adapters"] = list(component.peft_config.keys()) |
|
|
| |
| if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"): |
| processors = copy.deepcopy(component.attn_processors) |
| |
| processor_types = [v.__class__.__name__ for v in processors.values()] |
| if any("IPAdapter" in ptype for ptype in processor_types): |
| |
| scales = { |
| k: v.scale |
| for k, v in processors.items() |
| if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__ |
| } |
| if scales: |
| info["ip_adapter"] = summarize_dict_by_value_and_parts(scales) |
|
|
| |
| hf_quantizer = getattr(component, "hf_quantizer", None) |
| if hf_quantizer is not None: |
| quant_config = hf_quantizer.quantization_config |
| if hasattr(quant_config, "to_diff_dict"): |
| info["quantization"] = quant_config.to_diff_dict() |
| else: |
| info["quantization"] = quant_config.to_dict() |
| else: |
| info["quantization"] = None |
|
|
| |
| if fields is not None: |
| return {k: v for k, v in info.items() if k in fields} |
| else: |
| return info |
|
|
| |
| def __repr__(self): |
| |
| if not self.components: |
| return "Components:\n" + "=" * 50 + "\nNo components registered.\n" + "=" * 50 |
|
|
| |
| def get_load_id(component): |
| if hasattr(component, "_diffusers_load_id"): |
| return component._diffusers_load_id |
| return "N/A" |
|
|
| |
| def format_device(component, info): |
| if not info["has_hook"]: |
| return str(getattr(component, "device", "N/A")) |
| else: |
| device = str(getattr(component, "device", "N/A")) |
| exec_device = str(info["execution_device"] or "N/A") |
| return f"{device}({exec_device})" |
|
|
| |
| load_ids = [ |
| get_load_id(component) |
| for component in self.components.values() |
| if isinstance(component, torch.nn.Module) and hasattr(component, "_diffusers_load_id") |
| ] |
| max_load_id_len = max([15] + [len(str(lid)) for lid in load_ids]) if load_ids else 15 |
|
|
| |
| component_collections = {} |
| for name in self.components.keys(): |
| component_collections[name] = [] |
| for coll, comps in self.collections.items(): |
| if name in comps: |
| component_collections[name].append(coll) |
| if not component_collections[name]: |
| component_collections[name] = ["N/A"] |
|
|
| |
| all_collections = [coll for colls in component_collections.values() for coll in colls] |
| max_collection_len = max(10, max(len(str(c)) for c in all_collections)) if all_collections else 10 |
|
|
| col_widths = { |
| "id": max(15, max(len(name) for name in self.components.keys())), |
| "class": max(25, max(len(component.__class__.__name__) for component in self.components.values())), |
| "device": 20, |
| "dtype": 15, |
| "size": 10, |
| "load_id": max_load_id_len, |
| "collection": max_collection_len, |
| } |
|
|
| |
| sep_line = "=" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" |
| dash_line = "-" * (sum(col_widths.values()) + len(col_widths) * 3 - 1) + "\n" |
|
|
| output = "Components:\n" + sep_line |
|
|
| |
| models = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)} |
| others = {k: v for k, v in self.components.items() if not isinstance(v, torch.nn.Module)} |
|
|
| |
| if models: |
| output += "Models:\n" + dash_line |
| |
| output += f"{'Name_ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | " |
| output += f"{'Device: act(exec)':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | " |
| output += f"{'Size (GB)':<{col_widths['size']}} | {'Load ID':<{col_widths['load_id']}} | Collection\n" |
| output += dash_line |
|
|
| |
| for name, component in models.items(): |
| info = self.get_model_info(name) |
| device_str = format_device(component, info) |
| dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A" |
| load_id = get_load_id(component) |
|
|
| |
| first_collection = component_collections[name][0] if component_collections[name] else "N/A" |
|
|
| output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | " |
| output += f"{device_str:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | " |
| output += f"{info['size_gb']:<{col_widths['size']}.2f} | {load_id:<{col_widths['load_id']}} | {first_collection}\n" |
|
|
| |
| for i in range(1, len(component_collections[name])): |
| collection = component_collections[name][i] |
| output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | " |
| output += f"{'':<{col_widths['device']}} | {'':<{col_widths['dtype']}} | " |
| output += f"{'':<{col_widths['size']}} | {'':<{col_widths['load_id']}} | {collection}\n" |
|
|
| output += dash_line |
|
|
| |
| if others: |
| if models: |
| output += "\n" |
| output += "Other Components:\n" + dash_line |
| |
| output += f"{'ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | Collection\n" |
| output += dash_line |
|
|
| |
| for name, component in others.items(): |
| info = self.get_model_info(name) |
|
|
| |
| first_collection = component_collections[name][0] if component_collections[name] else "N/A" |
|
|
| output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | {first_collection}\n" |
|
|
| |
| for i in range(1, len(component_collections[name])): |
| collection = component_collections[name][i] |
| output += f"{'':<{col_widths['id']}} | {'':<{col_widths['class']}} | {collection}\n" |
|
|
| output += dash_line |
|
|
| |
| output += "\nAdditional Component Info:\n" + "=" * 50 + "\n" |
| for name in self.components: |
| info = self.get_model_info(name) |
| if info is not None and ( |
| info.get("adapters") is not None or info.get("ip_adapter") or info.get("quantization") |
| ): |
| output += f"\n{name}:\n" |
| if info.get("adapters") is not None: |
| output += f" Adapters: {info['adapters']}\n" |
| if info.get("ip_adapter"): |
| output += " IP-Adapter: Enabled\n" |
| if info.get("quantization"): |
| output += f" Quantization: {info['quantization']}\n" |
|
|
| return output |
|
|
| def get_one( |
| self, |
| component_id: str | None = None, |
| name: str | None = None, |
| collection: str | None = None, |
| load_id: str | None = None, |
| ) -> Any: |
| """ |
| Get a single component by either: |
| - searching name (pattern matching), collection, or load_id. |
| - passing in a component_id |
| Raises an error if multiple components match or none are found. |
| |
| Args: |
| component_id (str | None): Optional component ID to get |
| name (str | None): Component name or pattern |
| collection (str | None): Optional collection to filter by |
| load_id (str | None): Optional load_id to filter by |
| |
| Returns: |
| A single component |
| |
| Raises: |
| ValueError: If no components match or multiple components match |
| """ |
|
|
| if component_id is not None and (name is not None or collection is not None or load_id is not None): |
| raise ValueError("If searching by component_id, do not pass name, collection, or load_id") |
|
|
| |
| if component_id is not None: |
| if component_id not in self.components: |
| raise ValueError(f"Component '{component_id}' not found in ComponentsManager") |
| return self.components[component_id] |
| |
| results = self.search_components(name, collection, load_id) |
|
|
| if not results: |
| raise ValueError(f"No components found matching '{name}'") |
|
|
| if len(results) > 1: |
| raise ValueError(f"Multiple components found matching '{name}': {list(results.keys())}") |
|
|
| return next(iter(results.values())) |
|
|
| def get_ids(self, names: str | list[str] = None, collection: str | None = None): |
| """ |
| Get component IDs by a list of names, optionally filtered by collection. |
| |
| Args: |
| names (str | list[str]): list of component names |
| collection (str | None): Optional collection to filter by |
| |
| Returns: |
| list[str]: list of component IDs |
| """ |
| ids = set() |
| if not isinstance(names, list): |
| names = [names] |
| for name in names: |
| ids.update(self._lookup_ids(name=name, collection=collection)) |
| return list(ids) |
|
|
| def get_components_by_ids(self, ids: list[str], return_dict_with_names: bool | None = True): |
| """ |
| Get components by a list of IDs. |
| |
| Args: |
| ids (list[str]): |
| list of component IDs |
| return_dict_with_names (bool | None): |
| Whether to return a dictionary with component names as keys: |
| |
| Returns: |
| dict[str, Any]: Dictionary of components. |
| - If return_dict_with_names=True, keys are component names. |
| - If return_dict_with_names=False, keys are component IDs. |
| |
| Raises: |
| ValueError: If duplicate component names are found in the search results when return_dict_with_names=True |
| """ |
| components = {id: self.components[id] for id in ids} |
|
|
| if return_dict_with_names: |
| dict_to_return = {} |
| for comp_id, comp in components.items(): |
| comp_name = self._id_to_name(comp_id) |
| if comp_name in dict_to_return: |
| raise ValueError( |
| f"Duplicate component names found in the search results: {comp_name}, please set `return_dict_with_names=False` to return a dictionary with component IDs as keys" |
| ) |
| dict_to_return[comp_name] = comp |
| return dict_to_return |
| else: |
| return components |
|
|
| def get_components_by_names(self, names: list[str], collection: str | None = None): |
| """ |
| Get components by a list of names, optionally filtered by collection. |
| |
| Args: |
| names (list[str]): list of component names |
| collection (str | None): Optional collection to filter by |
| |
| Returns: |
| dict[str, Any]: Dictionary of components with component names as keys |
| |
| Raises: |
| ValueError: If duplicate component names are found in the search results |
| """ |
| ids = self.get_ids(names, collection) |
| return self.get_components_by_ids(ids) |
|
|