Yoshitaka16 commited on
Commit
8e26279
·
verified ·
1 Parent(s): 8c51444

Upload synthesizers.py

Browse files
Files changed (1) hide show
  1. synthesizers.py +250 -0
synthesizers.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from rvc.lib.algorithm.generators.hifigan_mrf import HiFiGANMRFGenerator
4
+ from rvc.lib.algorithm.generators.hifigan_nsf import HiFiGANNSFGenerator
5
+ from rvc.lib.algorithm.generators.hifigan import HiFiGANGenerator
6
+ from rvc.lib.algorithm.generators.refinegan import RefineGANGenerator
7
+ from rvc.lib.algorithm.commons import slice_segments, rand_slice_segments
8
+ from rvc.lib.algorithm.residuals import ResidualCouplingBlock
9
+ from rvc.lib.algorithm.encoders import TextEncoder, PosteriorEncoder
10
+
11
+
12
+ class Synthesizer(torch.nn.Module):
13
+ """
14
+ Base Synthesizer model.
15
+
16
+ Args:
17
+ spec_channels (int): Number of channels in the spectrogram.
18
+ segment_size (int): Size of the audio segment.
19
+ inter_channels (int): Number of channels in the intermediate layers.
20
+ hidden_channels (int): Number of channels in the hidden layers.
21
+ filter_channels (int): Number of channels in the filter layers.
22
+ n_heads (int): Number of attention heads.
23
+ n_layers (int): Number of layers in the encoder.
24
+ kernel_size (int): Size of the convolution kernel.
25
+ p_dropout (float): Dropout probability.
26
+ resblock (str): Type of residual block.
27
+ resblock_kernel_sizes (list): Kernel sizes for the residual blocks.
28
+ resblock_dilation_sizes (list): Dilation sizes for the residual blocks.
29
+ upsample_rates (list): Upsampling rates for the decoder.
30
+ upsample_initial_channel (int): Number of channels in the initial upsampling layer.
31
+ upsample_kernel_sizes (list): Kernel sizes for the upsampling layers.
32
+ spk_embed_dim (int): Dimension of the speaker embedding.
33
+ gin_channels (int): Number of channels in the global conditioning vector.
34
+ sr (int): Sampling rate of the audio.
35
+ use_f0 (bool): Whether to use F0 information.
36
+ text_enc_hidden_dim (int): Hidden dimension for the text encoder.
37
+ kwargs: Additional keyword arguments.
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ spec_channels: int,
43
+ segment_size: int,
44
+ inter_channels: int,
45
+ hidden_channels: int,
46
+ filter_channels: int,
47
+ n_heads: int,
48
+ n_layers: int,
49
+ kernel_size: int,
50
+ p_dropout: float,
51
+ resblock: str,
52
+ resblock_kernel_sizes: list,
53
+ resblock_dilation_sizes: list,
54
+ upsample_rates: list,
55
+ upsample_initial_channel: int,
56
+ upsample_kernel_sizes: list,
57
+ spk_embed_dim: int,
58
+ gin_channels: int,
59
+ sr: int,
60
+ use_f0: bool,
61
+ text_enc_hidden_dim: int = 768,
62
+ vocoder: str = "RefineGAN",
63
+ randomized: bool = True,
64
+ checkpointing: bool = False,
65
+ **kwargs,
66
+ ):
67
+ super().__init__()
68
+ self.segment_size = segment_size
69
+ self.use_f0 = use_f0
70
+ self.randomized = randomized
71
+
72
+ self.enc_p = TextEncoder(
73
+ inter_channels,
74
+ hidden_channels,
75
+ filter_channels,
76
+ n_heads,
77
+ n_layers,
78
+ kernel_size,
79
+ p_dropout,
80
+ text_enc_hidden_dim,
81
+ f0=use_f0,
82
+ )
83
+ print(f"Using {vocoder} vocoder")
84
+ if use_f0:
85
+ if vocoder == "MRF HiFi-GAN":
86
+ self.dec = HiFiGANMRFGenerator(
87
+ in_channel=inter_channels,
88
+ upsample_initial_channel=upsample_initial_channel,
89
+ upsample_rates=upsample_rates,
90
+ upsample_kernel_sizes=upsample_kernel_sizes,
91
+ resblock_kernel_sizes=resblock_kernel_sizes,
92
+ resblock_dilations=resblock_dilation_sizes,
93
+ gin_channels=gin_channels,
94
+ sample_rate=sr,
95
+ harmonic_num=8,
96
+ checkpointing=checkpointing,
97
+ )
98
+ elif vocoder == "RefineGAN":
99
+ self.dec = RefineGANGenerator(
100
+ sample_rate=sr,
101
+ downsample_rates=upsample_rates[::-1],
102
+ upsample_rates=upsample_rates,
103
+ start_channels=16,
104
+ num_mels=inter_channels,
105
+ checkpointing=checkpointing,
106
+ )
107
+ else:
108
+ self.dec = HiFiGANNSFGenerator(
109
+ inter_channels,
110
+ resblock_kernel_sizes,
111
+ resblock_dilation_sizes,
112
+ upsample_rates,
113
+ upsample_initial_channel,
114
+ upsample_kernel_sizes,
115
+ gin_channels=gin_channels,
116
+ sr=sr,
117
+ checkpointing=checkpointing,
118
+ )
119
+ else:
120
+ if vocoder == "MRF HiFi-GAN":
121
+ print("Using RefineGAN without pitch guidance (experimental).")
122
+ self.dec = None
123
+ elif vocoder == "RefineGAN":
124
+ print("RefineGAN does not support training without pitch guidance.")
125
+ self.dec = RefineGANGenerator(
126
+ sample_rate=sr,
127
+ downsample_rates=upsample_rates[::-1],
128
+ upsample_rates=upsample_rates,
129
+ start_channels=16,
130
+ num_mels=inter_channels,
131
+ checkpointing=checkpointing,
132
+ )
133
+ else:
134
+ self.dec = HiFiGANGenerator(
135
+ inter_channels,
136
+ resblock_kernel_sizes,
137
+ resblock_dilation_sizes,
138
+ upsample_rates,
139
+ upsample_initial_channel,
140
+ upsample_kernel_sizes,
141
+ gin_channels=gin_channels,
142
+ )
143
+ self.enc_q = PosteriorEncoder(
144
+ spec_channels,
145
+ inter_channels,
146
+ hidden_channels,
147
+ 5,
148
+ 1,
149
+ 16,
150
+ gin_channels=gin_channels,
151
+ )
152
+ self.flow = ResidualCouplingBlock(
153
+ inter_channels,
154
+ hidden_channels,
155
+ 5,
156
+ 1,
157
+ 3,
158
+ gin_channels=gin_channels,
159
+ )
160
+ self.emb_g = torch.nn.Embedding(spk_embed_dim, gin_channels)
161
+
162
+ def _remove_weight_norm_from(self, module):
163
+ for hook in module._forward_pre_hooks.values():
164
+ if getattr(hook, "__class__", None).__name__ == "WeightNorm":
165
+ torch.nn.utils.remove_weight_norm(module)
166
+
167
+ def remove_weight_norm(self):
168
+ for module in [self.dec, self.flow, self.enc_q]:
169
+ self._remove_weight_norm_from(module)
170
+
171
+ def __prepare_scriptable__(self):
172
+ self.remove_weight_norm()
173
+ return self
174
+
175
+ def forward(
176
+ self,
177
+ phone: torch.Tensor,
178
+ phone_lengths: torch.Tensor,
179
+ pitch: Optional[torch.Tensor] = None,
180
+ pitchf: Optional[torch.Tensor] = None,
181
+ y: Optional[torch.Tensor] = None,
182
+ y_lengths: Optional[torch.Tensor] = None,
183
+ ds: Optional[torch.Tensor] = None,
184
+ ):
185
+ g = self.emb_g(ds).unsqueeze(-1)
186
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
187
+
188
+ if y is not None:
189
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
190
+ z_p = self.flow(z, y_mask, g=g)
191
+ # regular old training method using random slices
192
+ if self.randomized:
193
+ z_slice, ids_slice = rand_slice_segments(
194
+ z, y_lengths, self.segment_size
195
+ )
196
+ if self.use_f0:
197
+ pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
198
+ o = self.dec(z_slice, pitchf, g=g)
199
+ else:
200
+ o = self.dec(z_slice, g=g)
201
+ return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
202
+ # future use for finetuning using the entire dataset each pass
203
+ else:
204
+ if self.use_f0:
205
+ o = self.dec(z, pitchf, g=g)
206
+ else:
207
+ o = self.dec(z, g=g)
208
+ return o, None, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
209
+ else:
210
+ return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
211
+
212
+ @torch.jit.export
213
+ def infer(
214
+ self,
215
+ phone: torch.Tensor,
216
+ phone_lengths: torch.Tensor,
217
+ pitch: Optional[torch.Tensor] = None,
218
+ nsff0: Optional[torch.Tensor] = None,
219
+ sid: torch.Tensor = None,
220
+ rate: Optional[torch.Tensor] = None,
221
+ ):
222
+ """
223
+ Inference of the model.
224
+
225
+ Args:
226
+ phone (torch.Tensor): Phoneme sequence.
227
+ phone_lengths (torch.Tensor): Lengths of the phoneme sequences.
228
+ pitch (torch.Tensor, optional): Pitch sequence.
229
+ nsff0 (torch.Tensor, optional): Fine-grained pitch sequence.
230
+ sid (torch.Tensor): Speaker embedding.
231
+ rate (torch.Tensor, optional): Rate for time-stretching.
232
+ """
233
+ g = self.emb_g(sid).unsqueeze(-1)
234
+ m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
235
+ z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
236
+
237
+ if rate is not None:
238
+ head = int(z_p.shape[2] * (1.0 - rate.item()))
239
+ z_p, x_mask = z_p[:, :, head:], x_mask[:, :, head:]
240
+ if self.use_f0 and nsff0 is not None:
241
+ nsff0 = nsff0[:, head:]
242
+
243
+ z = self.flow(z_p, x_mask, g=g, reverse=True)
244
+ o = (
245
+ self.dec(z * x_mask, nsff0, g=g)
246
+ if self.use_f0
247
+ else self.dec(z * x_mask, g=g)
248
+ )
249
+
250
+ return o, x_mask, (z, z_p, m_p, logs_p)