File size: 6,351 Bytes
3dabe4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from typing import Optional, Callable, Dict
from collections import OrderedDict
from warnings import warn
import logging
import torch


from modules import shared


logger = logging.getLogger(__name__)


def modules_add_field(modules, field, value=None):
    """ Add a field to a module if it isn't already added.
    Args:
        modules (list): Module or list of modules to add the field to
        field (str): Field name to add
        value (any): Value to assign to the field
    Returns:
        None

    """
    if not isinstance(modules, list):
        modules = [modules]
    for module in modules:
        if not hasattr(module, field):
            setattr(module, field, value)
        else:
            logger.warning(f"Field {field} already exists in module {module}")


def modules_remove_field(modules, field):
    """ Remove a field from a module if it exists.
    Args:
        modules (list): Module or list of modules to add the field to
        field (str): Field name to add
        value (any): Value to assign to the field
    Returns:
        None

    """
    if not isinstance(modules, list):
        modules = [modules]
    for module in modules:
        if hasattr(module, field):
                delattr(module, field)
        else:
            # logger.warning(f"Field {field} does not exist in module {module}")
            pass


def get_modules(network_layer_name_filter: Optional[str] = None, module_name_filter: Optional[str] = None):
    """ Get all modules from the shared.sd_model that match the filters provided. If no filters are provided, all modules are returned.

    Args:
        network_layer_name_filter (Optional[str], optional): Filters the modules by network layer name. Defaults to None. Example: "attn1" will return all modules that have "attn1" in their network layer name.
        module_name_filter (Optional[str], optional): Filters the modules by module class name. Defaults to None. Example: "CrossAttention" will return all modules that have "CrossAttention" in their class name.

    Returns:
        list: List of modules that match the filters provided.
    """
    try:
        m = shared.sd_model
        nlm = m.network_layer_mapping
        sd_model_modules = nlm.values()

        # Apply filters if they are provided
        if network_layer_name_filter is not None:
            sd_model_modules = list(filter(lambda m: network_layer_name_filter in m.network_layer_name, sd_model_modules))
        if module_name_filter is not None:
            sd_model_modules = list(filter(lambda m: module_name_filter in m.__class__.__name__, sd_model_modules))
        return sd_model_modules
    except AttributeError:
        logger.exception("AttributeError in get_modules", stack_info=True)
        return []
    except Exception:
        logger.exception("Exception in get_modules", stack_info=True)
        return []


# workaround for torch remove hooks issue
# thank you to @ProGamerGov for this https://github.com/pytorch/pytorch/issues/70455
def remove_module_forward_hook(
    module: torch.nn.Module, hook_fn_name: Optional[str] = None
) -> None:
    """
    This function removes all forward hooks in the specified module, without requiring
    any hook handles. This lets us clean up & remove any hooks that weren't property
    deleted.

    Warning: Various PyTorch modules and systems make use of hooks, and thus extreme
    caution should be exercised when removing all hooks. Users are recommended to give
    their hook function a unique name that can be used to safely identify and remove
    the target forward hooks.

    Args:

        module (nn.Module): The module instance to remove forward hooks from.
        hook_fn_name (str, optional): Optionally only remove specific forward hooks
            based on their function's __name__ attribute.
            Default: None
    """

    if hook_fn_name is None:
        warn("Removing all active hooks can break some PyTorch modules & systems.")

    def _remove_hooks(m: torch.nn.Module, name: Optional[str] = None) -> None:
        if hasattr(module, "_forward_hooks"):
            if m._forward_hooks != OrderedDict():
                if name is not None:
                    dict_items = list(m._forward_hooks.items())
                    m._forward_hooks = OrderedDict(
                        [(i, fn) for i, fn in dict_items if fn.__name__ != name]
                    )
                else:
                    m._forward_hooks: Dict[int, Callable] = OrderedDict()

    def _remove_child_hooks(
        target_module: torch.nn.Module, hook_name: Optional[str] = None
    ) -> None:
        for name, child in target_module._modules.items():
            if child is not None:
                _remove_hooks(child, hook_name)
                _remove_child_hooks(child, hook_name)

    # Remove hooks from target submodules
    _remove_child_hooks(module, hook_fn_name)

    # Remove hooks from the target module
    _remove_hooks(module, hook_fn_name)


def module_add_forward_hook(module, hook_fn, hook_type="forward", with_kwargs=False):
    """ Adds a forward hook to a module.

    hook_fn should be a function that accepts the following arguments:
        forward hook, no kwargs: hook(module, args, output) -> None or modified output
        forward hook, with kwargs: hook(module, args, kwargs output) -> None or modified output

    Args:
        module (torch.nn.Module): Module to hook
        hook_fn (Callable): Function to call when the hook is triggered
        hook_type (str, optional): Type of hook to create. Defaults to "forward". Can be "forward" or "pre_forward".
        with_kwargs (bool, optional): Whether the hook function should accept keyword arguments. Defaults to False.

    Returns:
        torch.utils.hooks.RemovableHandle: Handle for the hook
    """
    if module is None:
        raise ValueError("module must be provided")
    if not callable(hook_fn):
        raise ValueError("hook_fn must be a callable function")

    if hook_type == "forward":
        handle = module.register_forward_hook(hook_fn, with_kwargs=with_kwargs)
    elif hook_type == "pre_forward":
        handle = module.register_forward_pre_hook(hook_fn, with_kwargs=with_kwargs)
    else:
        raise ValueError(f"Invalid hook type {hook_type}. Must be 'forward' or 'pre_forward'.")

    return handle