Shaoan commited on
Commit
4e454d3
·
verified ·
1 Parent(s): 36f6af4

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. refiner.py +495 -0
refiner.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn
5
+ from einops import rearrange
6
+ from torch import nn
7
+
8
+ from layers import MLP, TextProjection, TimestepEmbedder, apply_gate, attention
9
+
10
+
11
+ class RMSNorm(nn.Module):
12
+ def __init__(
13
+ self,
14
+ dim: int,
15
+ elementwise_affine=True,
16
+ eps: float = 1e-6,
17
+ device=None,
18
+ dtype=None,
19
+ ):
20
+ """
21
+ Initialize the RMSNorm normalization layer.
22
+
23
+ Args:
24
+ dim (int): The dimension of the input tensor.
25
+ eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
26
+
27
+ Attributes:
28
+ eps (float): A small value added to the denominator for numerical stability.
29
+ weight (nn.Parameter): Learnable scaling parameter.
30
+
31
+ """
32
+ factory_kwargs = {"device": device, "dtype": dtype}
33
+ super().__init__()
34
+ self.eps = eps
35
+ if elementwise_affine:
36
+ self.weight = nn.Parameter(torch.ones(dim, **factory_kwargs))
37
+
38
+ def _norm(self, x):
39
+ """
40
+ Apply the RMSNorm normalization to the input tensor.
41
+
42
+ Args:
43
+ x (torch.Tensor): The input tensor.
44
+
45
+ Returns:
46
+ torch.Tensor: The normalized tensor.
47
+
48
+ """
49
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
50
+
51
+ def forward(self, x):
52
+ """
53
+ Forward pass through the RMSNorm layer.
54
+
55
+ Args:
56
+ x (torch.Tensor): The input tensor.
57
+
58
+ Returns:
59
+ torch.Tensor: The output tensor after applying RMSNorm.
60
+
61
+ """
62
+ output = self._norm(x.float()).type_as(x)
63
+ if hasattr(self, "weight"):
64
+ output = output * self.weight
65
+ return output
66
+
67
+
68
+ def get_norm_layer(norm_layer):
69
+ """
70
+ Get the normalization layer.
71
+
72
+ Args:
73
+ norm_layer (str): The type of normalization layer.
74
+
75
+ Returns:
76
+ norm_layer (nn.Module): The normalization layer.
77
+ """
78
+ if norm_layer == "layer":
79
+ return nn.LayerNorm
80
+ elif norm_layer == "rms":
81
+ return RMSNorm
82
+ else:
83
+ raise NotImplementedError(f"Norm layer {norm_layer} is not implemented")
84
+
85
+
86
+ def get_activation_layer(act_type):
87
+ """get activation layer
88
+
89
+ Args:
90
+ act_type (str): the activation type
91
+
92
+ Returns:
93
+ torch.nn.functional: the activation layer
94
+ """
95
+ if act_type == "gelu":
96
+ return lambda: nn.GELU()
97
+ elif act_type == "gelu_tanh":
98
+ return lambda: nn.GELU(approximate="tanh")
99
+ elif act_type == "relu":
100
+ return nn.ReLU
101
+ elif act_type == "silu":
102
+ return nn.SiLU
103
+ else:
104
+ raise ValueError(f"Unknown activation type: {act_type}")
105
+
106
+
107
+ class IndividualTokenRefinerBlock(torch.nn.Module):
108
+ def __init__(
109
+ self,
110
+ hidden_size,
111
+ heads_num,
112
+ mlp_width_ratio: str = 4.0,
113
+ mlp_drop_rate: float = 0.0,
114
+ act_type: str = "silu",
115
+ qk_norm: bool = False,
116
+ qk_norm_type: str = "layer",
117
+ qkv_bias: bool = True,
118
+ need_CA: bool = False,
119
+ dtype: Optional[torch.dtype] = None,
120
+ device: Optional[torch.device] = None,
121
+ ):
122
+ factory_kwargs = {"device": device, "dtype": dtype}
123
+ super().__init__()
124
+ self.need_CA = need_CA
125
+ self.heads_num = heads_num
126
+ head_dim = hidden_size // heads_num
127
+ mlp_hidden_dim = int(hidden_size * mlp_width_ratio)
128
+
129
+ self.norm1 = nn.LayerNorm(
130
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
131
+ )
132
+ self.self_attn_qkv = nn.Linear(
133
+ hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs
134
+ )
135
+ qk_norm_layer = get_norm_layer(qk_norm_type)
136
+ self.self_attn_q_norm = (
137
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
138
+ if qk_norm
139
+ else nn.Identity()
140
+ )
141
+ self.self_attn_k_norm = (
142
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
143
+ if qk_norm
144
+ else nn.Identity()
145
+ )
146
+ self.self_attn_proj = nn.Linear(
147
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
148
+ )
149
+
150
+ self.norm2 = nn.LayerNorm(
151
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
152
+ )
153
+ act_layer = get_activation_layer(act_type)
154
+ self.mlp = MLP(
155
+ in_channels=hidden_size,
156
+ hidden_channels=mlp_hidden_dim,
157
+ act_layer=act_layer,
158
+ drop=mlp_drop_rate,
159
+ **factory_kwargs,
160
+ )
161
+
162
+ self.adaLN_modulation = nn.Sequential(
163
+ act_layer(),
164
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
165
+ )
166
+
167
+ if self.need_CA:
168
+ self.cross_attnblock = CrossAttnBlock(hidden_size=hidden_size,
169
+ heads_num=heads_num,
170
+ mlp_width_ratio=mlp_width_ratio,
171
+ mlp_drop_rate=mlp_drop_rate,
172
+ act_type=act_type,
173
+ qk_norm=qk_norm,
174
+ qk_norm_type=qk_norm_type,
175
+ qkv_bias=qkv_bias,
176
+ **factory_kwargs, )
177
+ # Zero-initialize the modulation
178
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
179
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
180
+
181
+ def forward(
182
+ self,
183
+ x: torch.Tensor,
184
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
185
+ attn_mask: torch.Tensor = None,
186
+ y: torch.Tensor = None,
187
+ ):
188
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
189
+
190
+
191
+ norm_x = self.norm1(x)
192
+ qkv = self.self_attn_qkv(norm_x)
193
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num)
194
+ # Apply QK-Norm if needed
195
+ q = self.self_attn_q_norm(q).to(v)
196
+ k = self.self_attn_k_norm(k).to(v)
197
+
198
+ # Self-Attention
199
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
200
+
201
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
202
+
203
+ if self.need_CA:
204
+ x = self.cross_attnblock(x, c, attn_mask, y)
205
+
206
+
207
+ # FFN Layer
208
+ x = x + apply_gate(self.mlp(self.norm2(x)), gate_mlp)
209
+
210
+ return x
211
+
212
+
213
+ class CrossAttnBlock(torch.nn.Module):
214
+ def __init__(
215
+ self,
216
+ hidden_size,
217
+ heads_num,
218
+ mlp_width_ratio: str = 4.0,
219
+ mlp_drop_rate: float = 0.0,
220
+ act_type: str = "silu",
221
+ qk_norm: bool = False,
222
+ qk_norm_type: str = "layer",
223
+ qkv_bias: bool = True,
224
+ dtype: Optional[torch.dtype] = None,
225
+ device: Optional[torch.device] = None,
226
+ ):
227
+ factory_kwargs = {"device": device, "dtype": dtype}
228
+ super().__init__()
229
+ self.heads_num = heads_num
230
+ head_dim = hidden_size // heads_num
231
+
232
+ self.norm1 = nn.LayerNorm(
233
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
234
+ )
235
+ self.norm1_2 = nn.LayerNorm(
236
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
237
+ )
238
+ self.self_attn_q = nn.Linear(
239
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
240
+ )
241
+ self.self_attn_kv = nn.Linear(
242
+ hidden_size, hidden_size * 2, bias=qkv_bias, **factory_kwargs
243
+ )
244
+ qk_norm_layer = get_norm_layer(qk_norm_type)
245
+ self.self_attn_q_norm = (
246
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
247
+ if qk_norm
248
+ else nn.Identity()
249
+ )
250
+ self.self_attn_k_norm = (
251
+ qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs)
252
+ if qk_norm
253
+ else nn.Identity()
254
+ )
255
+ self.self_attn_proj = nn.Linear(
256
+ hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs
257
+ )
258
+
259
+ self.norm2 = nn.LayerNorm(
260
+ hidden_size, elementwise_affine=True, eps=1e-6, **factory_kwargs
261
+ )
262
+ act_layer = get_activation_layer(act_type)
263
+
264
+ self.adaLN_modulation = nn.Sequential(
265
+ act_layer(),
266
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True, **factory_kwargs),
267
+ )
268
+ # Zero-initialize the modulation
269
+ nn.init.zeros_(self.adaLN_modulation[1].weight)
270
+ nn.init.zeros_(self.adaLN_modulation[1].bias)
271
+
272
+ def forward(
273
+ self,
274
+ x: torch.Tensor,
275
+ c: torch.Tensor, # timestep_aware_representations + context_aware_representations
276
+ attn_mask: torch.Tensor = None,
277
+ y: torch.Tensor = None,
278
+
279
+ ):
280
+ gate_msa, gate_mlp = self.adaLN_modulation(c).chunk(2, dim=1)
281
+
282
+ norm_x = self.norm1(x)
283
+ norm_y = self.norm1_2(y)
284
+ q = self.self_attn_q(norm_x)
285
+ q = rearrange(q, "B L (H D) -> B L H D", H=self.heads_num)
286
+ kv = self.self_attn_kv(norm_y)
287
+ k, v = rearrange(kv, "B L (K H D) -> K B L H D", K=2, H=self.heads_num)
288
+ # Apply QK-Norm if needed
289
+ q = self.self_attn_q_norm(q).to(v)
290
+ k = self.self_attn_k_norm(k).to(v)
291
+
292
+ # Self-Attention
293
+ attn = attention(q, k, v, mode="torch", attn_mask=attn_mask)
294
+
295
+ x = x + apply_gate(self.self_attn_proj(attn), gate_msa)
296
+
297
+ return x
298
+
299
+
300
+ class IndividualTokenRefiner(torch.nn.Module):
301
+ def __init__(
302
+ self,
303
+ hidden_size,
304
+ heads_num,
305
+ depth,
306
+ mlp_width_ratio: float = 4.0,
307
+ mlp_drop_rate: float = 0.0,
308
+ act_type: str = "silu",
309
+ qk_norm: bool = False,
310
+ qk_norm_type: str = "layer",
311
+ qkv_bias: bool = True,
312
+ need_CA: bool = False,
313
+ dtype: Optional[torch.dtype] = None,
314
+ device: Optional[torch.device] = None,
315
+ ):
316
+
317
+ factory_kwargs = {"device": device, "dtype": dtype}
318
+ super().__init__()
319
+ self.need_CA = need_CA
320
+ self.blocks = nn.ModuleList(
321
+ [
322
+ IndividualTokenRefinerBlock(
323
+ hidden_size=hidden_size,
324
+ heads_num=heads_num,
325
+ mlp_width_ratio=mlp_width_ratio,
326
+ mlp_drop_rate=mlp_drop_rate,
327
+ act_type=act_type,
328
+ qk_norm=qk_norm,
329
+ qk_norm_type=qk_norm_type,
330
+ qkv_bias=qkv_bias,
331
+ need_CA=self.need_CA,
332
+ **factory_kwargs,
333
+ )
334
+ for _ in range(depth)
335
+ ]
336
+ )
337
+
338
+ def forward(
339
+ self,
340
+ x: torch.Tensor,
341
+ c: torch.LongTensor,
342
+ mask: Optional[torch.Tensor] = None,
343
+ y: torch.Tensor = None,
344
+ ):
345
+ self_attn_mask = None
346
+ if mask is not None:
347
+ batch_size = mask.shape[0]
348
+ seq_len = mask.shape[1]
349
+ mask = mask.to(x.device)
350
+ # batch_size x 1 x seq_len x seq_len
351
+ self_attn_mask_1 = mask.view(batch_size, 1, 1, seq_len).repeat(
352
+ 1, 1, seq_len, 1
353
+ )
354
+ # batch_size x 1 x seq_len x seq_len
355
+ self_attn_mask_2 = self_attn_mask_1.transpose(2, 3)
356
+ # batch_size x 1 x seq_len x seq_len, 1 for broadcasting of heads_num
357
+ #self_attn_mask = (self_attn_mask_1 & self_attn_mask_2).bool()
358
+ self_attn_mask = (self_attn_mask_1.bool() & self_attn_mask_2.bool()).bool()
359
+ # avoids self-attention weight being NaN for padding tokens
360
+ self_attn_mask[:, :, :, 0] = True
361
+
362
+ for block in self.blocks:
363
+ x = block(x, c, self_attn_mask, y)
364
+
365
+ return x
366
+
367
+
368
+ class SingleTokenRefiner(torch.nn.Module):
369
+ """
370
+ A single token refiner block for llm text embedding refine.
371
+ """
372
+
373
+ def __init__(
374
+ self,
375
+ in_channels,
376
+ hidden_size,
377
+ heads_num,
378
+ depth,
379
+ mlp_width_ratio: float = 4.0,
380
+ mlp_drop_rate: float = 0.0,
381
+ act_type: str = "silu",
382
+ qk_norm: bool = False,
383
+ qk_norm_type: str = "layer",
384
+ qkv_bias: bool = True,
385
+ need_CA: bool = False,
386
+ attn_mode: str = "torch",
387
+ dtype: Optional[torch.dtype] = None,
388
+ device: Optional[torch.device] = None,
389
+ identity_init: bool = False,
390
+ ):
391
+ factory_kwargs = {"device": device, "dtype": dtype}
392
+ super().__init__()
393
+ self.attn_mode = attn_mode
394
+ self.need_CA = need_CA
395
+ assert self.attn_mode == "torch", "Only support 'torch' mode for token refiner."
396
+
397
+ if not identity_init:
398
+ self.input_norm = RMSNorm(in_channels, eps=1e-6, **factory_kwargs)
399
+ self.input_embedder = nn.Linear(
400
+ in_channels, hidden_size, bias=True, **factory_kwargs
401
+ )
402
+ #nn.init.trunc_normal_(self.input_embedder.weight, std=0.02)
403
+ #nn.init.zeros_(self.input_embedder.bias)
404
+ else:
405
+ #self.input_norm = RMSNorm(in_channels, eps=1e-6, **factory_kwargs)
406
+ #self.input_embedder = nn.Linear(
407
+ # in_channels, hidden_size, bias=True, **factory_kwargs
408
+ #)
409
+ #self.input_embedder = nn.Identity()
410
+ self.input_embedder = nn.Linear(
411
+ in_channels, hidden_size, bias=True, **factory_kwargs
412
+ )
413
+ nn.init.zeros_(self.input_embedder.bias)
414
+ nn.init.eye_(self.input_embedder.weight)
415
+ self.input_norm = nn.Identity()
416
+
417
+ act_layer = get_activation_layer(act_type)
418
+ self.c_norm = nn.LayerNorm(in_channels)
419
+ self.c_embedder = TextProjection(
420
+ in_channels, hidden_size, act_layer, **factory_kwargs
421
+ )
422
+
423
+ #self.mean_mlp = nn.Sequential(nn.Linear(in_channels, hidden_size), nn.SiLU(),
424
+ # nn.Linear(hidden_size, in_channels))
425
+
426
+
427
+ self.individual_token_refiner = IndividualTokenRefiner(
428
+ hidden_size=hidden_size,
429
+ heads_num=heads_num,
430
+ depth=depth,
431
+ mlp_width_ratio=mlp_width_ratio,
432
+ mlp_drop_rate=mlp_drop_rate,
433
+ act_type=act_type,
434
+ qk_norm=qk_norm,
435
+ qk_norm_type=qk_norm_type,
436
+ qkv_bias=qkv_bias,
437
+ need_CA=need_CA,
438
+ **factory_kwargs,
439
+ )
440
+
441
+ def forward(
442
+ self,
443
+ x: torch.Tensor,
444
+ mask,
445
+ mean_start_id=0
446
+ ):
447
+
448
+ x = self.input_norm(x)
449
+ if mask is None:
450
+ x_mean = x[:,mean_start_id:].mean(dim=1)
451
+ else:
452
+ x_mean = (x[:,mean_start_id:]*mask[:,mean_start_id:].unsqueeze(-1)).sum(dim=1) / (mask[:,mean_start_id:].sum(dim=1, keepdim=True)+1e-4)
453
+ #x_mean = self.mean_mlp(x_mean)
454
+ c = self.c_norm(x_mean)
455
+ c = self.c_embedder(c)
456
+ x = self.input_embedder(x)
457
+ x = self.individual_token_refiner(x, c, mask)
458
+
459
+ return x
460
+
461
+
462
+
463
+ class Qwen2Connector(torch.nn.Module):
464
+ def __init__(
465
+ self,
466
+ in_channels=4096,
467
+ hidden_size=4096,
468
+ heads_num=32,
469
+ depth=1,
470
+ need_CA=False,
471
+ device=None,
472
+ dtype=torch.bfloat16,
473
+ identity_init=True,
474
+ ):
475
+ super().__init__()
476
+ factory_kwargs = {"device": device, "dtype": dtype}
477
+
478
+ self.S = SingleTokenRefiner(in_channels=in_channels, hidden_size=hidden_size, heads_num=heads_num, depth=depth, identity_init=identity_init,
479
+ need_CA=need_CA, **factory_kwargs)
480
+
481
+ def forward(self, x, mask=None, mean_start_id=0):
482
+ encoder_hidden_states = self.S(x, mask, mean_start_id)
483
+ return encoder_hidden_states
484
+
485
+
486
+
487
+ if __name__ == '__main__':
488
+ model = Qwen2Connector(in_channels=4096, hidden_size=4096).to('cuda').to(torch.bfloat16)
489
+ x = torch.randn([2, 300, 4096]).to('cuda').to(torch.bfloat16)
490
+ out = model(x)
491
+ print(x, ' >>> x')
492
+ print(out.shape)
493
+ print(out)
494
+ assert torch.allclose(out, x)
495
+