Update modeling_vits.py
Browse files- modeling_vits.py +6 -2
modeling_vits.py
CHANGED
|
@@ -26,5 +26,9 @@ class ModVitsModel(VitsModel):
|
|
| 26 |
|
| 27 |
@torch.inference_mode()
|
| 28 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
|
| 29 |
-
speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(kwargs['emotion_id'])
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
@torch.inference_mode()
|
| 28 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
|
| 29 |
+
# speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(kwargs['emotion_id'])
|
| 30 |
+
audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
|
| 31 |
+
B, T = audio.waveform.shape
|
| 32 |
+
mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
|
| 33 |
+
audio.waveform.masked_fill_(~mask, 0)
|
| 34 |
+
return audio
|