0xZohar commited on
Commit
c8fa637
·
verified ·
1 Parent(s): 4ef2971

Add code/cube3d/model/transformers/attention.py

Browse files
code/cube3d/model/transformers/attention.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from cube3d.model.transformers.norm import LayerNorm, RMSNorm
7
+
8
+
9
+ def init_linear(module, embed_dim: int):
10
+ """
11
+ Initializes the weights and biases of a given linear module.
12
+ Args:
13
+ module (nn.Module): The module to initialize. Expected to be an instance of nn.Linear.
14
+ embed_dim (int): The embedding dimension used to calculate the standard deviation
15
+ for weight initialization.
16
+ Returns:
17
+ None
18
+ """
19
+
20
+ if isinstance(module, nn.Linear):
21
+ nn.init.normal_(module.weight, std=math.sqrt(1.0 / embed_dim))
22
+ if module.bias is not None:
23
+ torch.nn.init.zeros_(module.bias)
24
+
25
+
26
+ def init_tfixup(module: nn.Module, num_layers: int):
27
+ """Special initialization from https://www.cs.toronto.edu/~mvolkovs/ICML2020_tfixup.pdf
28
+
29
+ Args:
30
+ module (nn.Module): decoder/encoder module
31
+ num_layers (int): number of layers in the module
32
+ """
33
+ with torch.no_grad():
34
+ for pn, p in module.named_parameters():
35
+ if (
36
+ pn.endswith("c_proj.weight")
37
+ or pn.endswith("up_proj.weight")
38
+ or pn.endswith("down_proj.weight")
39
+ ):
40
+ p *= (4 * num_layers) ** (-0.25)
41
+ elif pn.endswith("c_v.weight"):
42
+ p *= (4 * num_layers) ** (-0.25) * math.sqrt(2)
43
+
44
+
45
+ class MLP(nn.Module):
46
+ def __init__(self, embed_dim, hidden_dim, bias=True, approximate="none"):
47
+ """
48
+ MLP with GELU activation function."
49
+ """
50
+
51
+ super().__init__()
52
+ self.up_proj = nn.Linear(embed_dim, hidden_dim, bias=bias)
53
+ self.down_proj = nn.Linear(hidden_dim, embed_dim, bias=bias)
54
+ self.act_fn = nn.GELU(approximate=approximate)
55
+
56
+ def forward(self, x):
57
+ return self.down_proj(self.act_fn(self.up_proj(x)))
58
+
59
+
60
+ class SelfAttention(nn.Module):
61
+ def __init__(
62
+ self,
63
+ embed_dim: int,
64
+ num_heads: int,
65
+ bias: bool = True,
66
+ eps: float = 1e-6,
67
+ ):
68
+ """
69
+ Initializes the self attention mechanism.
70
+ Args:
71
+ embed_dim (int): The dimensionality of the embedding space.
72
+ num_heads (int): The number of attention heads.
73
+ bias (bool, optional): Whether to include bias terms in the linear layers. Defaults to True.
74
+ eps (float, optional): A small value added for numerical stability. Defaults to 1e-6.
75
+ Raises:
76
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
77
+ """
78
+
79
+ super().__init__()
80
+ assert embed_dim % num_heads == 0
81
+ self.num_heads = num_heads
82
+ self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=bias)
83
+ self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
84
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
85
+
86
+ head_dim = embed_dim // num_heads
87
+ self.q_norm = RMSNorm(head_dim)
88
+ self.k_norm = RMSNorm(head_dim)
89
+
90
+ def forward(self, x, attn_mask=None, is_causal: bool = False):
91
+ """
92
+ Performs the forward pass of the attention mechanism.
93
+ Args:
94
+ x (torch.Tensor): Input tensor.
95
+ attn_mask (Optional[torch.Tensor]): Attention mask to apply. Default is None.
96
+ is_causal (bool): If True, applies a causal mask to prevent attending to future positions.
97
+ Default is False.
98
+ Returns:
99
+ torch.Tensor: Output tensor after applying
100
+ the attention mechanism and projection.
101
+ """
102
+
103
+ b, l, d = x.shape
104
+
105
+ q, k = self.c_qk(x).chunk(2, dim=-1)
106
+ v = self.c_v(x)
107
+
108
+ q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
109
+ k = k.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
110
+ v = v.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
111
+
112
+ q = self.q_norm(q)
113
+ k = self.k_norm(k)
114
+
115
+ is_causal = is_causal and attn_mask is None
116
+ y = torch.nn.functional.scaled_dot_product_attention(
117
+ q,
118
+ k,
119
+ v,
120
+ attn_mask=attn_mask,
121
+ dropout_p=0.0,
122
+ is_causal=is_causal,
123
+ )
124
+ #import ipdb; ipdb.set_trace()
125
+ y = y.transpose(1, 2).contiguous().view(b, l, d)
126
+
127
+ y = self.c_proj(y)
128
+
129
+ return y
130
+
131
+
132
+ class CrossAttention(nn.Module):
133
+ def __init__(
134
+ self,
135
+ embed_dim: int,
136
+ num_heads: int,
137
+ q_dim=None,
138
+ kv_dim=None,
139
+ bias: bool = True,
140
+ ):
141
+ """
142
+ Initializes the cross attention mechanism.
143
+ Args:
144
+ embed_dim (int): The dimensionality of the embedding space.
145
+ num_heads (int): The number of attention heads.
146
+ q_dim (int, optional): The dimensionality of the query input. Defaults to `embed_dim`.
147
+ kv_dim (int, optional): The dimensionality of the key and value inputs. Defaults to `embed_dim`.
148
+ bias (bool, optional): Whether to include a bias term in the linear projections. Defaults to True.
149
+ Raises:
150
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
151
+ """
152
+ super().__init__()
153
+ assert embed_dim % num_heads == 0
154
+
155
+ q_dim = q_dim or embed_dim
156
+ kv_dim = kv_dim or embed_dim
157
+
158
+ self.c_q = nn.Linear(q_dim, embed_dim, bias=bias)
159
+ self.c_k = nn.Linear(kv_dim, embed_dim, bias=bias)
160
+ self.c_v = nn.Linear(kv_dim, embed_dim, bias=bias)
161
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
162
+ self.num_heads = num_heads
163
+
164
+ def forward(self, x, c, attn_mask=None, is_causal: bool = False):
165
+ """
166
+ Forward pass for the attention mechanism.
167
+ Args:
168
+ x (torch.Tensor): Input tensor of shape.
169
+ c (torch.Tensor): Context tensor.
170
+ attn_mask (torch.Tensor, optional): Attention mask.
171
+ Defaults to None.
172
+ is_causal (bool, optional): Whether to apply causal masking. Defaults to False.
173
+ Returns:
174
+ torch.Tensor: Output tensor.
175
+ """
176
+
177
+ q, k = self.c_q(x), self.c_k(c)
178
+ v = self.c_v(c)
179
+
180
+ b, l, d = q.shape
181
+ s = k.shape[1]
182
+
183
+ q = q.view(b, l, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
184
+ k = k.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
185
+ v = v.view(b, s, self.num_heads, -1).transpose(1, 2) # (B, nh, T, hs)
186
+
187
+ y = torch.nn.functional.scaled_dot_product_attention(
188
+ q,
189
+ k,
190
+ v,
191
+ attn_mask=attn_mask,
192
+ dropout_p=0.0,
193
+ is_causal=(attn_mask is not None) and is_causal,
194
+ )
195
+ #import ipdb; ipdb.set_trace()
196
+ y = y.transpose(1, 2).contiguous().view(b, l, d)
197
+
198
+ y = self.c_proj(y)
199
+ return y
200
+
201
+
202
+ class EncoderLayer(nn.Module):
203
+ def __init__(
204
+ self,
205
+ embed_dim: int,
206
+ num_heads: int,
207
+ bias: bool = True,
208
+ eps: float = 1e-6,
209
+ ) -> None:
210
+ """
211
+ Initializes the EncoderLayer module.
212
+ Args:
213
+ embed_dim (int): The dimensionality of the embedding space.
214
+ num_heads (int): The number of attention heads.
215
+ bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
216
+ eps (float, optional): A small value added for numerical stability in normalization layers. Defaults to 1e-6.
217
+ """
218
+ super().__init__()
219
+ self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
220
+ self.attn = SelfAttention(embed_dim, num_heads, bias=bias, eps=eps)
221
+ self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
222
+ self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
223
+
224
+ def forward(self, x, attn_mask=None, is_causal: bool = False):
225
+ """
226
+ Performs the forward pass of the transformer block.
227
+ Args:
228
+ x (torch.Tensor): The input tensor.
229
+ attn_mask (torch.Tensor, optional): An optional attention mask tensor to apply during the
230
+ attention computation. Default is None.
231
+ is_causal (bool, optional): If True, applies a causal mask to prevent attention to future
232
+ positions. Default is False.
233
+ Returns:
234
+ torch.Tensor: The output tensor of the same shape as the input.
235
+ """
236
+
237
+ x = x + self.attn(self.ln_1(x), attn_mask=attn_mask, is_causal=is_causal)
238
+ x = x + self.mlp(self.ln_2(x))
239
+ return x
240
+
241
+
242
+ class EncoderCrossAttentionLayer(nn.Module):
243
+ def __init__(
244
+ self,
245
+ embed_dim: int,
246
+ num_heads: int,
247
+ q_dim=None,
248
+ kv_dim=None,
249
+ bias: bool = True,
250
+ eps: float = 1e-6,
251
+ ) -> None:
252
+ """
253
+ Initializes the EncoderAttentionLayer module with cross-attention,
254
+ and a feed-forward MLP.
255
+ Args:
256
+ embed_dim (int): The dimensionality of the embedding space.
257
+ num_heads (int): The number of attention heads.
258
+ q_dim (int, optional): Dimensionality of the query input. Defaults to `embed_dim`.
259
+ kv_dim (int, optional): Dimensionality of the key and value inputs. Defaults to `embed_dim`.
260
+ bias (bool, optional): Whether to include bias terms in the layers. Defaults to True.
261
+ eps (float, optional): A small value added to the denominator for numerical stability
262
+ in layer normalization. Defaults to 1e-6.
263
+ """
264
+ super().__init__()
265
+
266
+ q_dim = q_dim or embed_dim
267
+ kv_dim = kv_dim or embed_dim
268
+
269
+ self.attn = CrossAttention(
270
+ embed_dim,
271
+ num_heads,
272
+ q_dim=q_dim,
273
+ kv_dim=kv_dim,
274
+ bias=bias,
275
+ )
276
+
277
+ self.ln_1 = LayerNorm(q_dim, elementwise_affine=False, eps=eps)
278
+ self.ln_2 = LayerNorm(kv_dim, elementwise_affine=False, eps=eps)
279
+
280
+ self.ln_f = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
281
+ self.mlp = MLP(embed_dim=embed_dim, hidden_dim=embed_dim * 4, bias=bias)
282
+
283
+ def forward(self, x, c, attn_mask=None, is_causal: bool = False):
284
+ """
285
+ Forward pass for the attention mechanism.
286
+ Args:
287
+ x (torch.Tensor): The input tensor to the attention mechanism.
288
+ c (torch.Tensor): The context tensor used for cross-attention.
289
+ attn_mask (torch.Tensor, optional): An optional attention mask to control
290
+ which positions can attend to others. Defaults to None.
291
+ is_causal (bool, optional): If True, applies a causal mask to prevent
292
+ attending to future positions. Defaults to False.
293
+ Returns:
294
+ torch.Tensor: The output tensor after applying attention and MLP layers.
295
+ """
296
+
297
+ x = x + self.attn(
298
+ self.ln_1(x), self.ln_2(c), attn_mask=attn_mask, is_causal=is_causal
299
+ )
300
+ x = x + self.mlp(self.ln_f(x))
301
+ return x