shethjenil commited on
Commit
c9d527b
·
verified ·
1 Parent(s): c35e06b

Update modeling_vits.py

Browse files
Files changed (1) hide show
  1. modeling_vits.py +185 -182
modeling_vits.py CHANGED
@@ -1,183 +1,186 @@
1
- from typing import Any, Optional, Union
2
- import numpy as np
3
- import torch
4
- from torch import nn
5
- from transformers import VitsPreTrainedModel , VitsConfig
6
- from transformers.models.vits.modeling_vits import VitsTextEncoder , VitsResidualCouplingBlock , VitsHifiGan , VitsStochasticDurationPredictor , VitsDurationPredictor , VitsPosteriorEncoder , VitsModelOutput
7
- from transformers.utils import auto_docstring
8
- from torch.nn.utils.parametrizations import weight_norm
9
- @auto_docstring(
10
- custom_intro="""
11
- The complete VITS model, for text-to-speech synthesis.
12
- """
13
- )
14
- class VitsModel(VitsPreTrainedModel):
15
- def __init__(self, config: VitsConfig):
16
- super().__init__(config)
17
- self.config = config
18
- self.text_encoder = VitsTextEncoder(config)
19
- self.flow = VitsResidualCouplingBlock(config)
20
- self.decoder = VitsHifiGan(config)
21
-
22
- if config.use_stochastic_duration_prediction:
23
- self.duration_predictor = VitsStochasticDurationPredictor(config)
24
- else:
25
- self.duration_predictor = VitsDurationPredictor(config)
26
-
27
- if config.num_speakers > 1:
28
- self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
29
-
30
- if config.num_emotions > 1:
31
- self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
32
-
33
- # This is used only for training.
34
- self.posterior_encoder = VitsPosteriorEncoder(config)
35
-
36
- # These parameters control the synthesised speech properties
37
- self.speaking_rate = config.speaking_rate
38
- self.noise_scale = config.noise_scale
39
- self.noise_scale_duration = config.noise_scale_duration
40
-
41
- # Weight Norm Apply
42
- for block in self.decoder.resblocks:
43
- block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
44
- block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
45
- self.decoder.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.decoder.upsampler])
46
-
47
- # Initialize weights and apply final processing
48
- self.post_init()
49
-
50
- @auto_docstring
51
- def forward(
52
- self,
53
- input_ids: Optional[torch.Tensor] = None,
54
- attention_mask: Optional[torch.Tensor] = None,
55
- speaker_id: Optional[int] = None,
56
- emotion_id: Optional[int] = None,
57
- output_attentions: Optional[bool] = None,
58
- output_hidden_states: Optional[bool] = None,
59
- return_dict: Optional[bool] = None,
60
- labels: Optional[torch.FloatTensor] = None,
61
- **kwargs,
62
- ) -> Union[tuple[Any], VitsModelOutput]:
63
- r"""
64
- speaker_id (`int`, *optional*):
65
- Which speaker embedding to use. Only used for multispeaker models.
66
- labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
67
- Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
68
- computation.
69
-
70
- Example:
71
-
72
- ```python
73
- >>> from transformers import VitsTokenizer, VitsModel, set_seed
74
- >>> import torch
75
-
76
- >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
77
- >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
78
-
79
- >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
80
-
81
- >>> set_seed(555) # make deterministic
82
-
83
- >>> with torch.no_grad():
84
- ... outputs = model(inputs["input_ids"])
85
- >>> outputs.waveform.shape
86
- torch.Size([1, 45824])
87
- ```
88
- """
89
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
90
- output_hidden_states = (
91
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
92
- )
93
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
94
-
95
- if labels is not None:
96
- raise NotImplementedError("Training of VITS is not supported yet.")
97
-
98
- mask_dtype = self.text_encoder.embed_tokens.weight.dtype
99
- if attention_mask is not None:
100
- input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
101
- else:
102
- input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
103
-
104
- if self.config.num_speakers > 1 and speaker_id is not None:
105
- if not 0 <= speaker_id < self.config.num_speakers:
106
- raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}. or Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
107
- if isinstance(speaker_id, int) and isinstance(emotion_id, int):
108
- speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
109
- emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
110
- speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + self.embed_emotion(emotion_id).unsqueeze(-1)
111
- else:
112
- speaker_embeddings = None
113
-
114
- text_encoder_output = self.text_encoder(
115
- input_ids=input_ids,
116
- padding_mask=input_padding_mask,
117
- attention_mask=attention_mask,
118
- output_attentions=output_attentions,
119
- output_hidden_states=output_hidden_states,
120
- return_dict=return_dict,
121
- )
122
- hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
123
- hidden_states = hidden_states.transpose(1, 2)
124
- input_padding_mask = input_padding_mask.transpose(1, 2)
125
- prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
126
- prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
127
-
128
- if self.config.use_stochastic_duration_prediction:
129
- log_duration = self.duration_predictor(
130
- hidden_states,
131
- input_padding_mask,
132
- speaker_embeddings,
133
- reverse=True,
134
- noise_scale=self.noise_scale_duration,
135
- )
136
- else:
137
- log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
138
-
139
- length_scale = 1.0 / self.speaking_rate
140
- duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
141
- predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
142
-
143
- # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
144
- indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
145
- output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
146
- output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
147
-
148
- # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
149
- attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
150
- batch_size, _, output_length, input_length = attn_mask.shape
151
- cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
152
- indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
153
- valid_indices = indices.unsqueeze(0) < cum_duration
154
- valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
155
- padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
156
- attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
157
-
158
- # Expand prior distribution
159
- prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
160
- prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
161
-
162
- prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
163
- latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
164
-
165
- spectrogram = latents * output_padding_mask
166
- waveform = self.decoder(spectrogram, speaker_embeddings)
167
- waveform = waveform.squeeze(1)
168
- sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
169
-
170
- if not return_dict:
171
- outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
172
- return outputs
173
-
174
- return VitsModelOutput(
175
- waveform=waveform,
176
- sequence_lengths=sequence_lengths,
177
- spectrogram=spectrogram,
178
- hidden_states=text_encoder_output.hidden_states,
179
- attentions=text_encoder_output.attentions,
180
- )
181
-
182
-
 
 
 
