File size: 2,234 Bytes
604379a
 
 
558b8b1
604379a
 
 
 
 
 
 
 
 
 
 
 
38bcf03
 
 
558b8b1
604379a
6c02118
604379a
 
38bcf03
6c02118
 
 
558b8b1
b02228a
 
 
 
 
558b8b1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from transformers import VitsModel , VitsConfig
from torch import nn
from torch.nn.utils.parametrizations import weight_norm
from safetensors.torch import load_file
import torch

class ModVitsModel(VitsModel):
    def __init__(self, config: VitsConfig):
        config.num_speakers = len(config.emotion_names) * len(config.speaker_names)
        super().__init__(config)
    def init_weights(self):
        self.decoder.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.decoder.upsampler])
        for block in self.decoder.resblocks:
          block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
          block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
        return super().init_weights()

    @staticmethod
    def _load_pretrained_model(model, state_dict, checkpoint_files, load_config):
        state_dict = load_file(checkpoint_files[0])
        speakers = state_dict['embed_speaker.weight'][:len(model.config.speaker_names)]
        emotions = state_dict['embed_emotion.weight'][:len(model.config.emotion_names)+len(model.config.undefined_emotion_index)][[i for i in range(len(model.config.emotion_names)+len(model.config.undefined_emotion_index)) if i not in model.config.undefined_emotion_index]]
        state_dict['embed_speaker.weight'] = torch.stack([s + e for s in speakers for e in emotions]).reshape(-1, model.config.speaker_embedding_size)
        del state_dict['embed_emotion.weight']
        return super()._load_pretrained_model(model, state_dict, checkpoint_files, load_config)
    
    @torch.inference_mode()
    def forward(self, input_ids = None, attention_mask = None, speaker_id = None, output_attentions = None, output_hidden_states = None, return_dict = None, labels = None, **kwargs):
        speaker_id = speaker_id * len(self.config.emotion_names) + kwargs['style_id']
        audio = super().forward(input_ids, attention_mask, speaker_id, output_attentions, output_hidden_states, return_dict, labels, **kwargs)
        B, T = audio.waveform.shape
        mask = torch.arange(T, device=audio.waveform.device).expand(B, T) < audio.sequence_lengths.unsqueeze(1)
        audio.waveform.masked_fill_(~mask, 0)
        return audio