Update modeling_vits.py
Browse files- modeling_vits.py +4 -2
modeling_vits.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from transformers import VitsModel , VitsConfig
|
| 2 |
from torch import nn
|
| 3 |
from torch.nn.utils.parametrizations import weight_norm
|
| 4 |
-
from
|
| 5 |
import torch
|
| 6 |
|
| 7 |
class ModVitsModel(VitsModel):
|
|
@@ -17,7 +17,7 @@ class ModVitsModel(VitsModel):
|
|
| 17 |
|
| 18 |
@staticmethod
|
| 19 |
def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
|
| 20 |
-
state_dict =
|
| 21 |
speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
|
| 22 |
emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
|
| 23 |
state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
|
|
@@ -26,8 +26,10 @@ class ModVitsModel(VitsModel):
|
|
| 26 |
|
| 27 |
@torch.inference_mode()
|
| 28 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
|
|
|
|
| 29 |
audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
|
| 30 |
B, T = audio.waveform.shape
|
| 31 |
mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
|
| 32 |
audio.waveform.masked_fill_(~mask, 0)
|
| 33 |
return audio
|
|
|
|
|
|
| 1 |
from transformers import VitsModel , VitsConfig
|
| 2 |
from torch import nn
|
| 3 |
from torch.nn.utils.parametrizations import weight_norm
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
import torch
|
| 6 |
|
| 7 |
class ModVitsModel(VitsModel):
|
|
|
|
| 17 |
|
| 18 |
@staticmethod
|
| 19 |
def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
|
| 20 |
+
state_dict = load_file(checkpoint_files[0])
|
| 21 |
speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
|
| 22 |
emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
|
| 23 |
state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
|
|
|
|
| 26 |
|
| 27 |
@torch.inference_mode()
|
| 28 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
|
| 29 |
+
speaker_id = speaker_id * len(self.config.emotion_names) + kwargs['style_id']
|
| 30 |
audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
|
| 31 |
B, T = audio.waveform.shape
|
| 32 |
mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
|
| 33 |
audio.waveform.masked_fill_(~mask, 0)
|
| 34 |
return audio
|
| 35 |
+
|