shethjenil commited on
Commit
1f135f7
·
verified ·
1 Parent(s): aee20da

Upload 6 files

Browse files
Files changed (2) hide show
  1. config.json +4 -3
  2. modeling_vits.py +182 -176
config.json CHANGED
@@ -2,11 +2,12 @@
2
  "_name_or_path": "rasa_boosted",
3
  "activation_dropout": 0.1,
4
  "architectures": [
5
- "VitsModel"
6
  ],
7
  "attention_dropout": 0.1,
8
  "auto_map": {
9
- "AutoModel": "modeling_vits.VitsModel"
 
10
  },
11
  "depth_separable_channels": 2,
12
  "depth_separable_num_layers": 3,
@@ -27,7 +28,7 @@
27
  "layer_norm_eps": 1e-05,
28
  "layerdrop": 0.1,
29
  "leaky_relu_slope": 0.1,
30
- "model_type": "vits",
31
  "noise_scale": 0.667,
32
  "noise_scale_duration": 0.8,
33
  "num_attention_heads": 2,
 
2
  "_name_or_path": "rasa_boosted",
3
  "activation_dropout": 0.1,
4
  "architectures": [
5
+ "IndicVitsModel"
6
  ],
7
  "attention_dropout": 0.1,
8
  "auto_map": {
9
+ "AutoConfig": "transformers.VitsConfig",
10
+ "AutoModel": "modeling_vits.IndicVitsModel"
11
  },
12
  "depth_separable_channels": 2,
13
  "depth_separable_num_layers": 3,
 
28
  "layer_norm_eps": 1e-05,
29
  "layerdrop": 0.1,
30
  "leaky_relu_slope": 0.1,
31
+ "model_type": "indic_vits_model",
32
  "noise_scale": 0.667,
33
  "noise_scale_duration": 0.8,
34
  "num_attention_heads": 2,
