shethjenil commited on
Commit
558b8b1
·
verified ·
1 Parent(s): 66bf596

Update modeling_vits.py

Browse files
Files changed (1) hide show
  1. modeling_vits.py +4 -2
modeling_vits.py CHANGED
@@ -1,7 +1,7 @@
1
  from transformers import VitsModel , VitsConfig
2
  from torch import nn
3
  from torch.nn.utils.parametrizations import weight_norm
4
- from transformers.modeling_utils import load_state_dict
5
  import torch
6
 
7
  class ModVitsModel(VitsModel):
@@ -17,7 +17,7 @@ class ModVitsModel(VitsModel):
17
 
18
  @staticmethod
19
  def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
20
- state_dict = load_state_dict(checkpoint_files[0])
21
  speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
22
  emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
23
  state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
@@ -26,8 +26,10 @@ class ModVitsModel(VitsModel):
26
 
27
  @torch.inference_mode()
28
  def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
 
29
  audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
30
  B, T = audio.waveform.shape
31
  mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
32
  audio.waveform.masked_fill_(~mask, 0)
33
  return audio
 
 
1
  from transformers import VitsModel , VitsConfig
2
  from torch import nn
3
  from torch.nn.utils.parametrizations import weight_norm
4
+ from safetensors.torch import load_file
5
  import torch
6
 
7
  class ModVitsModel(VitsModel):
 
17
 
18
  @staticmethod
19
  def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
20
+ state_dict = load_file(checkpoint_files[0])
21
  speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
22
  emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
23
  state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
 
26
 
27
  @torch.inference_mode()
28
  def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
29
+ speaker_id = speaker_id * len(self.config.emotion_names) + kwargs['style_id']
30
  audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
31
  B, T = audio.waveform.shape
32
  mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
33
  audio.waveform.masked_fill_(~mask, 0)
34
  return audio
35
+