183
  __all__ = ["VitsModel"]
 
1
+ from typing import Any, Optional, Union
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from transformers import VitsPreTrainedModel , VitsConfig
6
+ from transformers.models.vits.modeling_vits import VitsTextEncoder , VitsResidualCouplingBlock , VitsHifiGan , VitsStochasticDurationPredictor , VitsDurationPredictor , VitsPosteriorEncoder , VitsModelOutput
7
+ from transformers.utils import auto_docstring
8
+ from torch.nn.utils.parametrizations import weight_norm
9
+ @auto_docstring(
10
+ custom_intro="""
11
+ The complete VITS model, for text-to-speech synthesis.
12
+ """
13
+ )
14
+ class VitsModel(VitsPreTrainedModel):
15
+ def __init__(self, config: VitsConfig):
16
+ super().__init__(config)
17
+ self.config = config
18
+ self.text_encoder = VitsTextEncoder(config)
19
+ self.flow = VitsResidualCouplingBlock(config)
20
+ self.decoder = VitsHifiGan(config)
21
+
22
+ if config.use_stochastic_duration_prediction:
23
+ self.duration_predictor = VitsStochasticDurationPredictor(config)
24
+ else:
25
+ self.duration_predictor = VitsDurationPredictor(config)
26
+
27
+ if config.num_speakers > 1:
28
+ self.embed_speaker = nn.Embedding(config.num_speakers, config.speaker_embedding_size)
29
+
30
+ if config.num_emotions > 1:
31
+ self.embed_emotion = nn.Embedding(config.num_emotions, config.emotion_embedding_size)
32
+
33
+ # This is used only for training.
34
+ self.posterior_encoder = VitsPosteriorEncoder(config)
35
+
36
+ # These parameters control the synthesised speech properties
37
+ self.speaking_rate = config.speaking_rate
38
+ self.noise_scale = config.noise_scale
39
+ self.noise_scale_duration = config.noise_scale_duration
40
+
41
+ # Weight Norm Apply
42
+ for block in self.decoder.resblocks:
43
+ block.convs1 = nn.ModuleList([weight_norm(layer) for layer in block.convs1])
44
+ block.convs2 = nn.ModuleList([weight_norm(layer) for layer in block.convs2])
45
+ self.decoder.upsampler = nn.ModuleList([weight_norm(layer) for layer in self.decoder.upsampler])
46
+
47
+ # Initialize weights and apply final processing
48
+ self.post_init()
49
+
50
+ @auto_docstring
51
+ def forward(
52
+ self,
53
+ input_ids: Optional[torch.Tensor] = None,
54
+ attention_mask: Optional[torch.Tensor] = None,
55
+ speaker_id: Optional[int] = None,
56
+ emotion_id: Optional[int] = None,
57
+ output_attentions: Optional[bool] = None,
58
+ output_hidden_states: Optional[bool] = None,
59
+ return_dict: Optional[bool] = None,
60
+ labels: Optional[torch.FloatTensor] = None,
61
+ **kwargs,
62
+ ) -> Union[tuple[Any], VitsModelOutput]:
63
+ r"""
64
+ speaker_id (`int`, *optional*):
65
+ Which speaker embedding to use. Only used for multispeaker models.
66
+ emotion_id (`int`, *optional*):
67
+ Which emotion embedding to use. Only used for multiemotion models.
68
+
69
+ labels (`torch.FloatTensor` of shape `(batch_size, config.spectrogram_bins, sequence_length)`, *optional*):
70
+ Float values of target spectrogram. Timesteps set to `-100.0` are ignored (masked) for the loss
71
+ computation.
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import VitsTokenizer, VitsModel, set_seed
77
+ >>> import torch
78
+
79
+ >>> tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
80
+ >>> model = VitsModel.from_pretrained("facebook/mms-tts-eng")
81
+
82
+ >>> inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")
83
+
84
+ >>> set_seed(555) # make deterministic
85
+
86
+ >>> with torch.no_grad():
87
+ ... outputs = model(inputs["input_ids"])
88
+ >>> outputs.waveform.shape
89
+ torch.Size([1, 45824])
90
+ ```
91
+ """
92
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
93
+ output_hidden_states = (
94
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
95
+ )
96
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
97
+
98
+ if labels is not None:
99
+ raise NotImplementedError("Training of VITS is not supported yet.")
100
+
101
+ mask_dtype = self.text_encoder.embed_tokens.weight.dtype
102
+ if attention_mask is not None:
103
+ input_padding_mask = attention_mask.unsqueeze(-1).to(mask_dtype)
104
+ else:
105
+ input_padding_mask = torch.ones_like(input_ids).unsqueeze(-1).to(mask_dtype)
106
+
107
+ if self.config.num_speakers > 1 and speaker_id is not None:
108
+ if not 0 <= speaker_id < self.config.num_speakers:
109
+ raise ValueError(f"Set `speaker_id` in the range 0-{self.config.num_speakers - 1}. or Set `emotion_id` in the range 0-{self.config.num_emotions - 1}.")
110
+ if isinstance(speaker_id, int) and isinstance(emotion_id, int):
111
+ speaker_id = torch.full(size=(1,), fill_value=speaker_id, device=self.device)
112
+ emotion_id = torch.full(size=(1,), fill_value=emotion_id, device=self.device)
113
+ speaker_embeddings = self.embed_speaker(speaker_id).unsqueeze(-1) + self.embed_emotion(emotion_id).unsqueeze(-1)
114
+ else:
115
+ speaker_embeddings = None
116
+
117
+ text_encoder_output = self.text_encoder(
118
+ input_ids=input_ids,
119
+ padding_mask=input_padding_mask,
120
+ attention_mask=attention_mask,
121
+ output_attentions=output_attentions,
122
+ output_hidden_states=output_hidden_states,
123
+ return_dict=return_dict,
124
+ )
125
+ hidden_states = text_encoder_output[0] if not return_dict else text_encoder_output.last_hidden_state
126
+ hidden_states = hidden_states.transpose(1, 2)
127
+ input_padding_mask = input_padding_mask.transpose(1, 2)
128
+ prior_means = text_encoder_output[1] if not return_dict else text_encoder_output.prior_means
129
+ prior_log_variances = text_encoder_output[2] if not return_dict else text_encoder_output.prior_log_variances
130
+
131
+ if self.config.use_stochastic_duration_prediction:
132
+ log_duration = self.duration_predictor(
133
+ hidden_states,
134
+ input_padding_mask,
135
+ speaker_embeddings,
136
+ reverse=True,
137
+ noise_scale=self.noise_scale_duration,
138
+ )
139
+ else:
140
+ log_duration = self.duration_predictor(hidden_states, input_padding_mask, speaker_embeddings)
141
+
142
+ length_scale = 1.0 / self.speaking_rate
143
+ duration = torch.ceil(torch.exp(log_duration) * input_padding_mask * length_scale)
144
+ predicted_lengths = torch.clamp_min(torch.sum(duration, [1, 2]), 1).long()
145
+
146
+ # Create a padding mask for the output lengths of shape (batch, 1, max_output_length)
147
+ indices = torch.arange(predicted_lengths.max(), dtype=predicted_lengths.dtype, device=predicted_lengths.device)
148
+ output_padding_mask = indices.unsqueeze(0) < predicted_lengths.unsqueeze(1)
149
+ output_padding_mask = output_padding_mask.unsqueeze(1).to(input_padding_mask.dtype)
150
+
151
+ # Reconstruct an attention tensor of shape (batch, 1, out_length, in_length)
152
+ attn_mask = torch.unsqueeze(input_padding_mask, 2) * torch.unsqueeze(output_padding_mask, -1)
153
+ batch_size, _, output_length, input_length = attn_mask.shape
154
+ cum_duration = torch.cumsum(duration, -1).view(batch_size * input_length, 1)
155
+ indices = torch.arange(output_length, dtype=duration.dtype, device=duration.device)
156
+ valid_indices = indices.unsqueeze(0) < cum_duration
157
+ valid_indices = valid_indices.to(attn_mask.dtype).view(batch_size, input_length, output_length)
158
+ padded_indices = valid_indices - nn.functional.pad(valid_indices, [0, 0, 1, 0, 0, 0])[:, :-1]
159
+ attn = padded_indices.unsqueeze(1).transpose(2, 3) * attn_mask
160
+
161
+ # Expand prior distribution
162
+ prior_means = torch.matmul(attn.squeeze(1), prior_means).transpose(1, 2)
163
+ prior_log_variances = torch.matmul(attn.squeeze(1), prior_log_variances).transpose(1, 2)
164
+
165
+ prior_latents = prior_means + torch.randn_like(prior_means) * torch.exp(prior_log_variances) * self.noise_scale
166
+ latents = self.flow(prior_latents, output_padding_mask, speaker_embeddings, reverse=True)
167
+
168
+ spectrogram = latents * output_padding_mask
169
+ waveform = self.decoder(spectrogram, speaker_embeddings)
170
+ waveform = waveform.squeeze(1)
171
+ sequence_lengths = predicted_lengths * np.prod(self.config.upsample_rates)
172
+
173
+ if not return_dict:
174
+ outputs = (waveform, sequence_lengths, spectrogram) + text_encoder_output[3:]
175
+ return outputs
176
+
177
+ return VitsModelOutput(
178
+ waveform=waveform,
179
+ sequence_lengths=sequence_lengths,
180
+ spectrogram=spectrogram,
181
+ hidden_states=text_encoder_output.hidden_states,
182
+ attentions=text_encoder_output.attentions,
183
+ )
184
+
185
+
186
  __all__ = ["VitsModel"]