modeling_vits.py CHANGED
@@ -1,177 +1,183 @@
1
- from typing import Any, Optional, Union
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from transformers import VitsPreTrainedModel , VitsConfig
6
- from transformers.models.vits.modeling_vits import VitsTextEncoder , VitsResidualCouplingBlock , VitsHifiGan , VitsStochasticDurationPredictor , VitsDurationPredictor , VitsPosteriorEncoder , VitsModelOutput
7
- from transformers.utils import auto_docstring
8
-
9
- @auto_docstring(
10
- custom_intro="""
11
- The complete VITS model, for text-to-speech synthesis.
12
- """
13
- )
14
- class VitsModel(VitsPreTrainedModel):
15
- def __init__(self, config: VitsConfig):
16
- super().__init__(config)
17
- self.config = config
18
- self.text_encoder = VitsTextEncoder(config)
19
- self.flow = VitsResidualCouplingBlock(config)
20
- self.decoder = VitsHifiGan(config)
21
-
22
- if config.use_stochastic_duration_prediction:
23
- self.duration_predictor = VitsStochasticDurationPredictor(config)
24
- else:
25
- self.duration_predictor = VitsDurationPredictor(config)
26
-
27
- if config.num_speakers > 1:
28
- self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
29
-
30
- if config.num_emotions > 1:
31
- self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
32
-
33
- # This is used only for training.
34
- self.posterior_encoder = VitsPosteriorEncoder(config)
35
-
36
- # These parameters control the synthesised speech properties
37
- self.speaking_rate = config.speaking_rate
38
- self.noise_scale = config.noise_scale
39
- self.noise_scale_duration = config.noise_scale_duration
40
-
41
- # Initialize weights and apply final processing
42
- self.post_init()
43
-
44
- @auto_docstring
45
- def forward(
46
- self,
47
- input_ids: Optional[torch.Tensor] = None,
48
- attention_mask: Optional[torch.Tensor] = None,
49
- speaker_id: Optional[int] = None,
50
- emotion_id: Optional[int] = None,
51
- output_attentions: Optional[bool] = None,
52
- output_hidden_states: Optional[bool] = None,
53
- return_dict: Optional[bool] = None,
54
- labels: Optional[torch.FloatTensor] = None,
55
- **kwargs,
56
- ) -> Union[tuple[Any], VitsModelOutput]:
57
- r"""
58
- speaker_id (`int`, *optional*):
59
- Which speaker embedding to use. Only used for multispeaker models.
60
- labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
61
- Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
62
- computation.
63
-
64
- Example:
65
-
66
- ```python
67
- >>> from transformers import VitsTokenizer, VitsModel, set_seed
68
- >>> import torch
69
-
70
- >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
71
- >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
72
-
73
- >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
74
-
75
- >>> set_seed(555) # make deterministic
76
-
77
- >>> with torch.no_grad():
78
- ... outputs = model(inputs["input_ids"])
79
- >>> outputs.waveform.shape
80
- torch.Size([1, 45824])
81
- ```
82
- """
83
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
84
- output_hidden_states = (
85
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
86
- )
87
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88
-
89
- if labels is not None:
90
- raise NotImplementedError("Training of VITS is not supported yet.")
91
-
92
- mask_dtype = self.text_encoder.embed_tokens.weight.dtype
93
- if attention_mask is not None:
94
- input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
95
- else:
96
- input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
97
-
98
- if self.config.num_speakers > 1 and speaker_id is not None:
99
- if not 0 <= speaker_id < self.config.num_speakers:
100
- raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}. or Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
101
- if isinstance(speaker_id, int) and isinstance(emotion_id, int):
102
- speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
103
- emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
104
- speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + self.embed_emotion(emotion_id).unsqueeze(-1)
105
- else:
106
- speaker_embeddings = None
107
-
108
- text_encoder_output = self.text_encoder(
109
- input_ids=input_ids,
110
- padding_mask=input_padding_mask,
111
- attention_mask=attention_mask,
112
- output_attentions=output_attentions,
113
- output_hidden_states=output_hidden_states,
114
- return_dict=return_dict,
115
- )
116
- hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
117
- hidden_states = hidden_states.transpose(1, 2)
118
- input_padding_mask = input_padding_mask.transpose(1, 2)
119
- prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
120
- prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
121
-
122
- if self.config.use_stochastic_duration_prediction:
123
- log_duration = self.duration_predictor(
124
- hidden_states,
125
- input_padding_mask,
126
- speaker_embeddings,
127
- reverse=True,
128
- noise_scale=self.noise_scale_duration,
129
- )
130
- else:
131
- log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
132
-
133
- length_scale = 1.0 / self.speaking_rate
134
- duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
135
- predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
136
-
137
- # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
138
- indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
139
- output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
140
- output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
141
-
142
- # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
143
- attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
144
- batch_size, _, output_length, input_length = attn_mask.shape
145
- cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
146
- indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
147
- valid_indices = indices.unsqueeze(0) < cum_duration
148
- valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
149
- padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
150
- attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
151
-
152
- # Expand prior distribution
153
- prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
154
- prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
155
-
156
- prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
157
- latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
158
-
159
- spectrogram = latents * output_padding_mask
160
- waveform = self.decoder(spectrogram, speaker_embeddings)
161
- waveform = waveform.squeeze(1)
162
- sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
163
-
164
- if not return_dict:
165
- outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
166
- return outputs
167
-
168
- return VitsModelOutput(
169
- waveform=waveform,
170
- sequence_lengths=sequence_lengths,
171
- spectrogram=spectrogram,
172
- hidden_states=text_encoder_output.hidden_states,
173
- attentions=text_encoder_output.attentions,
174
- )
175
-
176
-
 
 
 
 
 
 
177
  __all__ = ["VitsModel"]
 
