from typing import Optional, Callable, Dict from collections import OrderedDict from warnings import warn import logging import torch from modules import shared logger = logging.getLogger(__name__) def modules_add_field(modules, field, value=None): """ Add a field to a module if it isn't already added. Args: modules (list): Module or list of modules to add the field to field (str): Field name to add value (any): Value to assign to the field Returns: None """ if not isinstance(modules, list): modules = [modules] for module in modules: if not hasattr(module, field): setattr(module, field, value) else: logger.warning(f"Field {field} already exists in module {module}") def modules_remove_field(modules, field): """ Remove a field from a module if it exists. Args: modules (list): Module or list of modules to add the field to field (str): Field name to add value (any): Value to assign to the field Returns: None """ if not isinstance(modules, list): modules = [modules] for module in modules: if hasattr(module, field): delattr(module, field) else: # logger.warning(f"Field {field} does not exist in module {module}") pass def get_modules(network_layer_name_filter: Optional[str] = None, module_name_filter: Optional[str] = None): """ Get all modules from the shared.sd_model that match the filters provided. If no filters are provided, all modules are returned. Args: network_layer_name_filter (Optional[str], optional): Filters the modules by network layer name. Defaults to None. Example: "attn1" will return all modules that have "attn1" in their network layer name. module_name_filter (Optional[str], optional): Filters the modules by module class name. Defaults to None. Example: "CrossAttention" will return all modules that have "CrossAttention" in their class name. Returns: list: List of modules that match the filters provided. """ try: m = shared.sd_model nlm = m.network_layer_mapping sd_model_modules = nlm.values() # Apply filters if they are provided if network_layer_name_filter is not None: sd_model_modules = list(filter(lambda m: network_layer_name_filter in m.network_layer_name, sd_model_modules)) if module_name_filter is not None: sd_model_modules = list(filter(lambda m: module_name_filter in m.__class__.__name__, sd_model_modules)) return sd_model_modules except AttributeError: logger.exception("AttributeError in get_modules", stack_info=True) return [] except Exception: logger.exception("Exception in get_modules", stack_info=True) return [] # workaround for torch remove hooks issue # thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455 def remove_module_forward_hook( module: torch.nn.Module, hook_fn_name: Optional[str] = None ) -> None: """ This function removes all forward hooks in the specified module, without requiring any hook handles. This lets us clean up & remove any hooks that weren't property deleted. Warning: Various PyTorch modules and systems make use of hooks, and thus extreme caution should be exercised when removing all hooks. Users are recommended to give their hook function a unique name that can be used to safely identify and remove the target forward hooks. Args: module (nn.Module): The module instance to remove forward hooks from. hook_fn_name (str, optional): Optionally only remove specific forward hooks based on their function's __name__ attribute. Default: None """ if hook_fn_name is None: warn("Removing all active hooks can break some PyTorch modules & systems.") def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None: if hasattr(module, "_forward_hooks"): if m._forward_hooks != OrderedDict(): if name is not None: dict_items = list(m._forward_hooks.items()) m._forward_hooks = OrderedDict( [(i, fn) for i, fn in dict_items if fn.__name__ != name] ) else: m._forward_hooks: Dict[int, Callable] = OrderedDict() def _remove_child_hooks( target_module: torch.nn.Module, hook_name: Optional[str] = None ) -> None: for name, child in target_module._modules.items(): if child is not None: _remove_hooks(child, hook_name) _remove_child_hooks(child, hook_name) # Remove hooks from target submodules _remove_child_hooks(module, hook_fn_name) # Remove hooks from the target module _remove_hooks(module, hook_fn_name) def module_add_forward_hook(module, hook_fn, hook_type="forward", with_kwargs=False): """ Adds a forward hook to a module. hook_fn should be a function that accepts the following arguments: forward hook, no kwargs: hook(module, args, output) -> None or modified output forward hook, with kwargs: hook(module, args, kwargs output) -> None or modified output Args: module (torch.nn.Module): Module to hook hook_fn (Callable): Function to call when the hook is triggered hook_type (str, optional): Type of hook to create. Defaults to "forward". Can be "forward" or "pre_forward". with_kwargs (bool, optional): Whether the hook function should accept keyword arguments. Defaults to False. Returns: torch.utils.hooks.RemovableHandle: Handle for the hook """ if module is None: raise ValueError("module must be provided") if not callable(hook_fn): raise ValueError("hook_fn must be a callable function") if hook_type == "forward": handle = module.register_forward_hook(hook_fn, with_kwargs=with_kwargs) elif hook_type == "pre_forward": handle = module.register_forward_pre_hook(hook_fn, with_kwargs=with_kwargs) else: raise ValueError(f"Invalid hook type {hook_type}. Must be 'forward' or 'pre_forward'.") return handle