shethjenil commited on
Commit
e669e36
·
verified ·
1 Parent(s): 69a4f0c

Upload 6 files

Browse files
Files changed (2) hide show
  1. config.json +6 -2
  2. modeling_vits.py +177 -0
config.json CHANGED
@@ -2,9 +2,13 @@
2
  "_name_or_path": "rasa_boosted",
3
  "activation_dropout": 0.1,
4
  "architectures": [
5
- "VitsModel"
6
  ],
7
  "attention_dropout": 0.1,
 
 
 
 
8
  "depth_separable_channels": 2,
9
  "depth_separable_num_layers": 3,
10
  "duration_predictor_dropout": 0.5,
@@ -24,7 +28,7 @@
24
  "layer_norm_eps": 1e-05,
25
  "layerdrop": 0.1,
26
  "leaky_relu_slope": 0.1,
27
- "model_type": "vits",
28
  "noise_scale": 0.667,
29
  "noise_scale_duration": 0.8,
30
  "num_attention_heads": 2,
 
2
  "_name_or_path": "rasa_boosted",
3
  "activation_dropout": 0.1,
4
  "architectures": [
5
+ "IndicVitsModel"
6
  ],
7
  "attention_dropout": 0.1,
8
+ "auto_map": {
9
+ "AutoConfig": "VitsConfig",
10
+ "AutoModel": "modeling_vits.IndicVitsModel"
11
+ },
12
  "depth_separable_channels": 2,
13
  "depth_separable_num_layers": 3,
14
  "duration_predictor_dropout": 0.5,
 
28
  "layer_norm_eps": 1e-05,
29
  "layerdrop": 0.1,
30
  "leaky_relu_slope": 0.1,
31
+ "model_type": "indic_vits_model",
32
  "noise_scale": 0.667,
33
  "noise_scale_duration": 0.8,
34
  "num_attention_heads": 2,
modeling_vits.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from transformers import VitsPreTrainedModel , VitsConfig
6
+ from transformers.models.vits.modeling_vits import VitsTextEncoder , VitsResidualCouplingBlock , VitsHifiGan , VitsStochasticDurationPredictor , VitsDurationPredictor , VitsPosteriorEncoder , VitsModelOutput
7
+ from transformers.utils import auto_docstring
8
+
9
+ @auto_docstring(
10
+ custom_intro="""
11
+ The complete VITS model, for text-to-speech synthesis.
12
+ """
13
+ )
14
+ class VitsModel(VitsPreTrainedModel):
15
+ def __init__(self, config: VitsConfig):
16
+ super().__init__(config)
17
+ self.config = config
18
+ self.text_encoder = VitsTextEncoder(config)
19
+ self.flow = VitsResidualCouplingBlock(config)
20
+ self.decoder = VitsHifiGan(config)
21
+
22
+ if config.use_stochastic_duration_prediction:
23
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
24
+ else:
25
+ self.duration_predictor = VitsDurationPredictor(config)
26
+
27
+ if config.num_speakers > 1:
28
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
29
+
30
+ if config.num_emotions > 1:
31
+ self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
32
+
33
+ # This is used only for training.
34
+ self.posterior_encoder = VitsPosteriorEncoder(config)
35
+
36
+ # These parameters control the synthesised speech properties
37
+ self.speaking_rate = config.speaking_rate
38
+ self.noise_scale = config.noise_scale
39
+ self.noise_scale_duration = config.noise_scale_duration
40
+
41
+ # Initialize weights and apply final processing
42
+ self.post_init()
43
+
44
+ @auto_docstring
45
+ def forward(
46
+ self,
47
+ input_ids: Optional[torch.Tensor] = None,
48
+ attention_mask: Optional[torch.Tensor] = None,
49
+ speaker_id: Optional[int] = None,
50
+ emotion_id: Optional[int] = None,
51
+ output_attentions: Optional[bool] = None,
52
+ output_hidden_states: Optional[bool] = None,
53
+ return_dict: Optional[bool] = None,
54
+ labels: Optional[torch.FloatTensor] = None,
55
+ **kwargs,
56
+ ) -> Union[tuple[Any], VitsModelOutput]:
57
+ r"""
58
+ speaker_id (`int`, *optional*):
59
+ Which speaker embedding to use. Only used for multispeaker models.
60
+ labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
61
+ Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
62
+ computation.
63
+
64
+ Example:
65
+
66
+ ```python
67
+ >>> from transformers import VitsTokenizer, VitsModel, set_seed
68
+ >>> import torch
69
+
70
+ >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
71
+ >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
72
+
73
+ >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
74
+
75
+ >>> set_seed(555) # make deterministic
76
+
77
+ >>> with torch.no_grad():
78
+ ... outputs = model(inputs["input_ids"])
79
+ >>> outputs.waveform.shape
80
+ torch.Size([1, 45824])
81
+ ```
82
+ """
83
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
+ output_hidden_states = (
85
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
+ )
87
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
+
89
+ if labels is not None:
90
+ raise NotImplementedError("Training of VITS is not supported yet.")
91
+
92
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
93
+ if attention_mask is not None:
94
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
95
+ else:
96
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
97
+
98
+ if self.config.num_speakers > 1 and speaker_id is not None:
99
+ if not 0 <= speaker_id < self.config.num_speakers:
100
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}. or Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
101
+ if isinstance(speaker_id, int) and isinstance(emotion_id, int):
102
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
103
+ emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
104
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + self.embed_emotion(emotion_id).unsqueeze(-1)
105
+ else:
106
+ speaker_embeddings = None
107
+
108
+ text_encoder_output = self.text_encoder(
109
+ input_ids=input_ids,
110
+ padding_mask=input_padding_mask,
111
+ attention_mask=attention_mask,
112
+ output_attentions=output_attentions,
113
+ output_hidden_states=output_hidden_states,
114
+ return_dict=return_dict,
115
+ )
116
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
117
+ hidden_states = hidden_states.transpose(1, 2)
118
+ input_padding_mask = input_padding_mask.transpose(1, 2)
119
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
120
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
121
+
122
+ if self.config.use_stochastic_duration_prediction:
123
+ log_duration = self.duration_predictor(
124
+ hidden_states,
125
+ input_padding_mask,
126
+ speaker_embeddings,
127
+ reverse=True,
128
+ noise_scale=self.noise_scale_duration,
129
+ )
130
+ else:
131
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
132
+
133
+ length_scale = 1.0 / self.speaking_rate
134
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
135
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
136
+
137
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
138
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
139
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
140
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
141
+
142
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
143
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
144
+ batch_size, _, output_length, input_length = attn_mask.shape
145
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
146
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
147
+ valid_indices = indices.unsqueeze(0) < cum_duration
148
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
149
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
150
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
151
+
152
+ # Expand prior distribution
153
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
154
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
155
+
156
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
157
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
158
+
159
+ spectrogram = latents * output_padding_mask
160
+ waveform = self.decoder(spectrogram, speaker_embeddings)
161
+ waveform = waveform.squeeze(1)
162
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
163
+
164
+ if not return_dict:
165
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
166
+ return outputs
167
+
168
+ return VitsModelOutput(
169
+ waveform=waveform,
170
+ sequence_lengths=sequence_lengths,
171
+ spectrogram=spectrogram,
172
+ hidden_states=text_encoder_output.hidden_states,
173
+ attentions=text_encoder_output.attentions,
174
+ )
175
+
176
+
177
+ __all__ = ["VitsModel"]