biosn2 commited on
Commit
6e82fa9
·
verified ·
1 Parent(s): 75467c2

Upload indextts/gpt/model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. indextts/gpt/model.py +708 -0
indextts/gpt/model.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from transformers import GPT2Config, GPT2PreTrainedModel, LogitsProcessorList, GenerationMixin
7
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
8
+ from transformers.utils.model_parallel_utils import (assert_device_map,
9
+ get_device_map)
10
+
11
+ from indextts.gpt.conformer_encoder import ConformerEncoder
12
+ from indextts.gpt.perceiver import PerceiverResampler
13
+ from indextts.utils.arch_util import AttentionBlock
14
+ from indextts.utils.typical_sampling import TypicalLogitsWarper
15
+
16
+
17
+ def null_position_embeddings(range, dim):
18
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
19
+
20
+
21
+ class ResBlock(nn.Module):
22
+ """
23
+ Basic residual convolutional block that uses GroupNorm.
24
+ """
25
+
26
+ def __init__(self, chan):
27
+ super().__init__()
28
+ self.net = nn.Sequential(
29
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
30
+ nn.GroupNorm(chan // 8, chan),
31
+ nn.ReLU(),
32
+ nn.Conv1d(chan, chan, kernel_size=3, padding=1),
33
+ nn.GroupNorm(chan // 8, chan)
34
+ )
35
+
36
+ def forward(self, x):
37
+ return F.relu(self.net(x) + x)
38
+
39
+
40
+ class GPT2InferenceModel(GPT2PreTrainedModel, GenerationMixin):
41
+ def __init__(self, config, gpt, text_pos_emb, embeddings, norm, linear, kv_cache=False):
42
+ super().__init__(config)
43
+ # Note: the argument named `text_pos_emb` here actually represents the mel position embedding
44
+ self.transformer = gpt
45
+ self.text_pos_embedding = text_pos_emb
46
+ self.embeddings = embeddings
47
+ self.final_norm = norm
48
+ self.lm_head = nn.Sequential(norm, linear)
49
+ self.kv_cache = kv_cache
50
+
51
+ # Model parallel
52
+ self.model_parallel = False
53
+ self.device_map = None
54
+ self.cached_mel_emb = None
55
+
56
+ def parallelize(self, device_map=None):
57
+ self.device_map = (
58
+ get_device_map(len(self.transformer.h), range(max(1, torch.cuda.device_count())))
59
+ if device_map is None
60
+ else device_map
61
+ )
62
+ assert_device_map(self.device_map, len(self.transformer.h))
63
+ self.transformer.parallelize(self.device_map)
64
+ self.lm_head = self.lm_head.to(self.transformer.first_device)
65
+ self.model_parallel = True
66
+
67
+ def deparallelize(self):
68
+ self.transformer.deparallelize()
69
+ self.transformer = self.transformer.to("cpu")
70
+ self.lm_head = self.lm_head.to("cpu")
71
+ self.model_parallel = False
72
+ torch.cuda.empty_cache()
73
+ if torch.backends.mps.is_available():
74
+ torch.mps.empty_cache()
75
+
76
+ def get_output_embeddings(self):
77
+ return self.lm_head
78
+
79
+ def set_output_embeddings(self, new_embeddings):
80
+ self.lm_head = new_embeddings
81
+
82
+ def store_mel_emb(self, mel_emb):
83
+ self.cached_mel_emb = mel_emb
84
+
85
+ def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs):
86
+ token_type_ids = kwargs.get("token_type_ids", None) # usually None
87
+ if not self.kv_cache:
88
+ past_key_values = None
89
+ # only last token for inputs_ids if past is defined in kwargs
90
+ if past_key_values:
91
+ input_ids = input_ids[:, -1].unsqueeze(-1)
92
+ if token_type_ids is not None:
93
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
94
+
95
+ attention_mask = kwargs.get("attention_mask", None)
96
+ position_ids = kwargs.get("position_ids", None)
97
+
98
+ if attention_mask is not None and position_ids is None:
99
+ # create position_ids on the fly for batch generation
100
+ position_ids = attention_mask.long().cumsum(-1) - 1
101
+ position_ids.masked_fill_(attention_mask == 0, 0)
102
+ if past_key_values:
103
+ position_ids = position_ids[:, -1].unsqueeze(-1)
104
+ else:
105
+ position_ids = None
106
+ return {
107
+ "input_ids": input_ids,
108
+ "past_key_values": past_key_values,
109
+ "use_cache": kwargs.get("use_cache"),
110
+ "position_ids": position_ids,
111
+ "attention_mask": attention_mask,
112
+ "token_type_ids": token_type_ids,
113
+ }
114
+
115
+ def forward(
116
+ self,
117
+ input_ids=None,
118
+ past_key_values=None,
119
+ attention_mask=None,
120
+ token_type_ids=None,
121
+ position_ids=None,
122
+ head_mask=None,
123
+ inputs_embeds=None,
124
+ encoder_hidden_states=None,
125
+ encoder_attention_mask=None,
126
+ labels=None,
127
+ use_cache=None,
128
+ output_attentions=None,
129
+ output_hidden_states=None,
130
+ return_dict=None,
131
+ ):
132
+ assert self.cached_mel_emb is not None
133
+ assert inputs_embeds is None # Not supported by this inference model.
134
+ assert labels is None # Training not supported by this inference model.
135
+ return_dict = (
136
+ return_dict if return_dict is not None else self.config.use_return_dict
137
+ )
138
+ # Create embedding
139
+ mel_len = self.cached_mel_emb.shape[1]
140
+ if input_ids.shape[1] != 1:
141
+ text_inputs = input_ids[:, mel_len:]
142
+ text_emb = self.embeddings(text_inputs)
143
+ text_emb = text_emb + self.text_pos_embedding(text_emb)
144
+ if self.cached_mel_emb.shape[0] != text_emb.shape[0]:
145
+ mel_emb = self.cached_mel_emb.repeat_interleave(
146
+ text_emb.shape[0] // self.cached_mel_emb.shape[0], 0
147
+ )
148
+ else: # this outcome only occurs once per loop in most cases
149
+ mel_emb = self.cached_mel_emb
150
+ emb = torch.cat([mel_emb, text_emb], dim=1)
151
+ else:
152
+ emb = self.embeddings(input_ids)
153
+ emb = emb + self.text_pos_embedding.get_fixed_embedding(
154
+ attention_mask.shape[1] - mel_len, attention_mask.device
155
+ )
156
+ transformer_outputs = self.transformer(
157
+ inputs_embeds=emb,
158
+ past_key_values=past_key_values,
159
+ attention_mask=attention_mask,
160
+ token_type_ids=token_type_ids,
161
+ position_ids=position_ids,
162
+ head_mask=head_mask,
163
+ encoder_hidden_states=encoder_hidden_states,
164
+ encoder_attention_mask=encoder_attention_mask,
165
+ use_cache=use_cache,
166
+ output_attentions=output_attentions,
167
+ output_hidden_states=output_hidden_states,
168
+ return_dict=return_dict,
169
+ )
170
+ hidden_states = transformer_outputs[0]
171
+
172
+ # Set device for model parallelism
173
+ if self.model_parallel:
174
+ if torch.backends.mps.is_available():
175
+ self.to(self.transformer.first_device)
176
+ else:
177
+ torch.cuda.set_device(self.transformer.first_device)
178
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
179
+
180
+ lm_logits = self.lm_head(hidden_states)
181
+
182
+ if not return_dict:
183
+ return (lm_logits,) + transformer_outputs[1:]
184
+
185
+ return CausalLMOutputWithCrossAttentions(
186
+ loss=None,
187
+ logits=lm_logits,
188
+ past_key_values=transformer_outputs.past_key_values,
189
+ hidden_states=transformer_outputs.hidden_states,
190
+ attentions=transformer_outputs.attentions,
191
+ cross_attentions=transformer_outputs.cross_attentions,
192
+ )
193
+
194
+ @staticmethod
195
+ def _reorder_cache(past, beam_idx):
196
+ """
197
+ This function is used to re-order the :obj:`past_key_values` cache if
198
+ :meth:`~transformers.PreTrainedModel.beam_search` or :meth:`~transformers.PreTrainedModel.beam_sample` is
199
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
200
+ """
201
+ return tuple(
202
+ tuple(
203
+ past_state.index_select(0, beam_idx.to(past_state.device))
204
+ for past_state in layer_past
205
+ )
206
+ for layer_past in past
207
+ )
208
+
209
+
210
+ class ConditioningEncoder(nn.Module):
211
+ def __init__(self,
212
+ spec_dim,
213
+ embedding_dim,
214
+ attn_blocks=6,
215
+ num_attn_heads=4,
216
+ do_checkpointing=False,
217
+ mean=False):
218
+ super().__init__()
219
+ attn = []
220
+ self.init = nn.Conv1d(spec_dim, embedding_dim, kernel_size=1)
221
+ for a in range(attn_blocks):
222
+ attn.append(AttentionBlock(embedding_dim, num_attn_heads))
223
+ self.attn = nn.Sequential(*attn)
224
+ self.dim = embedding_dim
225
+ self.do_checkpointing = do_checkpointing
226
+ self.mean = mean
227
+
228
+ def forward(self, x):
229
+ h = self.init(x)
230
+ h = self.attn(h)
231
+ if self.mean:
232
+ return h.mean(dim=2)
233
+ else:
234
+ return h
235
+ # return h[:, :, 0]
236
+
237
+
238
+ class LearnedPositionEmbeddings(nn.Module):
239
+ def __init__(self, seq_len, model_dim, init=.02):
240
+ super().__init__()
241
+ self.emb = nn.Embedding(seq_len, model_dim)
242
+ # Initializing this way is standard for GPT-2
243
+ self.emb.weight.data.normal_(mean=0.0, std=init)
244
+
245
+ def forward(self, x):
246
+ sl = x.shape[1]
247
+ return self.emb(torch.arange(0, sl, device=x.device))
248
+
249
+ def get_fixed_embedding(self, ind, dev):
250
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
251
+
252
+
253
+ def build_hf_gpt_transformer(layers, model_dim, heads, max_mel_seq_len, max_text_seq_len, checkpointing, activation_function):
254
+ """
255
+ GPT-2 implemented by the HuggingFace library.
256
+ """
257
+ from transformers import GPT2Config, GPT2Model
258
+ gpt_config = GPT2Config(vocab_size=256, # Unused.
259
+ n_positions=max_mel_seq_len + max_text_seq_len,
260
+ n_ctx=max_mel_seq_len + max_text_seq_len,
261
+ n_embd=model_dim,
262
+ n_layer=layers,
263
+ n_head=heads,
264
+ activation_function=activation_function or "gelu_new",
265
+ gradient_checkpointing=checkpointing,
266
+ use_cache=not checkpointing)
267
+ gpt = GPT2Model(gpt_config)
268
+ # Override the built in positional embeddings
269
+ del gpt.wpe
270
+ gpt.wpe = functools.partial(null_position_embeddings, dim=model_dim)
271
+ # Built-in token embeddings are unused.
272
+ del gpt.wte
273
+ return gpt, LearnedPositionEmbeddings(max_mel_seq_len, model_dim), LearnedPositionEmbeddings(max_text_seq_len, model_dim), \
274
+ None, None
275
+
276
+
277
+ class MelEncoder(nn.Module):
278
+ def __init__(self, channels, mel_channels=80, resblocks_per_reduction=2):
279
+ super().__init__()
280
+ self.channels = channels
281
+ self.encoder = nn.Sequential(nn.Conv1d(mel_channels, channels // 4, kernel_size=3, padding=1),
282
+ nn.Sequential(*[ResBlock(channels // 4) for _ in range(resblocks_per_reduction)]),
283
+ nn.Conv1d(channels // 4, channels // 2, kernel_size=3, stride=2, padding=1),
284
+ nn.GroupNorm(channels // 16, channels // 2),
285
+ nn.ReLU(),
286
+ nn.Sequential(*[ResBlock(channels // 2) for _ in range(resblocks_per_reduction)]),
287
+ nn.Conv1d(channels // 2, channels, kernel_size=3, stride=2, padding=1),
288
+ nn.GroupNorm(channels // 8, channels),
289
+ nn.ReLU(),
290
+ nn.Sequential(*[ResBlock(channels) for _ in range(resblocks_per_reduction)]),
291
+ )
292
+ self.reduction = 4
293
+
294
+ def forward(self, x):
295
+ for e in self.encoder:
296
+ x = e(x)
297
+ return x.permute(0, 2, 1)
298
+
299
+
300
+ class UnifiedVoice(nn.Module):
301
+ def __init__(self, layers=8, model_dim=512, heads=8, max_text_tokens=120, max_mel_tokens=250, max_conditioning_inputs=1,
302
+ mel_length_compression=1024, number_text_tokens=256,
303
+ start_text_token=0, stop_text_token=1, number_mel_codes=8194, start_mel_token=8192, stop_mel_token=8193,
304
+ train_solo_embeddings=False, use_mel_codes_as_input=True,
305
+ checkpointing=True, types=1, activation_function=None,
306
+ condition_num_latent=32, condition_type="perceiver", condition_module=None):
307
+ """
308
+ Args:
309
+ layers: Number of layers in transformer stack.
310
+ model_dim: Operating dimensions of the transformer
311
+ heads: Number of transformer heads. Must be divisible by model_dim. Recommend model_dim//64
312
+ max_text_tokens: Maximum number of text tokens that will be encountered by model.
313
+ max_mel_tokens: Maximum number of MEL tokens that will be encountered by model.
314
+ max_conditioning_inputs: Maximum number of conditioning inputs provided to the model. If (1), conditioning input can be of format (b,80,s), otherwise (b,n,80,s).
315
+ mel_length_compression: The factor between <number_input_samples> and <mel_tokens>. Used to compute MEL code padding given wav input length.
316
+ number_text_tokens:
317
+ start_text_token:
318
+ stop_text_token:
319
+ number_mel_codes:
320
+ start_mel_token:
321
+ stop_mel_token:
322
+ train_solo_embeddings:
323
+ use_mel_codes_as_input:
324
+ checkpointing:
325
+ condition_type: perceiver, gst or default encoder
326
+ """
327
+ super().__init__()
328
+ self.number_text_tokens = number_text_tokens
329
+ self.start_text_token = start_text_token
330
+ self.stop_text_token = stop_text_token
331
+ self.number_mel_codes = number_mel_codes
332
+ self.start_mel_token = start_mel_token
333
+ self.stop_mel_token = stop_mel_token
334
+ self.layers = layers
335
+ self.heads = heads
336
+ self.max_mel_tokens = max_mel_tokens
337
+ self.max_text_tokens = max_text_tokens
338
+ self.model_dim = model_dim
339
+ self.max_conditioning_inputs = max_conditioning_inputs
340
+ self.mel_length_compression = mel_length_compression
341
+ self.condition_type = condition_type
342
+ self.cond_num = condition_num_latent
343
+ self.cond_mask_pad = nn.ConstantPad1d((self.cond_num, 0), True)
344
+ if condition_type == "perceiver":
345
+ self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads)
346
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=model_dim, num_latents=self.cond_num)
347
+ elif condition_type == "conformer_perceiver" or condition_type == "conformer_encoder":
348
+ self.conditioning_encoder = ConformerEncoder(input_size=100,
349
+ output_size=condition_module['output_size'],
350
+ linear_units=condition_module['linear_units'],
351
+ attention_heads=condition_module['attention_heads'],
352
+ num_blocks=condition_module['num_blocks'],
353
+ input_layer=condition_module['input_layer'])
354
+ if condition_type == "conformer_perceiver":
355
+ self.perceiver_encoder = PerceiverResampler(model_dim, dim_context=condition_module['output_size'],
356
+ ff_mult=condition_module['perceiver_mult'],
357
+ heads=condition_module['attention_heads'],
358
+ num_latents=self.cond_num)
359
+ else:
360
+ self.conditioning_encoder = ConditioningEncoder(100, model_dim, num_attn_heads=heads, mean=True)
361
+
362
+ self.text_embedding = nn.Embedding(self.number_text_tokens * types + 1, model_dim)
363
+ if use_mel_codes_as_input:
364
+ self.mel_embedding = nn.Embedding(self.number_mel_codes, model_dim)
365
+ else:
366
+ self.mel_embedding = MelEncoder(model_dim, resblocks_per_reduction=1)
367
+ self.gpt, self.mel_pos_embedding, self.text_pos_embedding, self.mel_layer_pos_embedding, self.text_layer_pos_embedding = \
368
+ build_hf_gpt_transformer(layers, model_dim, heads, self.max_mel_tokens + 2 + self.max_conditioning_inputs,
369
+ self.max_text_tokens + 2, checkpointing, activation_function)
370
+ if train_solo_embeddings:
371
+ self.mel_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
372
+ self.text_solo_embedding = nn.Parameter(torch.randn(1, 1, model_dim) * .02, requires_grad=True)
373
+ else:
374
+ self.mel_solo_embedding = 0
375
+ self.text_solo_embedding = 0
376
+
377
+ self.final_norm = nn.LayerNorm(model_dim)
378
+ self.text_head = nn.Linear(model_dim, self.number_text_tokens * types + 1)
379
+ self.mel_head = nn.Linear(model_dim, self.number_mel_codes)
380
+
381
+ # Initialize the embeddings per the GPT-2 scheme
382
+ embeddings = [self.text_embedding]
383
+ if use_mel_codes_as_input:
384
+ embeddings.append(self.mel_embedding)
385
+ for module in embeddings:
386
+ module.weight.data.normal_(mean=0.0, std=.02)
387
+
388
+ def post_init_gpt2_config(self, use_deepspeed=False, kv_cache=False, half=False):
389
+ seq_length = self.max_mel_tokens + self.max_text_tokens + 2
390
+ gpt_config = GPT2Config(
391
+ vocab_size=self.number_mel_codes,
392
+ n_positions=seq_length,
393
+ n_ctx=seq_length,
394
+ n_embd=self.model_dim,
395
+ n_layer=self.layers,
396
+ n_head=self.heads,
397
+ gradient_checkpointing=False,
398
+ use_cache=True,
399
+ )
400
+ self.inference_model = GPT2InferenceModel(
401
+ gpt_config,
402
+ self.gpt,
403
+ self.mel_pos_embedding,
404
+ self.mel_embedding,
405
+ self.final_norm,
406
+ self.mel_head,
407
+ kv_cache=kv_cache,
408
+ )
409
+ if use_deepspeed and half and torch.cuda.is_available():
410
+ import deepspeed
411
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
412
+ mp_size=1,
413
+ replace_with_kernel_inject=False,
414
+ dtype=torch.float16)
415
+ self.inference_model = self.ds_engine.module.eval()
416
+ elif use_deepspeed and torch.cuda.is_available():
417
+ import deepspeed
418
+ self.ds_engine = deepspeed.init_inference(model=self.inference_model,
419
+ mp_size=1,
420
+ replace_with_kernel_inject=False,
421
+ dtype=torch.float32)
422
+ self.inference_model = self.ds_engine.module.eval()
423
+ else:
424
+ self.inference_model = self.inference_model.eval()
425
+
426
+ # self.inference_model = PrunedGPT2InferenceModel(gpt_config, self.gpt, self.mel_pos_embedding, self.mel_embedding, self.final_norm, self.mel_head)
427
+ self.gpt.wte = self.mel_embedding
428
+
429
+ def build_aligned_inputs_and_targets(self, input, start_token, stop_token):
430
+ inp = F.pad(input, (1, 0), value=start_token)
431
+ tar = F.pad(input, (0, 1), value=stop_token)
432
+ return inp, tar
433
+
434
+ def set_mel_padding(self, mel_input_tokens, mel_lengths):
435
+ """
436
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
437
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
438
+ preformatting to create a working TTS model.
439
+ """
440
+ for b in range(len(mel_lengths)):
441
+ # Due to the convolutional nature of how these tokens are generated,
442
+ # it would be best if the model predicts a token past the actual last token.
443
+ actual_end = mel_lengths[b]
444
+ if actual_end < mel_input_tokens.shape[-1]:
445
+ mel_input_tokens[b, actual_end:] = self.stop_mel_token
446
+ return mel_input_tokens
447
+
448
+ def set_text_padding(self, text_input_tokens, text_lengths):
449
+ """
450
+ Given mel tokens that are derived from a padded audio clip and the actual lengths of each batch element in
451
+ that audio clip, reformats the tokens with STOP_MEL_TOKEN in place of the zero padding. This is required
452
+ preformatting to create a working TTS model.
453
+ """
454
+ for b in range(len(text_lengths)):
455
+ # Due to the convolutional nature of how these tokens are generated,
456
+ # it would be best if the model predicts a token past the actual last token.
457
+ actual_end = text_lengths[b]
458
+ if actual_end < text_input_tokens.shape[-1]:
459
+ text_input_tokens[b, actual_end:] = self.stop_text_token
460
+ return text_input_tokens
461
+
462
+ def get_logits(self, speech_conditioning_inputs, first_inputs, first_head, second_inputs=None, second_head=None, get_attns=False, return_latent=False):
463
+ if second_inputs is not None:
464
+ emb = torch.cat([speech_conditioning_inputs, first_inputs, second_inputs], dim=1)
465
+ else:
466
+ emb = torch.cat([speech_conditioning_inputs, first_inputs], dim=1)
467
+
468
+ gpt_out = self.gpt(inputs_embeds=emb, return_dict=True, output_attentions=get_attns)
469
+ if get_attns:
470
+ return gpt_out.attentions
471
+
472
+ offset = speech_conditioning_inputs.shape[1]
473
+ enc = gpt_out.last_hidden_state[:, offset:]
474
+ enc = self.final_norm(enc)
475
+
476
+ if return_latent:
477
+ return enc[:, :first_inputs.shape[1]], enc[:, -second_inputs.shape[1]:]
478
+
479
+ first_logits = enc[:, :first_inputs.shape[1]]
480
+ first_logits = first_head(first_logits)
481
+ first_logits = first_logits.permute(0, 2, 1)
482
+ if second_inputs is not None:
483
+ second_logits = enc[:, -second_inputs.shape[1]:]
484
+ second_logits = second_head(second_logits)
485
+ second_logits = second_logits.permute(0, 2, 1)
486
+ return first_logits, second_logits
487
+ else:
488
+ return first_logits
489
+
490
+ def get_conditioning(self, speech_conditioning_input, cond_mel_lengths=None):
491
+ if self.condition_type == "perceiver":
492
+ if speech_conditioning_input.ndim == 4:
493
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
494
+ speech_conditioning_input = self.conditioning_encoder(speech_conditioning_input) # (b, d, s)
495
+ conds = self.perceiver_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 32, d)
496
+ elif self.condition_type == "conformer_perceiver":
497
+ speech_conditioning_input, mask = self.conditioning_encoder(speech_conditioning_input.transpose(1, 2),
498
+ cond_mel_lengths) # (b, s, d), (b, 1, s)
499
+ if self.condition_type == "conformer_perceiver":
500
+ # conds_mask = torch.cat([torch.ones((mask.shape[0], self.cond_num), dtype=torch.bool), mask.squeeze(1)], dim=1)
501
+ conds_mask = self.cond_mask_pad(mask.squeeze(1))
502
+ conds = self.perceiver_encoder(speech_conditioning_input, conds_mask) # (b, 32, d)
503
+ elif self.condition_type == "gst":
504
+ if speech_conditioning_input.ndim == 4:
505
+ speech_conditioning_input = speech_conditioning_input.squeeze(1)
506
+ conds = self.gst_encoder(speech_conditioning_input.transpose(1, 2)) # (b, 1, d)
507
+ else:
508
+ speech_conditioning_input = (
509
+ speech_conditioning_input.unsqueeze(1)
510
+ if len(speech_conditioning_input.shape) == 3
511
+ else speech_conditioning_input
512
+ )
513
+ conds = []
514
+ for j in range(speech_conditioning_input.shape[1]):
515
+ conds.append(self.conditioning_encoder(speech_conditioning_input[:, j]))
516
+ conds = torch.stack(conds, dim=1)
517
+ conds = conds.mean(dim=1)
518
+ conds = conds.unsqueeze(1)
519
+ return conds
520
+
521
+ def forward(self, speech_conditioning_latent, text_inputs, text_lengths, mel_codes, wav_lengths,
522
+ cond_mel_lengths=None, types=None, text_first=True, raw_mels=None, return_attentions=False,
523
+ return_latent=False, clip_inputs=False):
524
+ """
525
+ Forward pass that uses both text and voice in either text conditioning mode or voice conditioning mode
526
+ (actuated by `text_first`).
527
+
528
+ speech_conditioning_input: MEL float tensor, (b,1024)
529
+ text_inputs: long tensor, (b,t)
530
+ text_lengths: long tensor, (b,)
531
+ mel_inputs: long tensor, (b,m)
532
+ wav_lengths: long tensor, (b,)
533
+ raw_mels: MEL float tensor (b,80,s)
534
+
535
+ If return_attentions is specified, only logits are returned.
536
+ If return_latent is specified, loss & logits are not computed or returned. Only the predicted latents are returned.
537
+ If clip_inputs is True, the inputs will be clipped to the smallest input size across each input modality.
538
+ """
539
+
540
+ speech_conditioning_latent = self.get_conditioning(speech_conditioning_latent, cond_mel_lengths)
541
+ # Types are expressed by expanding the text embedding space.
542
+ if types is not None:
543
+ text_inputs = text_inputs * (1 + types).unsqueeze(-1)
544
+
545
+ if clip_inputs:
546
+ # This model will receive micro-batches with a ton of padding for both the text and MELs. Ameliorate this by
547
+ # chopping the inputs by the maximum actual length.
548
+ max_text_len = text_lengths.max()
549
+ text_inputs = text_inputs[:, :max_text_len]
550
+ max_mel_len = wav_lengths.max() // self.mel_length_compression
551
+ mel_codes = mel_codes[:, :max_mel_len]
552
+ if raw_mels is not None:
553
+ raw_mels = raw_mels[:, :, :max_mel_len * 4]
554
+
555
+ # Set padding areas within MEL (currently it is coded with the MEL code for <zero>).
556
+ # mel_codes_lengths = torch.div(wav_lengths, self.mel_length_compression, rounding_mode='trunc')
557
+ mel_codes_lengths = torch.ceil(wav_lengths / self.mel_length_compression).long() + 1
558
+ mel_codes = self.set_mel_padding(mel_codes, mel_codes_lengths)
559
+ text_inputs = self.set_text_padding(text_inputs, text_lengths)
560
+ text_inputs = F.pad(text_inputs, (0, 1), value=self.stop_text_token)
561
+ mel_codes = F.pad(mel_codes, (0, 1), value=self.stop_mel_token)
562
+
563
+ conds = speech_conditioning_latent
564
+ text_inputs, text_targets = self.build_aligned_inputs_and_targets(text_inputs, self.start_text_token, self.stop_text_token)
565
+ text_emb = self.text_embedding(text_inputs) + self.text_pos_embedding(text_inputs)
566
+ mel_codes, mel_targets = self.build_aligned_inputs_and_targets(mel_codes, self.start_mel_token, self.stop_mel_token)
567
+ if raw_mels is not None:
568
+ mel_inp = F.pad(raw_mels, (0, 8))
569
+ else:
570
+ mel_inp = mel_codes
571
+ mel_emb = self.mel_embedding(mel_inp)
572
+ mel_emb = mel_emb + self.mel_pos_embedding(mel_codes)
573
+
574
+ if text_first:
575
+ # print(f"conds: {conds.shape}, text_emb: {text_emb.shape}, mel_emb: {mel_emb.shape}")
576
+ text_logits, mel_logits = self.get_logits(conds, text_emb, self.text_head, mel_emb, self.mel_head, get_attns=return_attentions, return_latent=return_latent)
577
+ if return_latent:
578
+ return mel_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
579
+ else:
580
+ mel_logits, text_logits = self.get_logits(conds, mel_emb, self.mel_head, text_emb, self.text_head, get_attns=return_attentions, return_latent=return_latent)
581
+ if return_latent:
582
+ return text_logits[:, :-2] # Despite the name, these are not logits. Strip off the two tokens added by this forward pass.
583
+
584
+ if return_attentions:
585
+ return mel_logits
586
+
587
+ loss_text = F.cross_entropy(text_logits, text_targets.long())
588
+ loss_mel = F.cross_entropy(mel_logits, mel_targets.long())
589
+ return loss_text.mean(), loss_mel.mean(), mel_logits
590
+
591
+ def prepare_gpt_inputs(
592
+ self,
593
+ conditional_latents: torch.Tensor,
594
+ text_inputs: torch.Tensor,
595
+ ):
596
+
597
+ """
598
+ Prepare the inputs for the GPT2InferenceModel to generate.
599
+ Args:
600
+ conds_latent: (b, 32, dim) audio conditioning embedding by `get_conditioning()`
601
+ text_inputs: (b, L)
602
+ Returns:
603
+ input_ids: (b, s+1) the input ids for the GPT2InferenceModel.generate()
604
+ inputs_embeds: (b, s+1, dim) the input embeddings for the GPT2InferenceModel.forward()
605
+ attention_mask: (b, s+1) the attention mask for the GPT2InferenceModel.generate()
606
+ """
607
+ b, L = text_inputs.shape[:2]
608
+ device = text_inputs.device
609
+ single_cond = conditional_latents.ndim == 3 and conditional_latents.shape[0] == 1
610
+ if not single_cond:
611
+ assert conditional_latents.shape[0] == b, f"batch size mismatch: {conditional_latents.shape[0]} vs {b}"
612
+ batched_mel_emb = []
613
+ attention_masks = []
614
+ target_len = conditional_latents.shape[1] + L + 2
615
+ for i in range(b):
616
+ valid_mask = (text_inputs[i] != self.stop_text_token) & (text_inputs[i] != self.start_text_token)
617
+ text_input = text_inputs[i][valid_mask]
618
+ text_input = F.pad(text_input, (1, 0), value=self.start_text_token)
619
+ text_input = F.pad(text_input, (0, 1), value=self.stop_text_token)
620
+ text_input_pos = torch.arange(0, text_input.size(-1), device=device)
621
+ text_emb = self.text_embedding(text_input) + self.text_pos_embedding.emb(text_input_pos)
622
+ # concatenate [conditional latents][text embeddings]
623
+ conds_text_emb = [
624
+ conditional_latents.squeeze(0) if single_cond else conditional_latents[i],
625
+ text_emb,
626
+ ]
627
+ # +1 for the start_mel_token
628
+ attention_mask = torch.ones(target_len+1, dtype=torch.long, device=device)
629
+ # check this text input is padded
630
+ padding: int = L + 2 - text_input.size(-1)
631
+ # pad left of [cond][text] -> [pad][cond][text]
632
+ if padding > 0:
633
+ pad = torch.zeros((padding, conditional_latents.size(-1)), dtype=text_emb.dtype, device=device) # [p, dim]
634
+ conds_text_emb.insert(0, pad)
635
+ attention_mask[:padding] = 0
636
+ mel_emb = torch.cat(conds_text_emb) #[s, dim]
637
+ assert mel_emb.shape[0] == target_len, f"mel_emb.shape: {mel_emb.shape}, target_len: {target_len}"
638
+ batched_mel_emb.append(mel_emb)
639
+ attention_masks.append(attention_mask)
640
+ # [b, s, dim]
641
+ batched_mel_emb = torch.stack(batched_mel_emb, dim=0)
642
+ # [b, s+1]
643
+ attention_mask = torch.stack(attention_masks, dim=0)
644
+ # [b, s+1]
645
+ fake_inputs = torch.ones(
646
+ (
647
+ batched_mel_emb.shape[0],
648
+ batched_mel_emb.shape[1] + 1, # +1 for the start_mel_token
649
+ ),
650
+ dtype=torch.long,
651
+ device=device,
652
+ )
653
+ fake_inputs[:, -1] = self.start_mel_token
654
+ return fake_inputs, batched_mel_emb, attention_mask
655
+ def inference_speech(self, speech_conditioning_mel, text_inputs, cond_mel_lengths=None, input_tokens=None, num_return_sequences=1,
656
+ max_generate_length=None, typical_sampling=False, typical_mass=.9, **hf_generate_kwargs):
657
+ """
658
+ Args:
659
+ speech_conditioning_mel: (b, n_mels, frames) or (n_mels, frames)
660
+ text_inputs: (b, L)
661
+ cond_mel_lengths: lengths of the conditioning mel spectrograms in shape (b,) or (1,)
662
+ input_tokens: additional tokens for generation in shape (b, s) or (s,)
663
+ max_generate_length: limit the number of generated tokens
664
+ hf_generate_kwargs: kwargs for `GPT2InferenceModel.generate(**hf_generate_kwargs)`
665
+ """
666
+ if speech_conditioning_mel.ndim == 2:
667
+ speech_conditioning_mel = speech_conditioning_mel.unsqueeze(0)
668
+ if cond_mel_lengths is None:
669
+ cond_mel_lengths = torch.tensor([speech_conditioning_mel.shape[-1]], device=speech_conditioning_mel.device)
670
+ conds_latent = self.get_conditioning(speech_conditioning_mel, cond_mel_lengths)
671
+ input_ids, inputs_embeds, attention_mask = self.prepare_gpt_inputs(conds_latent, text_inputs)
672
+ self.inference_model.store_mel_emb(inputs_embeds)
673
+ if input_tokens is None:
674
+ inputs = input_ids
675
+ else:
676
+ if input_tokens.ndim == 1:
677
+ input_tokens = input_tokens.unsqueeze(0)
678
+ assert num_return_sequences % input_tokens.shape[0] == 0, \
679
+ "The num_return_sequences must be divisible by the batch number of input_tokens"
680
+ assert num_return_sequences % text_inputs.shape[0] == 0, \
681
+ "The num_return_sequences must be divisible by the batch number of text_inputs"
682
+ b = num_return_sequences // input_ids.shape[0]
683
+ if b > 1:
684
+ input_ids = input_ids.repeat(b, 1)
685
+ attention_mask = attention_mask.repeat(b, 1)
686
+ input_tokens = input_tokens.repeat(num_return_sequences // input_tokens.shape[0], 1)
687
+ inputs = torch.cat([input_ids, input_tokens], dim=1)
688
+ attention_mask = F.pad(attention_mask, (0, input_tokens.shape[1]), value=1)
689
+ trunc_index = inputs.shape[1]
690
+ logits_processor = LogitsProcessorList()
691
+ if typical_sampling:
692
+ # employ custom typical sampling
693
+ if not (typical_mass > 0.0 and typical_mass < 1.0):
694
+ raise ValueError(f"`typical_mass` has to be a float > 0 and < 1, but is {typical_mass}")
695
+ min_tokens_to_keep = 2 if hf_generate_kwargs.get("num_beams", 1) > 1 else 1
696
+ logits_processor.append(TypicalLogitsWarper(mass=typical_mass, min_tokens_to_keep=min_tokens_to_keep))
697
+ max_length = (trunc_index + self.max_mel_tokens - 1) if max_generate_length is None else trunc_index + max_generate_length
698
+ output = self.inference_model.generate(inputs,
699
+ bos_token_id=self.start_mel_token, pad_token_id=self.stop_mel_token,
700
+ eos_token_id=self.stop_mel_token, attention_mask=attention_mask,
701
+ max_length=max_length, logits_processor=logits_processor,
702
+ num_return_sequences=num_return_sequences,
703
+ **hf_generate_kwargs)
704
+ if isinstance(output, torch.Tensor):
705
+ return output[:, trunc_index:]
706
+ # GenerateOutput
707
+ output.sequences = output.sequences[:, trunc_index:]
708
+ return output