0xZohar commited on
Commit
a3e82d1
·
verified ·
1 Parent(s): c8fa637

Add code/cube3d/model/transformers/dual_stream_attention.py

Browse files
code/cube3d/model/transformers/dual_stream_attention.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from cube3d.model.transformers.cache import Cache
7
+ from cube3d.model.transformers.norm import LayerNorm, RMSNorm
8
+ from cube3d.model.transformers.roformer import SwiGLUMLP
9
+ from cube3d.model.transformers.rope import scaled_dot_product_attention_with_rotary_emb
10
+
11
+
12
+ class DismantledPreAttention(nn.Module):
13
+ def __init__(
14
+ self,
15
+ embed_dim: int,
16
+ num_heads: int,
17
+ query: bool = True,
18
+ bias: bool = True,
19
+ ) -> None:
20
+ """
21
+ Initializes the DismantledPreAttention module.
22
+ Args:
23
+ embed_dim (int): The dimensionality of the embedding space.
24
+ num_heads (int): The number of attention heads.
25
+ query (bool, optional): Whether to include query-key projection. Defaults to True.
26
+ bias (bool, optional): Whether to include bias in linear layers. Defaults to True.
27
+ Raises:
28
+ AssertionError: If `embed_dim` is not divisible by `num_heads`.
29
+ """
30
+ super().__init__()
31
+ assert embed_dim % num_heads == 0
32
+ self.query = query
33
+
34
+ head_dim = embed_dim // num_heads
35
+ # key, query, value projections for all heads, but in a batch
36
+ if query:
37
+ self.c_qk = nn.Linear(embed_dim, 2 * embed_dim, bias=False)
38
+ self.q_norm = RMSNorm(head_dim)
39
+ else:
40
+ self.c_k = nn.Linear(embed_dim, embed_dim, bias=bias)
41
+ self.k_norm = RMSNorm(head_dim)
42
+ self.c_v = nn.Linear(embed_dim, embed_dim, bias=bias)
43
+
44
+ # (B, T, C) -> (B, nh, T, hs)
45
+ self.to_mha = lambda x: x.view(*x.shape[:2], num_heads, -1).transpose(1, 2)
46
+
47
+ def forward(self, x):
48
+ """
49
+ Forward pass for the dismantled pre-attention mechanism.
50
+ Args:
51
+ x (torch.Tensor): Input tensor of shape (..., input_dim).
52
+ Returns:
53
+ tuple: A tuple containing:
54
+ - q (torch.Tensor or None): Query tensor after normalization and transformation,
55
+ or None if `self.query` is False.
56
+ - k (torch.Tensor): Key tensor after normalization and transformation.
57
+ - v (torch.Tensor): Value tensor after transformation.
58
+ """
59
+
60
+ if self.query:
61
+ q, k = self.c_qk(x).chunk(2, dim=-1)
62
+ q = self.q_norm(self.to_mha(q))
63
+ else:
64
+ q = None
65
+ k = self.c_k(x)
66
+
67
+ k = self.k_norm(self.to_mha(k))
68
+ v = self.to_mha(self.c_v(x))
69
+
70
+ return (q, k, v)
71
+
72
+
73
+ class DismantledPostAttention(nn.Module):
74
+ def __init__(
75
+ self,
76
+ embed_dim,
77
+ bias: bool = True,
78
+ eps: float = 1e-6,
79
+ ) -> None:
80
+ """
81
+ Initializes the DismantledPostAttention module.
82
+ Args:
83
+ embed_dim (int): The dimensionality of the embedding space.
84
+ bias (bool, optional): Whether to include a bias term in the linear projection. Defaults to True.
85
+ eps (float, optional): A small value added to the denominator for numerical stability in layer normalization. Defaults to 1e-6.
86
+ """
87
+ super().__init__()
88
+ self.c_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
89
+ self.ln_3 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
90
+ self.mlp = SwiGLUMLP(embed_dim, embed_dim * 4, bias=bias)
91
+
92
+ def forward(self, x, a):
93
+ """
94
+ Forward pass of the dual stream attention mechanism.
95
+ Args:
96
+ x (torch.Tensor): The input tensor to the model.
97
+ a (torch.Tensor): The attention tensor to be combined with the input.
98
+ Returns:
99
+ torch.Tensor: The output tensor after applying the projection,
100
+ layer normalization, and MLP transformations.
101
+ """
102
+
103
+ x = x + self.c_proj(a)
104
+ x = x + self.mlp(self.ln_3(x))
105
+ return x
106
+
107
+
108
+ class DualStreamAttentionWithRotaryEmbedding(nn.Module):
109
+ def __init__(
110
+ self,
111
+ embed_dim: int,
112
+ num_heads: int,
113
+ cond_pre_only: bool = False,
114
+ bias: bool = True,
115
+ ):
116
+ """
117
+ Initializes the DualStreamAttention module.
118
+ Args:
119
+ embed_dim (int): The dimensionality of the embedding space.
120
+ num_heads (int): The number of attention heads.
121
+ cond_pre_only (bool, optional): If True, the conditional pre-attention
122
+ will only process the key and value, not the query. Defaults to False.
123
+ bias (bool, optional): Whether to include a bias term in the attention layers.
124
+ Defaults to True.
125
+ """
126
+ super().__init__()
127
+
128
+ self.cond_pre_only = cond_pre_only
129
+
130
+ self.pre_x = DismantledPreAttention(
131
+ embed_dim=embed_dim, num_heads=num_heads, query=True, bias=bias
132
+ )
133
+
134
+ self.pre_c = DismantledPreAttention(
135
+ embed_dim=embed_dim, num_heads=num_heads, query=not cond_pre_only, bias=bias
136
+ )
137
+
138
+ def forward(
139
+ self,
140
+ x,
141
+ c: Optional[torch.Tensor],
142
+ freqs_cis,
143
+ attn_mask: Optional[torch.Tensor] = None,
144
+ is_causal: bool = False,
145
+ kv_cache: Optional[Cache] = None,
146
+ curr_pos_id: Optional[torch.Tensor] = None,
147
+ decode: bool = False,
148
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
149
+ """
150
+ Forward pass for dual stream Multi-Head Attention.
151
+
152
+ Efficient single weight matrix multiplication with results split into query, key, value.
153
+
154
+ Parameters
155
+ ----------
156
+ x : torch.Tensor
157
+ Hidden states [B, L, D]
158
+ c : torch.Tensor
159
+ Condition [B, S, D]
160
+ freqs_cis: torch.Tensor
161
+ Precomputed RoPE matrix from precompute_freqs_cis [B, S+L, Hd]
162
+ attn_mask : torch.Tensor, optional
163
+ Attention mask [B, S+L, S+L], by default None
164
+ kv_cache: None | Tensor
165
+ key-value cache, but only if not None; if None - it means that it's disabled
166
+ contains cache for keys and value from all previous steps
167
+ kv_cache_cond: None | Tensor
168
+ key-value cache, but only if not None; if None - it means that it's disabled
169
+ contains cache for keys and value from all previous steps for the text conditioning.
170
+
171
+ Returns
172
+ -------
173
+ torch.Tensor
174
+ Hidden state output [B, L, D]
175
+ """
176
+ if kv_cache is None or not decode:
177
+ # Either training or prefill
178
+ qkv_c = self.pre_c(c)
179
+ qkv_x = self.pre_x(x)
180
+ # prepend condition stream
181
+ # (B, nh, Tc, hs) + (B, nh, Tx, hs) -> (B, nh, Tc+Tx, hs)
182
+ if self.cond_pre_only:
183
+
184
+ q = qkv_x[0]
185
+ else:
186
+ q = torch.cat([qkv_c[0], qkv_x[0]], dim=2)
187
+ k = torch.cat([qkv_c[1], qkv_x[1]], dim=2)
188
+ v = torch.cat([qkv_c[2], qkv_x[2]], dim=2)
189
+
190
+ else:
191
+ # if using kv cache, query would only be the last token in the sequence, hence is_causal is False
192
+ assert x.shape[1] == 1
193
+ is_causal = False
194
+ q, k, v = self.pre_x(x)
195
+
196
+ if kv_cache is not None:
197
+ if not decode:
198
+ kv_cache.key_states[:, :, : k.shape[2], :].copy_(k)
199
+ kv_cache.value_states[:, :, : k.shape[2], :].copy_(v)
200
+ # kv_cache.key_states = kv_cache.key_states.clone() #
201
+ # kv_cache.value_states = kv_cache.value_states.clone()
202
+ # kv_cache.key_states[:, :, : k.shape[2], :] = k #
203
+ # kv_cache.value_states[:, :, : k.shape[2], :] = v
204
+ else:
205
+ assert curr_pos_id is not None
206
+ kv_cache.update(curr_pos_id, k, v)
207
+ k = kv_cache.key_states
208
+ v = kv_cache.value_states
209
+
210
+ if attn_mask is not None:
211
+ # trim attention mask to length
212
+ if decode:
213
+ assert curr_pos_id is not None
214
+ attn_mask = attn_mask[..., curr_pos_id, :]
215
+ else:
216
+ attn_mask = attn_mask[..., -q.shape[2] :, :]
217
+
218
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
219
+ # efficient attention using Flash Attention CUDA kernels
220
+ y = scaled_dot_product_attention_with_rotary_emb(
221
+ q,
222
+ k,
223
+ v,
224
+ freqs_cis=freqs_cis,
225
+ attn_mask=attn_mask,
226
+ curr_pos_id=curr_pos_id if decode else None,
227
+ is_causal=is_causal,
228
+ )
229
+
230
+ #import ipdb; ipdb.set_trace()
231
+ # re-assemble all head outputs side by side
232
+ y = y.transpose(1, 2).contiguous().view(x.shape[0], -1, x.shape[2])
233
+
234
+ if y.shape[1] == x.shape[1]:
235
+ y_c = None
236
+ y_x = y
237
+ else:
238
+ assert c is not None, "Conditioning is required for dual stream attention"
239
+ y_c, y_x = torch.split(y, [c.shape[1], x.shape[1]], dim=1)
240
+ return y_x, y_c
241
+
242
+
243
+ class DualStreamDecoderLayerWithRotaryEmbedding(nn.Module):
244
+ """Nicely wrapped decoder layer block for dual stream GPT model"""
245
+
246
+ def __init__(
247
+ self,
248
+ embed_dim,
249
+ num_heads: int,
250
+ cond_pre_only: bool = False,
251
+ bias: bool = True,
252
+ eps: float = 1.0e-6,
253
+ ) -> None:
254
+ """
255
+ Initializes the DualStreamDecoderLayerWithRotaryEmbedding module with optional conditional pre-only mode.
256
+ Args:
257
+ embed_dim (int): The dimensionality of the embedding space.
258
+ num_heads (int): The number of attention heads.
259
+ cond_pre_only (bool, optional): If True, applies conditional processing only before attention. Defaults to False.
260
+ bias (bool, optional): If True, includes bias terms in the attention and post-attention layers. Defaults to True.
261
+ eps (float, optional): A small value added for numerical stability in layer normalization. Defaults to 1.0e-6.
262
+ """
263
+ super().__init__()
264
+
265
+ self.ln_1 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
266
+ self.ln_2 = LayerNorm(embed_dim, elementwise_affine=False, eps=eps)
267
+
268
+ self.attn = DualStreamAttentionWithRotaryEmbedding(
269
+ embed_dim=embed_dim,
270
+ num_heads=num_heads,
271
+ cond_pre_only=cond_pre_only,
272
+ bias=bias,
273
+ )
274
+
275
+ self.post_1 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
276
+ if not cond_pre_only:
277
+ self.post_2 = DismantledPostAttention(embed_dim, bias=bias, eps=eps)
278
+
279
+ @classmethod
280
+ def from_config(cls, cfg, cond_pre_only: bool = False):
281
+ """
282
+ Create an instance of the class using the provided configuration.
283
+ Args:
284
+ cfg: A configuration object containing the necessary parameters:
285
+ - n_embd (int): The size of the embedding dimension.
286
+ - n_head (int): The number of attention heads.
287
+ - bias (bool): Whether to include a bias term.
288
+ - eps (float): A small value added for numerical stability.
289
+ cond_pre_only (bool, optional): If True, applies conditioning only in the pre-processing step.
290
+ Defaults to False.
291
+ Returns:
292
+ An instance of the class initialized with the specified configuration.
293
+ """
294
+
295
+ return cls(
296
+ cfg.n_embd,
297
+ num_heads=cfg.n_head,
298
+ cond_pre_only=cond_pre_only,
299
+ bias=cfg.bias,
300
+ eps=cfg.eps,
301
+ )
302
+
303
+ def forward(
304
+ self,
305
+ x,
306
+ c,
307
+ freqs_cis: torch.Tensor,
308
+ attn_mask: Optional[torch.Tensor] = None,
309
+ is_causal: bool = True,
310
+ kv_cache: Optional[Cache] = None,
311
+ curr_pos_id: Optional[torch.Tensor] = None,
312
+ decode: bool = False,
313
+ ):
314
+ """
315
+ Forward pass for DualStreamDecoderLayerWithRotaryEmbedding.
316
+
317
+ Parameters
318
+ ----------
319
+ x : torch.Tensor
320
+ Hidden states [B, L, D]
321
+ c : torch.Tensor
322
+ Condition [B, S, D]
323
+ freqs_cis: torch.Tensor
324
+ Postional embedding from RoPE [B, S+L, hd]
325
+ attn_mask : torch.Tensor, optional
326
+ Attention mask [B, S+L, S+L], by default None
327
+ kv_vache : torch.Tensor, optional
328
+ kv_cache by default None
329
+
330
+ Returns
331
+ -------
332
+ torch.Tensor
333
+ Hidden state output [B, L, D]
334
+ torch.Tensor
335
+ kv_cache output [1, L, D]
336
+ """
337
+ a_x, a_c = self.attn(
338
+ self.ln_1(x),
339
+ # NOTE condition could be none if using kv cache
340
+ self.ln_2(c) if c is not None else None,
341
+ freqs_cis=freqs_cis,
342
+ attn_mask=attn_mask,
343
+ is_causal=is_causal,
344
+ kv_cache=kv_cache,
345
+ curr_pos_id=curr_pos_id,
346
+ decode=decode,
347
+ )
348
+ x = self.post_1(x, a_x)
349
+ if a_c is not None:
350
+ c = self.post_2(c, a_c)
351
+ else:
352
+ c = None
353
+ return x, c