Spaces:
Sleeping
Sleeping
Speaker Embeddings integration
Browse files
model.py
CHANGED
|
@@ -166,16 +166,28 @@ class Encoder(nn.Module):
|
|
| 166 |
convolutions.append(conv_layer)
|
| 167 |
self.convolutions = nn.ModuleList(convolutions)
|
| 168 |
|
| 169 |
-
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
|
| 170 |
int(hparams.encoder_embedding_dim / 2), 1,
|
| 171 |
batch_first=True, bidirectional=True)
|
| 172 |
|
| 173 |
-
def forward(self, x, input_lengths):
|
| 174 |
for conv in self.convolutions:
|
| 175 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 176 |
|
| 177 |
x = x.transpose(1, 2)
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
# pytorch tensor are not reversible, hence the conversion
|
| 180 |
input_lengths = input_lengths.cpu().numpy()
|
| 181 |
x = nn.utils.rnn.pack_padded_sequence(
|
|
@@ -186,15 +198,24 @@ class Encoder(nn.Module):
|
|
| 186 |
|
| 187 |
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
| 188 |
outputs, batch_first=True)
|
| 189 |
-
|
| 190 |
return outputs
|
| 191 |
|
| 192 |
-
def inference(self, x):
|
| 193 |
for conv in self.convolutions:
|
| 194 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 195 |
|
| 196 |
x = x.transpose(1, 2)
|
| 197 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 198 |
self.lstm.flatten_parameters()
|
| 199 |
outputs, _ = self.lstm(x)
|
| 200 |
|
|
@@ -392,7 +413,6 @@ class Decoder(nn.Module):
|
|
| 392 |
gate_outputs: gate outputs from the decoder
|
| 393 |
alignments: sequence of attention weights from the decoder
|
| 394 |
"""
|
| 395 |
-
|
| 396 |
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
| 397 |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
| 398 |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
|
@@ -472,15 +492,17 @@ class Tacotron2(nn.Module):
|
|
| 472 |
|
| 473 |
def parse_batch(self, batch):
|
| 474 |
text_padded, input_lengths, mel_padded, gate_padded, \
|
| 475 |
-
output_lengths = batch
|
| 476 |
text_padded = to_gpu(text_padded).long()
|
| 477 |
input_lengths = to_gpu(input_lengths).long()
|
| 478 |
max_len = torch.max(input_lengths.data).item()
|
| 479 |
mel_padded = to_gpu(mel_padded).float()
|
| 480 |
gate_padded = to_gpu(gate_padded).float()
|
| 481 |
output_lengths = to_gpu(output_lengths).long()
|
|
|
|
| 482 |
|
| 483 |
return (
|
|
|
|
| 484 |
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
| 485 |
(mel_padded, gate_padded))
|
| 486 |
|
|
@@ -496,27 +518,28 @@ class Tacotron2(nn.Module):
|
|
| 496 |
|
| 497 |
return outputs
|
| 498 |
|
| 499 |
-
def forward(self, inputs):
|
| 500 |
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
| 501 |
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
| 502 |
|
| 503 |
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
| 504 |
|
| 505 |
-
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
| 506 |
|
| 507 |
mel_outputs, gate_outputs, alignments = self.decoder(
|
| 508 |
encoder_outputs, mels, memory_lengths=text_lengths)
|
| 509 |
|
| 510 |
mel_outputs_postnet = self.postnet(mel_outputs)
|
|
|
|
| 511 |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 512 |
|
| 513 |
return self.parse_output(
|
| 514 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
| 515 |
output_lengths)
|
| 516 |
|
| 517 |
-
def inference(self, inputs):
|
| 518 |
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
| 519 |
-
encoder_outputs = self.encoder.inference(embedded_inputs)
|
| 520 |
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
| 521 |
encoder_outputs)
|
| 522 |
|
|
@@ -526,4 +549,4 @@ class Tacotron2(nn.Module):
|
|
| 526 |
outputs = self.parse_output(
|
| 527 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
| 528 |
|
| 529 |
-
return outputs
|
|
|
|
| 166 |
convolutions.append(conv_layer)
|
| 167 |
self.convolutions = nn.ModuleList(convolutions)
|
| 168 |
|
| 169 |
+
self.lstm = nn.LSTM(hparams.encoder_embedding_dim + hparams.speaker_embedding_dim,
|
| 170 |
int(hparams.encoder_embedding_dim / 2), 1,
|
| 171 |
batch_first=True, bidirectional=True)
|
| 172 |
|
| 173 |
+
def forward(self, x, input_lengths, speaker_embedding):
|
| 174 |
for conv in self.convolutions:
|
| 175 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 176 |
|
| 177 |
x = x.transpose(1, 2)
|
| 178 |
|
| 179 |
+
# this concatenation part is largely from https://github.com/CorentinJ/Real-Time-Voice-Cloning
|
| 180 |
+
batch_size = x.size()[0]
|
| 181 |
+
num_chars = x.size()[1]
|
| 182 |
+
idx = 0 if speaker_embedding.dim() == 1 else 1
|
| 183 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
| 184 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
| 185 |
+
# Reshape & transpose
|
| 186 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
| 187 |
+
e = e.transpose(1, 2)
|
| 188 |
+
# Concatenate the tiled speaker embedding with the encoder output
|
| 189 |
+
x = torch.cat((x, e), 2)
|
| 190 |
+
|
| 191 |
# pytorch tensor are not reversible, hence the conversion
|
| 192 |
input_lengths = input_lengths.cpu().numpy()
|
| 193 |
x = nn.utils.rnn.pack_padded_sequence(
|
|
|
|
| 198 |
|
| 199 |
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
| 200 |
outputs, batch_first=True)
|
|
|
|
| 201 |
return outputs
|
| 202 |
|
| 203 |
+
def inference(self, x, speaker_embedding=None):
|
| 204 |
for conv in self.convolutions:
|
| 205 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 206 |
|
| 207 |
x = x.transpose(1, 2)
|
| 208 |
|
| 209 |
+
if speaker_embedding is not None:
|
| 210 |
+
batch_size = x.size()[0]
|
| 211 |
+
num_chars = x.size()[1]
|
| 212 |
+
idx = 0 if speaker_embedding.dim() == 1 else 1
|
| 213 |
+
speaker_embedding_size = speaker_embedding.size()[idx]
|
| 214 |
+
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
| 215 |
+
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
| 216 |
+
e = e.transpose(1, 2)
|
| 217 |
+
x = torch.cat((x, e), 2)
|
| 218 |
+
|
| 219 |
self.lstm.flatten_parameters()
|
| 220 |
outputs, _ = self.lstm(x)
|
| 221 |
|
|
|
|
| 413 |
gate_outputs: gate outputs from the decoder
|
| 414 |
alignments: sequence of attention weights from the decoder
|
| 415 |
"""
|
|
|
|
| 416 |
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
| 417 |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
| 418 |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
|
|
|
| 492 |
|
| 493 |
def parse_batch(self, batch):
|
| 494 |
text_padded, input_lengths, mel_padded, gate_padded, \
|
| 495 |
+
output_lengths, mel_speaker = batch
|
| 496 |
text_padded = to_gpu(text_padded).long()
|
| 497 |
input_lengths = to_gpu(input_lengths).long()
|
| 498 |
max_len = torch.max(input_lengths.data).item()
|
| 499 |
mel_padded = to_gpu(mel_padded).float()
|
| 500 |
gate_padded = to_gpu(gate_padded).float()
|
| 501 |
output_lengths = to_gpu(output_lengths).long()
|
| 502 |
+
mel_speaker = to_gpu(mel_speaker).float()
|
| 503 |
|
| 504 |
return (
|
| 505 |
+
mel_speaker,
|
| 506 |
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
| 507 |
(mel_padded, gate_padded))
|
| 508 |
|
|
|
|
| 518 |
|
| 519 |
return outputs
|
| 520 |
|
| 521 |
+
def forward(self, inputs, speaker_embedding):
|
| 522 |
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
| 523 |
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
| 524 |
|
| 525 |
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
| 526 |
|
| 527 |
+
encoder_outputs = self.encoder(embedded_inputs, text_lengths, speaker_embedding)
|
| 528 |
|
| 529 |
mel_outputs, gate_outputs, alignments = self.decoder(
|
| 530 |
encoder_outputs, mels, memory_lengths=text_lengths)
|
| 531 |
|
| 532 |
mel_outputs_postnet = self.postnet(mel_outputs)
|
| 533 |
+
|
| 534 |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 535 |
|
| 536 |
return self.parse_output(
|
| 537 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
| 538 |
output_lengths)
|
| 539 |
|
| 540 |
+
def inference(self, inputs, speaker_embedding):
|
| 541 |
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
| 542 |
+
encoder_outputs = self.encoder.inference(embedded_inputs,speaker_embedding)
|
| 543 |
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
| 544 |
encoder_outputs)
|
| 545 |
|
|
|
|
| 549 |
outputs = self.parse_output(
|
| 550 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
| 551 |
|
| 552 |
+
return outputs
|