Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -149,7 +149,6 @@ class Postnet(nn.Module):
|
|
| 149 |
class Encoder(nn.Module):
|
| 150 |
def __init__(self, hparams):
|
| 151 |
super(Encoder, self).__init__()
|
| 152 |
-
|
| 153 |
convolutions = []
|
| 154 |
for _ in range(hparams.encoder_n_convolutions):
|
| 155 |
conv_layer = nn.Sequential(
|
|
@@ -162,28 +161,19 @@ class Encoder(nn.Module):
|
|
| 162 |
convolutions.append(conv_layer)
|
| 163 |
self.convolutions = nn.ModuleList(convolutions)
|
| 164 |
|
| 165 |
-
|
| 166 |
-
lstm_input_dim = hparams.encoder_embedding_dim
|
| 167 |
-
self.lstm = nn.LSTM(lstm_input_dim,
|
| 168 |
int(hparams.encoder_embedding_dim / 2), 1,
|
| 169 |
batch_first=True, bidirectional=True)
|
| 170 |
|
| 171 |
def forward(self, x, input_lengths, speaker_embedding):
|
|
|
|
|
|
|
|
|
|
| 172 |
for conv in self.convolutions:
|
| 173 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 174 |
|
| 175 |
x = x.transpose(1, 2)
|
| 176 |
|
| 177 |
-
# No changes in the concatenation part
|
| 178 |
-
batch_size = x.size()[0]
|
| 179 |
-
num_chars = x.size()[1]
|
| 180 |
-
idx = 0 if speaker_embedding.dim() == 1 else 1
|
| 181 |
-
speaker_embedding_size = speaker_embedding.size()[idx]
|
| 182 |
-
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
| 183 |
-
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
| 184 |
-
e = e.transpose(1, 2)
|
| 185 |
-
x = torch.cat((x, e), 2)
|
| 186 |
-
|
| 187 |
input_lengths = input_lengths.cpu().numpy()
|
| 188 |
x = nn.utils.rnn.pack_padded_sequence(
|
| 189 |
x, input_lengths, batch_first=True)
|
|
@@ -193,31 +183,21 @@ class Encoder(nn.Module):
|
|
| 193 |
|
| 194 |
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
| 195 |
outputs, batch_first=True)
|
|
|
|
| 196 |
return outputs
|
| 197 |
|
| 198 |
-
def inference(self, x
|
| 199 |
for conv in self.convolutions:
|
| 200 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 201 |
|
| 202 |
x = x.transpose(1, 2)
|
| 203 |
|
| 204 |
-
if speaker_embedding is not None:
|
| 205 |
-
batch_size = x.size()[0]
|
| 206 |
-
num_chars = x.size()[1]
|
| 207 |
-
idx = 0 if speaker_embedding.dim() == 1 else 1
|
| 208 |
-
speaker_embedding_size = speaker_embedding.size()[idx]
|
| 209 |
-
e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
|
| 210 |
-
e = e.reshape(batch_size, speaker_embedding_size, num_chars)
|
| 211 |
-
e = e.transpose(1, 2)
|
| 212 |
-
x = torch.cat((x, e), 2)
|
| 213 |
-
|
| 214 |
self.lstm.flatten_parameters()
|
| 215 |
outputs, _ = self.lstm(x)
|
| 216 |
|
| 217 |
return outputs
|
| 218 |
|
| 219 |
|
| 220 |
-
|
| 221 |
class Decoder(nn.Module):
|
| 222 |
def __init__(self, hparams):
|
| 223 |
super(Decoder, self).__init__()
|
|
@@ -409,6 +389,7 @@ class Decoder(nn.Module):
|
|
| 409 |
gate_outputs: gate outputs from the decoder
|
| 410 |
alignments: sequence of attention weights from the decoder
|
| 411 |
"""
|
|
|
|
| 412 |
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
| 413 |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
| 414 |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
|
@@ -488,17 +469,15 @@ class Tacotron2(nn.Module):
|
|
| 488 |
|
| 489 |
def parse_batch(self, batch):
|
| 490 |
text_padded, input_lengths, mel_padded, gate_padded, \
|
| 491 |
-
output_lengths
|
| 492 |
text_padded = to_gpu(text_padded).long()
|
| 493 |
input_lengths = to_gpu(input_lengths).long()
|
| 494 |
max_len = torch.max(input_lengths.data).item()
|
| 495 |
mel_padded = to_gpu(mel_padded).float()
|
| 496 |
gate_padded = to_gpu(gate_padded).float()
|
| 497 |
output_lengths = to_gpu(output_lengths).long()
|
| 498 |
-
mel_speaker = to_gpu(mel_speaker).float()
|
| 499 |
|
| 500 |
return (
|
| 501 |
-
mel_speaker,
|
| 502 |
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
| 503 |
(mel_padded, gate_padded))
|
| 504 |
|
|
@@ -520,13 +499,13 @@ class Tacotron2(nn.Module):
|
|
| 520 |
|
| 521 |
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
| 522 |
|
|
|
|
| 523 |
encoder_outputs = self.encoder(embedded_inputs, text_lengths, speaker_embedding)
|
| 524 |
|
| 525 |
mel_outputs, gate_outputs, alignments = self.decoder(
|
| 526 |
encoder_outputs, mels, memory_lengths=text_lengths)
|
| 527 |
|
| 528 |
mel_outputs_postnet = self.postnet(mel_outputs)
|
| 529 |
-
|
| 530 |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 531 |
|
| 532 |
return self.parse_output(
|
|
@@ -535,7 +514,9 @@ class Tacotron2(nn.Module):
|
|
| 535 |
|
| 536 |
def inference(self, inputs, speaker_embedding):
|
| 537 |
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
| 538 |
-
|
|
|
|
|
|
|
| 539 |
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
| 540 |
encoder_outputs)
|
| 541 |
|
|
@@ -545,4 +526,5 @@ class Tacotron2(nn.Module):
|
|
| 545 |
outputs = self.parse_output(
|
| 546 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
| 547 |
|
| 548 |
-
return outputs
|
|
|
|
|
|
| 149 |
class Encoder(nn.Module):
|
| 150 |
def __init__(self, hparams):
|
| 151 |
super(Encoder, self).__init__()
|
|
|
|
| 152 |
convolutions = []
|
| 153 |
for _ in range(hparams.encoder_n_convolutions):
|
| 154 |
conv_layer = nn.Sequential(
|
|
|
|
| 161 |
convolutions.append(conv_layer)
|
| 162 |
self.convolutions = nn.ModuleList(convolutions)
|
| 163 |
|
| 164 |
+
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
|
|
|
|
|
|
|
| 165 |
int(hparams.encoder_embedding_dim / 2), 1,
|
| 166 |
batch_first=True, bidirectional=True)
|
| 167 |
|
| 168 |
def forward(self, x, input_lengths, speaker_embedding):
|
| 169 |
+
# Modify the input x to concatenate the speaker embedding
|
| 170 |
+
x = torch.cat((x, speaker_embedding.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1)
|
| 171 |
+
|
| 172 |
for conv in self.convolutions:
|
| 173 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 174 |
|
| 175 |
x = x.transpose(1, 2)
|
| 176 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
input_lengths = input_lengths.cpu().numpy()
|
| 178 |
x = nn.utils.rnn.pack_padded_sequence(
|
| 179 |
x, input_lengths, batch_first=True)
|
|
|
|
| 183 |
|
| 184 |
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
| 185 |
outputs, batch_first=True)
|
| 186 |
+
|
| 187 |
return outputs
|
| 188 |
|
| 189 |
+
def inference(self, x):
|
| 190 |
for conv in self.convolutions:
|
| 191 |
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 192 |
|
| 193 |
x = x.transpose(1, 2)
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
self.lstm.flatten_parameters()
|
| 196 |
outputs, _ = self.lstm(x)
|
| 197 |
|
| 198 |
return outputs
|
| 199 |
|
| 200 |
|
|
|
|
| 201 |
class Decoder(nn.Module):
|
| 202 |
def __init__(self, hparams):
|
| 203 |
super(Decoder, self).__init__()
|
|
|
|
| 389 |
gate_outputs: gate outputs from the decoder
|
| 390 |
alignments: sequence of attention weights from the decoder
|
| 391 |
"""
|
| 392 |
+
|
| 393 |
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
| 394 |
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
| 395 |
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
|
|
|
| 469 |
|
| 470 |
def parse_batch(self, batch):
|
| 471 |
text_padded, input_lengths, mel_padded, gate_padded, \
|
| 472 |
+
output_lengths = batch
|
| 473 |
text_padded = to_gpu(text_padded).long()
|
| 474 |
input_lengths = to_gpu(input_lengths).long()
|
| 475 |
max_len = torch.max(input_lengths.data).item()
|
| 476 |
mel_padded = to_gpu(mel_padded).float()
|
| 477 |
gate_padded = to_gpu(gate_padded).float()
|
| 478 |
output_lengths = to_gpu(output_lengths).long()
|
|
|
|
| 479 |
|
| 480 |
return (
|
|
|
|
| 481 |
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
| 482 |
(mel_padded, gate_padded))
|
| 483 |
|
|
|
|
| 499 |
|
| 500 |
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
| 501 |
|
| 502 |
+
# Pass the speaker embedding to the Encoder
|
| 503 |
encoder_outputs = self.encoder(embedded_inputs, text_lengths, speaker_embedding)
|
| 504 |
|
| 505 |
mel_outputs, gate_outputs, alignments = self.decoder(
|
| 506 |
encoder_outputs, mels, memory_lengths=text_lengths)
|
| 507 |
|
| 508 |
mel_outputs_postnet = self.postnet(mel_outputs)
|
|
|
|
| 509 |
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 510 |
|
| 511 |
return self.parse_output(
|
|
|
|
| 514 |
|
| 515 |
def inference(self, inputs, speaker_embedding):
|
| 516 |
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
| 517 |
+
# Pass the speaker embedding to the Encoder
|
| 518 |
+
encoder_outputs = self.encoder.inference(embedded_inputs, speaker_embedding)
|
| 519 |
+
|
| 520 |
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
| 521 |
encoder_outputs)
|
| 522 |
|
|
|
|
| 526 |
outputs = self.parse_output(
|
| 527 |
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
| 528 |
|
| 529 |
+
return outputs
|
| 530 |
+
|