zirobtc commited on
Commit
8e1bf4b
·
verified ·
1 Parent(s): 8c087f9

Upload llama_model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. llama_model.py +1675 -0
llama_model.py ADDED
@@ -0,0 +1,1675 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import math
3
+ from dataclasses import dataclass
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn import functional as F
8
+ from typing_extensions import Self
9
+ from typing import Optional
10
+ from transformers.modeling_utils import PreTrainedModel
11
+ from torch.distributions import Categorical
12
+ import torch.nn.functional as F
13
+
14
+
15
+ @dataclass
16
+ class LLaMAHFConfig:
17
+ block_size: int = 78
18
+ n_layer: int = 32
19
+ n_head: int = 32
20
+ n_embd: int = 4096
21
+ T5_xxl_dim: int = 768
22
+
23
+ @classmethod
24
+ def from_name(cls, name: str) -> Self:
25
+ return cls(**llama_configs[name])
26
+
27
+
28
+ llama_configs = {
29
+ "Normal_size": dict(n_layer=12, n_head=12, n_embd=768)
30
+ }
31
+
32
+
33
+ class LLaMAHF(nn.Module):
34
+ def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=9, input_token_dim=16, device=torch.device('cuda'), width=1792) -> None:
35
+ super().__init__()
36
+ assert config.block_size is not None
37
+ self.config = config
38
+
39
+ cond_dim = config.T5_xxl_dim
40
+
41
+ self.transformer = nn.ModuleDict(
42
+ dict(
43
+ wte=nn.Linear(input_token_dim, config.n_embd),
44
+ cond_embed=nn.Linear(cond_dim, config.n_embd),
45
+ h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
46
+ ln_f=RMSNorm(config.n_embd),
47
+ )
48
+ )
49
+
50
+ target_channels = input_token_dim
51
+ from models.diffloss import DiffLoss
52
+ self.diff_loss = DiffLoss(
53
+ target_channels=target_channels,
54
+ z_channels=config.n_embd,
55
+ width=width,
56
+ depth=num_diffusion_head_layers,
57
+ num_sampling_steps='50',
58
+ grad_checkpointing=False,
59
+ )
60
+ self.diff_loss = self.diff_loss.to(device)
61
+ self.out_proj = nn.Linear(config.n_embd, config.n_embd)
62
+ self.use_out_proj = True
63
+
64
+
65
+ def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
66
+ """Tie or clone module weights depending of whether we are using TorchScript or not"""
67
+ output_embeddings.weight = input_embeddings.weight
68
+
69
+ if getattr(output_embeddings, "bias", None) is not None:
70
+ output_embeddings.bias.data = nn.functional.pad(
71
+ output_embeddings.bias.data,
72
+ (
73
+ 0,
74
+ output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
75
+ ),
76
+ "constant",
77
+ 0,
78
+ )
79
+ if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
80
+ output_embeddings.out_features = input_embeddings.num_embeddings
81
+
82
+ def get_input_embeddings(self):
83
+ return self.transformer.wte
84
+
85
+ def set_input_embeddings(self, value):
86
+ self.transformer.wte = value
87
+
88
+ def get_output_embeddings(self):
89
+ return self.lm_head
90
+
91
+ def set_output_embeddings(self, new_embeddings):
92
+ self.lm_head = new_embeddings
93
+
94
+ def _init_weights(self, module: nn.Module) -> None:
95
+ if isinstance(module, nn.Linear):
96
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
97
+ elif isinstance(module, nn.Embedding):
98
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
99
+
100
+
101
+
102
+ def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor:
103
+
104
+ text_length = clip_feature.shape[1]
105
+ if len(idx) == 0:
106
+ x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :]
107
+ else:
108
+ _, t = idx.size()
109
+ assert (
110
+ t <= self.config.block_size
111
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
112
+ # forward the LLaMA model itself
113
+ x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
114
+ x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1)
115
+
116
+ for block in self.transformer.h:
117
+ x = block(x, y_mask)
118
+ x = self.transformer.ln_f(x)
119
+ logits = x
120
+ return logits
121
+
122
+
123
+
124
+ def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0):
125
+ max_token_len = length // unit_length
126
+ for k in range(max_token_len):
127
+ if k == 0:
128
+ x = []
129
+ else:
130
+ x = xs
131
+
132
+ feat_text = torch.from_numpy(tokenize_model.encode(text)).float()
133
+ feat_text = feat_text.to(device)
134
+ conditions = self.forward(x, feat_text)
135
+ conditions = conditions[:, -1, :]
136
+
137
+ empty_text = ''
138
+ empty_feat_text = torch.from_numpy(tokenize_model.encode(empty_text)).float()
139
+ empty_feat_text = empty_feat_text.unsqueeze(0)
140
+ empty_feat_text = empty_feat_text.to(device)
141
+ empty_conditions = self.forward(x, empty_feat_text)
142
+ empty_conditions = empty_conditions[:, -1, :]
143
+ temperature = 1.0
144
+
145
+ # chunk
146
+ if cfg != 1:
147
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
148
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
149
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
150
+ else: # no cfg
151
+ scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1)
152
+
153
+ scaled_logits = scaled_logits.unsqueeze(0)
154
+
155
+ if k == 0:
156
+ xs = scaled_logits
157
+ else:
158
+ xs = torch.cat((xs, scaled_logits), dim=1)
159
+
160
+ return xs
161
+
162
+
163
+
164
+ # For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold.
165
+ def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0):
166
+ max_token_len = length // unit_length
167
+ feat_text = torch.from_numpy(tokenizer.encode(text)).float()
168
+ feat_text = feat_text.to(device)
169
+
170
+ # CFG inference
171
+ empty_text = ''
172
+ empty_feat_text = torch.from_numpy(tokenizer.encode(empty_text)).float() # torch.Size([32, 768])
173
+ empty_feat_text = empty_feat_text.unsqueeze(0)
174
+ empty_feat_text = empty_feat_text.to(device)
175
+
176
+ for k in range(max_token_len):
177
+ if k == 0:
178
+ x = []
179
+ else:
180
+ x = xs
181
+
182
+ conditions = self.forward_inference(x, feat_text)
183
+ conditions = conditions[:, -1, :]
184
+
185
+ empty_conditions = self.forward(x, empty_feat_text)
186
+ empty_conditions = empty_conditions[:, -1, :]
187
+
188
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
189
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
190
+
191
+ # chunk
192
+ if cfg != 1:
193
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
194
+ else:
195
+ scaled_logits = sampled_token_latent
196
+
197
+ scaled_logits = scaled_logits.unsqueeze(0)
198
+
199
+ if reference_end_latent is not None:
200
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2))
201
+ print(distance_l2)
202
+ if distance_l2 < threshold:
203
+ break
204
+
205
+ if k == 0:
206
+ xs = scaled_logits
207
+ else:
208
+ xs = torch.cat((xs, scaled_logits), dim=1)
209
+
210
+ return xs
211
+
212
+
213
+ def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
214
+
215
+ import clip
216
+ max_token_len = length // unit_length
217
+
218
+ for k in range(max_token_len):
219
+ if k == 0:
220
+ x = []
221
+ else:
222
+ x = xs
223
+
224
+ try:
225
+ conditions = self.forward(x, feat_clip_text)
226
+ except:
227
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
228
+
229
+
230
+ conditions = conditions[:, -1, :]
231
+
232
+
233
+
234
+ empty_conditions = self.forward(x, empty_feat_clip_text)
235
+ empty_conditions = empty_conditions[:, -1, :]
236
+
237
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
238
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
239
+
240
+ # chunk
241
+ if cfg != 1:
242
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
243
+ else:
244
+ scaled_logits = sampled_token_latent
245
+
246
+ scaled_logits = scaled_logits.unsqueeze(0)
247
+
248
+ if reference_end_token is not None:
249
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
250
+ print(distance_l2)
251
+ if distance_l2 < threshold:
252
+ break
253
+
254
+ if k == 0:
255
+ xs = scaled_logits
256
+ else:
257
+ xs = torch.cat((xs, scaled_logits), dim=1)
258
+
259
+ return xs
260
+
261
+ def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
262
+
263
+ import clip
264
+ max_token_len = length // unit_length
265
+
266
+
267
+ for k in range(1):
268
+
269
+ if current_token == []:
270
+ x = []
271
+ else:
272
+ x = torch.cat(current_token, dim=1)
273
+
274
+
275
+ try:
276
+ conditions = self.forward(x, feat_clip_text)
277
+ except:
278
+ conditions = self.forward(x, feat_clip_text.unsqueeze(0))
279
+
280
+
281
+ conditions = conditions[:, -1, :]
282
+
283
+
284
+ empty_conditions = self.forward(x, empty_feat_clip_text)
285
+ empty_conditions = empty_conditions[:, -1, :]
286
+
287
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
288
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
289
+
290
+ # chunk
291
+ if cfg != 1:
292
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
293
+ else:
294
+ scaled_logits = sampled_token_latent
295
+
296
+
297
+ scaled_logits = scaled_logits.unsqueeze(0)
298
+
299
+
300
+ if k == 0:
301
+ xs = scaled_logits
302
+ else:
303
+ xs = torch.cat((xs, scaled_logits), dim=1)
304
+
305
+ return xs
306
+
307
+
308
+ def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
309
+
310
+ import clip
311
+ B_token_length = length // unit_length - A_motion.shape[0]
312
+
313
+ if tokenizer == 'clip':
314
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
315
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
316
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
317
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
318
+ elif tokenizer == 't5-xxl':
319
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
320
+ A_feat_clip_text = A_feat_clip_text.to(device)
321
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
322
+ B_feat_clip_text = B_feat_clip_text.to(device)
323
+
324
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
325
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
326
+
327
+ A_motion = A_motion.unsqueeze(0)
328
+ A_motion_embeddings = self.transformer.wte(A_motion)
329
+ B_motion = torch.tensor([]).to(device)
330
+
331
+ for k in range(B_token_length):
332
+ if k == 0:
333
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
334
+ else:
335
+ x = xs
336
+
337
+
338
+ conditions = self.forward_babel_eval(x)
339
+ conditions = conditions[:, -1, :]
340
+
341
+ empty_clip_text = ''
342
+ if tokenizer == 'clip':
343
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
344
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
345
+ elif tokenizer == 't5-xxl':
346
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
347
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
348
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
349
+
350
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
351
+
352
+ if k == 0:
353
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
354
+ empty_conditions = self.forward_babel_eval(empty_input)
355
+ else:
356
+ B_motion_embeddings = self.transformer.wte(B_motion)
357
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
358
+ empty_conditions = self.forward_babel_eval(empty_input)
359
+
360
+ empty_conditions = empty_conditions[:, -1, :]
361
+ temperature = 1.0
362
+
363
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
364
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
365
+
366
+ # chunk
367
+ if cfg != 1:
368
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
369
+ else:
370
+ scaled_logits = sampled_token_latent
371
+
372
+
373
+ scaled_logits = scaled_logits.unsqueeze(0)
374
+
375
+
376
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
377
+
378
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
379
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
380
+
381
+
382
+ return xs, B_motion
383
+
384
+ def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
385
+
386
+ import clip
387
+ B_token_length = length // unit_length - A_motion.shape[0]
388
+
389
+ if tokenizer == 'clip':
390
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
391
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
392
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
393
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
394
+ elif tokenizer == 't5-xxl':
395
+ A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
396
+ A_feat_clip_text = A_feat_clip_text.to(device)
397
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
398
+ B_feat_clip_text = B_feat_clip_text.to(device)
399
+
400
+ A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
401
+ A_text_embeddings = A_text_embeddings.unsqueeze(0)
402
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
403
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
404
+
405
+ A_motion = A_motion.unsqueeze(0)
406
+ A_motion_embeddings = self.transformer.wte(A_motion)
407
+ B_motion = torch.tensor([]).to(device)
408
+
409
+ attention_weights = []
410
+
411
+ for k in range(B_token_length):
412
+ if k == 0:
413
+ x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
414
+
415
+ else:
416
+ x = xs
417
+
418
+
419
+
420
+ conditions = self.forward_babel_eval(x, return_attention=False)
421
+ conditions = conditions[:, -1, :]
422
+
423
+ empty_clip_text = ''
424
+ if tokenizer == 'clip':
425
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
426
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
427
+ elif tokenizer == 't5-xxl':
428
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
429
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
430
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
431
+
432
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
433
+
434
+ if k == 0:
435
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
436
+ empty_conditions = self.forward_babel_eval(empty_input)
437
+ else:
438
+ B_motion_embeddings = self.transformer.wte(B_motion)
439
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
440
+ empty_conditions = self.forward_babel_eval(empty_input)
441
+
442
+ empty_conditions = empty_conditions[:, -1, :]
443
+ temperature = 1.0
444
+
445
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
446
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
447
+
448
+ # chunk
449
+ if cfg != 1:
450
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
451
+ else:
452
+ scaled_logits = sampled_token_latent
453
+
454
+ scaled_logits = scaled_logits.unsqueeze(0)
455
+
456
+ if reference_end_token is not None:
457
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
458
+ print(distance_l2)
459
+ if distance_l2 < threshold:
460
+ break
461
+
462
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
463
+
464
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
465
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
466
+
467
+
468
+
469
+ return xs, B_motion
470
+
471
+
472
+ def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3):
473
+
474
+ import clip
475
+ B_token_length = length // unit_length
476
+
477
+ if tokenizer == 'clip':
478
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
479
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
480
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
481
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
482
+ elif tokenizer == 't5-xxl':
483
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
484
+ B_feat_clip_text = B_feat_clip_text.to(device)
485
+
486
+ empty_clip_text = ''
487
+ if tokenizer == 'clip':
488
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
489
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
490
+ elif tokenizer == 't5-xxl':
491
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
492
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
493
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
494
+
495
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
496
+
497
+ A_motion = A_motion.unsqueeze(0)
498
+ A_motion_embeddings = self.transformer.wte(A_motion)
499
+ B_motion = torch.tensor([]).to(device)
500
+
501
+
502
+ attention_weights = []
503
+
504
+ for k in range(B_token_length):
505
+ if k == 0:
506
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
507
+ else:
508
+ x = xs
509
+
510
+ conditions = self.forward_babel_eval(x, return_attention=False)
511
+ conditions = conditions[:, -1, :]
512
+
513
+
514
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
515
+
516
+ if k == 0:
517
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
518
+
519
+ empty_conditions = self.forward_babel_eval(empty_input)
520
+ else:
521
+ B_motion_embeddings = self.transformer.wte(B_motion)
522
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
523
+ empty_conditions = self.forward_babel_eval(empty_input)
524
+
525
+ empty_conditions = empty_conditions[:, -1, :]
526
+ temperature = 1.0
527
+
528
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
529
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
530
+
531
+ # chunk
532
+ if cfg != 1:
533
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
534
+ else:
535
+ scaled_logits = sampled_token_latent
536
+
537
+ scaled_logits = scaled_logits.unsqueeze(0)
538
+
539
+ if reference_end_token is not None:
540
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
541
+ print(distance_l2)
542
+ if distance_l2 < threshold:
543
+ break
544
+
545
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
546
+
547
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
548
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
549
+
550
+
551
+
552
+ return xs, B_motion
553
+
554
+
555
+ def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
556
+
557
+ import clip
558
+ B_token_length = length // unit_length - A_motion.shape[0]
559
+
560
+ if tokenizer == 'clip':
561
+ A_text = clip.tokenize(A_text, truncate=True).to(device)
562
+ A_feat_clip_text = clip_model.encode_text(A_text).float()
563
+ B_text = clip.tokenize(B_text, truncate=True).to(device)
564
+ B_feat_clip_text = clip_model.encode_text(B_text).float()
565
+ elif tokenizer == 't5-xxl':
566
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
567
+ B_feat_clip_text = B_feat_clip_text.to(device)
568
+
569
+ empty_clip_text = ''
570
+ if tokenizer == 'clip':
571
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
572
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
573
+ elif tokenizer == 't5-xxl':
574
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
575
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
576
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
577
+
578
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
579
+ B_text_embeddings = B_text_embeddings.unsqueeze(0)
580
+
581
+ A_motion = A_motion.unsqueeze(0)
582
+ A_motion_embeddings = self.transformer.wte(A_motion)
583
+ B_motion = torch.tensor([]).to(device)
584
+
585
+ # 存储所有层的注意力权重
586
+ attention_weights = []
587
+
588
+ for k in range(B_token_length):
589
+ if k == 0:
590
+ x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
591
+
592
+ else:
593
+ x = xs
594
+
595
+
596
+ conditions = self.forward_babel_eval(x, return_attention=False)
597
+ conditions = conditions[:, -1, :]
598
+
599
+
600
+ empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
601
+
602
+ if k == 0:
603
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
604
+ empty_conditions = self.forward_babel_eval(empty_input)
605
+ else:
606
+ B_motion_embeddings = self.transformer.wte(B_motion)
607
+ empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
608
+ empty_conditions = self.forward_babel_eval(empty_input)
609
+
610
+ empty_conditions = empty_conditions[:, -1, :]
611
+
612
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
613
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
614
+
615
+ # chunk
616
+ if cfg != 1:
617
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
618
+ else:
619
+ scaled_logits = sampled_token_latent
620
+
621
+ scaled_logits = scaled_logits.unsqueeze(0)
622
+
623
+ if reference_end_token is not None:
624
+ distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
625
+ print(distance_l2)
626
+ if distance_l2 < threshold and k > 10:
627
+ break
628
+
629
+ B_motion = torch.cat((B_motion, scaled_logits), dim=1)
630
+
631
+ scaled_logits_embedding = self.transformer.wte(scaled_logits)
632
+ xs = torch.cat((x, scaled_logits_embedding), dim=1)
633
+
634
+
635
+
636
+ return xs, B_motion
637
+
638
+ def sample_for_eval_CFG_babel_inference_two_forward(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
639
+ """
640
+ Inference loop that mimics the "Two-Forward" training strategy.
641
+ This version correctly performs two full passes over the entire sequence.
642
+ """
643
+ import clip
644
+ B_token_length = length // unit_length - A_motion.shape[0]
645
+
646
+ if tokenizer == 't5-xxl':
647
+ B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float().to(device)
648
+ else:
649
+ raise NotImplementedError("Only t5-xxl is supported for this function.")
650
+ empty_feat_clip_text = torch.from_numpy(clip_model.encode('')).float().unsqueeze(0).to(device)
651
+
652
+ # --- Create 3D embeddings [batch, seq, dim] ---
653
+ B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0).unsqueeze(0)
654
+ empty_text_embeddings = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0) # This is [1, 1, 768]
655
+
656
+ A_motion_embeddings = self.transformer.wte(A_motion.unsqueeze(0))
657
+
658
+ # === 1. First Forward Pass (Generate Rough Draft) ===
659
+ rough_motion_tokens = A_motion
660
+ for k in range(B_token_length):
661
+ current_rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
662
+
663
+ # Conditioned
664
+ x_cond = torch.cat([B_text_embeddings, current_rough_embeddings], dim=1)
665
+ conditions = self.forward_babel_eval(x_cond, return_attention=False)[:, -1, :]
666
+
667
+ # Unconditioned
668
+ x_uncond = torch.cat([empty_text_embeddings, current_rough_embeddings], dim=1)
669
+ empty_conditions = self.forward_babel_eval(x_uncond, return_attention=False)[:, -1, :]
670
+
671
+ # Sample a rough prediction for the next token
672
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
673
+ pred_xstart_rough = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
674
+ if cfg != 1:
675
+ pred_xstart_rough, _ = pred_xstart_rough.chunk(2, dim=0)
676
+
677
+ rough_motion_tokens = torch.cat([rough_motion_tokens, pred_xstart_rough], dim=0)
678
+
679
+ # === 2. Second Forward Pass (Generate Refined Motion) ===
680
+ # Now we have the full rough draft. We use it as the input for the second pass.
681
+ refined_motion_tokens = A_motion
682
+ for k in range(B_token_length):
683
+ # The input to the transformer is the full rough sequence
684
+ rough_embeddings = self.transformer.wte(rough_motion_tokens.unsqueeze(0))
685
+
686
+ # Conditioned
687
+ x_cond_refined = torch.cat([B_text_embeddings, rough_embeddings], dim=1)
688
+ # We take the condition corresponding to the token we want to predict
689
+ conditions_refined = self.forward_babel_eval(x_cond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
690
+
691
+ # Unconditioned
692
+ x_uncond_refined = torch.cat([empty_text_embeddings, rough_embeddings], dim=1)
693
+ empty_conditions_refined = self.forward_babel_eval(x_uncond_refined, return_attention=False)[:, A_motion.shape[0] + k, :]
694
+
695
+ # Sample the final, refined token
696
+ mix_conditions_refined = torch.cat([conditions_refined, empty_conditions_refined], dim=0)
697
+ final_token, _ = self.diff_loss.sample(mix_conditions_refined, temperature=temperature, cfg=cfg).chunk(2, dim=0)
698
+
699
+ # Append the refined token to our final output history
700
+ refined_motion_tokens = torch.cat([refined_motion_tokens, final_token], dim=0)
701
+
702
+ # IMPORTANT: For the next step, we must update the "rough draft" with our new refined token
703
+ # This mimics the training where the input is a mix of GT and predictions.
704
+ # Here, it's a mix of the initial rough draft and the new refined tokens.
705
+ rough_motion_tokens[A_motion.shape[0] + k] = final_token.squeeze(0)
706
+
707
+ # Return only the newly generated tokens (B_motion)
708
+ B_motion = refined_motion_tokens[A_motion.shape[0]:, :].unsqueeze(0)
709
+ return None, B_motion
710
+
711
+
712
+ #--------------Test classification head--------------------
713
+ def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
714
+
715
+ import clip
716
+
717
+
718
+ for k in range(51):
719
+ if k == 0:
720
+ x = []
721
+ else:
722
+ x = xs
723
+
724
+ if tokenizer == 'clip':
725
+ text = clip.tokenize(clip_text, truncate=True).to(device)
726
+
727
+ feat_clip_text = clip_model.encode_text(text).float()
728
+ elif tokenizer == 't5-xxl':
729
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
730
+
731
+ conditions = self.forward(x, feat_clip_text)
732
+ conditions = conditions[:, -1, :]
733
+
734
+ empty_clip_text = ''
735
+ if tokenizer == 'clip':
736
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
737
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
738
+ elif tokenizer == 't5-xxl':
739
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
740
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
741
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
742
+
743
+ empty_conditions = self.forward(x, empty_feat_clip_text)
744
+ empty_conditions = empty_conditions[:, -1, :]
745
+
746
+ temperature = 1.0
747
+ cfg = 7.5
748
+
749
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
750
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
751
+
752
+ # chunk
753
+ if cfg != 1:
754
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
755
+ else:
756
+ scaled_logits = sampled_token_latent
757
+
758
+
759
+ prediction_logits = self.classify_head(conditions)
760
+ probs = torch.sigmoid(prediction_logits)
761
+ predicted_classes = torch.argmax(probs, dim=-1)
762
+
763
+
764
+ scaled_logits = scaled_logits.unsqueeze(0)
765
+
766
+ if k == 0:
767
+ xs = scaled_logits
768
+ else:
769
+ xs = torch.cat((xs, scaled_logits), dim=1)
770
+
771
+ if predicted_classes == 1:
772
+ break
773
+
774
+ return xs
775
+
776
+
777
+ #--------------------Test CFG-----------------------
778
+ def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
779
+
780
+ import clip
781
+ max_token_len = length // unit_length
782
+
783
+
784
+ for k in range(max_token_len):
785
+ if k == 0:
786
+ x = []
787
+ else:
788
+ x = xs
789
+
790
+
791
+ if cfg != 1:
792
+ if tokenizer == 'clip':
793
+ text = clip.tokenize(clip_text, truncate=True).to(device)
794
+
795
+ feat_clip_text = clip_model.encode_text(text).float()
796
+ elif tokenizer == 't5-xxl':
797
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
798
+
799
+ conditions = self.forward(x, feat_clip_text)
800
+
801
+ conditions = conditions[:, -1, :]
802
+ empty_clip_text = ''
803
+ if tokenizer == 'clip':
804
+ empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
805
+ empty_feat_clip_text = clip_model.encode_text(empty_text).float()
806
+ elif tokenizer == 't5-xxl':
807
+ empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
808
+ empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
809
+ empty_feat_clip_text = empty_feat_clip_text.to(device)
810
+
811
+ empty_conditions = self.forward(x, empty_feat_clip_text)
812
+ empty_conditions = empty_conditions[:, -1, :]
813
+ temperature = 1.0
814
+
815
+
816
+ mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
817
+ sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
818
+
819
+ # chunk
820
+ scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
821
+
822
+ else:
823
+ if tokenizer == 'clip':
824
+ text = clip.tokenize(clip_text, truncate=True).to(device)
825
+ feat_clip_text = clip_model.encode_text(text).float()
826
+ elif tokenizer == 't5-xxl':
827
+ feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
828
+ feat_clip_text = feat_clip_text.to(device)
829
+
830
+
831
+ conditions = self.forward(x, feat_clip_text)
832
+
833
+ conditions = conditions[:, -1, :]
834
+ temperature = 1.0
835
+ sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg)
836
+ scaled_logits = sampled_token_latent
837
+
838
+ scaled_logits = scaled_logits.unsqueeze(0)
839
+
840
+ if k == 0:
841
+ xs = scaled_logits
842
+ else:
843
+ xs = torch.cat((xs, scaled_logits), dim=1)
844
+
845
+ return xs
846
+ #--------------------------------------------------
847
+
848
+ def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor:
849
+ if len(idx) == 0:
850
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(0)
851
+
852
+ else:
853
+ b, t = idx.size()
854
+ #idx = idx.float()
855
+ assert (
856
+ t <= self.config.block_size
857
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
858
+
859
+ # forward the LLaMA model itself
860
+ token_embeddings = self.transformer.wte(idx)
861
+ text_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
862
+ token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
863
+
864
+ x = token_embeddings
865
+
866
+ # -------------------kv cache-------------------
867
+ #presents = () if use_cache else None
868
+ if use_cache:
869
+ if past_key_values is None:
870
+ past_key_values = [None] * len(self.transformer.h)
871
+
872
+
873
+ for i,block in enumerate(self.transformer.h):
874
+ if use_cache:
875
+ last_past = past_key_values[i]
876
+ x, presents = block(x, last_past, use_cache)
877
+ past_key_values[i] = list(presents)
878
+ else:
879
+ x = block(x)
880
+ x = self.transformer.ln_f(x)
881
+
882
+ logits = self.lm_head(x)
883
+
884
+
885
+ return logits
886
+
887
+
888
+ def forward(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
889
+ if len(idx) == 0:
890
+ token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
891
+
892
+ else:
893
+ b, t, c = idx.size()
894
+ idx = idx.float()
895
+ assert (
896
+ t <= self.config.block_size
897
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
898
+
899
+ # forward the LLaMA model itself
900
+ token_embeddings = self.transformer.wte(idx)
901
+ text_embeddings = self.transformer.cond_embed(feature).unsqueeze(1)
902
+ token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
903
+
904
+ x = token_embeddings
905
+
906
+ for i,block in enumerate(self.transformer.h):
907
+ x = block(x)
908
+ x = self.transformer.ln_f(x)
909
+ logits = self.out_proj(x)
910
+ return logits
911
+
912
+
913
+ def forward_inference(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
914
+ if len(idx) == 0:
915
+ token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
916
+
917
+ else:
918
+ b, t, c = idx.size()
919
+ idx = idx.float()
920
+ assert (
921
+ t <= self.config.block_size
922
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
923
+
924
+ # forward the LLaMA model itself
925
+ token_embeddings = self.transformer.wte(idx)
926
+ text_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
927
+ token_embeddings = torch.cat([text_embeddings.unsqueeze(0), token_embeddings], dim=1)
928
+
929
+ x = token_embeddings
930
+
931
+ if len(x.shape) == 2:
932
+ x = x.unsqueeze(0)
933
+
934
+ for i,block in enumerate(self.transformer.h):
935
+ x = block(x)
936
+ x = self.transformer.ln_f(x)
937
+ logits = self.out_proj(x)
938
+ return logits
939
+
940
+
941
+ def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor:
942
+
943
+ b, t, c = idx.size()
944
+ idx = idx.float()
945
+ idx = self.transformer.wte(idx)
946
+ assert (
947
+ t <= self.config.block_size
948
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
949
+ for i in range(b):
950
+ length_i = length[i][:num_subseq[i]]
951
+ clip_feature_i = clip_feature[i][:num_subseq[i]]
952
+
953
+ pointer = 0
954
+ for j in range(num_subseq[i]):
955
+ if j > 0:
956
+ pointer += length_i[j].item()
957
+ pointer += 1
958
+ pointer = int(pointer)
959
+
960
+ clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1)
961
+ idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0]
962
+
963
+ x = idx
964
+
965
+
966
+ if use_cache:
967
+ if past_key_values is None:
968
+ past_key_values = [None] * len(self.transformer.h)
969
+
970
+
971
+ for i,block in enumerate(self.transformer.h):
972
+ if use_cache:
973
+ last_past = past_key_values[i]
974
+ x, presents = block(x, last_past, use_cache)
975
+ past_key_values[i] = list(presents)
976
+ else:
977
+ x = block(x)
978
+ x = self.transformer.ln_f(x)
979
+
980
+ logits = self.out_proj(x)
981
+ return logits
982
+
983
+
984
+ def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor:
985
+ layer_attentions = []
986
+ for block in self.transformer.h:
987
+ if return_attention:
988
+ x, att = block(x, return_attention=True)
989
+ layer_attentions.append(att)
990
+ else:
991
+ x = block(x)
992
+
993
+ x = self.transformer.ln_f(x)
994
+ if self.use_out_proj:
995
+ logits = self.out_proj(x)
996
+ else:
997
+ logits = x
998
+
999
+ if return_attention:
1000
+ return logits, layer_attentions
1001
+ return logits
1002
+
1003
+ def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor:
1004
+ if len(idx) == 0: # inference
1005
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
1006
+
1007
+ else:
1008
+ b, t, c = idx.size()
1009
+ idx = idx.float()
1010
+ assert (
1011
+ t <= self.config.block_size
1012
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1013
+
1014
+
1015
+
1016
+ A_feature = clip_feature[:, 0, :]
1017
+ B_feature = clip_feature[:, 1, :]
1018
+
1019
+
1020
+ A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1)
1021
+ B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1)
1022
+
1023
+ token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device)
1024
+ for i in range(b):
1025
+ A_idx = idx[i, :A_token_length[i].item(), :]
1026
+ B_idx = idx[i, A_token_length[i].item():-2, :]
1027
+ token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024)
1028
+
1029
+ x = token_embeddings
1030
+ for block in self.transformer.h:
1031
+ x = block(x)
1032
+ x = self.transformer.ln_f(x)
1033
+
1034
+ if self.use_out_proj:
1035
+ logits = self.out_proj(x)
1036
+ else:
1037
+ logits = x
1038
+
1039
+
1040
+ return logits
1041
+
1042
+ def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor:
1043
+ if len(idx) == 0: # inference
1044
+ token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
1045
+
1046
+ else:
1047
+ b, t, c = idx.size()
1048
+ idx = idx.float()
1049
+ assert (
1050
+ t <= self.config.block_size
1051
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
1052
+
1053
+ B_feature = clip_feature
1054
+ B_text_embeddings = self.transformer.cond_embed(B_feature)
1055
+
1056
+ idx_embeddings = self.transformer.wte(idx)
1057
+
1058
+
1059
+ token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1)
1060
+
1061
+
1062
+ x = token_embeddings
1063
+ for block in self.transformer.h:
1064
+ x = block(x)
1065
+ x = self.transformer.ln_f(x)
1066
+
1067
+ if self.use_out_proj:
1068
+ logits = self.out_proj(x)
1069
+ else:
1070
+ logits = x
1071
+
1072
+ return logits
1073
+
1074
+
1075
+ def resize_token_embeddings(
1076
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False
1077
+ ) -> nn.Embedding:
1078
+ """
1079
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1080
+
1081
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1082
+
1083
+ Arguments:
1084
+ new_num_tokens (`int`, *optional*):
1085
+ The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
1086
+ vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
1087
+ returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
1088
+ pad_to_multiple_of (`int`, *optional*):
1089
+ If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
1090
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1091
+
1092
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1093
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1094
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1095
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1096
+
1097
+ Return:
1098
+ `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
1099
+ """
1100
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1101
+ if new_num_tokens is None and pad_to_multiple_of is None:
1102
+ return model_embeds
1103
+
1104
+ # Update base model and current model config
1105
+ self.config.vocab_size = model_embeds.weight.shape[0]
1106
+ self.vocab_size = model_embeds.weight.shape[0]
1107
+
1108
+ # Tie weights again if needed
1109
+ # self.tie_weights()
1110
+
1111
+ return model_embeds
1112
+
1113
+ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
1114
+ old_embeddings = self.get_input_embeddings()
1115
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
1116
+ old_embeddings_requires_grad = old_embeddings.weight.requires_grad
1117
+ new_embeddings.requires_grad_(old_embeddings_requires_grad)
1118
+ self.set_input_embeddings(new_embeddings)
1119
+
1120
+ # Update new_num_tokens with the actual size of new_embeddings
1121
+ if pad_to_multiple_of is not None:
1122
+ # if is_deepspeed_zero3_enabled():
1123
+ # import deepspeed
1124
+
1125
+ # with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
1126
+ # new_num_tokens = new_embeddings.weight.shape[0]
1127
+ # else:
1128
+ new_num_tokens = new_embeddings.weight.shape[0]
1129
+
1130
+ # if word embeddings are not tied, make sure that lm head is resized as well
1131
+ # if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
1132
+ if self.get_output_embeddings() is not None and not False:
1133
+ old_lm_head = self.get_output_embeddings()
1134
+ new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
1135
+ # if hasattr(old_lm_head, "_hf_hook"):
1136
+ # hook = old_lm_head._hf_hook
1137
+ # add_hook_to_module(new_lm_head, hook)
1138
+ old_lm_head_requires_grad = old_lm_head.weight.requires_grad
1139
+ new_lm_head.requires_grad_(old_lm_head_requires_grad)
1140
+ self.set_output_embeddings(new_lm_head)
1141
+
1142
+ return self.get_input_embeddings()
1143
+
1144
+ def _get_resized_embeddings(
1145
+ self,
1146
+ old_embeddings: nn.Embedding,
1147
+ new_num_tokens: Optional[int] = None,
1148
+ pad_to_multiple_of: Optional[int] = None,
1149
+ ) -> nn.Embedding:
1150
+ """
1151
+ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
1152
+ initialized vectors at the end. Reducing the size will remove vectors from the end
1153
+
1154
+ Args:
1155
+ old_embeddings (`torch.nn.Embedding`):
1156
+ Old embeddings to be resized.
1157
+ new_num_tokens (`int`, *optional*):
1158
+ New number of tokens in the embedding matrix.
1159
+
1160
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1161
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1162
+ `torch.nn.Embedding` module of the model without doing anything.
1163
+ pad_to_multiple_of (`int`, *optional*):
1164
+ If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
1165
+ `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1166
+
1167
+ This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1168
+ `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1169
+ details about this, or help on choosing the correct value for resizing, refer to this guide:
1170
+ https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1171
+
1172
+
1173
+ Return:
1174
+ `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
1175
+ `new_num_tokens` is `None`
1176
+ """
1177
+
1178
+ if pad_to_multiple_of is not None:
1179
+ if not isinstance(pad_to_multiple_of, int):
1180
+ raise ValueError(
1181
+ f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
1182
+ )
1183
+ if new_num_tokens is None:
1184
+ new_num_tokens = old_embeddings.weight.shape[0]
1185
+ new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
1186
+ else:
1187
+ print(
1188
+ "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
1189
+ f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
1190
+ " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
1191
+ " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
1192
+ )
1193
+
1194
+ if new_num_tokens is None:
1195
+ return old_embeddings
1196
+
1197
+ # if is_deepspeed_zero3_enabled():
1198
+ if False:
1199
+ import deepspeed
1200
+
1201
+ with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
1202
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1203
+ else:
1204
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1205
+
1206
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1207
+ if old_num_tokens == new_num_tokens and not False:
1208
+ return old_embeddings
1209
+
1210
+ if not isinstance(old_embeddings, nn.Embedding):
1211
+ raise TypeError(
1212
+ f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
1213
+ " should either use a different resize function or make sure that `old_embeddings` are an instance of"
1214
+ f" {nn.Embedding}."
1215
+ )
1216
+
1217
+ # Build new embeddings
1218
+
1219
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1220
+ # because the shape of the new embedding layer is used across various modeling files
1221
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1222
+ # to errors when training.
1223
+ new_embeddings = nn.Embedding(
1224
+ new_num_tokens,
1225
+ old_embedding_dim,
1226
+ device=old_embeddings.weight.device,
1227
+ dtype=old_embeddings.weight.dtype,
1228
+ )
1229
+
1230
+ # initialize all new embeddings (in particular added tokens)
1231
+ self._init_weights(new_embeddings)
1232
+
1233
+ # Copy token embeddings from the previous weights
1234
+
1235
+ # numbers of tokens to copy
1236
+ n = min(old_num_tokens, new_num_tokens)
1237
+
1238
+ # if is_deepspeed_zero3_enabled():
1239
+ if False:
1240
+ import deepspeed
1241
+
1242
+ params = [old_embeddings.weight, new_embeddings.weight]
1243
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1244
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1245
+ else:
1246
+ new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1247
+
1248
+ return new_embeddings
1249
+
1250
+
1251
+ def _get_resized_lm_head(
1252
+ self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
1253
+ ) -> nn.Linear:
1254
+ """
1255
+ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
1256
+ vectors at the end. Reducing the size will remove vectors from the end
1257
+
1258
+ Args:
1259
+ old_lm_head (`torch.nn.Linear`):
1260
+ Old lm head liner layer to be resized.
1261
+ new_num_tokens (`int`, *optional*):
1262
+ New number of tokens in the linear matrix.
1263
+
1264
+ Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1265
+ vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1266
+ `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
1267
+ to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
1268
+ vocab_size` else `vocab_size, lm_head_dim`.
1269
+
1270
+ Return:
1271
+ `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
1272
+ `None`
1273
+ """
1274
+ if new_num_tokens is None:
1275
+ return old_lm_head
1276
+
1277
+ # if is_deepspeed_zero3_enabled():
1278
+ if False:
1279
+ import deepspeed
1280
+
1281
+ with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
1282
+ old_num_tokens, old_lm_head_dim = (
1283
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1284
+ )
1285
+ else:
1286
+ old_num_tokens, old_lm_head_dim = (
1287
+ old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1288
+ )
1289
+
1290
+ # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1291
+ if old_num_tokens == new_num_tokens and not False:
1292
+ return old_lm_head
1293
+
1294
+ if not isinstance(old_lm_head, nn.Linear):
1295
+ raise TypeError(
1296
+ f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
1297
+ " should either use a different resize function or make sure that `old_lm_head` are an instance of"
1298
+ f" {nn.Linear}."
1299
+ )
1300
+
1301
+ # Build new lm head
1302
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
1303
+ has_new_lm_head_bias = old_lm_head.bias is not None
1304
+
1305
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1306
+ # because the shape of the new embedding layer is used across various modeling files
1307
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1308
+ # to errors when training.
1309
+ new_lm_head = nn.Linear(
1310
+ *new_lm_head_shape,
1311
+ bias=has_new_lm_head_bias,
1312
+ device=old_lm_head.weight.device,
1313
+ dtype=old_lm_head.weight.dtype,
1314
+ )
1315
+
1316
+ # initialize new lm head (in particular added tokens)
1317
+ self._init_weights(new_lm_head)
1318
+
1319
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
1320
+
1321
+ # if is_deepspeed_zero3_enabled():
1322
+ if False:
1323
+ import deepspeed
1324
+
1325
+ params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
1326
+ with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1327
+ self._copy_lm_head_original_to_resized(
1328
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1329
+ )
1330
+ else:
1331
+ self._copy_lm_head_original_to_resized(
1332
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1333
+ )
1334
+
1335
+ return new_lm_head
1336
+
1337
+ def _copy_lm_head_original_to_resized(
1338
+ self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1339
+ ):
1340
+ # Copy old lm head weights to new lm head
1341
+ if not transposed:
1342
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
1343
+ else:
1344
+ new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
1345
+
1346
+ # Copy bias weights to new lm head
1347
+ if has_new_lm_head_bias:
1348
+ new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
1349
+
1350
+ @classmethod
1351
+ def from_name(cls, name: str) -> Self:
1352
+ return cls(LLaMAHFConfig.from_name(name))
1353
+
1354
+
1355
+ class Block(nn.Module):
1356
+ def __init__(self, config: LLaMAHFConfig) -> None:
1357
+ super().__init__()
1358
+ self.rms_1 = RMSNorm(config.n_embd)
1359
+
1360
+ # sentence level:
1361
+ self.attn = CausalSelfAttention(config)
1362
+ self.rms_2 = RMSNorm(config.n_embd)
1363
+ self.mlp = MLP(config)
1364
+
1365
+ def forward(self, x: torch.Tensor, last_past=None, use_cache=False, return_attention=False) -> torch.Tensor:
1366
+ if use_cache:
1367
+ if return_attention:
1368
+ a, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache)
1369
+ else:
1370
+ a, present = self.attn(self.rms_1(x), last_past, use_cache)
1371
+ x = x + a
1372
+ else:
1373
+ if return_attention:
1374
+ a, attn = self.attn.forward_attn(self.rms_1(x))
1375
+ else:
1376
+ a = self.attn(self.rms_1(x))
1377
+ x = x + a
1378
+ x = x + self.mlp(self.rms_2(x))
1379
+
1380
+ if use_cache:
1381
+ if return_attention:
1382
+ return x, present, attn
1383
+ else:
1384
+ return x, present
1385
+ else:
1386
+ if return_attention:
1387
+ return x, attn
1388
+ else:
1389
+ return x
1390
+
1391
+
1392
+ class CausalSelfAttention(nn.Module):
1393
+ def __init__(self, config: LLaMAHFConfig) -> None:
1394
+ super().__init__()
1395
+ assert config.n_embd % config.n_head == 0
1396
+
1397
+ # key, query, value projections for all heads, but in a batch
1398
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1399
+ # output projection
1400
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1401
+
1402
+ self.n_head = config.n_head
1403
+ self.n_embd = config.n_embd
1404
+ self.block_size = config.block_size
1405
+ self.rope_cache = None
1406
+
1407
+ def scaling_factor(sequence_threshold):
1408
+ return np.log2((sequence_threshold**2) - sequence_threshold)
1409
+ scale_init = scaling_factor(self.block_size)
1410
+ self.scale = nn.Parameter(torch.tensor(scale_init))
1411
+
1412
+ def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1413
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1414
+
1415
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1416
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1417
+
1418
+ head_size = C // self.n_head
1419
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1420
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1421
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1422
+
1423
+ # kv_cache
1424
+ if use_cache:
1425
+ if last_past is not None:
1426
+ past_key, past_value = last_past
1427
+ k = torch.cat([past_key, k], dim=-2)
1428
+ v = torch.cat([past_value, v], dim=-2)
1429
+ # else:
1430
+ # key_states = k
1431
+ # value_states = v
1432
+
1433
+ if use_cache:
1434
+ present = (k, v)
1435
+ else:
1436
+ present = None
1437
+
1438
+ # QK-Norm
1439
+ q = F.normalize(q, p=2, dim=-1)
1440
+ k = F.normalize(k, p=2, dim=-1)
1441
+
1442
+ if self.rope_cache is None:
1443
+ # cache for future forward calls
1444
+ self.rope_cache = build_rope_cache(
1445
+ seq_len=self.block_size,
1446
+ n_elem=self.n_embd // self.n_head,
1447
+ dtype=x.dtype,
1448
+ device=x.device,
1449
+ )
1450
+
1451
+
1452
+ q = apply_rope(q, self.rope_cache)
1453
+ k = apply_rope(k, self.rope_cache)
1454
+
1455
+
1456
+
1457
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
1458
+ # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1459
+ # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
1460
+ # att = F.softmax(att, dim=-1)
1461
+ # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
1462
+
1463
+ # efficient attention using Flash Attention CUDA kernels
1464
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=self.scale.item())
1465
+
1466
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1467
+
1468
+ # output projection
1469
+ y = self.c_proj(y)
1470
+
1471
+
1472
+ if use_cache:
1473
+ return y, present
1474
+ return y
1475
+
1476
+ def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1477
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1478
+
1479
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1480
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1481
+
1482
+ head_size = C // self.n_head
1483
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1484
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1485
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1486
+
1487
+ # kv_cache
1488
+ if use_cache:
1489
+ if last_past is not None:
1490
+ past_key, past_value = last_past
1491
+ k = torch.cat([past_key, k], dim=-2)
1492
+ v = torch.cat([past_value, v], dim=-2)
1493
+ # else:
1494
+ # key_states = k
1495
+ # value_states = v
1496
+
1497
+ if use_cache:
1498
+ present = (k, v)
1499
+ else:
1500
+ present = None
1501
+
1502
+ # QK-Norm
1503
+ q = F.normalize(q, p=2, dim=-1)
1504
+ k = F.normalize(k, p=2, dim=-1)
1505
+
1506
+ if self.rope_cache is None:
1507
+ # cache for future forward calls
1508
+ self.rope_cache = build_rope_cache(
1509
+ seq_len=self.block_size,
1510
+ n_elem=self.n_embd // self.n_head,
1511
+ dtype=x.dtype,
1512
+ device=x.device,
1513
+ )
1514
+
1515
+
1516
+ q = apply_rope(q, self.rope_cache)
1517
+ k = apply_rope(k, self.rope_cache)
1518
+
1519
+
1520
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1521
+ att = F.softmax(att, dim=-1) # [B, n_head, T, T]
1522
+
1523
+ # efficient attention using Flash Attention CUDA kernels
1524
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
1525
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1526
+
1527
+ # output projection
1528
+ y = self.c_proj(y)
1529
+
1530
+ return y, att
1531
+
1532
+ class LengthCausalSelfAttention(nn.Module):
1533
+ def __init__(self, config: LLaMAHFConfig) -> None:
1534
+ super().__init__()
1535
+ assert config.n_embd % config.n_head == 0
1536
+
1537
+ # key, query, value projections for all heads, but in a batch
1538
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1539
+ # output projection
1540
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1541
+
1542
+ self.n_head = config.n_head
1543
+ self.n_embd = config.n_embd
1544
+ self.block_size = config.block_size
1545
+ self.rope_cache = None
1546
+
1547
+ def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
1548
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1549
+
1550
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1551
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1552
+
1553
+ head_size = C // self.n_head
1554
+ k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1555
+ q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1556
+ v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1557
+
1558
+ if self.rope_cache is None:
1559
+ # cache for future forward calls
1560
+ self.rope_cache = build_rope_cache(
1561
+ seq_len=self.block_size,
1562
+ n_elem=self.n_embd // self.n_head,
1563
+ dtype=x.dtype,
1564
+ device=x.device,
1565
+ )
1566
+
1567
+
1568
+ # q: 1, 16, 40 ,64
1569
+ # q: 128, 16, 106, 64
1570
+ q = apply_rope(q, self.rope_cache)
1571
+ k = apply_rope(k, self.rope_cache)
1572
+
1573
+ attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device)
1574
+ attn_mask = torch.tril(attn_mask)
1575
+ attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
1576
+
1577
+ text_mask = y_mask.unsqueeze(2)*y_mask.unsqueeze(1)
1578
+ text_mask = F.pad(text_mask, (0, T-y_mask.shape[1], 0, T-y_mask.shape[1]), mode='constant', value=0)
1579
+ attn_mask = torch.logical_or(attn_mask, text_mask)
1580
+
1581
+ y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False)
1582
+
1583
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
1584
+
1585
+
1586
+ y = self.c_proj(y)
1587
+
1588
+ return y
1589
+
1590
+
1591
+ class MLP(nn.Module):
1592
+ def __init__(self, config: LLaMAHFConfig) -> None:
1593
+ super().__init__()
1594
+ hidden_dim = 4 * config.n_embd
1595
+ n_hidden = int(2 * hidden_dim / 3)
1596
+ N = 256
1597
+ # ensure n_hidden is multiple of N
1598
+ n_hidden = ((n_hidden - 1) // N) * N + N
1599
+
1600
+ self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
1601
+ self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
1602
+ self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
1603
+
1604
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1605
+
1606
+ x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
1607
+ x = self.c_proj(x)
1608
+ return x
1609
+
1610
+
1611
+ class RMSNorm(nn.Module):
1612
+ """Root Mean Square Layer Normalization.
1613
+
1614
+ Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1615
+ https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1616
+ """
1617
+
1618
+ def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1619
+ super().__init__()
1620
+ self.scale = nn.Parameter(torch.ones(size))
1621
+ self.eps = eps
1622
+ self.dim = dim
1623
+
1624
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1625
+ # NOTE: the original RMSNorm paper implementation is not equivalent
1626
+ # norm_x = x.norm(2, dim=self.dim, keepdim=True)
1627
+ # rms_x = norm_x * d_x ** (-1. / 2)
1628
+ # x_normed = x / (rms_x + self.eps)
1629
+ norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1630
+ x_normed = x * torch.rsqrt(norm_x + self.eps)
1631
+ return self.scale * x_normed
1632
+
1633
+
1634
+ def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
1635
+ """Enhanced Transformer with Rotary Position Embedding.
1636
+
1637
+ Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
1638
+ transformers/rope/__init__.py. MIT License:
1639
+ https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
1640
+ """
1641
+ # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
1642
+ theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
1643
+
1644
+ # Create position indexes `[0, 1, ..., seq_len - 1]`
1645
+ seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
1646
+
1647
+ # Calculate the product of position index and $\theta_i$
1648
+ idx_theta = torch.outer(seq_idx, theta)
1649
+
1650
+ # Compute cache. Because polar only takes float32 or float64, we need to cast
1651
+ # when working with 16 bit floats (float16 or bfloat16)
1652
+ dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
1653
+ working_dtype = (
1654
+ torch.float32 if dtype in dtypes_requiring_casting else dtype
1655
+ )
1656
+ complex_dtype = (
1657
+ torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
1658
+ )
1659
+ cache = torch.polar(
1660
+ torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
1661
+ ).to(complex_dtype)
1662
+ return cache
1663
+
1664
+
1665
+ def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
1666
+ x = x.transpose(1, 2)
1667
+
1668
+ # truncate to support variable sizes
1669
+ T = x.size(1)
1670
+ rope_cache = rope_cache[:T]
1671
+ # cast because `view_as_complex` does not support 16 bit tensors
1672
+ xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1673
+ rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
1674
+ x_out = torch.view_as_real(xc * rope_cache).flatten(3)
1675
+ return x_out.transpose(1, 2).type_as(x)