soiz1 commited on
Commit
37cfcaa
·
verified ·
1 Parent(s): da18641

Create vallex.py

Browse files
Files changed (1) hide show
  1. models/vallex.py +851 -0
models/vallex.py ADDED
@@ -0,0 +1,851 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 (authors: Feiteng Li)
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import random
16
+ from typing import Dict, Iterator, List, Tuple, Union
17
+ import gc
18
+
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ # from icefall.utils import make_pad_mask
24
+ # from torchmetrics.classification import MulticlassAccuracy
25
+
26
+ from modules.embedding import SinePositionalEmbedding, TokenEmbedding
27
+ from modules.transformer import (
28
+ AdaptiveLayerNorm,
29
+ LayerNorm,
30
+ TransformerDecoderLayer,
31
+ TransformerEncoder,
32
+ TransformerEncoderLayer,
33
+ )
34
+
35
+ from .macros import NUM_AUDIO_TOKENS, NUM_TEXT_TOKENS
36
+
37
+ import psutil
38
+ def get_memory_usage():
39
+ process = psutil.Process()
40
+ memory_info = process.memory_info()
41
+
42
+ memory_used = memory_info.rss
43
+ memory_used_mb = memory_used / (1024 * 1024)
44
+
45
+ return memory_used_mb
46
+
47
+ class Transpose(nn.Identity):
48
+ """(N, T, D) -> (N, D, T)"""
49
+
50
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
51
+ return input.transpose(1, 2)
52
+
53
+
54
+ # NOTE: There are two ways to implement the model
55
+ # 1) [VALL-F] standard TransformerDecoder, use x as memory
56
+ # 2) [VALL-E] modified TransformerDecoder like GPT-x(e.g. causal TransformerEncoder),
57
+ # use x as the prefix of decoder inputs
58
+ class VALLF(nn.Module):
59
+ """It implements https://arxiv.org/abs/2301.02111
60
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ d_model: int,
66
+ nhead: int,
67
+ num_layers: int,
68
+ norm_first: bool = True,
69
+ add_prenet: bool = False,
70
+ decoder_cls: Union[
71
+ nn.TransformerDecoder, nn.TransformerEncoder
72
+ ] = nn.TransformerDecoder,
73
+ decoder_layer_cls: Union[
74
+ TransformerDecoderLayer, TransformerEncoderLayer
75
+ ] = TransformerDecoderLayer,
76
+ prefix_mode: int = 0,
77
+ share_embedding: bool = True,
78
+ nar_scale_factor: float = 1.0,
79
+ prepend_bos: bool = True,
80
+ num_quantizers: int = 8,
81
+ ):
82
+ """
83
+ Args:
84
+ d_model:
85
+ The number of expected features in the input (required).
86
+ nhead:
87
+ The number of heads in the multiheadattention models (required).
88
+ num_layers:
89
+ The number of sub-decoder-layers in the decoder (required).
90
+ """
91
+ super().__init__()
92
+ nar_d_model = int(d_model * nar_scale_factor)
93
+
94
+ self.ar_text_embedding = TokenEmbedding(d_model, NUM_TEXT_TOKENS) # W_x
95
+ self.nar_text_embedding = TokenEmbedding(nar_d_model, NUM_TEXT_TOKENS)
96
+
97
+ # ID NUM_AUDIO_TOKENS -> PAD
98
+ # ID NUM_AUDIO_TOKENS + 1 -> BOS
99
+ self.ar_audio_prepend_bos = prepend_bos
100
+ self.ar_audio_embedding = TokenEmbedding(
101
+ d_model, NUM_AUDIO_TOKENS + 1 + int(prepend_bos)
102
+ )
103
+
104
+ # PreNet
105
+ if add_prenet:
106
+ self.ar_text_prenet = nn.Sequential(
107
+ Transpose(),
108
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
109
+ nn.BatchNorm1d(d_model),
110
+ nn.ReLU(),
111
+ nn.Dropout(0.5),
112
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
113
+ nn.BatchNorm1d(d_model),
114
+ nn.ReLU(),
115
+ nn.Dropout(0.5),
116
+ nn.Conv1d(d_model, d_model, kernel_size=5, padding="same"),
117
+ nn.BatchNorm1d(d_model),
118
+ nn.ReLU(),
119
+ nn.Dropout(0.5),
120
+ Transpose(),
121
+ nn.Linear(d_model, d_model),
122
+ )
123
+
124
+ self.ar_audio_prenet = nn.Sequential(
125
+ nn.Linear(d_model, 256),
126
+ nn.ReLU(),
127
+ nn.Dropout(0.25),
128
+ nn.Linear(256, 256),
129
+ nn.ReLU(),
130
+ nn.Dropout(0.25),
131
+ nn.Linear(256, d_model),
132
+ )
133
+ else:
134
+ self.ar_text_prenet = nn.Identity()
135
+ self.ar_audio_prenet = nn.Identity()
136
+
137
+ self.ar_text_position = SinePositionalEmbedding(
138
+ d_model,
139
+ dropout=0.1,
140
+ scale=False,
141
+ alpha=True,
142
+ )
143
+ self.ar_audio_position = SinePositionalEmbedding(
144
+ d_model,
145
+ dropout=0.1,
146
+ scale=False,
147
+ alpha=True,
148
+ )
149
+
150
+ self.ar_decoder = decoder_cls(
151
+ decoder_layer_cls(
152
+ d_model,
153
+ nhead,
154
+ dim_feedforward=d_model * 4,
155
+ dropout=0.1,
156
+ batch_first=True,
157
+ norm_first=norm_first,
158
+ ),
159
+ num_layers=num_layers,
160
+ norm=LayerNorm(d_model) if norm_first else None,
161
+ )
162
+ self.ar_predict_layer = nn.Linear(
163
+ d_model, NUM_AUDIO_TOKENS + 1, bias=False
164
+ )
165
+
166
+ self.rng = random.Random(0)
167
+ self.num_heads = nhead
168
+ self.prefix_mode = prefix_mode
169
+ self.num_quantizers = num_quantizers
170
+
171
+ assert num_quantizers >= 1
172
+ if num_quantizers > 1:
173
+ self.nar_audio_embeddings = nn.ModuleList(
174
+ [TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS + 1)]
175
+ + [
176
+ TokenEmbedding(nar_d_model, NUM_AUDIO_TOKENS)
177
+ for i in range(num_quantizers - 1)
178
+ ]
179
+ ) # W_a
180
+
181
+ # PreNet
182
+ if add_prenet:
183
+ self.nar_text_prenet = nn.Sequential(
184
+ Transpose(),
185
+ nn.Conv1d(
186
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
187
+ ),
188
+ nn.BatchNorm1d(nar_d_model),
189
+ nn.ReLU(),
190
+ nn.Dropout(0.5),
191
+ nn.Conv1d(
192
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
193
+ ),
194
+ nn.BatchNorm1d(nar_d_model),
195
+ nn.ReLU(),
196
+ nn.Dropout(0.5),
197
+ nn.Conv1d(
198
+ nar_d_model, nar_d_model, kernel_size=5, padding="same"
199
+ ),
200
+ nn.BatchNorm1d(nar_d_model),
201
+ nn.ReLU(),
202
+ nn.Dropout(0.5),
203
+ Transpose(),
204
+ nn.Linear(nar_d_model, nar_d_model),
205
+ )
206
+ self.nar_audio_prenet = nn.Sequential(
207
+ nn.Linear(nar_d_model, 256),
208
+ nn.ReLU(),
209
+ nn.Dropout(0.25),
210
+ nn.Linear(256, 256),
211
+ nn.ReLU(),
212
+ nn.Dropout(0.25),
213
+ nn.Linear(256, nar_d_model),
214
+ )
215
+ else:
216
+ self.nar_text_prenet = nn.Identity()
217
+ self.nar_audio_prenet = nn.Identity()
218
+
219
+ self.nar_text_position = SinePositionalEmbedding(
220
+ nar_d_model,
221
+ dropout=0.0,
222
+ scale=False,
223
+ alpha=False,
224
+ )
225
+ self.nar_audio_position = SinePositionalEmbedding(
226
+ nar_d_model,
227
+ dropout=0.1,
228
+ scale=False,
229
+ alpha=False,
230
+ )
231
+
232
+ self.nar_decoder = decoder_cls(
233
+ decoder_layer_cls(
234
+ nar_d_model,
235
+ int(nhead * nar_scale_factor),
236
+ dim_feedforward=nar_d_model * 4,
237
+ dropout=0.1,
238
+ batch_first=True,
239
+ norm_first=norm_first,
240
+ adaptive_layer_norm=True,
241
+ ),
242
+ num_layers=int(num_layers * nar_scale_factor),
243
+ norm=AdaptiveLayerNorm(
244
+ nar_d_model, norm=nn.LayerNorm(nar_d_model)
245
+ )
246
+ if norm_first
247
+ else None,
248
+ )
249
+ self.nar_predict_layers = nn.ModuleList(
250
+ [
251
+ nn.Linear(nar_d_model, NUM_AUDIO_TOKENS, bias=False)
252
+ for i in range(num_quantizers - 1)
253
+ ]
254
+ )
255
+ self.nar_stage_embeddings = nn.ModuleList(
256
+ [
257
+ TokenEmbedding(nar_d_model, 1)
258
+ for i in range(num_quantizers - 1)
259
+ ]
260
+ )
261
+
262
+ if share_embedding:
263
+ # We share the parameters of the output projection layer with the parameters of the acoustic embedding Wa
264
+ # NOTE(Feiteng): In the experiment, this undermines accuracy
265
+ # self.ar_predict_layer.weight = self.ar_audio_embedding.weight
266
+
267
+ # We also share the parameters of the acoustic embedding layer and the output prediction layer,
268
+ # which means the weights of the j-th prediction layer are the same as the (j + 1)-th acoustic embedding layer.
269
+ for j in range(0, num_quantizers - 2):
270
+ self.nar_predict_layers[
271
+ j
272
+ ].weight = self.nar_audio_embeddings[j + 2].weight
273
+
274
+ def stage_parameters(self, stage: int = 1) -> Iterator[nn.Parameter]:
275
+ assert stage > 0
276
+ if stage == 1:
277
+ for name, param in self.named_parameters():
278
+ if name.startswith("ar_"):
279
+ print(f" AR parameter: {name}")
280
+ yield param
281
+
282
+ if stage == 2:
283
+ for name, param in self.named_parameters():
284
+ if name.startswith("nar_"):
285
+ print(f"NAR parameter: {name}")
286
+ yield param
287
+
288
+ def stage_named_parameters(
289
+ self, stage: int = 1
290
+ ) -> Iterator[Tuple[str, nn.Parameter]]:
291
+ assert stage > 0
292
+ if stage == 1:
293
+ for pair in self.named_parameters():
294
+ if pair[0].startswith("ar_"):
295
+ yield pair
296
+
297
+ if stage == 2:
298
+ for pair in self.named_parameters():
299
+ if pair[0].startswith("nar_"):
300
+ yield pair
301
+
302
+ def pad_y_eos(self, y, y_mask_int, eos_id):
303
+ targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
304
+ y_mask_int, (0, 1), value=1
305
+ )
306
+ # inputs, targets
307
+ if self.ar_audio_prepend_bos:
308
+ return (
309
+ F.pad(targets[:, :-1], (1, 0), value=NUM_AUDIO_TOKENS + 1),
310
+ targets,
311
+ )
312
+
313
+ return targets[:, :-1], targets[:, 1:]
314
+
315
+ def _prepare_prompts(self, y, y_lens, codes, nar_stage, y_prompts_codes, prefix_mode):
316
+ # 5.1 For the NAR acoustic prompt tokens, we select a random segment waveform of 3 seconds
317
+ # from the same utterance.
318
+ # We implement this differently.
319
+ if prefix_mode == 0:
320
+ # no prefix
321
+ prefix_len = 0
322
+ y_emb = self.nar_audio_embeddings[0](y)
323
+ for j in range(1, nar_stage):
324
+ # Formula (4) (5)
325
+ y_emb = y_emb + self.nar_audio_embeddings[j](codes[..., j])
326
+ elif prefix_mode == 1:
327
+ # prefix at begining
328
+ int_low = (0.25 * y_lens.min()).type(torch.int64).item()
329
+ prefix_len = torch.randint(0, int_low * 2, size=()).item()
330
+ prefix_len = min(prefix_len, 225) # 24000/320 * 3s = 225 frames
331
+
332
+ y_prompts = self.nar_audio_embeddings[0](y[:, :prefix_len])
333
+ y_emb = self.nar_audio_embeddings[0](y[:, prefix_len:])
334
+ for j in range(1, self.num_quantizers):
335
+ y_prompts += self.nar_audio_embeddings[j](
336
+ codes[:, :prefix_len, j]
337
+ )
338
+ if j < nar_stage:
339
+ y_emb += self.nar_audio_embeddings[j](
340
+ codes[:, prefix_len:, j]
341
+ )
342
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
343
+ elif prefix_mode in [2, 4]:
344
+ if prefix_mode == 2:
345
+ # random prefix
346
+ prefix_len = min(225, int(0.25 * y_lens.min().item()))
347
+
348
+ y_prompts_codes = []
349
+ for b in range(codes.shape[0]):
350
+ start = self.rng.randint(0, y_lens[b].item() - prefix_len)
351
+ y_prompts_codes.append(
352
+ torch.clone(codes[b, start : start + prefix_len])
353
+ )
354
+ codes[
355
+ b, start : start + prefix_len, nar_stage
356
+ ] = NUM_AUDIO_TOKENS
357
+ y_prompts_codes = torch.stack(y_prompts_codes, dim=0)
358
+ else:
359
+ prefix_len = y_prompts_codes.shape[1]
360
+
361
+ y_prompts = self.nar_audio_embeddings[0](y_prompts_codes[..., 0])
362
+ y_emb = self.nar_audio_embeddings[0](y)
363
+ for j in range(1, self.num_quantizers):
364
+ y_prompts += self.nar_audio_embeddings[j](
365
+ y_prompts_codes[..., j]
366
+ )
367
+ if j < nar_stage:
368
+ y_emb += self.nar_audio_embeddings[j](codes[..., j])
369
+ y_emb = torch.concat([y_prompts, y_emb], axis=1)
370
+ else:
371
+ raise ValueError
372
+
373
+ return y_emb, prefix_len
374
+
375
+ def forward(
376
+ self,
377
+ x: torch.Tensor,
378
+ x_lens: torch.Tensor,
379
+ y: Union[torch.Tensor],
380
+ y_lens: Union[torch.Tensor],
381
+ reduction: str = "sum",
382
+ train_stage: int = 0,
383
+ **kwargs,
384
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
385
+ raise NotImplementedError
386
+
387
+ def inference(
388
+ self,
389
+ x: torch.Tensor,
390
+ x_lens: torch.Tensor,
391
+ y: torch.Tensor,
392
+ enroll_x_lens: Union[torch.Tensor, None] = None,
393
+ top_k: int = -100,
394
+ temperature: float = 1.0,
395
+ ) -> torch.Tensor:
396
+ raise NotImplementedError
397
+
398
+ def visualize(
399
+ self,
400
+ predicts: Tuple[torch.Tensor],
401
+ batch: Dict[str, Union[List, torch.Tensor]],
402
+ output_dir: str,
403
+ limit: int = 4,
404
+ ) -> None:
405
+ raise NotImplementedError
406
+
407
+
408
+ class VALLE(VALLF):
409
+ """It implements https://arxiv.org/abs/2301.02111
410
+ "Neural Codec Language Models are Zero-Shot Text to Speech Synthesizers"
411
+ """
412
+
413
+ def __init__(
414
+ self,
415
+ d_model: int,
416
+ nhead: int,
417
+ num_layers: int,
418
+ norm_first: bool = True,
419
+ add_prenet: bool = False,
420
+ prefix_mode: int = 0,
421
+ share_embedding: bool = True,
422
+ nar_scale_factor: float = 1.0,
423
+ **kwargs,
424
+ ):
425
+ """
426
+ Args:
427
+ d_model:
428
+ The number of expected features in the input (required).
429
+ nhead:
430
+ The number of heads in the multiheadattention models (required).
431
+ num_layers:
432
+ The number of sub-decoder-layers in the decoder (required).
433
+ """
434
+ super(VALLE, self).__init__(
435
+ d_model,
436
+ nhead,
437
+ num_layers,
438
+ norm_first=norm_first,
439
+ add_prenet=add_prenet,
440
+ decoder_cls=TransformerEncoder,
441
+ decoder_layer_cls=TransformerEncoderLayer,
442
+ prefix_mode=prefix_mode,
443
+ share_embedding=share_embedding,
444
+ nar_scale_factor=nar_scale_factor,
445
+ **kwargs,
446
+ )
447
+ self.language_ID = {
448
+ 'en': 0,
449
+ 'zh': 1,
450
+ 'ja': 2,
451
+ }
452
+ self.ar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
453
+ self.nar_language_embedding = TokenEmbedding(d_model, len(self.language_ID))
454
+
455
+ def forward(
456
+ self,
457
+ x: torch.Tensor,
458
+ x_lens: torch.Tensor,
459
+ y: Union[torch.Tensor],
460
+ y_lens: Union[torch.Tensor],
461
+ reduction: str = "sum",
462
+ train_stage: int = 0,
463
+ **kwargs,
464
+ ):
465
+ raise NotImplementedError
466
+
467
+ def inference(
468
+ self,
469
+ x: torch.Tensor,
470
+ x_lens: torch.Tensor,
471
+ y: torch.Tensor,
472
+ enroll_x_lens: torch.Tensor,
473
+ top_k: int = -100,
474
+ temperature: float = 1.0,
475
+ prompt_language: str = None,
476
+ text_language: str = None,
477
+ ) -> torch.Tensor:
478
+ """
479
+ Args:
480
+ x:
481
+ A 2-D tensor of shape (1, S).
482
+ x_lens:
483
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
484
+ before padding.
485
+ y:
486
+ A 3-D tensor of shape (1, T, 8).
487
+ top_k: (`optional`) int
488
+ The number of highest probability tokens to keep for top-k-filtering. Default to -100.
489
+ temperature: (`optional`) float
490
+ The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
491
+ Returns:
492
+ Return the predicted audio code matrix.
493
+ """
494
+ assert x.ndim == 2, x.shape
495
+ assert x_lens.ndim == 1, x_lens.shape
496
+ assert y.ndim == 3, y.shape
497
+ assert y.shape[0] == 1, y.shape
498
+
499
+ assert torch.all(x_lens > 0)
500
+
501
+ # NOTE: x has been padded in TextTokenCollater
502
+ text = x
503
+ x = self.ar_text_embedding(text)
504
+ # Add language embedding
505
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
506
+ if isinstance(text_language, str):
507
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
508
+ elif isinstance(text_language, List):
509
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
510
+ x[:, :enroll_x_lens, :] += self.ar_language_embedding(prompt_language_id)
511
+ x[:, enroll_x_lens:, :] += self.ar_language_embedding(text_language_id)
512
+ x = self.ar_text_prenet(x)
513
+ x = self.ar_text_position(x)
514
+
515
+ text_len = x_lens.max()
516
+ prompts = y
517
+ prefix_len = y.shape[1]
518
+
519
+ # AR Decoder
520
+ # TODO: Managing decoder steps avoid repetitive computation
521
+ y = prompts[..., 0]
522
+ if self.ar_audio_prepend_bos:
523
+ y = F.pad(y, (1, 0), value=NUM_AUDIO_TOKENS + 1)
524
+
525
+ x_len = x_lens.max()
526
+ x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
527
+
528
+ kv_cache = None
529
+ use_kv_caching = True
530
+ while True:
531
+ y_emb = self.ar_audio_embedding(y)
532
+ y_emb = self.ar_audio_prenet(y_emb)
533
+ y_pos = self.ar_audio_position(y_emb)
534
+ xy_pos = torch.concat([x, y_pos], dim=1)
535
+
536
+ y_len = y.shape[1]
537
+ x_attn_mask_pad = F.pad(
538
+ x_attn_mask,
539
+ (0, y_len),
540
+ value=True,
541
+ )
542
+ y_attn_mask = F.pad(
543
+ torch.triu(
544
+ torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1
545
+ ),
546
+ (x_len, 0),
547
+ value=False,
548
+ )
549
+ xy_attn_mask = torch.concat(
550
+ [x_attn_mask_pad, y_attn_mask], dim=0
551
+ ).to(y.device)
552
+
553
+
554
+ if use_kv_caching and kv_cache is not None:
555
+ xy_pos = xy_pos[:, [-1]]
556
+ else:
557
+ pass
558
+
559
+ xy_dec, kv_cache = self.ar_decoder.infer(
560
+ xy_pos,
561
+ mask=xy_attn_mask,
562
+ past_kv=kv_cache,
563
+ use_cache=use_kv_caching,
564
+ )
565
+ # xy_dec, _ = self.ar_decoder(
566
+ # (xy_pos, None),
567
+ # mask=xy_attn_mask,
568
+ # )
569
+
570
+ logits = self.ar_predict_layer(xy_dec[:, -1])
571
+ samples = topk_sampling(
572
+ logits, top_k=top_k, top_p=1, temperature=temperature
573
+ )
574
+
575
+ if (
576
+ torch.argmax(logits, dim=-1)[0] == NUM_AUDIO_TOKENS
577
+ or samples[0, 0] == NUM_AUDIO_TOKENS
578
+ or (y.shape[1] - prompts.shape[1]) > x_lens.max() * 16
579
+ ):
580
+ if prompts.shape[1] == y.shape[1]:
581
+ raise SyntaxError(
582
+ "well trained model shouldn't reach here."
583
+ )
584
+
585
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
586
+
587
+ memory_used = get_memory_usage()
588
+ print(f"Current memory used: {memory_used:.2f} MB")
589
+ break
590
+
591
+ # safety measure, break if token sequence too long
592
+ if y.shape[1] > 2250:
593
+ print(f"VALL-E EOS [{prompts.shape[1]} -> {y.shape[1]}]")
594
+ break
595
+
596
+ y = torch.concat([y, samples], dim=1)
597
+
598
+ codes = [y[:, prefix_len + int(self.ar_audio_prepend_bos) :]]
599
+ if self.num_quantizers == 1:
600
+ return torch.stack(codes, dim=-1)
601
+
602
+ # Non-AR Decoders
603
+ y_emb = self.nar_audio_embeddings[0](
604
+ y[:, int(self.ar_audio_prepend_bos) :]
605
+ )
606
+
607
+ if self.prefix_mode in [2, 4]: # Exclude enrolled_phonemes
608
+ enrolled_len = enroll_x_lens.max().item()
609
+ # SOS + Synthesis Text + EOS
610
+ text = torch.concat(
611
+ [
612
+ text[:, :1],
613
+ text[:, enrolled_len - 1 :],
614
+ ],
615
+ dim=1,
616
+ )
617
+ text_len = text_len - (enrolled_len - 2)
618
+ assert text.shape[0] == 1
619
+
620
+ x = self.nar_text_embedding(text)
621
+ # Add language embedding
622
+ prompt_language_id = torch.LongTensor(np.array([self.language_ID[prompt_language]])).to(x.device)
623
+ if isinstance(text_language, str):
624
+ text_language_id = torch.LongTensor(np.array([self.language_ID[text_language]])).to(x.device)
625
+ elif isinstance(text_language, List):
626
+ text_language_id = torch.LongTensor(np.array([self.language_ID[tl] for tl in text_language])).to(x.device)
627
+ x[:, :enroll_x_lens, :] += self.nar_language_embedding(prompt_language_id)
628
+ x[:, enroll_x_lens:, :] += self.nar_language_embedding(text_language_id)
629
+ x = self.nar_text_prenet(x)
630
+ x = self.nar_text_position(x)
631
+
632
+ if self.prefix_mode == 0:
633
+ for i, (predict_layer, embedding_layer) in enumerate(
634
+ zip(
635
+ self.nar_predict_layers,
636
+ self.nar_audio_embeddings[1:],
637
+ )
638
+ ):
639
+ y_pos = self.nar_audio_prenet(y_emb)
640
+ y_pos = self.nar_audio_position(y_pos)
641
+ xy_pos = torch.concat([x, y_pos], dim=1)
642
+
643
+ xy_dec, _ = self.nar_decoder(
644
+ (xy_pos, self.nar_stage_embeddings[i].weight)
645
+ )
646
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
647
+
648
+ samples = torch.argmax(logits, dim=-1)
649
+ codes.append(samples)
650
+
651
+ if i < self.num_quantizers - 2:
652
+ y_emb[:, :prefix_len] += embedding_layer(
653
+ prompts[..., i + 1]
654
+ )
655
+ y_emb[:, prefix_len:] += embedding_layer(samples)
656
+ else:
657
+ for j in range(1, self.num_quantizers):
658
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
659
+ prompts[..., j]
660
+ )
661
+
662
+ for i, (predict_layer, embedding_layer) in enumerate(
663
+ zip(
664
+ self.nar_predict_layers,
665
+ self.nar_audio_embeddings[1:],
666
+ )
667
+ ):
668
+ y_pos = self.nar_audio_prenet(y_emb)
669
+ y_pos = self.nar_audio_position(y_pos)
670
+ xy_pos = torch.concat([x, y_pos], dim=1)
671
+
672
+ xy_dec, _ = self.nar_decoder(
673
+ (xy_pos, self.nar_stage_embeddings[i].weight)
674
+ )
675
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
676
+
677
+ samples = torch.argmax(logits, dim=-1)
678
+ codes.append(samples)
679
+
680
+ if i < self.num_quantizers - 2:
681
+ y_emb[:, prefix_len:] += embedding_layer(samples)
682
+
683
+ assert len(codes) == self.num_quantizers
684
+ del text_language_id, prompt_language_id, y_emb, x, y_pos, xy_pos, xy_dec, logits, samples, kv_cache, x_attn_mask, y_attn_mask, xy_attn_mask
685
+ gc.collect()
686
+ return torch.stack(codes, dim=-1)
687
+
688
+ def continual(
689
+ self,
690
+ x: torch.Tensor,
691
+ x_lens: torch.Tensor,
692
+ y: torch.Tensor,
693
+ ) -> torch.Tensor:
694
+ """
695
+ Args:
696
+ x:
697
+ A 2-D tensor of shape (1, S).
698
+ x_lens:
699
+ A 1-D tensor of shape (1,). It contains the number of tokens in `x`
700
+ before padding.
701
+ y:
702
+ A 3-D tensor of shape (1, T, 8).
703
+ Returns:
704
+ Return the predicted audio code matrix.
705
+ """
706
+ assert x.ndim == 2, x.shape
707
+ assert x_lens.ndim == 1, x_lens.shape
708
+ assert y.ndim == 3, y.shape
709
+ assert y.shape[0] == 1, y.shape
710
+
711
+ assert torch.all(x_lens > 0)
712
+ assert self.num_quantizers == 8
713
+
714
+ # NOTE: x has been padded in TextTokenCollater
715
+ text = x
716
+ x = self.ar_text_embedding(text)
717
+ x = self.ar_text_prenet(x)
718
+ x = self.ar_text_position(x)
719
+
720
+ text_len = x_lens.max()
721
+
722
+ prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
723
+
724
+ # AR Decoder
725
+ prompts = y[:, :prefix_len]
726
+
727
+ codes = [y[:, prefix_len:, 0]]
728
+ # Non-AR Decoders
729
+ x = self.nar_text_embedding(text)
730
+ x = self.nar_text_prenet(x)
731
+ x = self.nar_text_position(x)
732
+
733
+ y_emb = self.nar_audio_embeddings[0](y[..., 0])
734
+
735
+ if self.prefix_mode == 0:
736
+ for i, (predict_layer, embedding_layer) in enumerate(
737
+ zip(
738
+ self.nar_predict_layers,
739
+ self.nar_audio_embeddings[1:],
740
+ )
741
+ ):
742
+ y_pos = self.nar_audio_position(y_emb)
743
+ y_pos = self.nar_audio_prenet(y_pos)
744
+ xy_pos = torch.concat([x, y_pos], dim=1)
745
+
746
+ xy_dec, _ = self.nar_decoder(
747
+ (xy_pos, self.nar_stage_embeddings[i].weight)
748
+ )
749
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
750
+
751
+ samples = torch.argmax(logits, dim=-1)
752
+ codes.append(samples)
753
+
754
+ if i < 6:
755
+ y_emb[:, :prefix_len] += embedding_layer(
756
+ prompts[..., i + 1]
757
+ )
758
+ y_emb[:, prefix_len:] += embedding_layer(samples)
759
+ else:
760
+ for j in range(1, 8):
761
+ y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](
762
+ prompts[..., j]
763
+ )
764
+
765
+ for i, (predict_layer, embedding_layer) in enumerate(
766
+ zip(
767
+ self.nar_predict_layers,
768
+ self.nar_audio_embeddings[1:],
769
+ )
770
+ ):
771
+ y_pos = self.nar_audio_prenet(y_emb)
772
+ y_pos = self.nar_audio_position(y_pos)
773
+ xy_pos = torch.concat([x, y_pos], dim=1)
774
+
775
+ xy_dec, _ = self.nar_decoder(
776
+ (xy_pos, self.nar_stage_embeddings[i].weight)
777
+ )
778
+ logits = predict_layer(xy_dec[:, text_len + prefix_len :])
779
+
780
+ samples = torch.argmax(logits, dim=-1)
781
+ codes.append(samples)
782
+
783
+ if i < 6:
784
+ y_emb[:, prefix_len:] += embedding_layer(samples)
785
+
786
+ assert len(codes) == 8
787
+ return torch.stack(codes, dim=-1)
788
+
789
+
790
+ # https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
791
+ def top_k_top_p_filtering(
792
+ logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
793
+ ):
794
+ """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
795
+ Args:
796
+ logits: logits distribution shape (batch size, vocabulary size)
797
+ if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
798
+ if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
799
+ Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
800
+ Make sure we keep at least min_tokens_to_keep per batch example in the output
801
+ From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
802
+ """
803
+ if top_k > 0:
804
+ top_k = min(
805
+ max(top_k, min_tokens_to_keep), logits.size(-1)
806
+ ) # Safety check
807
+ # Remove all tokens with a probability less than the last token of the top-k
808
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
809
+ logits[indices_to_remove] = filter_value
810
+
811
+ if top_p < 1.0:
812
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
813
+ cumulative_probs = torch.cumsum(
814
+ F.softmax(sorted_logits, dim=-1), dim=-1
815
+ )
816
+
817
+ # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
818
+ sorted_indices_to_remove = cumulative_probs > top_p
819
+ if min_tokens_to_keep > 1:
820
+ # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
821
+ sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
822
+ # Shift the indices to the right to keep also the first token above the threshold
823
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[
824
+ ..., :-1
825
+ ].clone()
826
+ sorted_indices_to_remove[..., 0] = 0
827
+
828
+ # scatter sorted tensors to original indexing
829
+ indices_to_remove = sorted_indices_to_remove.scatter(
830
+ 1, sorted_indices, sorted_indices_to_remove
831
+ )
832
+ logits[indices_to_remove] = filter_value
833
+ return logits
834
+
835
+
836
+ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
837
+ # temperature: (`optional`) float
838
+ # The value used to module the next token probabilities. Must be strictly positive. Default to 1.0.
839
+ # top_k: (`optional`) int
840
+ # The number of highest probability vocabulary tokens to keep for top-k-filtering. Between 1 and infinity. Default to 50.
841
+ # top_p: (`optional`) float
842
+ # The cumulative probability of parameter highest probability vocabulary tokens to keep for nucleus sampling. Must be between 0 and 1. Default to 1.
843
+
844
+ # Temperature (higher temperature => more likely to sample low probability tokens)
845
+ if temperature != 1.0:
846
+ logits = logits / temperature
847
+ # Top-p/top-k filtering
848
+ logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
849
+ # Sample
850
+ token = torch.multinomial(F.softmax(logits, dim=-1), num_samples=1)
851
+ return token