Update modeling_vits.py
Browse files- modeling_vits.py +5 -5
modeling_vits.py
CHANGED
|
@@ -14,16 +14,16 @@ class ModVitsModel(VitsModel):
|
|
| 14 |
block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
|
| 15 |
block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
|
| 16 |
return super().init_weights()
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
|
|
|
| 20 |
emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
|
| 21 |
speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
|
| 22 |
state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
|
| 23 |
del state_dict['embed_emotion.weight']
|
| 24 |
-
return super()._load_pretrained_model(model, state_dict, checkpoint_files,
|
| 25 |
|
| 26 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = 'Marathi-Male', emotion_id = 'ALEXA',output_attentions = None, output_hidden_states = None, return_dict = None, labels = None):
|
| 27 |
speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(emotion_id)
|
| 28 |
return super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels)
|
| 29 |
-
|
|
|
|
| 14 |
block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
|
| 15 |
block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
|
| 16 |
return super().init_weights()
|
| 17 |
+
|
| 18 |
+
@staticmethod
|
| 19 |
+
def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
|
| 20 |
+
state_dict = load_state_dict(checkpoint_files[0])
|
| 21 |
emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
|
| 22 |
speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
|
| 23 |
state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
|
| 24 |
del state_dict['embed_emotion.weight']
|
| 25 |
+
return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
|
| 26 |
|
| 27 |
def forward(self, input_ids = None, attention_mask = None, speaker_id = 'Marathi-Male', emotion_id = 'ALEXA',output_attentions = None, output_hidden_states = None, return_dict = None, labels = None):
|
| 28 |
speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(emotion_id)
|
| 29 |
return super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels)
|
|
|