0xZohar commited on
Commit
6d7f949
·
verified ·
1 Parent(s): 8ec3dcb

Add code/cube3d/model/gpt/dual_stream_roformer.py

Browse files
code/cube3d/model/gpt/dual_stream_roformer.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from cube3d.model.transformers.cache import Cache
8
+ from cube3d.model.transformers.dual_stream_attention import (
9
+ DualStreamDecoderLayerWithRotaryEmbedding,
10
+ )
11
+ from cube3d.model.transformers.norm import LayerNorm
12
+ from cube3d.model.transformers.roformer import DecoderLayerWithRotaryEmbedding
13
+ from cube3d.model.transformers.rope import precompute_freqs_cis
14
+
15
+
16
+ class DualStreamRoformer(nn.Module):
17
+ @dataclass
18
+ class Config:
19
+ checkpoint_path: str = ""
20
+ n_layer: int = 12
21
+ n_single_layer: int = 0
22
+ rope_theta: float = 1000
23
+
24
+ n_head: int = 16
25
+ n_embd: int = 2048
26
+ bias: bool = False # bias in Linears and LayerNorms
27
+ eps: float = 1e-6 # Norm eps
28
+
29
+ shape_model_vocab_size: int = 4096
30
+ shape_model_embed_dim: int = 16
31
+
32
+ text_model_embed_dim: int = 512
33
+ use_pooled_text_embed: bool = False
34
+
35
+ encoder_with_cls_token: bool = True
36
+
37
+ use_bbox: bool = False
38
+
39
+ ldr_in_embed_dim: int = 2048
40
+ ldr_out_embed_dim: int = 2048
41
+ def __init__(self, cfg: Config) -> None:
42
+ """
43
+ Initializes the DualStreamRoFormer model.
44
+ Args:
45
+ cfg (Config): Configuration object containing model parameters.
46
+ Attributes:
47
+ cfg (Config): Stores the configuration object.
48
+ text_proj (nn.Linear): Linear layer to project text model embeddings to the desired embedding dimension.
49
+ shape_proj (nn.Linear, optional): Linear layer to project shape model embeddings to the desired embedding
50
+ dimension
51
+ vocab_size (int): Vocabulary size for the shape model, including special tokens.
52
+ shape_bos_id (int): Token ID for the beginning-of-sequence (BOS) token for the shape model.
53
+ shape_eos_id (int): Token ID for the end-of-sequence (EOS) token for the shape model.
54
+ padding_id (int): Token ID for the padding token.
55
+ transformer (nn.ModuleDict): Dictionary containing the following components:
56
+ - wte (nn.Embedding): Embedding layer for the vocabulary.
57
+ - dual_blocks (nn.ModuleList): List of dual-stream decoder layers with rotary embeddings.
58
+ - single_blocks (nn.ModuleList): List of single-stream decoder layers with rotary embeddings.
59
+ - ln_f (LayerNorm): Layer normalization applied to the final output.
60
+ lm_head (nn.Linear): Linear layer mapping the final embeddings to the vocabulary size for language modeling.
61
+ """
62
+
63
+ super().__init__()
64
+
65
+ self.cfg = cfg
66
+
67
+ self.text_proj = nn.Linear(
68
+ in_features=self.cfg.text_model_embed_dim,
69
+ out_features=self.cfg.n_embd,
70
+ bias=self.cfg.bias,
71
+ )
72
+
73
+ self.shape_proj = nn.Linear(self.cfg.shape_model_embed_dim, self.cfg.n_embd)
74
+
75
+ self.ldr_proj = nn.Linear(self.cfg.ldr_in_embed_dim, self.cfg.n_embd)
76
+ #self.postion_proj = nn.Linear(3, 3)
77
+
78
+ self.vocab_size = self.cfg.shape_model_vocab_size
79
+
80
+ x_num = 251
81
+ y_num = 215
82
+ z_num = 525
83
+ rot_num = 24
84
+
85
+ self.x_num = x_num
86
+ self.y_num = y_num
87
+ self.z_num = z_num
88
+ self.rot_num = rot_num
89
+
90
+ self.x = x_num
91
+ self.xy = x_num + y_num + rot_num
92
+ self.xyz = x_num + y_num + z_num + rot_num
93
+ self.dat_num = 1217 #286 #604
94
+ self.dte = nn.Embedding(
95
+ self.dat_num+1,
96
+ #(self.cfg.n_embd-768),
97
+ self.cfg.n_embd,
98
+ padding_idx=self.dat_num,
99
+ )
100
+
101
+ self.rte = nn.Embedding(
102
+ self.rot_num+2,
103
+ #(self.cfg.n_embd-768),
104
+ self.cfg.n_embd,
105
+ padding_idx=self.rot_num,
106
+ )
107
+
108
+ self.xte = nn.Embedding(
109
+ self.x_num+2,
110
+ #(self.cfg.n_embd-768),
111
+ self.cfg.n_embd,
112
+ padding_idx=self.x_num,
113
+ )
114
+
115
+ self.yte = nn.Embedding(
116
+ self.y_num+2,
117
+ #(self.cfg.n_embd-768),
118
+ self.cfg.n_embd,
119
+ padding_idx=self.y_num,
120
+ )
121
+
122
+ self.zte = nn.Embedding(
123
+ self.z_num+2,
124
+ #(self.cfg.n_embd-768),
125
+ self.cfg.n_embd,
126
+ padding_idx=self.z_num,
127
+ )
128
+ self.is_compute = False
129
+
130
+ def add_special_token():
131
+ token_id = self.vocab_size
132
+ self.vocab_size += 1
133
+ return token_id
134
+
135
+ self.shape_bos_id = add_special_token() #16384
136
+ self.shape_eos_id = add_special_token() #16385
137
+ self.padding_id = add_special_token() #16386
138
+
139
+ self.transformer = nn.ModuleDict(
140
+ dict(
141
+ wte=nn.Embedding(
142
+ self.vocab_size,
143
+ self.cfg.n_embd,
144
+ padding_idx=self.padding_id,
145
+ ),
146
+ dual_blocks=nn.ModuleList(
147
+ [
148
+ DualStreamDecoderLayerWithRotaryEmbedding.from_config(
149
+ self.cfg, cond_pre_only=(i == self.cfg.n_layer - 1)
150
+ )
151
+ for i in range(self.cfg.n_layer)
152
+ ]
153
+ ),
154
+ single_blocks=nn.ModuleList(
155
+ [
156
+ DecoderLayerWithRotaryEmbedding.from_config(self.cfg)
157
+ for _ in range(self.cfg.n_single_layer)
158
+ ]
159
+ ),
160
+ ln_f=LayerNorm(
161
+ self.cfg.n_embd, elementwise_affine=False, eps=self.cfg.eps
162
+ ),
163
+ )
164
+ )
165
+
166
+ self.lm_head = nn.Linear(self.cfg.n_embd, self.vocab_size, bias=False)
167
+ self.ldr_head = nn.Linear(self.cfg.n_embd, self.cfg.ldr_out_embed_dim, bias=False)
168
+
169
+ if self.cfg.use_bbox:
170
+ self.bbox_proj = nn.Linear(3, self.cfg.n_embd)
171
+
172
+ def encode_embed(self, ldr_embed):
173
+ """
174
+ Encodes the given ldr embeddings by projecting them through a linear transformation.
175
+ Args:
176
+ ldr_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
177
+ Returns:
178
+ torch.Tensor: The projected ldr embeddings after applying the linear transformation.
179
+ """
180
+
181
+ return self.ldr_proj(ldr_embed)
182
+
183
+ def encode_text(self, text_embed):
184
+ """
185
+ Encodes the given text embeddings by projecting them through a linear transformation.
186
+ Args:
187
+ text_embed (torch.Tensor): A tensor representing the text embeddings to be encoded.
188
+ Returns:
189
+ torch.Tensor: The projected text embeddings after applying the linear transformation.
190
+ """
191
+
192
+ return self.text_proj(text_embed)
193
+
194
+ def encode_token(self, tokens):
195
+ """
196
+ Encodes the input tokens using the word token embedding layer of the transformer model.
197
+ Args:
198
+ tokens (torch.Tensor): A tensor containing the input tokens to be encoded.
199
+ Returns:
200
+ torch.Tensor: A tensor containing the encoded token embeddings.
201
+ """
202
+
203
+ return self.transformer.wte(tokens)
204
+
205
+ def init_kv_cache(
206
+ self,
207
+ batch_size: int,
208
+ cond_len: int,
209
+ max_shape_tokens: int,
210
+ dtype: torch.dtype,
211
+ device: torch.device,
212
+ ) -> list[Cache]:
213
+ """
214
+ Initializes the key-value cache for the transformer model.
215
+ This method creates a list of `Cache` objects to store the key and value
216
+ states for both dual-stream and single-stream transformer blocks. The
217
+ cache is pre-allocated with zeros and is used to optimize the computation
218
+ of attention mechanisms during model inference.
219
+ Args:
220
+ batch_size (int): The batch size for the input data.
221
+ cond_len (int): The length of the conditioning sequence.
222
+ max_shape_tokens (int): The maximum number of tokens in the shape sequence.
223
+ dtype (torch.dtype): The data type for the tensors (e.g., torch.float32).
224
+ device (torch.device): The device on which the tensors will be allocated
225
+ (e.g., torch.device('cuda') or torch.device('cpu')).
226
+ Returns:
227
+ list[Cache]: A list of `Cache` objects containing pre-allocated key and
228
+ value states for each transformer block.
229
+ """
230
+ num_heads = self.cfg.n_head
231
+ max_all_tokens = cond_len + max_shape_tokens
232
+ per_head_dim = self.cfg.n_embd // num_heads
233
+
234
+ kv_cache = [
235
+ Cache(
236
+ key_states=torch.zeros(
237
+ (batch_size, num_heads, max_all_tokens, per_head_dim),
238
+ dtype=dtype,
239
+ device=device,
240
+ ),
241
+ value_states=torch.zeros(
242
+ (batch_size, num_heads, max_all_tokens, per_head_dim),
243
+ dtype=dtype,
244
+ device=device,
245
+ ),
246
+ )
247
+ for _ in range(len(self.transformer.dual_blocks))
248
+ ]
249
+ kv_cache += [
250
+ Cache(
251
+ key_states=torch.zeros(
252
+ (batch_size, num_heads, max_shape_tokens, per_head_dim),
253
+ dtype=dtype,
254
+ device=device,
255
+ ),
256
+ value_states=torch.zeros(
257
+ (batch_size, num_heads, max_shape_tokens, per_head_dim),
258
+ dtype=dtype,
259
+ device=device,
260
+ ),
261
+ )
262
+ for _ in range(len(self.transformer.single_blocks))
263
+ ]
264
+ return kv_cache
265
+
266
+ def forward(
267
+ self,
268
+ embed: torch.Tensor,
269
+ cond: torch.Tensor,
270
+ kv_cache: Optional[list[Cache]] = None,
271
+ curr_pos_id: Optional[torch.Tensor] = None,
272
+ decode: bool = False,
273
+ **kwargs,
274
+ ):
275
+ """
276
+ Forward pass for the dual-stream RoFormer model.
277
+ Args:
278
+ embed (torch.Tensor): The input embedding tensor.
279
+ cond (torch.Tensor): The conditioning tensor.
280
+ kv_cache (Optional[list[Cache]]): A list of key-value caches for each layer, used for decoding. Default is None.
281
+ curr_pos_id (Optional[torch.Tensor]): The current position ID tensor of shape (batch_size,). Required if `decode` is True. Default is None.
282
+ decode (bool): Whether the model is in decoding mode. Default is False.
283
+ Returns:
284
+ torch.Tensor: The output logits tensor.
285
+ """
286
+ b, l = embed.shape[:2]
287
+ s = cond.shape[1]
288
+ device = embed.device
289
+
290
+ # attn_mask = torch.tril(
291
+ # torch.ones(s + l, s + l, dtype=torch.bool, device=device)
292
+ # ) #Causal Attention Mask
293
+ attn_mask = torch.ones(s + l, s + l, dtype=torch.bool, device=device) #Without Attention Mask
294
+
295
+ # positions = torch.arange(s + l, device=device)
296
+ # mask_1d = (positions > 1) & ((positions % 5 == 0) | (positions % 5 == 1) | (positions % 5 == 4))
297
+ # attn_mask[mask_1d, :] = False
298
+ # attn_mask[:, mask_1d] = False
299
+
300
+ position_ids = torch.arange(l, dtype=torch.long, device=device) # shape (t)
301
+ position_ids = position_ids.unsqueeze_(0).expand(b, -1)
302
+ #position_ids = position_ids.unsqueeze(0).expand(b, -1)
303
+
304
+ s_freqs_cis = precompute_freqs_cis(
305
+ dim=self.cfg.n_embd // self.cfg.n_head, # 128
306
+ t=position_ids,
307
+ theta=self.cfg.rope_theta, #10000.0
308
+ )
309
+
310
+ position_ids = torch.cat(
311
+ [
312
+ torch.zeros([b, s], dtype=torch.long, device=position_ids.device),
313
+ position_ids,
314
+ ],
315
+ dim=1,
316
+ ) #full position_ids
317
+
318
+ d_freqs_cis = precompute_freqs_cis(
319
+ dim=self.cfg.n_embd // self.cfg.n_head,
320
+ t=position_ids,
321
+ theta=self.cfg.rope_theta,
322
+ ) #full position embedding
323
+
324
+ #import ipdb; ipdb.set_trace()
325
+ if kv_cache is not None and decode:
326
+ assert curr_pos_id is not None
327
+ embed = embed[:, curr_pos_id, :]
328
+ #print(decode)
329
+
330
+ h = embed
331
+ c = cond
332
+
333
+ layer_idx = 0
334
+ for block in self.transformer.dual_blocks:
335
+ h, c = block(
336
+ h,
337
+ c=c,
338
+ freqs_cis=d_freqs_cis,
339
+ attn_mask=attn_mask,
340
+ is_causal=True,
341
+ kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
342
+ #kv_cache=None,
343
+ curr_pos_id=curr_pos_id + s if curr_pos_id is not None else None,
344
+ decode=decode,
345
+ )
346
+ layer_idx += 1
347
+
348
+
349
+ for block in self.transformer.single_blocks:
350
+ h = block(
351
+ h,
352
+ freqs_cis=s_freqs_cis,
353
+ attn_mask=None,
354
+ is_causal=True,
355
+ kv_cache=kv_cache[layer_idx] if kv_cache is not None else None,
356
+ #kv_cache=None,
357
+ curr_pos_id=curr_pos_id,
358
+ decode=decode,
359
+ )
360
+ layer_idx += 1
361
+
362
+ #import ipdb; ipdb.set_trace()
363
+ # Normalization
364
+ h = self.transformer.ln_f(h)
365
+ logits = self.ldr_head(h)
366
+
367
+ return logits