1
+ from typing import Any, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from transformers import VitsPreTrainedModel , VitsConfig
6
+ from transformers.models.vits.modeling_vits import VitsTextEncoder , VitsResidualCouplingBlock , VitsHifiGan , VitsStochasticDurationPredictor , VitsDurationPredictor , VitsPosteriorEncoder , VitsModelOutput
7
+ from transformers.utils import auto_docstring
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+ @auto_docstring(
10
+ custom_intro="""
11
+ The complete VITS model, for text-to-speech synthesis.
12
+ """
13
+ )
14
+ class VitsModel(VitsPreTrainedModel):
15
+ def __init__(self, config: VitsConfig):
16
+ super().__init__(config)
17
+ self.config = config
18
+ self.text_encoder = VitsTextEncoder(config)
19
+ self.flow = VitsResidualCouplingBlock(config)
20
+ self.decoder = VitsHifiGan(config)
21
+
22
+ if config.use_stochastic_duration_prediction:
23
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
24
+ else:
25
+ self.duration_predictor = VitsDurationPredictor(config)
26
+
27
+ if config.num_speakers > 1:
28
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
29
+
30
+ if config.num_emotions > 1:
31
+ self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
32
+
33
+ # This is used only for training.
34
+ self.posterior_encoder = VitsPosteriorEncoder(config)
35
+
36
+ # These parameters control the synthesised speech properties
37
+ self.speaking_rate = config.speaking_rate
38
+ self.noise_scale = config.noise_scale
39
+ self.noise_scale_duration = config.noise_scale_duration
40
+
41
+ # Weight Norm Apply
42
+ for block in self.decoder.resblocks:
43
+ block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
44
+ block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
45
+ self.decoder.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.decoder.upsampler])
46
+
47
+ # Initialize weights and apply final processing
48
+ self.post_init()
49
+
50
+ @auto_docstring
51
+ def forward(
52
+ self,
53
+ input_ids: Optional[torch.Tensor] = None,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ speaker_id: Optional[int] = None,
56
+ emotion_id: Optional[int] = None,
57
+ output_attentions: Optional[bool] = None,
58
+ output_hidden_states: Optional[bool] = None,
59
+ return_dict: Optional[bool] = None,
60
+ labels: Optional[torch.FloatTensor] = None,
61
+ **kwargs,
62
+ ) -> Union[tuple[Any], VitsModelOutput]:
63
+ r"""
64
+ speaker_id (`int`, *optional*):
65
+ Which speaker embedding to use. Only used for multispeaker models.
66
+ labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
67
+ Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
68
+ computation.
69
+
70
+ Example:
71
+
72
+ ```python
73
+ >>> from transformers import VitsTokenizer, VitsModel, set_seed
74
+ >>> import torch
75
+
76
+ >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
77
+ >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
78
+
79
+ >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
80
+
81
+ >>> set_seed(555) # make deterministic
82
+
83
+ >>> with torch.no_grad():
84
+ ... outputs = model(inputs["input_ids"])
85
+ >>> outputs.waveform.shape
86
+ torch.Size([1, 45824])
87
+ ```
88
+ """
89
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
+ output_hidden_states = (
91
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
92
+ )
93
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
94
+
95
+ if labels is not None:
96
+ raise NotImplementedError("Training of VITS is not supported yet.")
97
+
98
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
99
+ if attention_mask is not None:
100
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
101
+ else:
102
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
103
+
104
+ if self.config.num_speakers > 1 and speaker_id is not None:
105
+ if not 0 <= speaker_id < self.config.num_speakers:
106
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}. or Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
107
+ if isinstance(speaker_id, int) and isinstance(emotion_id, int):
108
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
109
+ emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
110
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + self.embed_emotion(emotion_id).unsqueeze(-1)
111
+ else:
112
+ speaker_embeddings = None
113
+
114
+ text_encoder_output = self.text_encoder(
115
+ input_ids=input_ids,
116
+ padding_mask=input_padding_mask,
117
+ attention_mask=attention_mask,
118
+ output_attentions=output_attentions,
119
+ output_hidden_states=output_hidden_states,
120
+ return_dict=return_dict,
121
+ )
122
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
123
+ hidden_states = hidden_states.transpose(1, 2)
124
+ input_padding_mask = input_padding_mask.transpose(1, 2)
125
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
126
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
127
+
128
+ if self.config.use_stochastic_duration_prediction:
129
+ log_duration = self.duration_predictor(
130
+ hidden_states,
131
+ input_padding_mask,
132
+ speaker_embeddings,
133
+ reverse=True,
134
+ noise_scale=self.noise_scale_duration,
135
+ )
136
+ else:
137
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
138
+
139
+ length_scale = 1.0 / self.speaking_rate
140
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
141
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
142
+
143
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
144
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
145
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
146
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
147
+
148
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
149
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
150
+ batch_size, _, output_length, input_length = attn_mask.shape
151
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
152
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
153
+ valid_indices = indices.unsqueeze(0) < cum_duration
154
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
155
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
156
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
157
+
158
+ # Expand prior distribution
159
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
160
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
161
+
162
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
163
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
164
+
165
+ spectrogram = latents * output_padding_mask
166
+ waveform = self.decoder(spectrogram, speaker_embeddings)
167
+ waveform = waveform.squeeze(1)
168
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
169
+
170
+ if not return_dict:
171
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
172
+ return outputs
173
+
174
+ return VitsModelOutput(
175
+ waveform=waveform,
176
+ sequence_lengths=sequence_lengths,
177
+ spectrogram=spectrogram,
178
+ hidden_states=text_encoder_output.hidden_states,
179
+ attentions=text_encoder_output.attentions,
180
+ )
181
+
182
+
183
  __all__ = ["VitsModel"]