zirobtc commited on
Commit
6eb0e3f
·
verified ·
1 Parent(s): 8e1bf4b

Delete models/llama_model.py

Browse files
Files changed (1) hide show
  1. models/llama_model.py +0 -1603
models/llama_model.py DELETED
@@ -1,1603 +0,0 @@
1
-
2
- import math
3
- from dataclasses import dataclass
4
- import numpy as np
5
- import torch
6
- import torch.nn as nn
7
- from torch.nn import functional as F
8
- from typing_extensions import Self
9
- from typing import Optional
10
- from transformers.modeling_utils import PreTrainedModel
11
- from torch.distributions import Categorical
12
- import torch.nn.functional as F
13
-
14
-
15
- @dataclass
16
- class LLaMAHFConfig:
17
- block_size: int = 78
18
- n_layer: int = 32
19
- n_head: int = 32
20
- n_embd: int = 4096
21
- T5_xxl_dim: int = 768
22
-
23
- @classmethod
24
- def from_name(cls, name: str) -> Self:
25
- return cls(**llama_configs[name])
26
-
27
-
28
- llama_configs = {
29
- "Normal_size": dict(n_layer=12, n_head=12, n_embd=768)
30
- }
31
-
32
-
33
- class LLaMAHF(nn.Module):
34
- def __init__(self, config: LLaMAHFConfig, num_diffusion_head_layers=9, input_token_dim=16, device=torch.device('cuda'), width=1792) -> None:
35
- super().__init__()
36
- assert config.block_size is not None
37
- self.config = config
38
-
39
- cond_dim = config.T5_xxl_dim
40
-
41
- self.transformer = nn.ModuleDict(
42
- dict(
43
- wte=nn.Linear(input_token_dim, config.n_embd),
44
- cond_embed=nn.Linear(cond_dim, config.n_embd),
45
- h=nn.ModuleList([Block(config) for _ in range(config.n_layer)]),
46
- ln_f=RMSNorm(config.n_embd),
47
- )
48
- )
49
-
50
- target_channels = input_token_dim
51
- from models.diffloss import DiffLoss
52
- self.diff_loss = DiffLoss(
53
- target_channels=target_channels,
54
- z_channels=config.n_embd,
55
- width=width,
56
- depth=num_diffusion_head_layers,
57
- num_sampling_steps='50',
58
- grad_checkpointing=False,
59
- )
60
- self.diff_loss = self.diff_loss.to(device)
61
- self.out_proj = nn.Linear(config.n_embd, config.n_embd)
62
- self.use_out_proj = True
63
-
64
-
65
- def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
66
- """Tie or clone module weights depending of whether we are using TorchScript or not"""
67
- output_embeddings.weight = input_embeddings.weight
68
-
69
- if getattr(output_embeddings, "bias", None) is not None:
70
- output_embeddings.bias.data = nn.functional.pad(
71
- output_embeddings.bias.data,
72
- (
73
- 0,
74
- output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
75
- ),
76
- "constant",
77
- 0,
78
- )
79
- if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
80
- output_embeddings.out_features = input_embeddings.num_embeddings
81
-
82
- def get_input_embeddings(self):
83
- return self.transformer.wte
84
-
85
- def set_input_embeddings(self, value):
86
- self.transformer.wte = value
87
-
88
- def get_output_embeddings(self):
89
- return self.lm_head
90
-
91
- def set_output_embeddings(self, new_embeddings):
92
- self.lm_head = new_embeddings
93
-
94
- def _init_weights(self, module: nn.Module) -> None:
95
- if isinstance(module, nn.Linear):
96
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
97
- elif isinstance(module, nn.Embedding):
98
- torch.nn.init.normal_(module.weight, mean=0.0, std=0.02 / math.sqrt(2 * self.config.n_layer))
99
-
100
-
101
-
102
- def forward_sample(self, idx: torch.Tensor, clip_feature: torch.Tensor, y_mask) -> torch.Tensor:
103
-
104
- text_length = clip_feature.shape[1]
105
- if len(idx) == 0:
106
- x = self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :]
107
- else:
108
- _, t = idx.size()
109
- assert (
110
- t <= self.config.block_size
111
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
112
- # forward the LLaMA model itself
113
- x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
114
- x = torch.cat((self.llama_proj(clip_feature)[:, :int(y_mask[0].sum()), :],x), dim=1)
115
-
116
- for block in self.transformer.h:
117
- x = block(x, y_mask)
118
- x = self.transformer.ln_f(x)
119
- logits = x
120
- return logits
121
-
122
-
123
-
124
- def sample_for_eval_CFG(self, text, length=196, tokenize_model=None, device=torch.device('cuda'), unit_length=4, cfg=4.0):
125
- max_token_len = length // unit_length
126
- for k in range(max_token_len):
127
- if k == 0:
128
- x = []
129
- else:
130
- x = xs
131
-
132
- feat_text = torch.from_numpy(tokenize_model.encode(text)).float()
133
- feat_text = feat_text.to(device)
134
- conditions = self.forward(x, feat_text)
135
- conditions = conditions[:, -1, :]
136
-
137
- empty_text = ''
138
- empty_feat_text = torch.from_numpy(tokenize_model.encode(empty_text)).float()
139
- empty_feat_text = empty_feat_text.unsqueeze(0)
140
- empty_feat_text = empty_feat_text.to(device)
141
- empty_conditions = self.forward(x, empty_feat_text)
142
- empty_conditions = empty_conditions[:, -1, :]
143
- temperature = 1.0
144
-
145
- # chunk
146
- if cfg != 1:
147
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
148
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
149
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
150
- else: # no cfg
151
- scaled_logits = self.diff_loss.sample(conditions, temperature=temperature, cfg=1)
152
-
153
- scaled_logits = scaled_logits.unsqueeze(0)
154
-
155
- if k == 0:
156
- xs = scaled_logits
157
- else:
158
- xs = torch.cat((xs, scaled_logits), dim=1)
159
-
160
- return xs
161
-
162
-
163
-
164
- # For inference, can stop sampling when the distance between the current token and the reference end token is less than the threshold.
165
- def sample_for_eval_CFG_inference(self, text, length=312, tokenizer=None, device=torch.device('cuda'), unit_length=4, reference_end_latent=None, threshold=0.1, cfg=4.0, temperature=1.0):
166
- max_token_len = length // unit_length
167
- feat_text = torch.from_numpy(tokenizer.encode(text)).float()
168
- feat_text = feat_text.to(device)
169
-
170
- # CFG inference
171
- empty_text = ''
172
- empty_feat_text = torch.from_numpy(tokenizer.encode(empty_text)).float() # torch.Size([32, 768])
173
- empty_feat_text = empty_feat_text.unsqueeze(0)
174
- empty_feat_text = empty_feat_text.to(device)
175
-
176
- for k in range(max_token_len):
177
- if k == 0:
178
- x = []
179
- else:
180
- x = xs
181
-
182
- conditions = self.forward_inference(x, feat_text)
183
- conditions = conditions[:, -1, :]
184
-
185
- empty_conditions = self.forward(x, empty_feat_text)
186
- empty_conditions = empty_conditions[:, -1, :]
187
-
188
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
189
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
190
-
191
- # chunk
192
- if cfg != 1:
193
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
194
- else:
195
- scaled_logits = sampled_token_latent
196
-
197
- scaled_logits = scaled_logits.unsqueeze(0)
198
-
199
- if reference_end_latent is not None:
200
- distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_latent)**2))
201
- print(distance_l2)
202
- if distance_l2 < threshold:
203
- break
204
-
205
- if k == 0:
206
- xs = scaled_logits
207
- else:
208
- xs = torch.cat((xs, scaled_logits), dim=1)
209
-
210
- return xs
211
-
212
-
213
- def sample_for_eval_CFG_inference2(self, feat_clip_text, empty_feat_clip_text, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
214
-
215
- import clip
216
- max_token_len = length // unit_length
217
-
218
- for k in range(max_token_len):
219
- if k == 0:
220
- x = []
221
- else:
222
- x = xs
223
-
224
- try:
225
- conditions = self.forward(x, feat_clip_text)
226
- except:
227
- conditions = self.forward(x, feat_clip_text.unsqueeze(0))
228
-
229
-
230
- conditions = conditions[:, -1, :]
231
-
232
-
233
-
234
- empty_conditions = self.forward(x, empty_feat_clip_text)
235
- empty_conditions = empty_conditions[:, -1, :]
236
-
237
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
238
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
239
-
240
- # chunk
241
- if cfg != 1:
242
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
243
- else:
244
- scaled_logits = sampled_token_latent
245
-
246
- scaled_logits = scaled_logits.unsqueeze(0)
247
-
248
- if reference_end_token is not None:
249
- distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
250
- print(distance_l2)
251
- if distance_l2 < threshold:
252
- break
253
-
254
- if k == 0:
255
- xs = scaled_logits
256
- else:
257
- xs = torch.cat((xs, scaled_logits), dim=1)
258
-
259
- return xs
260
-
261
- def sample_for_eval_CFG_inference_next_one(self, current_token=[], feat_clip_text=None, empty_feat_clip_text=None, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, threshold=3, cfg=4.5, temperature=1.0):
262
-
263
- import clip
264
- max_token_len = length // unit_length
265
-
266
-
267
- for k in range(1):
268
-
269
- if current_token == []:
270
- x = []
271
- else:
272
- x = torch.cat(current_token, dim=1)
273
-
274
-
275
- try:
276
- conditions = self.forward(x, feat_clip_text)
277
- except:
278
- conditions = self.forward(x, feat_clip_text.unsqueeze(0))
279
-
280
-
281
- conditions = conditions[:, -1, :]
282
-
283
-
284
- empty_conditions = self.forward(x, empty_feat_clip_text)
285
- empty_conditions = empty_conditions[:, -1, :]
286
-
287
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
288
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
289
-
290
- # chunk
291
- if cfg != 1:
292
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
293
- else:
294
- scaled_logits = sampled_token_latent
295
-
296
-
297
- scaled_logits = scaled_logits.unsqueeze(0)
298
-
299
-
300
- if k == 0:
301
- xs = scaled_logits
302
- else:
303
- xs = torch.cat((xs, scaled_logits), dim=1)
304
-
305
- return xs
306
-
307
-
308
- def sample_for_eval_CFG_babel(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
309
-
310
- import clip
311
- B_token_length = length // unit_length - A_motion.shape[0]
312
-
313
- if tokenizer == 'clip':
314
- A_text = clip.tokenize(A_text, truncate=True).to(device)
315
- A_feat_clip_text = clip_model.encode_text(A_text).float()
316
- B_text = clip.tokenize(B_text, truncate=True).to(device)
317
- B_feat_clip_text = clip_model.encode_text(B_text).float()
318
- elif tokenizer == 't5-xxl':
319
- A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
320
- A_feat_clip_text = A_feat_clip_text.to(device)
321
- B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
322
- B_feat_clip_text = B_feat_clip_text.to(device)
323
-
324
- A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
325
- B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
326
-
327
- A_motion = A_motion.unsqueeze(0)
328
- A_motion_embeddings = self.transformer.wte(A_motion)
329
- B_motion = torch.tensor([]).to(device)
330
-
331
- for k in range(B_token_length):
332
- if k == 0:
333
- x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
334
- else:
335
- x = xs
336
-
337
-
338
- conditions = self.forward_babel_eval(x)
339
- conditions = conditions[:, -1, :]
340
-
341
- empty_clip_text = ''
342
- if tokenizer == 'clip':
343
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
344
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
345
- elif tokenizer == 't5-xxl':
346
- empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
347
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
348
- empty_feat_clip_text = empty_feat_clip_text.to(device)
349
-
350
- empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
351
-
352
- if k == 0:
353
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
354
- empty_conditions = self.forward_babel_eval(empty_input)
355
- else:
356
- B_motion_embeddings = self.transformer.wte(B_motion)
357
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
358
- empty_conditions = self.forward_babel_eval(empty_input)
359
-
360
- empty_conditions = empty_conditions[:, -1, :]
361
- temperature = 1.0
362
-
363
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
364
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
365
-
366
- # chunk
367
- if cfg != 1:
368
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
369
- else:
370
- scaled_logits = sampled_token_latent
371
-
372
-
373
- scaled_logits = scaled_logits.unsqueeze(0)
374
-
375
-
376
- B_motion = torch.cat((B_motion, scaled_logits), dim=1)
377
-
378
- scaled_logits_embedding = self.transformer.wte(scaled_logits)
379
- xs = torch.cat((x, scaled_logits_embedding), dim=1)
380
-
381
-
382
- return xs, B_motion
383
-
384
- def sample_for_eval_CFG_babel_inference(self, A_text, B_text, A_motion, if_categorial=False, length=6400, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=7.0, threshold=3):
385
-
386
- import clip
387
- B_token_length = length // unit_length - A_motion.shape[0]
388
-
389
- if tokenizer == 'clip':
390
- A_text = clip.tokenize(A_text, truncate=True).to(device)
391
- A_feat_clip_text = clip_model.encode_text(A_text).float()
392
- B_text = clip.tokenize(B_text, truncate=True).to(device)
393
- B_feat_clip_text = clip_model.encode_text(B_text).float()
394
- elif tokenizer == 't5-xxl':
395
- A_feat_clip_text = torch.from_numpy(clip_model.encode(A_text)).float()
396
- A_feat_clip_text = A_feat_clip_text.to(device)
397
- B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
398
- B_feat_clip_text = B_feat_clip_text.to(device)
399
-
400
- A_text_embeddings = self.transformer.cond_embed(A_feat_clip_text).unsqueeze(0)
401
- A_text_embeddings = A_text_embeddings.unsqueeze(0)
402
- B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
403
- B_text_embeddings = B_text_embeddings.unsqueeze(0)
404
-
405
- A_motion = A_motion.unsqueeze(0)
406
- A_motion_embeddings = self.transformer.wte(A_motion)
407
- B_motion = torch.tensor([]).to(device)
408
-
409
- attention_weights = []
410
-
411
- for k in range(B_token_length):
412
- if k == 0:
413
- x = torch.cat([A_text_embeddings, A_motion_embeddings, B_text_embeddings], dim=1)
414
-
415
- else:
416
- x = xs
417
-
418
-
419
-
420
- conditions = self.forward_babel_eval(x, return_attention=False)
421
- conditions = conditions[:, -1, :]
422
-
423
- empty_clip_text = ''
424
- if tokenizer == 'clip':
425
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
426
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
427
- elif tokenizer == 't5-xxl':
428
- empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
429
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
430
- empty_feat_clip_text = empty_feat_clip_text.to(device)
431
-
432
- empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
433
-
434
- if k == 0:
435
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding], dim=1)
436
- empty_conditions = self.forward_babel_eval(empty_input)
437
- else:
438
- B_motion_embeddings = self.transformer.wte(B_motion)
439
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, empty_feat_clip_text_embedding, B_motion_embeddings], dim=1)
440
- empty_conditions = self.forward_babel_eval(empty_input)
441
-
442
- empty_conditions = empty_conditions[:, -1, :]
443
- temperature = 1.0
444
-
445
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
446
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
447
-
448
- # chunk
449
- if cfg != 1:
450
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
451
- else:
452
- scaled_logits = sampled_token_latent
453
-
454
- scaled_logits = scaled_logits.unsqueeze(0)
455
-
456
- if reference_end_token is not None:
457
- distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
458
- print(distance_l2)
459
- if distance_l2 < threshold:
460
- break
461
-
462
- B_motion = torch.cat((B_motion, scaled_logits), dim=1)
463
-
464
- scaled_logits_embedding = self.transformer.wte(scaled_logits)
465
- xs = torch.cat((x, scaled_logits_embedding), dim=1)
466
-
467
-
468
-
469
- return xs, B_motion
470
-
471
-
472
- def sample_for_eval_CFG_babel_inference_new(self, B_text, A_motion, if_categorial=False, length=78, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3):
473
-
474
- import clip
475
- B_token_length = length // unit_length
476
-
477
- if tokenizer == 'clip':
478
- A_text = clip.tokenize(A_text, truncate=True).to(device)
479
- A_feat_clip_text = clip_model.encode_text(A_text).float()
480
- B_text = clip.tokenize(B_text, truncate=True).to(device)
481
- B_feat_clip_text = clip_model.encode_text(B_text).float()
482
- elif tokenizer == 't5-xxl':
483
- B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
484
- B_feat_clip_text = B_feat_clip_text.to(device)
485
-
486
- empty_clip_text = ''
487
- if tokenizer == 'clip':
488
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
489
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
490
- elif tokenizer == 't5-xxl':
491
- empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
492
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
493
- empty_feat_clip_text = empty_feat_clip_text.to(device)
494
-
495
- B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
496
-
497
- A_motion = A_motion.unsqueeze(0)
498
- A_motion_embeddings = self.transformer.wte(A_motion)
499
- B_motion = torch.tensor([]).to(device)
500
-
501
-
502
- attention_weights = []
503
-
504
- for k in range(B_token_length):
505
- if k == 0:
506
- x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
507
- else:
508
- x = xs
509
-
510
- conditions = self.forward_babel_eval(x, return_attention=False)
511
- conditions = conditions[:, -1, :]
512
-
513
-
514
- empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
515
-
516
- if k == 0:
517
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
518
-
519
- empty_conditions = self.forward_babel_eval(empty_input)
520
- else:
521
- B_motion_embeddings = self.transformer.wte(B_motion)
522
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
523
- empty_conditions = self.forward_babel_eval(empty_input)
524
-
525
- empty_conditions = empty_conditions[:, -1, :]
526
- temperature = 1.0
527
-
528
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
529
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
530
-
531
- # chunk
532
- if cfg != 1:
533
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
534
- else:
535
- scaled_logits = sampled_token_latent
536
-
537
- scaled_logits = scaled_logits.unsqueeze(0)
538
-
539
- if reference_end_token is not None:
540
- distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
541
- print(distance_l2)
542
- if distance_l2 < threshold:
543
- break
544
-
545
- B_motion = torch.cat((B_motion, scaled_logits), dim=1)
546
-
547
- scaled_logits_embedding = self.transformer.wte(scaled_logits)
548
- xs = torch.cat((x, scaled_logits_embedding), dim=1)
549
-
550
-
551
-
552
- return xs, B_motion
553
-
554
-
555
- def sample_for_eval_CFG_babel_inference_new_demo(self, B_text, A_motion, if_categorial=False, length=312, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4, reference_end_token=None, cfg=4.5, threshold=3, temperature=1.0):
556
-
557
- import clip
558
- B_token_length = length // unit_length - A_motion.shape[0]
559
-
560
- if tokenizer == 'clip':
561
- A_text = clip.tokenize(A_text, truncate=True).to(device)
562
- A_feat_clip_text = clip_model.encode_text(A_text).float()
563
- B_text = clip.tokenize(B_text, truncate=True).to(device)
564
- B_feat_clip_text = clip_model.encode_text(B_text).float()
565
- elif tokenizer == 't5-xxl':
566
- B_feat_clip_text = torch.from_numpy(clip_model.encode(B_text)).float()
567
- B_feat_clip_text = B_feat_clip_text.to(device)
568
-
569
- empty_clip_text = ''
570
- if tokenizer == 'clip':
571
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
572
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
573
- elif tokenizer == 't5-xxl':
574
- empty_feat_clip_text = torch.from_numpy(clip_model.encode(empty_clip_text)).float()
575
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
576
- empty_feat_clip_text = empty_feat_clip_text.to(device)
577
-
578
- B_text_embeddings = self.transformer.cond_embed(B_feat_clip_text).unsqueeze(0)
579
- B_text_embeddings = B_text_embeddings.unsqueeze(0)
580
-
581
- A_motion = A_motion.unsqueeze(0)
582
- A_motion_embeddings = self.transformer.wte(A_motion)
583
- B_motion = torch.tensor([]).to(device)
584
-
585
- # 存储所有层的注意力权重
586
- attention_weights = []
587
-
588
- for k in range(B_token_length):
589
- if k == 0:
590
- x = torch.cat([B_text_embeddings, A_motion_embeddings], dim=1)
591
-
592
- else:
593
- x = xs
594
-
595
-
596
- conditions = self.forward_babel_eval(x, return_attention=False)
597
- conditions = conditions[:, -1, :]
598
-
599
-
600
- empty_feat_clip_text_embedding = self.transformer.cond_embed(empty_feat_clip_text).unsqueeze(0)
601
-
602
- if k == 0:
603
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings], dim=1)
604
- empty_conditions = self.forward_babel_eval(empty_input)
605
- else:
606
- B_motion_embeddings = self.transformer.wte(B_motion)
607
- empty_input = torch.cat([empty_feat_clip_text_embedding, A_motion_embeddings, B_motion_embeddings], dim=1)
608
- empty_conditions = self.forward_babel_eval(empty_input)
609
-
610
- empty_conditions = empty_conditions[:, -1, :]
611
-
612
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
613
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
614
-
615
- # chunk
616
- if cfg != 1:
617
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
618
- else:
619
- scaled_logits = sampled_token_latent
620
-
621
- scaled_logits = scaled_logits.unsqueeze(0)
622
-
623
- if reference_end_token is not None:
624
- distance_l2 = torch.sqrt(torch.sum((scaled_logits - reference_end_token)**2))
625
- print(distance_l2)
626
- if distance_l2 < threshold and k > 10:
627
- break
628
-
629
- B_motion = torch.cat((B_motion, scaled_logits), dim=1)
630
-
631
- scaled_logits_embedding = self.transformer.wte(scaled_logits)
632
- xs = torch.cat((x, scaled_logits_embedding), dim=1)
633
-
634
-
635
-
636
- return xs, B_motion
637
-
638
-
639
-
640
- #--------------Test classification head--------------------
641
- def sample_for_eval_classification(self, clip_text, if_categorial=False, length=196, clip_model=None, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
642
-
643
- import clip
644
-
645
-
646
- for k in range(51):
647
- if k == 0:
648
- x = []
649
- else:
650
- x = xs
651
-
652
- if tokenizer == 'clip':
653
- text = clip.tokenize(clip_text, truncate=True).to(device)
654
-
655
- feat_clip_text = clip_model.encode_text(text).float()
656
- elif tokenizer == 't5-xxl':
657
- feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
658
-
659
- conditions = self.forward(x, feat_clip_text)
660
- conditions = conditions[:, -1, :]
661
-
662
- empty_clip_text = ''
663
- if tokenizer == 'clip':
664
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
665
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
666
- elif tokenizer == 't5-xxl':
667
- empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
668
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
669
- empty_feat_clip_text = empty_feat_clip_text.to(device)
670
-
671
- empty_conditions = self.forward(x, empty_feat_clip_text)
672
- empty_conditions = empty_conditions[:, -1, :]
673
-
674
- temperature = 1.0
675
- cfg = 7.5
676
-
677
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
678
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
679
-
680
- # chunk
681
- if cfg != 1:
682
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
683
- else:
684
- scaled_logits = sampled_token_latent
685
-
686
-
687
- prediction_logits = self.classify_head(conditions)
688
- probs = torch.sigmoid(prediction_logits)
689
- predicted_classes = torch.argmax(probs, dim=-1)
690
-
691
-
692
- scaled_logits = scaled_logits.unsqueeze(0)
693
-
694
- if k == 0:
695
- xs = scaled_logits
696
- else:
697
- xs = torch.cat((xs, scaled_logits), dim=1)
698
-
699
- if predicted_classes == 1:
700
- break
701
-
702
- return xs
703
-
704
-
705
- #--------------------Test CFG-----------------------
706
- def sample_for_eval_CFG_test(self, clip_text, if_categorial=False, length=196, clip_model=None, cfg=1, device=torch.device('cuda'), tokenizer='clip', unit_length=4):
707
-
708
- import clip
709
- max_token_len = length // unit_length
710
-
711
-
712
- for k in range(max_token_len):
713
- if k == 0:
714
- x = []
715
- else:
716
- x = xs
717
-
718
-
719
- if cfg != 1:
720
- if tokenizer == 'clip':
721
- text = clip.tokenize(clip_text, truncate=True).to(device)
722
-
723
- feat_clip_text = clip_model.encode_text(text).float()
724
- elif tokenizer == 't5-xxl':
725
- feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
726
-
727
- conditions = self.forward(x, feat_clip_text)
728
-
729
- conditions = conditions[:, -1, :]
730
- empty_clip_text = ''
731
- if tokenizer == 'clip':
732
- empty_text = clip.tokenize(empty_clip_text, truncate=True).to(device)
733
- empty_feat_clip_text = clip_model.encode_text(empty_text).float()
734
- elif tokenizer == 't5-xxl':
735
- empty_feat_clip_text = torch.from_numpy(clip_model.module.encode(empty_clip_text)).float()
736
- empty_feat_clip_text = empty_feat_clip_text.unsqueeze(0)
737
- empty_feat_clip_text = empty_feat_clip_text.to(device)
738
-
739
- empty_conditions = self.forward(x, empty_feat_clip_text)
740
- empty_conditions = empty_conditions[:, -1, :]
741
- temperature = 1.0
742
-
743
-
744
- mix_conditions = torch.cat([conditions, empty_conditions], dim=0)
745
- sampled_token_latent = self.diff_loss.sample(mix_conditions, temperature=temperature, cfg=cfg)
746
-
747
- # chunk
748
- scaled_logits, _ = sampled_token_latent.chunk(2, dim=0)
749
-
750
- else:
751
- if tokenizer == 'clip':
752
- text = clip.tokenize(clip_text, truncate=True).to(device)
753
- feat_clip_text = clip_model.encode_text(text).float()
754
- elif tokenizer == 't5-xxl':
755
- feat_clip_text = torch.from_numpy(clip_model.module.encode(clip_text)).float()
756
- feat_clip_text = feat_clip_text.to(device)
757
-
758
-
759
- conditions = self.forward(x, feat_clip_text)
760
-
761
- conditions = conditions[:, -1, :]
762
- temperature = 1.0
763
- sampled_token_latent = self.diff_loss.sample(conditions, temperature=temperature, cfg=cfg)
764
- scaled_logits = sampled_token_latent
765
-
766
- scaled_logits = scaled_logits.unsqueeze(0)
767
-
768
- if k == 0:
769
- xs = scaled_logits
770
- else:
771
- xs = torch.cat((xs, scaled_logits), dim=1)
772
-
773
- return xs
774
- #--------------------------------------------------
775
-
776
- def forward_discrete(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None) -> torch.Tensor:
777
- if len(idx) == 0:
778
- token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(0)
779
-
780
- else:
781
- b, t = idx.size()
782
- #idx = idx.float()
783
- assert (
784
- t <= self.config.block_size
785
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
786
-
787
- # forward the LLaMA model itself
788
- token_embeddings = self.transformer.wte(idx)
789
- text_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
790
- token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
791
-
792
- x = token_embeddings
793
-
794
- # -------------------kv cache-------------------
795
- #presents = () if use_cache else None
796
- if use_cache:
797
- if past_key_values is None:
798
- past_key_values = [None] * len(self.transformer.h)
799
-
800
-
801
- for i,block in enumerate(self.transformer.h):
802
- if use_cache:
803
- last_past = past_key_values[i]
804
- x, presents = block(x, last_past, use_cache)
805
- past_key_values[i] = list(presents)
806
- else:
807
- x = block(x)
808
- x = self.transformer.ln_f(x)
809
-
810
- logits = self.lm_head(x)
811
-
812
-
813
- return logits
814
-
815
-
816
- def forward(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
817
- if len(idx) == 0:
818
- token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
819
-
820
- else:
821
- b, t, c = idx.size()
822
- idx = idx.float()
823
- assert (
824
- t <= self.config.block_size
825
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
826
-
827
- # forward the LLaMA model itself
828
- token_embeddings = self.transformer.wte(idx)
829
- text_embeddings = self.transformer.cond_embed(feature).unsqueeze(1)
830
- token_embeddings = torch.cat([text_embeddings, token_embeddings], dim=1)
831
-
832
- x = token_embeddings
833
-
834
- for i,block in enumerate(self.transformer.h):
835
- x = block(x)
836
- x = self.transformer.ln_f(x)
837
- logits = self.out_proj(x)
838
- return logits
839
-
840
-
841
- def forward_inference(self, idx: torch.Tensor, feature: torch.Tensor) -> torch.Tensor:
842
- if len(idx) == 0:
843
- token_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
844
-
845
- else:
846
- b, t, c = idx.size()
847
- idx = idx.float()
848
- assert (
849
- t <= self.config.block_size
850
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
851
-
852
- # forward the LLaMA model itself
853
- token_embeddings = self.transformer.wte(idx)
854
- text_embeddings = self.transformer.cond_embed(feature).unsqueeze(0)
855
- token_embeddings = torch.cat([text_embeddings.unsqueeze(0), token_embeddings], dim=1)
856
-
857
- x = token_embeddings
858
-
859
- if len(x.shape) == 2:
860
- x = x.unsqueeze(0)
861
-
862
- for i,block in enumerate(self.transformer.h):
863
- x = block(x)
864
- x = self.transformer.ln_f(x)
865
- logits = self.out_proj(x)
866
- return logits
867
-
868
-
869
- def babel_long(self, idx: torch.Tensor, clip_feature: torch.Tensor, use_cache=False, past_key_values=None, num_subseq=None, length=None) -> torch.Tensor:
870
-
871
- b, t, c = idx.size()
872
- idx = idx.float()
873
- idx = self.transformer.wte(idx)
874
- assert (
875
- t <= self.config.block_size
876
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
877
- for i in range(b):
878
- length_i = length[i][:num_subseq[i]]
879
- clip_feature_i = clip_feature[i][:num_subseq[i]]
880
-
881
- pointer = 0
882
- for j in range(num_subseq[i]):
883
- if j > 0:
884
- pointer += length_i[j].item()
885
- pointer += 1
886
- pointer = int(pointer)
887
-
888
- clip_feature_i_j = self.transformer.cond_embed(clip_feature_i[j].unsqueeze(0)).unsqueeze(1)
889
- idx[i] = torch.cat([idx[i][:pointer].unsqueeze(0), clip_feature_i_j, idx[i][pointer:-1].unsqueeze(0)], dim=1)[0]
890
-
891
- x = idx
892
-
893
-
894
- if use_cache:
895
- if past_key_values is None:
896
- past_key_values = [None] * len(self.transformer.h)
897
-
898
-
899
- for i,block in enumerate(self.transformer.h):
900
- if use_cache:
901
- last_past = past_key_values[i]
902
- x, presents = block(x, last_past, use_cache)
903
- past_key_values[i] = list(presents)
904
- else:
905
- x = block(x)
906
- x = self.transformer.ln_f(x)
907
-
908
- logits = self.out_proj(x)
909
- return logits
910
-
911
-
912
- def forward_babel_eval(self, x, return_attention=False) -> torch.Tensor:
913
- layer_attentions = []
914
- for block in self.transformer.h:
915
- if return_attention:
916
- x, att = block(x, return_attention=True)
917
- layer_attentions.append(att)
918
- else:
919
- x = block(x)
920
-
921
- x = self.transformer.ln_f(x)
922
- if self.use_out_proj:
923
- logits = self.out_proj(x)
924
- else:
925
- logits = x
926
-
927
- if return_attention:
928
- return logits, layer_attentions
929
- return logits
930
-
931
- def forward_babel(self, idx: torch.Tensor, clip_feature: torch.Tensor, A_token_length) -> torch.Tensor:
932
- if len(idx) == 0: # inference
933
- token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
934
-
935
- else:
936
- b, t, c = idx.size()
937
- idx = idx.float()
938
- assert (
939
- t <= self.config.block_size
940
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
941
-
942
-
943
-
944
- A_feature = clip_feature[:, 0, :]
945
- B_feature = clip_feature[:, 1, :]
946
-
947
-
948
- A_text_embeddings = self.transformer.cond_embed(A_feature).unsqueeze(1)
949
- B_text_embeddings = self.transformer.cond_embed(B_feature).unsqueeze(1)
950
-
951
- token_embeddings = torch.zeros(b, self.config.block_size, self.config.n_embd).to(idx.device)
952
- for i in range(b):
953
- A_idx = idx[i, :A_token_length[i].item(), :]
954
- B_idx = idx[i, A_token_length[i].item():-2, :]
955
- token_embeddings[i, :, :] = torch.cat([A_text_embeddings[i], self.BOM_tag, self.transformer.wte(A_idx), B_text_embeddings[i], self.BOM_tag, self.transformer.wte(B_idx)], dim=0) #token_embeddings.shape = (b,t+1,1024)
956
-
957
- x = token_embeddings
958
- for block in self.transformer.h:
959
- x = block(x)
960
- x = self.transformer.ln_f(x)
961
-
962
- if self.use_out_proj:
963
- logits = self.out_proj(x)
964
- else:
965
- logits = x
966
-
967
-
968
- return logits
969
-
970
- def forward_babel2(self, idx: torch.Tensor, clip_feature: torch.Tensor) -> torch.Tensor:
971
- if len(idx) == 0: # inference
972
- token_embeddings = self.transformer.cond_embed(clip_feature).unsqueeze(1)
973
-
974
- else:
975
- b, t, c = idx.size()
976
- idx = idx.float()
977
- assert (
978
- t <= self.config.block_size
979
- ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
980
-
981
- B_feature = clip_feature
982
- B_text_embeddings = self.transformer.cond_embed(B_feature)
983
-
984
- idx_embeddings = self.transformer.wte(idx)
985
-
986
-
987
- token_embeddings = torch.cat([B_text_embeddings, idx_embeddings], dim=1)
988
-
989
-
990
- x = token_embeddings
991
- for block in self.transformer.h:
992
- x = block(x)
993
- x = self.transformer.ln_f(x)
994
-
995
- if self.use_out_proj:
996
- logits = self.out_proj(x)
997
- else:
998
- logits = x
999
-
1000
- return logits
1001
-
1002
-
1003
- def resize_token_embeddings(
1004
- self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None, using_old_initilization: bool = False
1005
- ) -> nn.Embedding:
1006
- """
1007
- Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
1008
-
1009
- Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
1010
-
1011
- Arguments:
1012
- new_num_tokens (`int`, *optional*):
1013
- The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
1014
- vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
1015
- returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
1016
- pad_to_multiple_of (`int`, *optional*):
1017
- If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
1018
- `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1019
-
1020
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1021
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1022
- details about this, or help on choosing the correct value for resizing, refer to this guide:
1023
- https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1024
-
1025
- Return:
1026
- `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
1027
- """
1028
- model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
1029
- if new_num_tokens is None and pad_to_multiple_of is None:
1030
- return model_embeds
1031
-
1032
- # Update base model and current model config
1033
- self.config.vocab_size = model_embeds.weight.shape[0]
1034
- self.vocab_size = model_embeds.weight.shape[0]
1035
-
1036
- # Tie weights again if needed
1037
- # self.tie_weights()
1038
-
1039
- return model_embeds
1040
-
1041
- def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
1042
- old_embeddings = self.get_input_embeddings()
1043
- new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
1044
- old_embeddings_requires_grad = old_embeddings.weight.requires_grad
1045
- new_embeddings.requires_grad_(old_embeddings_requires_grad)
1046
- self.set_input_embeddings(new_embeddings)
1047
-
1048
- # Update new_num_tokens with the actual size of new_embeddings
1049
- if pad_to_multiple_of is not None:
1050
- # if is_deepspeed_zero3_enabled():
1051
- # import deepspeed
1052
-
1053
- # with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
1054
- # new_num_tokens = new_embeddings.weight.shape[0]
1055
- # else:
1056
- new_num_tokens = new_embeddings.weight.shape[0]
1057
-
1058
- # if word embeddings are not tied, make sure that lm head is resized as well
1059
- # if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
1060
- if self.get_output_embeddings() is not None and not False:
1061
- old_lm_head = self.get_output_embeddings()
1062
- new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens)
1063
- # if hasattr(old_lm_head, "_hf_hook"):
1064
- # hook = old_lm_head._hf_hook
1065
- # add_hook_to_module(new_lm_head, hook)
1066
- old_lm_head_requires_grad = old_lm_head.weight.requires_grad
1067
- new_lm_head.requires_grad_(old_lm_head_requires_grad)
1068
- self.set_output_embeddings(new_lm_head)
1069
-
1070
- return self.get_input_embeddings()
1071
-
1072
- def _get_resized_embeddings(
1073
- self,
1074
- old_embeddings: nn.Embedding,
1075
- new_num_tokens: Optional[int] = None,
1076
- pad_to_multiple_of: Optional[int] = None,
1077
- ) -> nn.Embedding:
1078
- """
1079
- Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
1080
- initialized vectors at the end. Reducing the size will remove vectors from the end
1081
-
1082
- Args:
1083
- old_embeddings (`torch.nn.Embedding`):
1084
- Old embeddings to be resized.
1085
- new_num_tokens (`int`, *optional*):
1086
- New number of tokens in the embedding matrix.
1087
-
1088
- Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1089
- vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1090
- `torch.nn.Embedding` module of the model without doing anything.
1091
- pad_to_multiple_of (`int`, *optional*):
1092
- If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
1093
- `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
1094
-
1095
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
1096
- `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
1097
- details about this, or help on choosing the correct value for resizing, refer to this guide:
1098
- https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
1099
-
1100
-
1101
- Return:
1102
- `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
1103
- `new_num_tokens` is `None`
1104
- """
1105
-
1106
- if pad_to_multiple_of is not None:
1107
- if not isinstance(pad_to_multiple_of, int):
1108
- raise ValueError(
1109
- f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
1110
- )
1111
- if new_num_tokens is None:
1112
- new_num_tokens = old_embeddings.weight.shape[0]
1113
- new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
1114
- else:
1115
- print(
1116
- "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
1117
- f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
1118
- " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
1119
- " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
1120
- )
1121
-
1122
- if new_num_tokens is None:
1123
- return old_embeddings
1124
-
1125
- # if is_deepspeed_zero3_enabled():
1126
- if False:
1127
- import deepspeed
1128
-
1129
- with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
1130
- old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1131
- else:
1132
- old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
1133
-
1134
- # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1135
- if old_num_tokens == new_num_tokens and not False:
1136
- return old_embeddings
1137
-
1138
- if not isinstance(old_embeddings, nn.Embedding):
1139
- raise TypeError(
1140
- f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
1141
- " should either use a different resize function or make sure that `old_embeddings` are an instance of"
1142
- f" {nn.Embedding}."
1143
- )
1144
-
1145
- # Build new embeddings
1146
-
1147
- # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1148
- # because the shape of the new embedding layer is used across various modeling files
1149
- # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1150
- # to errors when training.
1151
- new_embeddings = nn.Embedding(
1152
- new_num_tokens,
1153
- old_embedding_dim,
1154
- device=old_embeddings.weight.device,
1155
- dtype=old_embeddings.weight.dtype,
1156
- )
1157
-
1158
- # initialize all new embeddings (in particular added tokens)
1159
- self._init_weights(new_embeddings)
1160
-
1161
- # Copy token embeddings from the previous weights
1162
-
1163
- # numbers of tokens to copy
1164
- n = min(old_num_tokens, new_num_tokens)
1165
-
1166
- # if is_deepspeed_zero3_enabled():
1167
- if False:
1168
- import deepspeed
1169
-
1170
- params = [old_embeddings.weight, new_embeddings.weight]
1171
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1172
- new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1173
- else:
1174
- new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
1175
-
1176
- return new_embeddings
1177
-
1178
-
1179
- def _get_resized_lm_head(
1180
- self, old_lm_head: nn.Linear, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
1181
- ) -> nn.Linear:
1182
- """
1183
- Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
1184
- vectors at the end. Reducing the size will remove vectors from the end
1185
-
1186
- Args:
1187
- old_lm_head (`torch.nn.Linear`):
1188
- Old lm head liner layer to be resized.
1189
- new_num_tokens (`int`, *optional*):
1190
- New number of tokens in the linear matrix.
1191
-
1192
- Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
1193
- vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
1194
- `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
1195
- to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
1196
- vocab_size` else `vocab_size, lm_head_dim`.
1197
-
1198
- Return:
1199
- `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
1200
- `None`
1201
- """
1202
- if new_num_tokens is None:
1203
- return old_lm_head
1204
-
1205
- # if is_deepspeed_zero3_enabled():
1206
- if False:
1207
- import deepspeed
1208
-
1209
- with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
1210
- old_num_tokens, old_lm_head_dim = (
1211
- old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1212
- )
1213
- else:
1214
- old_num_tokens, old_lm_head_dim = (
1215
- old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
1216
- )
1217
-
1218
- # if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
1219
- if old_num_tokens == new_num_tokens and not False:
1220
- return old_lm_head
1221
-
1222
- if not isinstance(old_lm_head, nn.Linear):
1223
- raise TypeError(
1224
- f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
1225
- " should either use a different resize function or make sure that `old_lm_head` are an instance of"
1226
- f" {nn.Linear}."
1227
- )
1228
-
1229
- # Build new lm head
1230
- new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
1231
- has_new_lm_head_bias = old_lm_head.bias is not None
1232
-
1233
- # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
1234
- # because the shape of the new embedding layer is used across various modeling files
1235
- # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
1236
- # to errors when training.
1237
- new_lm_head = nn.Linear(
1238
- *new_lm_head_shape,
1239
- bias=has_new_lm_head_bias,
1240
- device=old_lm_head.weight.device,
1241
- dtype=old_lm_head.weight.dtype,
1242
- )
1243
-
1244
- # initialize new lm head (in particular added tokens)
1245
- self._init_weights(new_lm_head)
1246
-
1247
- num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
1248
-
1249
- # if is_deepspeed_zero3_enabled():
1250
- if False:
1251
- import deepspeed
1252
-
1253
- params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
1254
- with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
1255
- self._copy_lm_head_original_to_resized(
1256
- new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1257
- )
1258
- else:
1259
- self._copy_lm_head_original_to_resized(
1260
- new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1261
- )
1262
-
1263
- return new_lm_head
1264
-
1265
- def _copy_lm_head_original_to_resized(
1266
- self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
1267
- ):
1268
- # Copy old lm head weights to new lm head
1269
- if not transposed:
1270
- new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
1271
- else:
1272
- new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
1273
-
1274
- # Copy bias weights to new lm head
1275
- if has_new_lm_head_bias:
1276
- new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
1277
-
1278
- @classmethod
1279
- def from_name(cls, name: str) -> Self:
1280
- return cls(LLaMAHFConfig.from_name(name))
1281
-
1282
-
1283
- class Block(nn.Module):
1284
- def __init__(self, config: LLaMAHFConfig) -> None:
1285
- super().__init__()
1286
- self.rms_1 = RMSNorm(config.n_embd)
1287
-
1288
- # sentence level:
1289
- self.attn = CausalSelfAttention(config)
1290
- self.rms_2 = RMSNorm(config.n_embd)
1291
- self.mlp = MLP(config)
1292
-
1293
- def forward(self, x: torch.Tensor, last_past=None, use_cache=False, return_attention=False) -> torch.Tensor:
1294
- if use_cache:
1295
- if return_attention:
1296
- a, attn = self.attn.forward_attn(self.rms_1(x), last_past, use_cache)
1297
- else:
1298
- a, present = self.attn(self.rms_1(x), last_past, use_cache)
1299
- x = x + a
1300
- else:
1301
- if return_attention:
1302
- a, attn = self.attn.forward_attn(self.rms_1(x))
1303
- else:
1304
- a = self.attn(self.rms_1(x))
1305
- x = x + a
1306
- x = x + self.mlp(self.rms_2(x))
1307
-
1308
- if use_cache:
1309
- if return_attention:
1310
- return x, present, attn
1311
- else:
1312
- return x, present
1313
- else:
1314
- if return_attention:
1315
- return x, attn
1316
- else:
1317
- return x
1318
-
1319
-
1320
- class CausalSelfAttention(nn.Module):
1321
- def __init__(self, config: LLaMAHFConfig) -> None:
1322
- super().__init__()
1323
- assert config.n_embd % config.n_head == 0
1324
-
1325
- # key, query, value projections for all heads, but in a batch
1326
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1327
- # output projection
1328
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1329
-
1330
- self.n_head = config.n_head
1331
- self.n_embd = config.n_embd
1332
- self.block_size = config.block_size
1333
- self.rope_cache = None
1334
-
1335
- def scaling_factor(sequence_threshold):
1336
- return np.log2((sequence_threshold**2) - sequence_threshold)
1337
- scale_init = scaling_factor(self.block_size)
1338
- self.scale = nn.Parameter(torch.tensor(scale_init))
1339
-
1340
- def forward(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1341
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1342
-
1343
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1344
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1345
-
1346
- head_size = C // self.n_head
1347
- k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1348
- q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1349
- v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1350
-
1351
- # kv_cache
1352
- if use_cache:
1353
- if last_past is not None:
1354
- past_key, past_value = last_past
1355
- k = torch.cat([past_key, k], dim=-2)
1356
- v = torch.cat([past_value, v], dim=-2)
1357
- # else:
1358
- # key_states = k
1359
- # value_states = v
1360
-
1361
- if use_cache:
1362
- present = (k, v)
1363
- else:
1364
- present = None
1365
-
1366
- # QK-Norm
1367
- q = F.normalize(q, p=2, dim=-1)
1368
- k = F.normalize(k, p=2, dim=-1)
1369
-
1370
- if self.rope_cache is None:
1371
- # cache for future forward calls
1372
- self.rope_cache = build_rope_cache(
1373
- seq_len=self.block_size,
1374
- n_elem=self.n_embd // self.n_head,
1375
- dtype=x.dtype,
1376
- device=x.device,
1377
- )
1378
-
1379
-
1380
- q = apply_rope(q, self.rope_cache)
1381
- k = apply_rope(k, self.rope_cache)
1382
-
1383
-
1384
-
1385
- # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
1386
- # att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1387
- # att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
1388
- # att = F.softmax(att, dim=-1)
1389
- # y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
1390
-
1391
- # efficient attention using Flash Attention CUDA kernels
1392
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True, scale=self.scale.item())
1393
-
1394
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1395
-
1396
- # output projection
1397
- y = self.c_proj(y)
1398
-
1399
-
1400
- if use_cache:
1401
- return y, present
1402
- return y
1403
-
1404
- def forward_attn(self, x: torch.Tensor, last_past=None, use_cache=False) -> torch.Tensor:
1405
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1406
-
1407
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1408
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1409
-
1410
- head_size = C // self.n_head
1411
- k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1412
- q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1413
- v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1414
-
1415
- # kv_cache
1416
- if use_cache:
1417
- if last_past is not None:
1418
- past_key, past_value = last_past
1419
- k = torch.cat([past_key, k], dim=-2)
1420
- v = torch.cat([past_value, v], dim=-2)
1421
- # else:
1422
- # key_states = k
1423
- # value_states = v
1424
-
1425
- if use_cache:
1426
- present = (k, v)
1427
- else:
1428
- present = None
1429
-
1430
- # QK-Norm
1431
- q = F.normalize(q, p=2, dim=-1)
1432
- k = F.normalize(k, p=2, dim=-1)
1433
-
1434
- if self.rope_cache is None:
1435
- # cache for future forward calls
1436
- self.rope_cache = build_rope_cache(
1437
- seq_len=self.block_size,
1438
- n_elem=self.n_embd // self.n_head,
1439
- dtype=x.dtype,
1440
- device=x.device,
1441
- )
1442
-
1443
-
1444
- q = apply_rope(q, self.rope_cache)
1445
- k = apply_rope(k, self.rope_cache)
1446
-
1447
-
1448
- att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
1449
- att = F.softmax(att, dim=-1) # [B, n_head, T, T]
1450
-
1451
- # efficient attention using Flash Attention CUDA kernels
1452
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=True)
1453
- y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
1454
-
1455
- # output projection
1456
- y = self.c_proj(y)
1457
-
1458
- return y, att
1459
-
1460
- class LengthCausalSelfAttention(nn.Module):
1461
- def __init__(self, config: LLaMAHFConfig) -> None:
1462
- super().__init__()
1463
- assert config.n_embd % config.n_head == 0
1464
-
1465
- # key, query, value projections for all heads, but in a batch
1466
- self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=False)
1467
- # output projection
1468
- self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=False)
1469
-
1470
- self.n_head = config.n_head
1471
- self.n_embd = config.n_embd
1472
- self.block_size = config.block_size
1473
- self.rope_cache = None
1474
-
1475
- def forward(self, x: torch.Tensor, y_mask: torch.Tensor) -> torch.Tensor:
1476
- B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
1477
-
1478
- # calculate query, key, values for all heads in batch and move head forward to be the batch dim
1479
- q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
1480
-
1481
- head_size = C // self.n_head
1482
- k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1483
- q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1484
- v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
1485
-
1486
- if self.rope_cache is None:
1487
- # cache for future forward calls
1488
- self.rope_cache = build_rope_cache(
1489
- seq_len=self.block_size,
1490
- n_elem=self.n_embd // self.n_head,
1491
- dtype=x.dtype,
1492
- device=x.device,
1493
- )
1494
-
1495
-
1496
- # q: 1, 16, 40 ,64
1497
- # q: 128, 16, 106, 64
1498
- q = apply_rope(q, self.rope_cache)
1499
- k = apply_rope(k, self.rope_cache)
1500
-
1501
- attn_mask = torch.ones(T, T, dtype=torch.bool, device=x.device)
1502
- attn_mask = torch.tril(attn_mask)
1503
- attn_mask = attn_mask.unsqueeze(0).expand(B, -1, -1)
1504
-
1505
- text_mask = y_mask.unsqueeze(2)*y_mask.unsqueeze(1)
1506
- text_mask = F.pad(text_mask, (0, T-y_mask.shape[1], 0, T-y_mask.shape[1]), mode='constant', value=0)
1507
- attn_mask = torch.logical_or(attn_mask, text_mask)
1508
-
1509
- y = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask.unsqueeze(1), dropout_p=0.0, is_causal=False)
1510
-
1511
- y = y.transpose(1, 2).contiguous().view(B, T, C)
1512
-
1513
-
1514
- y = self.c_proj(y)
1515
-
1516
- return y
1517
-
1518
-
1519
- class MLP(nn.Module):
1520
- def __init__(self, config: LLaMAHFConfig) -> None:
1521
- super().__init__()
1522
- hidden_dim = 4 * config.n_embd
1523
- n_hidden = int(2 * hidden_dim / 3)
1524
- N = 256
1525
- # ensure n_hidden is multiple of N
1526
- n_hidden = ((n_hidden - 1) // N) * N + N
1527
-
1528
- self.c_fc1 = nn.Linear(config.n_embd, n_hidden, bias=False)
1529
- self.c_fc2 = nn.Linear(config.n_embd, n_hidden, bias=False)
1530
- self.c_proj = nn.Linear(n_hidden, config.n_embd, bias=False)
1531
-
1532
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1533
-
1534
- x = F.silu(self.c_fc1(x)) * self.c_fc2(x)
1535
- x = self.c_proj(x)
1536
- return x
1537
-
1538
-
1539
- class RMSNorm(nn.Module):
1540
- """Root Mean Square Layer Normalization.
1541
-
1542
- Derived from https://github.com/bzhangGo/rmsnorm/blob/master/rmsnorm_torch.py. BSD 3-Clause License:
1543
- https://github.com/bzhangGo/rmsnorm/blob/master/LICENSE.
1544
- """
1545
-
1546
- def __init__(self, size: int, dim: int = -1, eps: float = 1e-5) -> None:
1547
- super().__init__()
1548
- self.scale = nn.Parameter(torch.ones(size))
1549
- self.eps = eps
1550
- self.dim = dim
1551
-
1552
- def forward(self, x: torch.Tensor) -> torch.Tensor:
1553
- # NOTE: the original RMSNorm paper implementation is not equivalent
1554
- # norm_x = x.norm(2, dim=self.dim, keepdim=True)
1555
- # rms_x = norm_x * d_x ** (-1. / 2)
1556
- # x_normed = x / (rms_x + self.eps)
1557
- norm_x = torch.mean(x * x, dim=self.dim, keepdim=True)
1558
- x_normed = x * torch.rsqrt(norm_x + self.eps)
1559
- return self.scale * x_normed
1560
-
1561
-
1562
- def build_rope_cache(seq_len: int, n_elem: int, dtype: torch.dtype, device: torch.device, base: int = 10000) -> torch.Tensor:
1563
- """Enhanced Transformer with Rotary Position Embedding.
1564
-
1565
- Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
1566
- transformers/rope/__init__.py. MIT License:
1567
- https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
1568
- """
1569
- # $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
1570
- theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, dtype=dtype, device=device) / n_elem))
1571
-
1572
- # Create position indexes `[0, 1, ..., seq_len - 1]`
1573
- seq_idx = torch.arange(seq_len, dtype=dtype, device=device)
1574
-
1575
- # Calculate the product of position index and $\theta_i$
1576
- idx_theta = torch.outer(seq_idx, theta)
1577
-
1578
- # Compute cache. Because polar only takes float32 or float64, we need to cast
1579
- # when working with 16 bit floats (float16 or bfloat16)
1580
- dtypes_requiring_casting = [torch.float16, torch.bfloat16, torch.int8]
1581
- working_dtype = (
1582
- torch.float32 if dtype in dtypes_requiring_casting else dtype
1583
- )
1584
- complex_dtype = (
1585
- torch.complex32 if dtype in dtypes_requiring_casting else torch.complex64
1586
- )
1587
- cache = torch.polar(
1588
- torch.ones_like(idx_theta).to(working_dtype), idx_theta.to(working_dtype)
1589
- ).to(complex_dtype)
1590
- return cache
1591
-
1592
-
1593
- def apply_rope(x: torch.Tensor, rope_cache: torch.Tensor) -> torch.Tensor:
1594
- x = x.transpose(1, 2)
1595
-
1596
- # truncate to support variable sizes
1597
- T = x.size(1)
1598
- rope_cache = rope_cache[:T]
1599
- # cast because `view_as_complex` does not support 16 bit tensors
1600
- xc = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
1601
- rope_cache = rope_cache.view(1, xc.size(1), 1, xc.size(3))
1602
- x_out = torch.view_as_real(xc * rope_cache).flatten(3)
1603
- return x_out.transpose(1, 2).type_as(x)