File size: 6,351 Bytes
3dabe4a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
from typing import Optional, Callable, Dict
from collections import OrderedDict
from warnings import warn
import logging
import torch
from modules import shared
logger = logging.getLogger(__name__)
def modules_add_field(modules, field, value=None):
""" Add a field to a module if it isn't already added.
Args:
modules (list): Module or list of modules to add the field to
field (str): Field name to add
value (any): Value to assign to the field
Returns:
None
"""
if not isinstance(modules, list):
modules = [modules]
for module in modules:
if not hasattr(module, field):
setattr(module, field, value)
else:
logger.warning(f"Field {field} already exists in module {module}")
def modules_remove_field(modules, field):
""" Remove a field from a module if it exists.
Args:
modules (list): Module or list of modules to add the field to
field (str): Field name to add
value (any): Value to assign to the field
Returns:
None
"""
if not isinstance(modules, list):
modules = [modules]
for module in modules:
if hasattr(module, field):
delattr(module, field)
else:
# logger.warning(f"Field {field} does not exist in module {module}")
pass
def get_modules(network_layer_name_filter: Optional[str] = None, module_name_filter: Optional[str] = None):
""" Get all modules from the shared.sd_model that match the filters provided. If no filters are provided, all modules are returned.
Args:
network_layer_name_filter (Optional[str], optional): Filters the modules by network layer name. Defaults to None. Example: "attn1" will return all modules that have "attn1" in their network layer name.
module_name_filter (Optional[str], optional): Filters the modules by module class name. Defaults to None. Example: "CrossAttention" will return all modules that have "CrossAttention" in their class name.
Returns:
list: List of modules that match the filters provided.
"""
try:
m = shared.sd_model
nlm = m.network_layer_mapping
sd_model_modules = nlm.values()
# Apply filters if they are provided
if network_layer_name_filter is not None:
sd_model_modules = list(filter(lambda m: network_layer_name_filter in m.network_layer_name, sd_model_modules))
if module_name_filter is not None:
sd_model_modules = list(filter(lambda m: module_name_filter in m.__class__.__name__, sd_model_modules))
return sd_model_modules
except AttributeError:
logger.exception("AttributeError in get_modules", stack_info=True)
return []
except Exception:
logger.exception("Exception in get_modules", stack_info=True)
return []
# workaround for torch remove hooks issue
# thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455
def remove_module_forward_hook(
module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> None:
"""
This function removes all forward hooks in the specified module, without requiring
any hook handles. This lets us clean up & remove any hooks that weren't property
deleted.
Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
caution should be exercised when removing all hooks. Users are recommended to give
their hook function a unique name that can be used to safely identify and remove
the target forward hooks.
Args:
module (nn.Module): The module instance to remove forward hooks from.
hook_fn_name (str, optional): Optionally only remove specific forward hooks
based on their function's __name__ attribute.
Default: None
"""
if hook_fn_name is None:
warn("Removing all active hooks can break some PyTorch modules & systems.")
def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
if hasattr(module, "_forward_hooks"):
if m._forward_hooks != OrderedDict():
if name is not None:
dict_items = list(m._forward_hooks.items())
m._forward_hooks = OrderedDict(
[(i, fn) for i, fn in dict_items if fn.__name__ != name]
)
else:
m._forward_hooks: Dict[int, Callable] = OrderedDict()
def _remove_child_hooks(
target_module: torch.nn.Module, hook_name: Optional[str] = None
) -> None:
for name, child in target_module._modules.items():
if child is not None:
_remove_hooks(child, hook_name)
_remove_child_hooks(child, hook_name)
# Remove hooks from target submodules
_remove_child_hooks(module, hook_fn_name)
# Remove hooks from the target module
_remove_hooks(module, hook_fn_name)
def module_add_forward_hook(module, hook_fn, hook_type="forward", with_kwargs=False):
""" Adds a forward hook to a module.
hook_fn should be a function that accepts the following arguments:
forward hook, no kwargs: hook(module, args, output) -> None or modified output
forward hook, with kwargs: hook(module, args, kwargs output) -> None or modified output
Args:
module (torch.nn.Module): Module to hook
hook_fn (Callable): Function to call when the hook is triggered
hook_type (str, optional): Type of hook to create. Defaults to "forward". Can be "forward" or "pre_forward".
with_kwargs (bool, optional): Whether the hook function should accept keyword arguments. Defaults to False.
Returns:
torch.utils.hooks.RemovableHandle: Handle for the hook
"""
if module is None:
raise ValueError("module must be provided")
if not callable(hook_fn):
raise ValueError("hook_fn must be a callable function")
if hook_type == "forward":
handle = module.register_forward_hook(hook_fn, with_kwargs=with_kwargs)
elif hook_type == "pre_forward":
handle = module.register_forward_pre_hook(hook_fn, with_kwargs=with_kwargs)
else:
raise ValueError(f"Invalid hook type {hook_type}. Must be 'forward' or 'pre_forward'.")
return handle |