manbeast3b commited on
Commit
7052af9
·
verified ·
1 Parent(s): 5021878

Create norm_attn_hook.py

Browse files
Files changed (1) hide show
  1. src/norm_attn_hook.py +242 -0
src/norm_attn_hook.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TODO should be a parent class for all the hooks !! for the offical repo
2
+ # 1: FLUX Norm
3
+
4
+ import logging
5
+ import os
6
+ from collections import OrderedDict
7
+ from functools import partial
8
+
9
+ import torch
10
+ from torch import nn
11
+
12
+ import re
13
+
14
+
15
+ class NormHooker:
16
+ def __init__(
17
+ self,
18
+ pipeline: nn.Module,
19
+ regex: str,
20
+ dtype: torch.dtype,
21
+ masking: str,
22
+ dst: str,
23
+ epsilon: float = 0.0,
24
+ eps: float = 1e-6,
25
+ use_log: bool = False,
26
+ binary: bool = False,
27
+ ):
28
+ self.pipeline = pipeline
29
+ self.net = pipeline.unet if hasattr(pipeline, "unet") else pipeline.transformer
30
+ self.logger = logging.getLogger(__name__)
31
+ self.dtype = dtype
32
+ self.regex = regex
33
+ self.hook_dict = {}
34
+ self.masking = masking
35
+ self.dst = dst
36
+ self.epsilon = epsilon
37
+ self.eps = eps
38
+ self.use_log = use_log
39
+ self.lambs = []
40
+ self.lambs_module_names = [] # store the module names for each lambda block
41
+ self.hook_counter = 0
42
+ self.module_neurons = OrderedDict()
43
+ self.binary = (
44
+ binary # default, need to discuss if we need to keep this attribute or not
45
+ )
46
+
47
+ def add_hooks_to_norm(self, hook_fn: callable):
48
+ """
49
+ Add forward hooks to every feed forward layer matching the regex
50
+ :param hook_fn: a callable to be added to torch nn module as a hook
51
+ :return: dictionary of added hooks
52
+ """
53
+ total_hooks = 0
54
+ for name, module in self.net.named_modules():
55
+ name_last_word = name.split(".")[-1]
56
+ if "norm1_context" in name_last_word:
57
+ if re.match(self.regex, name):
58
+ hook_fn_with_name = partial(hook_fn, name=name)
59
+
60
+ if hasattr(module, "linear"):
61
+ actual_module = module.linear
62
+ else:
63
+ if isinstance(module, nn.Linear):
64
+ actual_module = module
65
+ else:
66
+ continue
67
+
68
+ hook = actual_module.register_forward_hook(
69
+ hook_fn_with_name, with_kwargs=True
70
+ )
71
+ self.hook_dict[name] = hook
72
+
73
+ # AdaLayerNormZero
74
+ if isinstance(actual_module, torch.nn.Linear):
75
+ self.module_neurons[name] = actual_module.out_features
76
+ else:
77
+ raise NotImplementedError(
78
+ f"Module {name} is not implemented, please check"
79
+ )
80
+ self.logger.info(
81
+ f"Adding hook to {name}, neurons: {self.module_neurons[name]}"
82
+ )
83
+ total_hooks += 1
84
+ self.logger.info(f"Total hooks added: {total_hooks}")
85
+ return self.hook_dict
86
+
87
+ def add_hooks(self, init_value=1.0):
88
+ hook_fn = self.get_norm_masking_hook(init_value)
89
+ self.add_hooks_to_norm(hook_fn)
90
+ # initialize the lambda
91
+ self.lambs = [None] * len(self.hook_dict)
92
+ # initialize the lambda module names
93
+ self.lambs_module_names = [None] * len(self.hook_dict)
94
+
95
+ def clear_hooks(self):
96
+ """clear all hooks"""
97
+ for hook in self.hook_dict.values():
98
+ hook.remove()
99
+ self.hook_dict.clear()
100
+
101
+ def save(self, name: str = None):
102
+ if name is not None:
103
+ dst = os.path.join(os.path.dirname(self.dst), name)
104
+ else:
105
+ dst = self.dst
106
+ dst_dir = os.path.dirname(dst)
107
+ if not os.path.exists(dst_dir):
108
+ self.logger.info(f"Creating directory {dst_dir}")
109
+ os.makedirs(dst_dir)
110
+ torch.save(self.lambs, dst)
111
+
112
+ @property
113
+ def get_lambda_block_names(self):
114
+ return self.lambs_module_names
115
+
116
+ def load(self, device, threshold):
117
+ if os.path.exists(self.dst):
118
+ self.logger.info(f"loading lambda from {self.dst}")
119
+ self.lambs = torch.load(self.dst, weights_only=True, map_location=device)
120
+ if self.binary:
121
+ # set binary masking for each lambda by using clamp
122
+ self.lambs = [
123
+ (torch.relu(lamb - threshold) > 0).float() for lamb in self.lambs
124
+ ]
125
+ else:
126
+ self.lambs = [torch.clamp(lamb, min=0.0) for lamb in self.lambs]
127
+ # self.lambs_module_names = [None for _ in self.lambs]
128
+ else:
129
+ self.logger.info("skipping loading, training from scratch")
130
+
131
+ def binarize(self, scope: str, ratio: float):
132
+ """
133
+ binarize lambda to be 0 or 1
134
+ :param scope: either locally (sparsity within layer) or globally (sparsity within model)
135
+ :param ratio: the ratio of the number of 1s to the total number of elements
136
+ """
137
+ assert scope in ["local", "global"], "scope must be either local or global"
138
+ assert (
139
+ not self.binary
140
+ ), "binarization is not supported when using binary mask already"
141
+ if scope == "local":
142
+ # Local binarization
143
+ for i, lamb in enumerate(self.lambs):
144
+ num_heads = lamb.size(0)
145
+ num_activate_heads = int(num_heads * ratio)
146
+ # Sort the lambda values with stable sorting to maintain order for equal values
147
+ sorted_lamb, sorted_indices = torch.sort(
148
+ lamb, descending=True, stable=True
149
+ )
150
+ # Find the threshold value
151
+ threshold = sorted_lamb[num_activate_heads - 1]
152
+ # Create a mask based on the sorted indices
153
+ mask = torch.zeros_like(lamb)
154
+ mask[sorted_indices[:num_activate_heads]] = 1.0
155
+ # Binarize the lambda based on the threshold and the mask
156
+ self.lambs[i] = torch.where(
157
+ lamb > threshold, torch.ones_like(lamb), mask
158
+ )
159
+ else:
160
+ # Global binarization
161
+ all_lambs = torch.cat([lamb.flatten() for lamb in self.lambs])
162
+ num_total = all_lambs.numel()
163
+ num_activate = int(num_total * ratio)
164
+ # Sort all lambda values globally with stable sorting
165
+ sorted_lambs, sorted_indices = torch.sort(
166
+ all_lambs, descending=True, stable=True
167
+ )
168
+ # Find the global threshold value
169
+ threshold = sorted_lambs[num_activate - 1]
170
+ # Create a global mask based on the sorted indices
171
+ global_mask = torch.zeros_like(all_lambs)
172
+ global_mask[sorted_indices[:num_activate]] = 1.0
173
+ # Binarize all lambdas based on the global threshold and mask
174
+ start_idx = 0
175
+ for i in range(len(self.lambs)):
176
+ end_idx = start_idx + self.lambs[i].numel()
177
+ lamb_mask = global_mask[start_idx:end_idx].reshape(self.lambs[i].shape)
178
+ self.lambs[i] = torch.where(
179
+ self.lambs[i] > threshold, torch.ones_like(self.lambs[i]), lamb_mask
180
+ )
181
+ start_idx = end_idx
182
+ self.binary = True
183
+
184
+ @staticmethod
185
+ def masking_fn(hidden_states, **kwargs):
186
+ hidden_states_dtype = hidden_states.dtype
187
+ lamb = kwargs["lamb"].view(1, 1, kwargs["lamb"].shape[0])
188
+ if kwargs.get("masking", None) == "sigmoid":
189
+ mask = torch.sigmoid(lamb)
190
+ elif kwargs.get("masking", None) == "binary":
191
+ mask = lamb
192
+ elif kwargs.get("masking", None) == "continues2binary":
193
+ # TODO: this might cause potential issue as it hard threshold at 0
194
+ mask = (lamb > 0).float()
195
+ elif kwargs.get("masking", None) == "no_masking":
196
+ mask = torch.ones_like(lamb)
197
+ else:
198
+ raise NotImplementedError
199
+ epsilon = kwargs.get("epsilon", 0.0)
200
+
201
+ if hidden_states.dim() == 2:
202
+ mask = mask.squeeze(1)
203
+
204
+ hidden_states = hidden_states * mask + torch.randn_like(
205
+ hidden_states
206
+ ) * epsilon * (1 - mask)
207
+ return hidden_states.to(hidden_states_dtype)
208
+
209
+ def get_norm_masking_hook(self, init_value=1.0):
210
+ """
211
+ Get a hook function to mask feed forward layer
212
+ """
213
+
214
+ def hook_fn(module, args, kwargs, output, name):
215
+ # initialize lambda with acual head dim in the first run
216
+ if self.lambs[self.hook_counter] is None:
217
+ self.lambs[self.hook_counter] = (
218
+ torch.ones(
219
+ self.module_neurons[name],
220
+ device=self.pipeline.device,
221
+ dtype=self.dtype,
222
+ )
223
+ * init_value
224
+ )
225
+ self.lambs[self.hook_counter].requires_grad = True
226
+ # load norm lambda module name for logging
227
+ self.lambs_module_names[self.hook_counter] = name
228
+
229
+ # perform masking
230
+ output = self.masking_fn(
231
+ output,
232
+ masking=self.masking,
233
+ lamb=self.lambs[self.hook_counter],
234
+ epsilon=self.epsilon,
235
+ eps=self.eps,
236
+ use_log=self.use_log,
237
+ )
238
+ self.hook_counter += 1
239
+ self.hook_counter %= len(self.lambs)
240
+ return output
241
+
242
+ return hook_fn