shethjenil commited on
Commit
38bcf03
·
verified ·
1 Parent(s): 604379a

Update modeling_vits.py

Browse files
Files changed (1) hide show
  1. modeling_vits.py +5 -5
modeling_vits.py CHANGED
@@ -14,16 +14,16 @@ class ModVitsModel(VitsModel):
14
  block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
15
  block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
16
  return super().init_weights()
17
- @classmethod
18
- def _load_pretrained_model(cls, model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes = False, sharded_metadata = None, device_map = None, disk_offload_folder = None, dtype = None, hf_quantizer = None, keep_in_fp32_regex = None, device_mesh = None, key_mapping = None, weights_only = True):
19
- state_dict = load_state_dict(checkpoint_files[0],weights_only=weights_only)
 
20
  emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
21
  speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
22
  state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
23
  del state_dict['embed_emotion.weight']
24
- return super()._load_pretrained_model(model, state_dict, checkpoint_files, pretrained_model_name_or_path, ignore_mismatched_sizes, sharded_metadata, device_map, disk_offload_folder, dtype, hf_quantizer, keep_in_fp32_regex, device_mesh, key_mapping, weights_only)
25
 
26
  def forward(self, input_ids = None, attention_mask = None, speaker_id = 'Marathi-Male', emotion_id = 'ALEXA',output_attentions = None, output_hidden_states = None, return_dict = None, labels = None):
27
  speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(emotion_id)
28
  return super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels)
29
-
 
14
  block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
15
  block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
16
  return super().init_weights()
17
+
18
+ @staticmethod
19
+ def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
20
+ state_dict = load_state_dict(checkpoint_files[0])
21
  emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
22
  speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
23
  state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
24
  del state_dict['embed_emotion.weight']
25
+ return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
26
 
27
  def forward(self, input_ids = None, attention_mask = None, speaker_id = 'Marathi-Male', emotion_id = 'ALEXA',output_attentions = None, output_hidden_states = None, return_dict = None, labels = None):
28
  speaker_id = self.config.speaker_names.index(speaker_id) * len(self.config.emotion_names) + self.config.emotion_names.index(emotion_id)
29
  return super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels)