zirobtc commited on
Commit
8c087f9
·
verified ·
1 Parent(s): 69b50e2

Upload llama_model_v1.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llama_model_v1.py +1685 -0
llama_model_v1.py ADDED
@@ -0,0 +1,1685 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from typing_extensions import Self
9
+ from typing import Optional
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from torch.distributions import Categorical
12
+ import torch.nn.functional as F
13
+
14
+
15
+ @dataclass
16
+ class LLaMAHFConfig:
17
+ block_size: int = 78
18
+ n_layer: int = 32
19
+ n_head: int = 32
20
+ n_embd: int = 4096
21
+ T5_xxl_dim: int = 768
22
+
23
+ @classmethod
24
+ def from_name(cls, name: str) -> Self:
25
+ return cls(**llama_configs[name])
26
+
27
+
28
+ llama_configs = {
29
+ "Normal_size": dict(n_layer=12, n_head=12, n_embd=768)
30
+ }
31
+
32
+
33
+ class LLaMAHF(nn.Module):
34
+ def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=9, input_token_dim=16, device=torch.device('cuda'), width=1792) -> None:
35
+ super().__init__()
36
+ assert config.block_size is not None
37
+ self.config = config
38
+
39
+ cond_dim = config.T5_xxl_dim
40
+
41
+ self.transformer = nn.ModuleDict(
42
+ dict(
43
+ wte=nn.Linear(input_token_dim, config.n_embd),
44
+ cond_embed=nn.Linear(cond_dim, config.n_embd),
45
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
46
+ ln_f=RMSNorm(config.n_embd),
47
+ )
48
+ )
49
+
50
+ target_channels = input_token_dim
51
+ from models.diffloss import DiffLoss
52
+ self.diff_loss = DiffLoss(
53
+ target_channels=target_channels,
54
+ z_channels=config.n_embd,
55
+ width=width,
56
+ depth=num_diffusion_head_layers,
57
+ num_sampling_steps='50',
58
+ grad_checkpointing=False,
59
+ )
60
+ self.diff_loss = self.diff_loss.to(device)
61
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd)
62
+ self.use_out_proj = True
63
+
64
+
65
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
66
+ """Tie or clone module weights depending of whether we are using TorchScript or not"""
67
+ output_embeddings.weight = input_embeddings.weight
68
+
69
+ if getattr(output_embeddings, "bias", None) is not None:
70
+ output_embeddings.bias.data = nn.functional.pad(
71
+ output_embeddings.bias.data,
72
+ (
73
+ 0,
74
+ output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
75
+ ),
76
+ "constant",
77
+ 0,
78
+ )
79
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
80
+ output_embeddings.out_features = input_embeddings.num_embeddings
81
+
82
+ def get_input_embeddings(self):
83
+ return self.transformer.wte
84
+
85
+ def set_input_embeddings(self, value):
86
+ self.transformer.wte = value
87
+
88
+ def get_output_embeddings(self):
89
+ return self.lm_head
90
+
91
+ def set_output_embeddings(self, new_embeddings):
92
+ self.lm_head = new_embeddings
93
+
94
+ def _init_weights(self, module: nn.Module) -> None:
95
+ if isinstance(module, nn.Linear):
96
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
97
+ elif isinstance(module, nn.Embedding):
98
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
99
+
100
+
101
+
102
+ def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor:
103
+
104
+ text_length = clip_feature.shape[1]
105
+ if len(idx) == 0:
106
+ x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :]
107
+ else:
108
+ _, t = idx.size()
109
+ assert (
110
+ t <= self.config.block_size
111
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
112
+ # forward the LLaMA model itself
113
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
114
+ x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1)
115
+
116
+ for block in self.transformer.h:
117
+ x = block(x, y_mask)
118
+ x = self.transformer.ln_f(x)
119
+ logits = x
120
+ return logits
121
+
122
+
123
+
124
+ def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0):
125
+ max_token_len = length // unit_length
126
+ for k in range(max_token_len):
127
+ if k == 0:
128
+ x = []
129
+ else:
130
+ x = xs
131
+
132
+ feat_text = torch.from_numpy(tokenize_model.encode(text)).float()
133
+ feat_text = feat_text.to(device)
134
+ conditions = self.forward(x, feat_text)
135
+ conditions = conditions[:, -1, :]
136
+
137
+ empty_text = ''
138
+ empty_feat_text = torch.from_numpy(tokenize_model.encode(empty_text)).float()
139
+ empty_feat_text = empty_feat_text.unsqueeze(0)
140
+ empty_feat_text = empty_feat_text.to(device)
141
+ empty_conditions = self.forward(x, empty_feat_text)
142
+ empty_conditions = empty_conditions[:, -1, :]
143
+ temperature = 1.0
144
+
145
+ # chunk
146
+ if cfg != 1:
147
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
148
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
149
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
150
+ else: # no cfg
151
+ scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1)
152
+
153
+ scaled_logits = scaled_logits.unsqueeze(0)
154
+
155
+ if k == 0:
156
+ xs = scaled_logits
157
+ else:
158
+ xs = torch.cat((xs, scaled_logits), dim=1)
159
+
160
+ return xs
161
+
162
+
163
+
164
+ # For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold.
165
+ def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0):
166
+ max_token_len = length // unit_length
167
+ feat_text = torch.from_numpy(tokenizer.encode(text)).float()
168
+ feat_text = feat_text.to(device)
169
+
170
+ # CFG inference
171
+ empty_text = ''
172
+ empty_feat_text = torch.from_numpy(tokenizer.encode(empty_text)).float() # torch.Size([32, 768])
173
+ empty_feat_text = empty_feat_text.unsqueeze(0)
174
+ empty_feat_text = empty_feat_text.to(device)
175
+
176
+ for k in range(max_token_len):
177
+ if k == 0:
178
+ x = []
179
+ else:
180
+ x = xs
181
+
182
+ conditions = self.forward_inference(x, feat_text)
183
+ conditions = conditions[:, -1, :]
184
+
185
+ empty_conditions = self.forward(x, empty_feat_text)
186
+ empty_conditions = empty_conditions[:, -1, :]
187
+
188
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
189
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
190
+
191
+ # chunk
192
+ if cfg != 1:
193
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
194
+ else:
195
+ scaled_logits = sampled_token_latent
196
+
197
+ scaled_logits = scaled_logits.unsqueeze(0)
198
+
199
+ if reference_end_latent is not None:
200
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2))
201
+ print(distance_l2)
202
+ if distance_l2 < threshold:
203
+ break
204
+
205
+ if k == 0:
206
+ xs = scaled_logits
207
+ else:
208
+ xs = torch.cat((xs, scaled_logits), dim=1)
209
+
210
+ return xs
211
+
212
+
213
+ def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
214
+
215
+ import clip
216
+ max_token_len = length // unit_length
217
+
218
+ for k in range(max_token_len):
219
+ if k == 0:
220
+ x = []
221
+ else:
222
+ x = xs
223
+
224
+ try:
225
+ conditions = self.forward(x, feat_clip_text)
226
+ except:
227
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
228
+
229
+
230
+ conditions = conditions[:, -1, :]
231
+
232
+
233
+
234
+ empty_conditions = self.forward(x, empty_feat_clip_text)
235
+ empty_conditions = empty_conditions[:, -1, :]
236
+
237
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
238
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
239
+
240
+ # chunk
241
+ if cfg != 1:
242
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
243
+ else:
244
+ scaled_logits = sampled_token_latent
245
+
246
+ scaled_logits = scaled_logits.unsqueeze(0)
247
+
248
+ if reference_end_token is not None:
249
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
250
+ print(distance_l2)
251
+ if distance_l2 < threshold:
252
+ break
253
+
254
+ if k == 0:
255
+ xs = scaled_logits
256
+ else:
257
+ xs = torch.cat((xs, scaled_logits), dim=1)
258
+
259
+ return xs
260
+
261
+ def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
262
+
263
+ import clip
264
+ max_token_len = length // unit_length
265
+
266
+
267
+ for k in range(1):
268
+
269
+ if current_token == []:
270
+ x = []
271
+ else:
272
+ x = torch.cat(current_token, dim=1)
273
+
274
+
275
+ try:
276
+ conditions = self.forward(x, feat_clip_text)
277
+ except:
278
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
279
+
280
+
281
+ conditions = conditions[:, -1, :]
282
+
283
+
284
+ empty_conditions = self.forward(x, empty_feat_clip_text)
285
+ empty_conditions = empty_conditions[:, -1, :]
286
+
287
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
288
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
289
+
290
+ # chunk
291
+ if cfg != 1:
292
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
293
+ else:
294
+ scaled_logits = sampled_token_latent
295
+
296
+
297
+ scaled_logits = scaled_logits.unsqueeze(0)
298
+
299
+
300
+ if k == 0:
301
+ xs = scaled_logits
302
+ else:
303
+ xs = torch.cat((xs, scaled_logits), dim=1)
304
+
305
+ return xs
306
+
307
+
308
+ def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
309
+
310
+ import clip
311
+ B_token_length = length // unit_length - A_motion.shape[0]
312
+
313
+ if tokenizer == 'clip':
314
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
315
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
316
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
317
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
318
+ elif tokenizer == 't5-xxl':
319
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
320
+ A_feat_clip_text = A_feat_clip_text.to(device)
321
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
322
+ B_feat_clip_text = B_feat_clip_text.to(device)
323
+
324
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
325
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
326
+
327
+ A_motion = A_motion.unsqueeze(0)
328
+ A_motion_embeddings = self.transformer.wte(A_motion)
329
+ B_motion = torch.tensor([]).to(device)
330
+
331
+ for k in range(B_token_length):
332
+ if k == 0:
333
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
334
+ else:
335
+ x = xs
336
+
337
+
338
+ conditions = self.forward_babel_eval(x)
339
+ conditions = conditions[:, -1, :]
340
+
341
+ empty_clip_text = ''
342
+ if tokenizer == 'clip':
343
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
344
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
345
+ elif tokenizer == 't5-xxl':
346
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
347
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
348
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
349
+
350
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
351
+
352
+ if k == 0:
353
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
354
+ empty_conditions = self.forward_babel_eval(empty_input)
355
+ else:
356
+ B_motion_embeddings = self.transformer.wte(B_motion)
357
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
358
+ empty_conditions = self.forward_babel_eval(empty_input)
359
+
360
+ empty_conditions = empty_conditions[:, -1, :]
361
+ temperature = 1.0
362
+
363
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
364
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
365
+
366
+ # chunk
367
+ if cfg != 1:
368
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
369
+ else:
370
+ scaled_logits = sampled_token_latent
371
+
372
+
373
+ scaled_logits = scaled_logits.unsqueeze(0)
374
+
375
+
376
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
377
+
378
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
379
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
380
+
381
+
382
+ return xs, B_motion
383
+
384
+ def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
385
+
386
+ import clip
387
+ B_token_length = length // unit_length - A_motion.shape[0]
388
+
389
+ if tokenizer == 'clip':
390
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
391
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
392
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
393
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
394
+ elif tokenizer == 't5-xxl':
395
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
396
+ A_feat_clip_text = A_feat_clip_text.to(device)
397
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
398
+ B_feat_clip_text = B_feat_clip_text.to(device)
399
+
400
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
401
+ A_text_embeddings = A_text_embeddings.unsqueeze(0)
402
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
403
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
404
+
405
+ A_motion = A_motion.unsqueeze(0)
406
+ A_motion_embeddings = self.transformer.wte(A_motion)
407
+ B_motion = torch.tensor([]).to(device)
408
+
409
+ attention_weights = []
410
+
411
+ for k in range(B_token_length):
412
+ if k == 0:
413
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
414
+
415
+ else:
416
+ x = xs
417
+
418
+
419
+
420
+ conditions = self.forward_babel_eval(x, return_attention=False)
421
+ conditions = conditions[:, -1, :]
422
+
423
+ empty_clip_text = ''
424
+ if tokenizer == 'clip':
425
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
426
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
427
+ elif tokenizer == 't5-xxl':
428
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
429
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
430
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
431
+
432
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
433
+
434
+ if k == 0:
435
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
436
+ empty_conditions = self.forward_babel_eval(empty_input)
437
+ else:
438
+ B_motion_embeddings = self.transformer.wte(B_motion)
439
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
440
+ empty_conditions = self.forward_babel_eval(empty_input)
441
+
442
+ empty_conditions = empty_conditions[:, -1, :]
443
+ temperature = 1.0
444
+
445
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
446
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
447
+
448
+ # chunk
449
+ if cfg != 1:
450
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
451
+ else:
452
+ scaled_logits = sampled_token_latent
453
+
454
+ scaled_logits = scaled_logits.unsqueeze(0)
455
+
456
+ if reference_end_token is not None:
457
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
458
+ print(distance_l2)
459
+ if distance_l2 < threshold:
460
+ break
461
+
462
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
463
+
464
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
465
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
466
+
467
+
468
+
469
+ return xs, B_motion
470
+
471
+
472
+ def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3):
473
+
474
+ import clip
475
+ B_token_length = length // unit_length
476
+
477
+ if tokenizer == 'clip':
478
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
479
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
480
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
481
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
482
+ elif tokenizer == 't5-xxl':
483
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
484
+ B_feat_clip_text = B_feat_clip_text.to(device)
485
+
486
+ empty_clip_text = ''
487
+ if tokenizer == 'clip':
488
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
489
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
490
+ elif tokenizer == 't5-xxl':
491
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
492
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
493
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
494
+
495
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
496
+
497
+ A_motion = A_motion.unsqueeze(0)
498
+ A_motion_embeddings = self.transformer.wte(A_motion)
499
+ B_motion = torch.tensor([]).to(device)
500
+
501
+
502
+ attention_weights = []
503
+
504
+ for k in range(B_token_length):
505
+ if k == 0:
506
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
507
+ else:
508
+ x = xs
509
+
510
+ conditions = self.forward_babel_eval(x, return_attention=False)
511
+ conditions = conditions[:, -1, :]
512
+
513
+
514
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
515
+
516
+ if k == 0:
517
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
518
+
519
+ empty_conditions = self.forward_babel_eval(empty_input)
520
+ else:
521
+ B_motion_embeddings = self.transformer.wte(B_motion)
522
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
523
+ empty_conditions = self.forward_babel_eval(empty_input)
524
+
525
+ empty_conditions = empty_conditions[:, -1, :]
526
+ temperature = 1.0
527
+
528
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
529
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
530
+
531
+ # chunk
532
+ if cfg != 1:
533
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
534
+ else:
535
+ scaled_logits = sampled_token_latent
536
+
537
+ scaled_logits = scaled_logits.unsqueeze(0)
538
+
539
+ if reference_end_token is not None:
540
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
541
+ print(distance_l2)
542
+ if distance_l2 < threshold:
543
+ break
544
+
545
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
546
+
547
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
548
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
549
+
550
+
551
+
552
+ return xs, B_motion
553
+
554
+
555
+ def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
556
+
557
+ import clip
558
+ B_token_length = length // unit_length - A_motion.shape[0]
559
+
560
+ if tokenizer == 'clip':
561
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
562
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
563
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
564
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
565
+ elif tokenizer == 't5-xxl':
566
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
567
+ B_feat_clip_text = B_feat_clip_text.to(device)
568
+
569
+ empty_clip_text = ''
570
+ if tokenizer == 'clip':
571
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
572
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
573
+ elif tokenizer == 't5-xxl':
574
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
575
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
576
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
577
+
578
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
579
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
580
+
581
+ A_motion = A_motion.unsqueeze(0)
582
+ A_motion_embeddings = self.transformer.wte(A_motion)
583
+ B_motion = torch.tensor([]).to(device)
584
+
585
+ # 存储所有层的注意力权重
586
+ attention_weights = []
587
+
588
+ for k in range(B_token_length):
589
+ if k == 0:
590
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
591
+
592
+ else:
593
+ x = xs
594
+
595
+
596
+ conditions = self.forward_babel_eval(x, return_attention=False)
597
+ conditions = conditions[:, -1, :]
598
+
599
+
600
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
601
+
602
+ if k == 0:
603
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
604
+ empty_conditions = self.forward_babel_eval(empty_input)
605
+ else:
606
+ B_motion_embeddings = self.transformer.wte(B_motion)
607
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
608
+ empty_conditions = self.forward_babel_eval(empty_input)
609
+
610
+ empty_conditions = empty_conditions[:, -1, :]
611
+
612
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
613
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
614
+
615
+ # chunk
616
+ if cfg != 1:
617
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
618
+ else:
619
+ scaled_logits = sampled_token_latent
620
+
621
+ scaled_logits = scaled_logits.unsqueeze(0)
622
+
623
+ if reference_end_token is not None:
624
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
625
+ print(distance_l2)
626
+ if distance_l2 < threshold and k > 10:
627
+ break
628
+
629
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
630
+
631
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
632
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
633
+
634
+
635
+
636
+ return xs, B_motion
637
+
638
+ def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
639
+ """
640
+ Inference loop that mimics the "Two-Forward" training strategy.
641
+ This version is rewritten for clarity and to fix dimension mismatches.
642
+ """
643
+ import clip
644
+ print("\n--- [DEBUG] Entering Two-Forward Inference ---")
645
+
646
+ B_token_length = length // unit_length - A_motion.shape[0]
647
+
648
+ if tokenizer == 't5-xxl':
649
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device)
650
+ else:
651
+ raise NotImplementedError("Only t5-xxl is supported for this function.")
652
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device)
653
+
654
+ # --- Consistently create 3D embeddings [batch, seq, dim] ---
655
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0)
656
+ empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
657
+ print(f"[DEBUG] Initial B_text_embeddings shape: {B_text_embeddings.shape}")
658
+ print(f"[DEBUG] Initial empty_text_embeddings shape: {empty_text_embeddings.shape}")
659
+
660
+ # --- Initial motion history setup ---
661
+ A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0))
662
+ current_motion_tokens = A_motion
663
+ print(f"[DEBUG] Initial A_motion_embeddings shape: {A_motion_embeddings.shape}")
664
+ print(f"[DEBUG] Initial current_motion_tokens shape: {current_motion_tokens.shape}")
665
+
666
+ for k in range(B_token_length):
667
+ print(f"\n--- [DEBUG] Generating Token {k+1}/{B_token_length} ---")
668
+
669
+ # --- Prepare motion embeddings for this step ---
670
+ if k == 0:
671
+ # Use the empty initial motion history
672
+ step_motion_embeddings = A_motion_embeddings
673
+ else:
674
+ # Embed the current token history
675
+ step_motion_embeddings = self.transformer.wte(current_motion_tokens.unsqueeze(0))
676
+ print(f"[DEBUG] step_motion_embeddings shape: {step_motion_embeddings.shape}")
677
+
678
+ # === 1. First Forward Pass (Rough Draft) ===
679
+ x_first_pass = torch.cat([B_text_embeddings, step_motion_embeddings], dim=1)
680
+ empty_input_first_pass = torch.cat([empty_text_embeddings, step_motion_embeddings], dim=1)
681
+ print(f"[DEBUG] x_first_pass shape: {x_first_pass.shape}")
682
+ print(f"[DEBUG] empty_input_first_pass shape: {empty_input_first_pass.shape}")
683
+
684
+ # Get conditions from the transformer
685
+ conditions_first_pass = self.forward_babel_eval(x_first_pass, return_attention=False)[:, -1, :]
686
+ empty_conditions_first_pass = self.forward_babel_eval(empty_input_first_pass)[:, -1, :]
687
+
688
+ # Sample a rough prediction for the next token using diffusion
689
+ mix_conditions_first_pass = torch.cat([conditions_first_pass, empty_conditions_first_pass], dim=0)
690
+ pred_xstart_rough = self.diff_loss.sample(mix_conditions_first_pass, temperature=temperature, cfg=cfg)
691
+ if cfg != 1:
692
+ pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0)
693
+ print(f"[DEBUG] pred_xstart_rough shape: {pred_xstart_rough.shape}")
694
+
695
+ # === 2. Second Forward Pass (Refined) ===
696
+ # Create the "dirtied" input by appending the rough prediction
697
+ updated_motion_tokens = torch.cat([current_motion_tokens, pred_xstart_rough], dim=0)
698
+ updated_motion_embeddings = self.transformer.wte(updated_motion_tokens.unsqueeze(0))
699
+ print(f"[DEBUG] updated_motion_embeddings shape: {updated_motion_embeddings.shape}")
700
+
701
+ x_second_pass = torch.cat([B_text_embeddings, updated_motion_embeddings], dim=1)
702
+ empty_input_second_pass = torch.cat([empty_text_embeddings, updated_motion_embeddings], dim=1)
703
+
704
+ conditions_second_pass = self.forward_babel_eval(x_second_pass, return_attention=False)[:, -1, :]
705
+ empty_conditions_second_pass = self.forward_babel_eval(empty_input_second_pass)[:, -1, :]
706
+
707
+ # Sample the final, refined token using diffusion
708
+ mix_conditions_second_pass = torch.cat([conditions_second_pass, empty_conditions_second_pass], dim=0)
709
+ final_token, _ = self.diff_loss.sample(mix_conditions_second_pass, temperature=temperature, cfg=cfg).chunk(2, dim=0)
710
+ print(f"[DEBUG] final_token shape: {final_token.shape}")
711
+
712
+ # === 3. Update History ===
713
+ current_motion_tokens = torch.cat([current_motion_tokens, final_token], dim=0)
714
+ print(f"[DEBUG] New current_motion_tokens shape: {current_motion_tokens.shape}")
715
+
716
+ print("\n--- [DEBUG] Finished Token Generation ---")
717
+ # Return only the newly generated tokens (B_motion)
718
+ B_motion = current_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0)
719
+ return None, B_motion
720
+
721
+
722
+ #--------------Test classification head--------------------
723
+ def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
724
+
725
+ import clip
726
+
727
+
728
+ for k in range(51):
729
+ if k == 0:
730
+ x = []
731
+ else:
732
+ x = xs
733
+
734
+ if tokenizer == 'clip':
735
+ text = clip.tokenize(clip_text, truncate=True).to(device)
736
+
737
+ feat_clip_text = clip_model.encode_text(text).float()
738
+ elif tokenizer == 't5-xxl':
739
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
740
+
741
+ conditions = self.forward(x, feat_clip_text)
742
+ conditions = conditions[:, -1, :]
743
+
744
+ empty_clip_text = ''
745
+ if tokenizer == 'clip':
746
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
747
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
748
+ elif tokenizer == 't5-xxl':
749
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
750
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
751
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
752
+
753
+ empty_conditions = self.forward(x, empty_feat_clip_text)
754
+ empty_conditions = empty_conditions[:, -1, :]
755
+
756
+ temperature = 1.0
757
+ cfg = 7.5
758
+
759
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
760
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
761
+
762
+ # chunk
763
+ if cfg != 1:
764
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
765
+ else:
766
+ scaled_logits = sampled_token_latent
767
+
768
+
769
+ prediction_logits = self.classify_head(conditions)
770
+ probs = torch.sigmoid(prediction_logits)
771
+ predicted_classes = torch.argmax(probs, dim=-1)
772
+
773
+
774
+ scaled_logits = scaled_logits.unsqueeze(0)
775
+
776
+ if k == 0:
777
+ xs = scaled_logits
778
+ else:
779
+ xs = torch.cat((xs, scaled_logits), dim=1)
780
+
781
+ if predicted_classes == 1:
782
+ break
783
+
784
+ return xs
785
+
786
+
787
+ #--------------------Test CFG-----------------------
788
+ def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
789
+
790
+ import clip
791
+ max_token_len = length // unit_length
792
+
793
+
794
+ for k in range(max_token_len):
795
+ if k == 0:
796
+ x = []
797
+ else:
798
+ x = xs
799
+
800
+
801
+ if cfg != 1:
802
+ if tokenizer == 'clip':
803
+ text = clip.tokenize(clip_text, truncate=True).to(device)
804
+
805
+ feat_clip_text = clip_model.encode_text(text).float()
806
+ elif tokenizer == 't5-xxl':
807
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
808
+
809
+ conditions = self.forward(x, feat_clip_text)
810
+
811
+ conditions = conditions[:, -1, :]
812
+ empty_clip_text = ''
813
+ if tokenizer == 'clip':
814
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
815
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
816
+ elif tokenizer == 't5-xxl':
817
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
818
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
819
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
820
+
821
+ empty_conditions = self.forward(x, empty_feat_clip_text)
822
+ empty_conditions = empty_conditions[:, -1, :]
823
+ temperature = 1.0
824
+
825
+
826
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
827
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
828
+
829
+ # chunk
830
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
831
+
832
+ else:
833
+ if tokenizer == 'clip':
834
+ text = clip.tokenize(clip_text, truncate=True).to(device)
835
+ feat_clip_text = clip_model.encode_text(text).float()
836
+ elif tokenizer == 't5-xxl':
837
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
838
+ feat_clip_text = feat_clip_text.to(device)
839
+
840
+
841
+ conditions = self.forward(x, feat_clip_text)
842
+
843
+ conditions = conditions[:, -1, :]
844
+ temperature = 1.0
845
+ sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg)
846
+ scaled_logits = sampled_token_latent
847
+
848
+ scaled_logits = scaled_logits.unsqueeze(0)
849
+
850
+ if k == 0:
851
+ xs = scaled_logits
852
+ else:
853
+ xs = torch.cat((xs, scaled_logits), dim=1)
854
+
855
+ return xs
856
+ #--------------------------------------------------
857
+
858
+ def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor:
859
+ if len(idx) == 0:
860
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(0)
861
+
862
+ else:
863
+ b, t = idx.size()
864
+ #idx = idx.float()
865
+ assert (
866
+ t <= self.config.block_size
867
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
868
+
869
+ # forward the LLaMA model itself
870
+ token_embeddings = self.transformer.wte(idx)
871
+ text_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
872
+ token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
873
+
874
+ x = token_embeddings
875
+
876
+ # -------------------kv cache-------------------
877
+ #presents = () if use_cache else None
878
+ if use_cache:
879
+ if past_key_values is None:
880
+ past_key_values = [None] * len(self.transformer.h)
881
+
882
+
883
+ for i,block in enumerate(self.transformer.h):
884
+ if use_cache:
885
+ last_past = past_key_values[i]
886
+ x, presents = block(x, last_past, use_cache)
887
+ past_key_values[i] = list(presents)
888
+ else:
889
+ x = block(x)
890
+ x = self.transformer.ln_f(x)
891
+
892
+ logits = self.lm_head(x)
893
+
894
+
895
+ return logits
896
+
897
+
898
+ def forward(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
899
+ if len(idx) == 0:
900
+ token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
901
+
902
+ else:
903
+ b, t, c = idx.size()
904
+ idx = idx.float()
905
+ assert (
906
+ t <= self.config.block_size
907
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
908
+
909
+ # forward the LLaMA model itself
910
+ token_embeddings = self.transformer.wte(idx)
911
+ text_embeddings = self.transformer.cond_embed(feature).unsqueeze(1)
912
+ token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
913
+
914
+ x = token_embeddings
915
+
916
+ for i,block in enumerate(self.transformer.h):
917
+ x = block(x)
918
+ x = self.transformer.ln_f(x)
919
+ logits = self.out_proj(x)
920
+ return logits
921
+
922
+
923
+ def forward_inference(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
924
+ if len(idx) == 0:
925
+ token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
926
+
927
+ else:
928
+ b, t, c = idx.size()
929
+ idx = idx.float()
930
+ assert (
931
+ t <= self.config.block_size
932
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
933
+
934
+ # forward the LLaMA model itself
935
+ token_embeddings = self.transformer.wte(idx)
936
+ text_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
937
+ token_embeddings = torch.cat([text_embeddings.unsqueeze(0), token_embeddings], dim=1)
938
+
939
+ x = token_embeddings
940
+
941
+ if len(x.shape) == 2:
942
+ x = x.unsqueeze(0)
943
+
944
+ for i,block in enumerate(self.transformer.h):
945
+ x = block(x)
946
+ x = self.transformer.ln_f(x)
947
+ logits = self.out_proj(x)
948
+ return logits
949
+
950
+
951
+ def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor:
952
+
953
+ b, t, c = idx.size()
954
+ idx = idx.float()
955
+ idx = self.transformer.wte(idx)
956
+ assert (
957
+ t <= self.config.block_size
958
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
959
+ for i in range(b):
960
+ length_i = length[i][:num_subseq[i]]
961
+ clip_feature_i = clip_feature[i][:num_subseq[i]]
962
+
963
+ pointer = 0
964
+ for j in range(num_subseq[i]):
965
+ if j > 0:
966
+ pointer += length_i[j].item()
967
+ pointer += 1
968
+ pointer = int(pointer)
969
+
970
+ clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1)
971
+ idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0]
972
+
973
+ x = idx
974
+
975
+
976
+ if use_cache:
977
+ if past_key_values is None:
978
+ past_key_values = [None] * len(self.transformer.h)
979
+
980
+
981
+ for i,block in enumerate(self.transformer.h):
982
+ if use_cache:
983
+ last_past = past_key_values[i]
984
+ x, presents = block(x, last_past, use_cache)
985
+ past_key_values[i] = list(presents)
986
+ else:
987
+ x = block(x)
988
+ x = self.transformer.ln_f(x)
989
+
990
+ logits = self.out_proj(x)
991
+ return logits
992
+
993
+
994
+ def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor:
995
+ layer_attentions = []
996
+ for block in self.transformer.h:
997
+ if return_attention:
998
+ x, att = block(x, return_attention=True)
999
+ layer_attentions.append(att)
1000
+ else:
1001
+ x = block(x)
1002
+
1003
+ x = self.transformer.ln_f(x)
1004
+ if self.use_out_proj:
1005
+ logits = self.out_proj(x)
1006
+ else:
1007
+ logits = x
1008
+
1009
+ if return_attention:
1010
+ return logits, layer_attentions
1011
+ return logits
1012
+
1013
+ def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor:
1014
+ if len(idx) == 0: # inference
1015
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
1016
+
1017
+ else:
1018
+ b, t, c = idx.size()
1019
+ idx = idx.float()
1020
+ assert (
1021
+ t <= self.config.block_size
1022
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1023
+
1024
+
1025
+
1026
+ A_feature = clip_feature[:, 0, :]
1027
+ B_feature = clip_feature[:, 1, :]
1028
+
1029
+
1030
+ A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1)
1031
+ B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1)
1032
+
1033
+ token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device)
1034
+ for i in range(b):
1035
+ A_idx = idx[i, :A_token_length[i].item(), :]
1036
+ B_idx = idx[i, A_token_length[i].item():-2, :]
1037
+ token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024)
1038
+
1039
+ x = token_embeddings
1040
+ for block in self.transformer.h:
1041
+ x = block(x)
1042
+ x = self.transformer.ln_f(x)
1043
+
1044
+ if self.use_out_proj:
1045
+ logits = self.out_proj(x)
1046
+ else:
1047
+ logits = x
1048
+
1049
+
1050
+ return logits
1051
+
1052
+ def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor:
1053
+ if len(idx) == 0: # inference
1054
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
1055
+
1056
+ else:
1057
+ b, t, c = idx.size()
1058
+ idx = idx.float()
1059
+ assert (
1060
+ t <= self.config.block_size
1061
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1062
+
1063
+ B_feature = clip_feature
1064
+ B_text_embeddings = self.transformer.cond_embed(B_feature)
1065
+
1066
+ idx_embeddings = self.transformer.wte(idx)
1067
+
1068
+
1069
+ token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1)
1070
+
1071
+
1072
+ x = token_embeddings
1073
+ for block in self.transformer.h:
1074
+ x = block(x)
1075
+ x = self.transformer.ln_f(x)
1076
+
1077
+ if self.use_out_proj:
1078
+ logits = self.out_proj(x)
1079
+ else:
1080
+ logits = x
1081
+
1082
+ return logits
1083
+
1084
+
1085
+ def resize_token_embeddings(
1086
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False
1087
+ ) -> nn.Embedding:
1088
+ """
1089
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1090
+
1091
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1092
+
1093
+ Arguments:
1094
+ new_num_tokens (`int`, *optional*):
1095
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
1096
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
1097
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
1098
+ pad_to_multiple_of (`int`, *optional*):
1099
+ If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
1100
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1101
+
1102
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1103
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1104
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1105
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1106
+
1107
+ Return:
1108
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
1109
+ """
1110
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1111
+ if new_num_tokens is None and pad_to_multiple_of is None:
1112
+ return model_embeds
1113
+
1114
+ # Update base model and current model config
1115
+ self.config.vocab_size = model_embeds.weight.shape[0]
1116
+ self.vocab_size = model_embeds.weight.shape[0]
1117
+
1118
+ # Tie weights again if needed
1119
+ # self.tie_weights()
1120
+
1121
+ return model_embeds
1122
+
1123
+ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
1124
+ old_embeddings = self.get_input_embeddings()
1125
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
1126
+ old_embeddings_requires_grad = old_embeddings.weight.requires_grad
1127
+ new_embeddings.requires_grad_(old_embeddings_requires_grad)
1128
+ self.set_input_embeddings(new_embeddings)
1129
+
1130
+ # Update new_num_tokens with the actual size of new_embeddings
1131
+ if pad_to_multiple_of is not None:
1132
+ # if is_deepspeed_zero3_enabled():
1133
+ # import deepspeed
1134
+
1135
+ # with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
1136
+ # new_num_tokens = new_embeddings.weight.shape[0]
1137
+ # else:
1138
+ new_num_tokens = new_embeddings.weight.shape[0]
1139
+
1140
+ # if word embeddings are not tied, make sure that lm head is resized as well
1141
+ # if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
1142
+ if self.get_output_embeddings() is not None and not False:
1143
+ old_lm_head = self.get_output_embeddings()
1144
+ new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
1145
+ # if hasattr(old_lm_head, "_hf_hook"):
1146
+ # hook = old_lm_head._hf_hook
1147
+ # add_hook_to_module(new_lm_head, hook)
1148
+ old_lm_head_requires_grad = old_lm_head.weight.requires_grad
1149
+ new_lm_head.requires_grad_(old_lm_head_requires_grad)
1150
+ self.set_output_embeddings(new_lm_head)
1151
+
1152
+ return self.get_input_embeddings()
1153
+
1154
+ def _get_resized_embeddings(
1155
+ self,
1156
+ old_embeddings: nn.Embedding,
1157
+ new_num_tokens: Optional[int] = None,
1158
+ pad_to_multiple_of: Optional[int] = None,
1159
+ ) -> nn.Embedding:
1160
+ """
1161
+ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
1162
+ initialized vectors at the end. Reducing the size will remove vectors from the end
1163
+
1164
+ Args:
1165
+ old_embeddings (`torch.nn.Embedding`):
1166
+ Old embeddings to be resized.
1167
+ new_num_tokens (`int`, *optional*):
1168
+ New number of tokens in the embedding matrix.
1169
+
1170
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1171
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1172
+ `torch.nn.Embedding` module of the model without doing anything.
1173
+ pad_to_multiple_of (`int`, *optional*):
1174
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
1175
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1176
+
1177
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1178
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1179
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1180
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1181
+
1182
+
1183
+ Return:
1184
+ `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
1185
+ `new_num_tokens` is `None`
1186
+ """
1187
+
1188
+ if pad_to_multiple_of is not None:
1189
+ if not isinstance(pad_to_multiple_of, int):
1190
+ raise ValueError(
1191
+ f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
1192
+ )
1193
+ if new_num_tokens is None:
1194
+ new_num_tokens = old_embeddings.weight.shape[0]
1195
+ new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
1196
+ else:
1197
+ print(
1198
+ "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
1199
+ f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
1200
+ " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
1201
+ " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
1202
+ )
1203
+
1204
+ if new_num_tokens is None:
1205
+ return old_embeddings
1206
+
1207
+ # if is_deepspeed_zero3_enabled():
1208
+ if False:
1209
+ import deepspeed
1210
+
1211
+ with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
1212
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1213
+ else:
1214
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1215
+
1216
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1217
+ if old_num_tokens == new_num_tokens and not False:
1218
+ return old_embeddings
1219
+
1220
+ if not isinstance(old_embeddings, nn.Embedding):
1221
+ raise TypeError(
1222
+ f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
1223
+ " should either use a different resize function or make sure that `old_embeddings` are an instance of"
1224
+ f" {nn.Embedding}."
1225
+ )
1226
+
1227
+ # Build new embeddings
1228
+
1229
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1230
+ # because the shape of the new embedding layer is used across various modeling files
1231
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1232
+ # to errors when training.
1233
+ new_embeddings = nn.Embedding(
1234
+ new_num_tokens,
1235
+ old_embedding_dim,
1236
+ device=old_embeddings.weight.device,
1237
+ dtype=old_embeddings.weight.dtype,
1238
+ )
1239
+
1240
+ # initialize all new embeddings (in particular added tokens)
1241
+ self._init_weights(new_embeddings)
1242
+
1243
+ # Copy token embeddings from the previous weights
1244
+
1245
+ # numbers of tokens to copy
1246
+ n = min(old_num_tokens, new_num_tokens)
1247
+
1248
+ # if is_deepspeed_zero3_enabled():
1249
+ if False:
1250
+ import deepspeed
1251
+
1252
+ params = [old_embeddings.weight, new_embeddings.weight]
1253
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1254
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1255
+ else:
1256
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1257
+
1258
+ return new_embeddings
1259
+
1260
+
1261
+ def _get_resized_lm_head(
1262
+ self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
1263
+ ) -> nn.Linear:
1264
+ """
1265
+ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
1266
+ vectors at the end. Reducing the size will remove vectors from the end
1267
+
1268
+ Args:
1269
+ old_lm_head (`torch.nn.Linear`):
1270
+ Old lm head liner layer to be resized.
1271
+ new_num_tokens (`int`, *optional*):
1272
+ New number of tokens in the linear matrix.
1273
+
1274
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1275
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1276
+ `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
1277
+ to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
1278
+ vocab_size` else `vocab_size, lm_head_dim`.
1279
+
1280
+ Return:
1281
+ `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
1282
+ `None`
1283
+ """
1284
+ if new_num_tokens is None:
1285
+ return old_lm_head
1286
+
1287
+ # if is_deepspeed_zero3_enabled():
1288
+ if False:
1289
+ import deepspeed
1290
+
1291
+ with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
1292
+ old_num_tokens, old_lm_head_dim = (
1293
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1294
+ )
1295
+ else:
1296
+ old_num_tokens, old_lm_head_dim = (
1297
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1298
+ )
1299
+
1300
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1301
+ if old_num_tokens == new_num_tokens and not False:
1302
+ return old_lm_head
1303
+
1304
+ if not isinstance(old_lm_head, nn.Linear):
1305
+ raise TypeError(
1306
+ f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
1307
+ " should either use a different resize function or make sure that `old_lm_head` are an instance of"
1308
+ f" {nn.Linear}."
1309
+ )
1310
+
1311
+ # Build new lm head
1312
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
1313
+ has_new_lm_head_bias = old_lm_head.bias is not None
1314
+
1315
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1316
+ # because the shape of the new embedding layer is used across various modeling files
1317
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1318
+ # to errors when training.
1319
+ new_lm_head = nn.Linear(
1320
+ *new_lm_head_shape,
1321
+ bias=has_new_lm_head_bias,
1322
+ device=old_lm_head.weight.device,
1323
+ dtype=old_lm_head.weight.dtype,
1324
+ )
1325
+
1326
+ # initialize new lm head (in particular added tokens)
1327
+ self._init_weights(new_lm_head)
1328
+
1329
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
1330
+
1331
+ # if is_deepspeed_zero3_enabled():
1332
+ if False:
1333
+ import deepspeed
1334
+
1335
+ params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
1336
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1337
+ self._copy_lm_head_original_to_resized(
1338
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1339
+ )
1340
+ else:
1341
+ self._copy_lm_head_original_to_resized(
1342
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1343
+ )
1344
+
1345
+ return new_lm_head
1346
+
1347
+ def _copy_lm_head_original_to_resized(
1348
+ self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1349
+ ):
1350
+ # Copy old lm head weights to new lm head
1351
+ if not transposed:
1352
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
1353
+ else:
1354
+ new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
1355
+
1356
+ # Copy bias weights to new lm head
1357
+ if has_new_lm_head_bias:
1358
+ new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
1359
+
1360
+ @classmethod
1361
+ def from_name(cls, name: str) -> Self:
1362
+ return cls(LLaMAHFConfig.from_name(name))
1363
+
1364
+
1365
+ class Block(nn.Module):
1366
+ def __init__(self, config: LLaMAHFConfig) -> None:
1367
+ super().__init__()
1368
+ self.rms_1 = RMSNorm(config.n_embd)
1369
+
1370
+ # sentence level:
1371
+ self.attn = CausalSelfAttention(config)
1372
+ self.rms_2 = RMSNorm(config.n_embd)
1373
+ self.mlp = MLP(config)
1374
+
1375
+ def forward(self, x: torch.Tensor, last_past=None, use_cache=False, return_attention=False) -> torch.Tensor:
1376
+ if use_cache:
1377
+ if return_attention:
1378
+ a, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache)
1379
+ else:
1380
+ a, present = self.attn(self.rms_1(x), last_past, use_cache)
1381
+ x = x + a
1382
+ else:
1383
+ if return_attention:
1384
+ a, attn = self.attn.forward_attn(self.rms_1(x))
1385
+ else:
1386
+ a = self.attn(self.rms_1(x))
1387
+ x = x + a
1388
+ x = x + self.mlp(self.rms_2(x))
1389
+
1390
+ if use_cache:
1391
+ if return_attention:
1392
+ return x, present, attn
1393
+ else:
1394
+ return x, present
1395
+ else:
1396
+ if return_attention:
1397
+ return x, attn
1398
+ else:
1399
+ return x
1400
+
1401
+
1402
+ class CausalSelfAttention(nn.Module):
1403
+ def __init__(self, config: LLaMAHFConfig) -> None:
1404
+ super().__init__()
1405
+ assert config.n_embd % config.n_head == 0
1406
+
1407
+ # key, query, value projections for all heads, but in a batch
1408
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1409
+ # output projection
1410
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1411
+
1412
+ self.n_head = config.n_head
1413
+ self.n_embd = config.n_embd
1414
+ self.block_size = config.block_size
1415
+ self.rope_cache = None
1416
+
1417
+ def scaling_factor(sequence_threshold):
1418
+ return np.log2((sequence_threshold**2) - sequence_threshold)
1419
+ scale_init = scaling_factor(self.block_size)
1420
+ self.scale = nn.Parameter(torch.tensor(scale_init))
1421
+
1422
+ def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1423
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1424
+
1425
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1426
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1427
+
1428
+ head_size = C // self.n_head
1429
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1430
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1431
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1432
+
1433
+ # kv_cache
1434
+ if use_cache:
1435
+ if last_past is not None:
1436
+ past_key, past_value = last_past
1437
+ k = torch.cat([past_key, k], dim=-2)
1438
+ v = torch.cat([past_value, v], dim=-2)
1439
+ # else:
1440
+ # key_states = k
1441
+ # value_states = v
1442
+
1443
+ if use_cache:
1444
+ present = (k, v)
1445
+ else:
1446
+ present = None
1447
+
1448
+ # QK-Norm
1449
+ q = F.normalize(q, p=2, dim=-1)
1450
+ k = F.normalize(k, p=2, dim=-1)
1451
+
1452
+ if self.rope_cache is None:
1453
+ # cache for future forward calls
1454
+ self.rope_cache = build_rope_cache(
1455
+ seq_len=self.block_size,
1456
+ n_elem=self.n_embd // self.n_head,
1457
+ dtype=x.dtype,
1458
+ device=x.device,
1459
+ )
1460
+
1461
+
1462
+ q = apply_rope(q, self.rope_cache)
1463
+ k = apply_rope(k, self.rope_cache)
1464
+
1465
+
1466
+
1467
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
1468
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1469
+ # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
1470
+ # att = F.softmax(att, dim=-1)
1471
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
1472
+
1473
+ # efficient attention using Flash Attention CUDA kernels
1474
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=self.scale.item())
1475
+
1476
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1477
+
1478
+ # output projection
1479
+ y = self.c_proj(y)
1480
+
1481
+
1482
+ if use_cache:
1483
+ return y, present
1484
+ return y
1485
+
1486
+ def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1487
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1488
+
1489
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1490
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1491
+
1492
+ head_size = C // self.n_head
1493
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1494
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1495
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1496
+
1497
+ # kv_cache
1498
+ if use_cache:
1499
+ if last_past is not None:
1500
+ past_key, past_value = last_past
1501
+ k = torch.cat([past_key, k], dim=-2)
1502
+ v = torch.cat([past_value, v], dim=-2)
1503
+ # else:
1504
+ # key_states = k
1505
+ # value_states = v
1506
+
1507
+ if use_cache:
1508
+ present = (k, v)
1509
+ else:
1510
+ present = None
1511
+
1512
+ # QK-Norm
1513
+ q = F.normalize(q, p=2, dim=-1)
1514
+ k = F.normalize(k, p=2, dim=-1)
1515
+
1516
+ if self.rope_cache is None:
1517
+ # cache for future forward calls
1518
+ self.rope_cache = build_rope_cache(
1519
+ seq_len=self.block_size,
1520
+ n_elem=self.n_embd // self.n_head,
1521
+ dtype=x.dtype,
1522
+ device=x.device,
1523
+ )
1524
+
1525
+
1526
+ q = apply_rope(q, self.rope_cache)
1527
+ k = apply_rope(k, self.rope_cache)
1528
+
1529
+
1530
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1531
+ att = F.softmax(att, dim=-1) # [B, n_head, T, T]
1532
+
1533
+ # efficient attention using Flash Attention CUDA kernels
1534
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
1535
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1536
+
1537
+ # output projection
1538
+ y = self.c_proj(y)
1539
+
1540
+ return y, att
1541
+
1542
+ class LengthCausalSelfAttention(nn.Module):
1543
+ def __init__(self, config: LLaMAHFConfig) -> None:
1544
+ super().__init__()
1545
+ assert config.n_embd % config.n_head == 0
1546
+
1547
+ # key, query, value projections for all heads, but in a batch
1548
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1549
+ # output projection
1550
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1551
+
1552
+ self.n_head = config.n_head
1553
+ self.n_embd = config.n_embd
1554
+ self.block_size = config.block_size
1555
+ self.rope_cache = None
1556
+
1557
+ def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
1558
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1559
+
1560
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1561
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1562
+
1563
+ head_size = C // self.n_head
1564
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1565
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1566
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1567
+
1568
+ if self.rope_cache is None:
1569
+ # cache for future forward calls
1570
+ self.rope_cache = build_rope_cache(
1571
+ seq_len=self.block_size,
1572
+ n_elem=self.n_embd // self.n_head,
1573
+ dtype=x.dtype,
1574
+ device=x.device,
1575
+ )
1576
+
1577
+
1578
+ # q: 1, 16, 40 ,64
1579
+ # q: 128, 16, 106, 64
1580
+ q = apply_rope(q, self.rope_cache)
1581
+ k = apply_rope(k, self.rope_cache)
1582
+
1583
+ attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device)
1584
+ attn_mask = torch.tril(attn_mask)
1585
+ attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
1586
+
1587
+ text_mask = y_mask.unsqueeze(2)*y_mask.unsqueeze(1)
1588
+ text_mask = F.pad(text_mask, (0, T-y_mask.shape[1], 0, T-y_mask.shape[1]), mode='constant', value=0)
1589
+ attn_mask = torch.logical_or(attn_mask, text_mask)
1590
+
1591
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False)
1592
+
1593
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
1594
+
1595
+
1596
+ y = self.c_proj(y)
1597
+
1598
+ return y
1599
+
1600
+
1601
+ class MLP(nn.Module):
1602
+ def __init__(self, config: LLaMAHFConfig) -> None:
1603
+ super().__init__()
1604
+ hidden_dim = 4 * config.n_embd
1605
+ n_hidden = int(2 * hidden_dim / 3)
1606
+ N = 256
1607
+ # ensure n_hidden is multiple of N
1608
+ n_hidden = ((n_hidden - 1) // N) * N + N
1609
+
1610
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
1611
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
1612
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
1613
+
1614
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1615
+
1616
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
1617
+ x = self.c_proj(x)
1618
+ return x
1619
+
1620
+
1621
+ class RMSNorm(nn.Module):
1622
+ """Root Mean Square Layer Normalization.
1623
+
1624
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1625
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1626
+ """
1627
+
1628
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1629
+ super().__init__()
1630
+ self.scale = nn.Parameter(torch.ones(size))
1631
+ self.eps = eps
1632
+ self.dim = dim
1633
+
1634
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1635
+ # NOTE: the original RMSNorm paper implementation is not equivalent
1636
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
1637
+ # rms_x = norm_x * d_x ** (-1. / 2)
1638
+ # x_normed = x / (rms_x + self.eps)
1639
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1640
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
1641
+ return self.scale * x_normed
1642
+
1643
+
1644
+ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
1645
+ """Enhanced Transformer with Rotary Position Embedding.
1646
+
1647
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
1648
+ transformers/rope/__init__.py. MIT License:
1649
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
1650
+ """
1651
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
1652
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
1653
+
1654
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
1655
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
1656
+
1657
+ # Calculate the product of position index and $\theta_i$
1658
+ idx_theta = torch.outer(seq_idx, theta)
1659
+
1660
+ # Compute cache. Because polar only takes float32 or float64, we need to cast
1661
+ # when working with 16 bit floats (float16 or bfloat16)
1662
+ dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
1663
+ working_dtype = (
1664
+ torch.float32 if dtype in dtypes_requiring_casting else dtype
1665
+ )
1666
+ complex_dtype = (
1667
+ torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
1668
+ )
1669
+ cache = torch.polar(
1670
+ torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
1671
+ ).to(complex_dtype)
1672
+ return cache
1673
+
1674
+
1675
+ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
1676
+ x = x.transpose(1, 2)
1677
+
1678
+ # truncate to support variable sizes
1679
+ T = x.size(1)
1680
+ rope_cache = rope_cache[:T]
1681
+ # cast because `view_as_complex` does not support 16 bit tensors
1682
+ xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1683
+ rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
1684
+ x_out = torch.view_as_real(xc * rope_cache).flatten(3)
1685
+ return x_out.transpose(1, 2).type_as(x)