lord-reso commited on
Commit
ee51773
·
verified ·
1 Parent(s): a4ba75c

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +14 -32
model.py CHANGED
@@ -149,7 +149,6 @@ class Postnet(nn.Module):
149
  class Encoder(nn.Module):
150
  def __init__(self, hparams):
151
  super(Encoder, self).__init__()
152
-
153
  convolutions = []
154
  for _ in range(hparams.encoder_n_convolutions):
155
  conv_layer = nn.Sequential(
@@ -162,28 +161,19 @@ class Encoder(nn.Module):
162
  convolutions.append(conv_layer)
163
  self.convolutions = nn.ModuleList(convolutions)
164
 
165
- # Modify the input dimensionality for LSTM
166
- lstm_input_dim = hparams.encoder_embedding_dim
167
- self.lstm = nn.LSTM(lstm_input_dim,
168
  int(hparams.encoder_embedding_dim / 2), 1,
169
  batch_first=True, bidirectional=True)
170
 
171
  def forward(self, x, input_lengths, speaker_embedding):
 
 
 
172
  for conv in self.convolutions:
173
  x = F.dropout(F.relu(conv(x)), 0.5, self.training)
174
 
175
  x = x.transpose(1, 2)
176
 
177
- # No changes in the concatenation part
178
- batch_size = x.size()[0]
179
- num_chars = x.size()[1]
180
- idx = 0 if speaker_embedding.dim() == 1 else 1
181
- speaker_embedding_size = speaker_embedding.size()[idx]
182
- e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
183
- e = e.reshape(batch_size, speaker_embedding_size, num_chars)
184
- e = e.transpose(1, 2)
185
- x = torch.cat((x, e), 2)
186
-
187
  input_lengths = input_lengths.cpu().numpy()
188
  x = nn.utils.rnn.pack_padded_sequence(
189
  x, input_lengths, batch_first=True)
@@ -193,31 +183,21 @@ class Encoder(nn.Module):
193
 
194
  outputs, _ = nn.utils.rnn.pad_packed_sequence(
195
  outputs, batch_first=True)
 
196
  return outputs
197
 
198
- def inference(self, x, speaker_embedding=None):
199
  for conv in self.convolutions:
200
  x = F.dropout(F.relu(conv(x)), 0.5, self.training)
201
 
202
  x = x.transpose(1, 2)
203
 
204
- if speaker_embedding is not None:
205
- batch_size = x.size()[0]
206
- num_chars = x.size()[1]
207
- idx = 0 if speaker_embedding.dim() == 1 else 1
208
- speaker_embedding_size = speaker_embedding.size()[idx]
209
- e = speaker_embedding.repeat_interleave(num_chars, dim=idx)
210
- e = e.reshape(batch_size, speaker_embedding_size, num_chars)
211
- e = e.transpose(1, 2)
212
- x = torch.cat((x, e), 2)
213
-
214
  self.lstm.flatten_parameters()
215
  outputs, _ = self.lstm(x)
216
 
217
  return outputs
218
 
219
 
220
-
221
  class Decoder(nn.Module):
222
  def __init__(self, hparams):
223
  super(Decoder, self).__init__()
@@ -409,6 +389,7 @@ class Decoder(nn.Module):
409
  gate_outputs: gate outputs from the decoder
410
  alignments: sequence of attention weights from the decoder
411
  """
 
412
  decoder_input = self.get_go_frame(memory).unsqueeze(0)
413
  decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
414
  decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
@@ -488,17 +469,15 @@ class Tacotron2(nn.Module):
488
 
489
  def parse_batch(self, batch):
490
  text_padded, input_lengths, mel_padded, gate_padded, \
491
- output_lengths, mel_speaker = batch
492
  text_padded = to_gpu(text_padded).long()
493
  input_lengths = to_gpu(input_lengths).long()
494
  max_len = torch.max(input_lengths.data).item()
495
  mel_padded = to_gpu(mel_padded).float()
496
  gate_padded = to_gpu(gate_padded).float()
497
  output_lengths = to_gpu(output_lengths).long()
498
- mel_speaker = to_gpu(mel_speaker).float()
499
 
500
  return (
501
- mel_speaker,
502
  (text_padded, input_lengths, mel_padded, max_len, output_lengths),
503
  (mel_padded, gate_padded))
504
 
@@ -520,13 +499,13 @@ class Tacotron2(nn.Module):
520
 
521
  embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
522
 
 
523
  encoder_outputs = self.encoder(embedded_inputs, text_lengths, speaker_embedding)
524
 
525
  mel_outputs, gate_outputs, alignments = self.decoder(
526
  encoder_outputs, mels, memory_lengths=text_lengths)
527
 
528
  mel_outputs_postnet = self.postnet(mel_outputs)
529
-
530
  mel_outputs_postnet = mel_outputs + mel_outputs_postnet
531
 
532
  return self.parse_output(
@@ -535,7 +514,9 @@ class Tacotron2(nn.Module):
535
 
536
  def inference(self, inputs, speaker_embedding):
537
  embedded_inputs = self.embedding(inputs).transpose(1, 2)
538
- encoder_outputs = self.encoder.inference(embedded_inputs,speaker_embedding)
 
 
539
  mel_outputs, gate_outputs, alignments = self.decoder.inference(
540
  encoder_outputs)
541
 
@@ -545,4 +526,5 @@ class Tacotron2(nn.Module):
545
  outputs = self.parse_output(
546
  [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
547
 
548
- return outputs
 
 
149
  class Encoder(nn.Module):
150
  def __init__(self, hparams):
151
  super(Encoder, self).__init__()
 
152
  convolutions = []
153
  for _ in range(hparams.encoder_n_convolutions):
154
  conv_layer = nn.Sequential(
 
161
  convolutions.append(conv_layer)
162
  self.convolutions = nn.ModuleList(convolutions)
163
 
164
+ self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
 
 
165
  int(hparams.encoder_embedding_dim / 2), 1,
166
  batch_first=True, bidirectional=True)
167
 
168
  def forward(self, x, input_lengths, speaker_embedding):
169
+ # Modify the input x to concatenate the speaker embedding
170
+ x = torch.cat((x, speaker_embedding.unsqueeze(1).expand(-1, x.size(1), -1)), dim=-1)
171
+
172
  for conv in self.convolutions:
173
  x = F.dropout(F.relu(conv(x)), 0.5, self.training)
174
 
175
  x = x.transpose(1, 2)
176
 
 
 
 
 
 
 
 
 
 
 
177
  input_lengths = input_lengths.cpu().numpy()
178
  x = nn.utils.rnn.pack_padded_sequence(
179
  x, input_lengths, batch_first=True)
 
183
 
184
  outputs, _ = nn.utils.rnn.pad_packed_sequence(
185
  outputs, batch_first=True)
186
+
187
  return outputs
188
 
189
+ def inference(self, x):
190
  for conv in self.convolutions:
191
  x = F.dropout(F.relu(conv(x)), 0.5, self.training)
192
 
193
  x = x.transpose(1, 2)
194
 
 
 
 
 
 
 
 
 
 
 
195
  self.lstm.flatten_parameters()
196
  outputs, _ = self.lstm(x)
197
 
198
  return outputs
199
 
200
 
 
201
  class Decoder(nn.Module):
202
  def __init__(self, hparams):
203
  super(Decoder, self).__init__()
 
389
  gate_outputs: gate outputs from the decoder
390
  alignments: sequence of attention weights from the decoder
391
  """
392
+
393
  decoder_input = self.get_go_frame(memory).unsqueeze(0)
394
  decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
395
  decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
 
469
 
470
  def parse_batch(self, batch):
471
  text_padded, input_lengths, mel_padded, gate_padded, \
472
+ output_lengths = batch
473
  text_padded = to_gpu(text_padded).long()
474
  input_lengths = to_gpu(input_lengths).long()
475
  max_len = torch.max(input_lengths.data).item()
476
  mel_padded = to_gpu(mel_padded).float()
477
  gate_padded = to_gpu(gate_padded).float()
478
  output_lengths = to_gpu(output_lengths).long()
 
479
 
480
  return (
 
481
  (text_padded, input_lengths, mel_padded, max_len, output_lengths),
482
  (mel_padded, gate_padded))
483
 
 
499
 
500
  embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
501
 
502
+ # Pass the speaker embedding to the Encoder
503
  encoder_outputs = self.encoder(embedded_inputs, text_lengths, speaker_embedding)
504
 
505
  mel_outputs, gate_outputs, alignments = self.decoder(
506
  encoder_outputs, mels, memory_lengths=text_lengths)
507
 
508
  mel_outputs_postnet = self.postnet(mel_outputs)
 
509
  mel_outputs_postnet = mel_outputs + mel_outputs_postnet
510
 
511
  return self.parse_output(
 
514
 
515
  def inference(self, inputs, speaker_embedding):
516
  embedded_inputs = self.embedding(inputs).transpose(1, 2)
517
+ # Pass the speaker embedding to the Encoder
518
+ encoder_outputs = self.encoder.inference(embedded_inputs, speaker_embedding)
519
+
520
  mel_outputs, gate_outputs, alignments = self.decoder.inference(
521
  encoder_outputs)
522
 
 
526
  outputs = self.parse_output(
527
  [mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
528
 
529
+ return outputs
530
+