Yoshitaka16 commited on
Commit
78e292b
·
verified ·
1 Parent(s): b6e3132

Upload refinegan.py

Browse files
Files changed (1) hide show
  1. refinegan.py +451 -0
refinegan.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torchaudio
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torch.nn.utils.parametrizations import weight_norm
7
+ from torch.nn.utils import remove_weight_norm
8
+ from torch.utils.checkpoint import checkpoint
9
+
10
+ from rvc.lib.algorithm.commons import init_weights, get_padding
11
+
12
+
13
+ class ResBlock(nn.Module):
14
+ """
15
+ Residual block with multiple dilated convolutions.
16
+
17
+ This block applies a sequence of dilated convolutional layers with Leaky ReLU activation.
18
+ It's designed to capture information at different scales due to the varying dilation rates.
19
+
20
+ Args:
21
+ in_channels (int): Number of input channels.
22
+ out_channels (int): Number of output channels.
23
+ kernel_size (int, optional): Kernel size for the convolutional layers. Defaults to 7.
24
+ dilation (tuple[int], optional): Tuple of dilation rates for the convolutional layers. Defaults to (1, 3, 5).
25
+ leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation. Defaults to 0.2.
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ channels: int,
31
+ kernel_size: int = 7,
32
+ dilation: tuple[int] = (1, 3, 5),
33
+ leaky_relu_slope: float = 0.2,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.leaky_relu_slope = leaky_relu_slope
38
+
39
+ self.convs1 = nn.ModuleList(
40
+ [
41
+ weight_norm(
42
+ nn.Conv1d(
43
+ channels,
44
+ channels,
45
+ kernel_size,
46
+ stride=1,
47
+ dilation=d,
48
+ padding=get_padding(kernel_size, d),
49
+ )
50
+ )
51
+ for d in dilation
52
+ ]
53
+ )
54
+ self.convs1.apply(init_weights)
55
+
56
+ self.convs2 = nn.ModuleList(
57
+ [
58
+ weight_norm(
59
+ nn.Conv1d(
60
+ channels,
61
+ channels,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=1,
65
+ padding=get_padding(kernel_size, 1),
66
+ )
67
+ )
68
+ for d in dilation
69
+ ]
70
+ )
71
+ self.convs2.apply(init_weights)
72
+
73
+ def forward(self, x: torch.Tensor):
74
+ for c1, c2 in zip(self.convs1, self.convs2):
75
+ xt = F.leaky_relu(x, self.leaky_relu_slope)
76
+ xt = c1(xt)
77
+ xt = F.leaky_relu(xt, self.leaky_relu_slope)
78
+ xt = c2(xt)
79
+ x = xt + x
80
+
81
+ return x
82
+
83
+ def remove_weight_norm(self):
84
+ for c1, c2 in zip(self.convs1, self.convs2):
85
+ remove_weight_norm(c1)
86
+ remove_weight_norm(c2)
87
+
88
+
89
+ class AdaIN(nn.Module):
90
+ """
91
+ Adaptive Instance Normalization layer.
92
+
93
+ This layer applies a scaling factor to the input based on a learnable weight.
94
+
95
+ Args:
96
+ channels (int): Number of input channels.
97
+ leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation applied after scaling. Defaults to 0.2.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ *,
103
+ channels: int,
104
+ leaky_relu_slope: float = 0.2,
105
+ ):
106
+ super().__init__()
107
+
108
+ self.weight = nn.Parameter(torch.ones(channels) * 1e-4)
109
+ # safe to use in-place as it is used on a new x+gaussian tensor
110
+ self.activation = nn.LeakyReLU(leaky_relu_slope)
111
+
112
+ def forward(self, x: torch.Tensor):
113
+ gaussian = torch.randn_like(x) * self.weight[None, :, None]
114
+
115
+ return self.activation(x + gaussian)
116
+
117
+
118
+ class ParallelResBlock(nn.Module):
119
+ """
120
+ Parallel residual block that applies multiple residual blocks with different kernel sizes in parallel.
121
+
122
+ Args:
123
+ in_channels (int): Number of input channels.
124
+ out_channels (int): Number of output channels.
125
+ kernel_sizes (tuple[int], optional): Tuple of kernel sizes for the parallel residual blocks. Defaults to (3, 7, 11).
126
+ dilation (tuple[int], optional): Tuple of dilation rates for the convolutional layers within the residual blocks. Defaults to (1, 3, 5).
127
+ leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation. Defaults to 0.2.
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ *,
133
+ in_channels: int,
134
+ out_channels: int,
135
+ kernel_sizes: tuple[int] = (3, 7, 11),
136
+ dilation: tuple[int] = (1, 3, 5),
137
+ leaky_relu_slope: float = 0.2,
138
+ ):
139
+ super().__init__()
140
+
141
+ self.in_channels = in_channels
142
+ self.out_channels = out_channels
143
+
144
+ self.input_conv = nn.Conv1d(
145
+ in_channels=in_channels,
146
+ out_channels=out_channels,
147
+ kernel_size=7,
148
+ stride=1,
149
+ padding=3,
150
+ )
151
+
152
+ self.input_conv.apply(init_weights)
153
+
154
+ self.blocks = nn.ModuleList(
155
+ [
156
+ nn.Sequential(
157
+ AdaIN(channels=out_channels),
158
+ ResBlock(
159
+ out_channels,
160
+ kernel_size=kernel_size,
161
+ dilation=dilation,
162
+ leaky_relu_slope=leaky_relu_slope,
163
+ ),
164
+ AdaIN(channels=out_channels),
165
+ )
166
+ for kernel_size in kernel_sizes
167
+ ]
168
+ )
169
+
170
+ def forward(self, x: torch.Tensor):
171
+ x = self.input_conv(x)
172
+ return torch.stack([block(x) for block in self.blocks], dim=0).mean(dim=0)
173
+
174
+ def remove_weight_norm(self):
175
+ remove_weight_norm(self.input_conv)
176
+ for block in self.blocks:
177
+ block[1].remove_weight_norm()
178
+
179
+
180
+ class SineGenerator(nn.Module):
181
+ """
182
+ Definition of sine generator
183
+
184
+ Generates sine waveforms with optional harmonics and additive noise.
185
+ Can be used to create harmonic noise source for neural vocoders.
186
+
187
+ Args:
188
+ samp_rate (int): Sampling rate in Hz.
189
+ harmonic_num (int): Number of harmonic overtones (default 0).
190
+ sine_amp (float): Amplitude of sine-waveform (default 0.1).
191
+ noise_std (float): Standard deviation of Gaussian noise (default 0.003).
192
+ voiced_threshold (float): F0 threshold for voiced/unvoiced classification (default 0).
193
+ """
194
+
195
+ def __init__(
196
+ self,
197
+ samp_rate,
198
+ harmonic_num=0,
199
+ sine_amp=0.1,
200
+ noise_std=0.003,
201
+ voiced_threshold=0,
202
+ ):
203
+ super(SineGenerator, self).__init__()
204
+ self.sine_amp = sine_amp
205
+ self.noise_std = noise_std
206
+ self.harmonic_num = harmonic_num
207
+ self.dim = self.harmonic_num + 1
208
+ self.sampling_rate = samp_rate
209
+ self.voiced_threshold = voiced_threshold
210
+
211
+ self.merge = nn.Sequential(
212
+ nn.Linear(self.dim, 1, bias=False),
213
+ nn.Tanh(),
214
+ )
215
+
216
+ def _f02uv(self, f0):
217
+ # generate uv signal
218
+ uv = torch.ones_like(f0)
219
+ uv = uv * (f0 > self.voiced_threshold)
220
+ return uv
221
+
222
+ def _f02sine(self, f0_values):
223
+ """f0_values: (batchsize, length, dim)
224
+ where dim indicates fundamental tone and overtones
225
+ """
226
+ # convert to F0 in rad. The integer part n can be ignored
227
+ # because 2 * np.pi * n doesn't affect phase
228
+ rad_values = (f0_values / self.sampling_rate) % 1
229
+
230
+ # initial phase noise (no noise for fundamental component)
231
+ rand_ini = torch.rand(
232
+ f0_values.shape[0], f0_values.shape[2], device=f0_values.device
233
+ )
234
+ rand_ini[:, 0] = 0
235
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
236
+
237
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
238
+ tmp_over_one = torch.cumsum(rad_values, 1) % 1
239
+ tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
240
+ cumsum_shift = torch.zeros_like(rad_values)
241
+ cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
242
+
243
+ sines = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * np.pi)
244
+
245
+ return sines
246
+
247
+ def forward(self, f0):
248
+ with torch.no_grad():
249
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
250
+ # fundamental component
251
+ f0_buf[:, :, 0] = f0[:, :, 0]
252
+ for idx in np.arange(self.harmonic_num):
253
+ f0_buf[:, :, idx + 1] = f0_buf[:, :, 0] * (idx + 2)
254
+
255
+ sine_waves = self._f02sine(f0_buf) * self.sine_amp
256
+
257
+ uv = self._f02uv(f0)
258
+
259
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
260
+ noise = noise_amp * torch.randn_like(sine_waves)
261
+
262
+ sine_waves = sine_waves * uv + noise
263
+
264
+ # merge with grad
265
+ return self.merge(sine_waves)
266
+
267
+
268
+ class RefineGANGenerator(nn.Module):
269
+ """
270
+ RefineGAN generator for audio synthesis.
271
+
272
+ This generator uses a combination of downsampling, residual blocks, and parallel residual blocks
273
+ to refine an input mel-spectrogram and fundamental frequency (F0) into an audio waveform.
274
+ It can also incorporate global conditioning.
275
+
276
+ Args:
277
+ sample_rate (int, optional): Sampling rate of the audio. Defaults to 44100.
278
+ downsample_rates (tuple[int], optional): Downsampling rates for the downsampling blocks. Defaults to (2, 2, 8, 8).
279
+ upsample_rates (tuple[int], optional): Upsampling rates for the upsampling blocks. Defaults to (8, 8, 2, 2).
280
+ leaky_relu_slope (float, optional): Slope for the Leaky ReLU activation. Defaults to 0.2.
281
+ num_mels (int, optional): Number of mel-frequency bins in the input mel-spectrogram. Defaults to 128.
282
+ start_channels (int, optional): Number of channels in the initial convolutional layer. Defaults to 16.
283
+ gin_channels (int, optional): Number of channels for the global conditioning input. Defaults to 256.
284
+ checkpointing (bool, optional): Whether to use checkpointing for memory efficiency. Defaults to False.
285
+ """
286
+
287
+ def __init__(
288
+ self,
289
+ *,
290
+ sample_rate: int = 44100,
291
+ downsample_rates: tuple[int] = (2, 2, 8, 8), # unused
292
+ upsample_rates: tuple[int] = (8, 8, 2, 2),
293
+ leaky_relu_slope: float = 0.2,
294
+ num_mels: int = 128,
295
+ start_channels: int = 16, # unused
296
+ gin_channels: int = 256,
297
+ checkpointing: bool = False,
298
+ upsample_initial_channel=512,
299
+ ):
300
+ super().__init__()
301
+ self.upsample_rates = upsample_rates
302
+ self.leaky_relu_slope = leaky_relu_slope
303
+ self.checkpointing = checkpointing
304
+
305
+ self.upp = np.prod(upsample_rates)
306
+ self.m_source = SineGenerator(sample_rate)
307
+
308
+ # expanded f0 sinegen -> match mel_conv
309
+ # (8, 1, 17280) -> (8, 16, 17280)
310
+ self.pre_conv = weight_norm(
311
+ nn.Conv1d(
312
+ 1,
313
+ 16,
314
+ 7,
315
+ 1,
316
+ padding=3,
317
+ )
318
+ )
319
+
320
+ # (8, 16, 17280) = 4th upscale
321
+ # (8, 32, 8640) = 3rd upscale
322
+ # (8, 64, 4320) = 2nd upscale
323
+ # (8, 128, 432) = 1st upscale
324
+ # (8, 256, 36) merged to mel
325
+
326
+ # f0 downsampling and upchanneling
327
+ channels = start_channels
328
+ size = self.upp
329
+ self.downsample_blocks = nn.ModuleList([])
330
+ self.df0 = []
331
+ for i, u in enumerate(upsample_rates):
332
+
333
+ new_size = int(size / upsample_rates[-i - 1])
334
+ # T dimension factors for torchaudio.functional.resample
335
+ self.df0.append([size, new_size])
336
+ size = new_size
337
+
338
+ new_channels = channels * 2
339
+ self.downsample_blocks.append(
340
+ weight_norm(nn.Conv1d(channels, new_channels, 7, 1, padding=3))
341
+ )
342
+ channels = new_channels
343
+
344
+ # mel handling
345
+ channels = upsample_initial_channel
346
+
347
+ self.mel_conv = weight_norm(
348
+ nn.Conv1d(
349
+ num_mels,
350
+ channels // 2,
351
+ 7,
352
+ 1,
353
+ padding=3,
354
+ )
355
+ )
356
+
357
+ self.mel_conv.apply(init_weights)
358
+
359
+ if gin_channels != 0:
360
+ self.cond = nn.Conv1d(256, channels // 2, 1)
361
+
362
+ self.upsample_blocks = nn.ModuleList([])
363
+ self.upsample_conv_blocks = nn.ModuleList([])
364
+
365
+ for rate in upsample_rates:
366
+ new_channels = channels // 2
367
+
368
+ self.upsample_blocks.append(nn.Upsample(scale_factor=rate, mode="linear"))
369
+
370
+ self.upsample_conv_blocks.append(
371
+ ParallelResBlock(
372
+ in_channels=channels + channels // 4,
373
+ out_channels=new_channels,
374
+ kernel_sizes=(3, 7, 11),
375
+ dilation=(1, 3, 5),
376
+ leaky_relu_slope=leaky_relu_slope,
377
+ )
378
+ )
379
+
380
+ channels = new_channels
381
+
382
+ self.conv_post = weight_norm(
383
+ nn.Conv1d(channels, 1, 7, 1, padding=3, bias=False)
384
+ )
385
+ self.conv_post.apply(init_weights)
386
+
387
+ def forward(self, mel: torch.Tensor, f0: torch.Tensor, g: torch.Tensor = None):
388
+ f0_size = mel.shape[-1]
389
+ # change f0 helper to full size
390
+ f0 = F.interpolate(f0.unsqueeze(1), size=f0_size * self.upp, mode="linear")
391
+ # get f0 turned into sines harmonics
392
+ har_source = self.m_source(f0.transpose(1, 2)).transpose(1, 2)
393
+ # prepare for fusion to mel
394
+ x = self.pre_conv(har_source)
395
+ # downsampled/upchanneled versions for each upscale
396
+ downs = []
397
+ for block, (old_size, new_size) in zip(self.downsample_blocks, self.df0):
398
+ x = F.leaky_relu(x, self.leaky_relu_slope)
399
+ downs.append(x)
400
+ # attempt to cancel spectral aliasing
401
+ x = torchaudio.functional.resample(
402
+ x.contiguous(),
403
+ orig_freq=int(f0_size * old_size),
404
+ new_freq=int(f0_size * new_size),
405
+ lowpass_filter_width=64,
406
+ rolloff=0.9475937167399596,
407
+ resampling_method="sinc_interp_kaiser",
408
+ beta=14.769656459379492,
409
+ )
410
+ x = block(x)
411
+
412
+ # expanding spectrogram from 192 to 256 channels
413
+ mel = self.mel_conv(mel)
414
+ if g is not None:
415
+ # adding expanded speaker embedding
416
+ mel = mel + self.cond(g)
417
+
418
+ x = torch.cat([mel, x], dim=1)
419
+
420
+ for ups, res, down in zip(
421
+ self.upsample_blocks,
422
+ self.upsample_conv_blocks,
423
+ reversed(downs),
424
+ ):
425
+ x = F.leaky_relu(x, self.leaky_relu_slope)
426
+
427
+ if self.training and self.checkpointing:
428
+ x = checkpoint(ups, x, use_reentrant=False)
429
+ x = torch.cat([x, down], dim=1)
430
+ x = checkpoint(res, x, use_reentrant=False)
431
+ else:
432
+ x = ups(x)
433
+ x = torch.cat([x, down], dim=1)
434
+ x = res(x)
435
+
436
+ x = F.leaky_relu(x, self.leaky_relu_slope)
437
+ x = self.conv_post(x)
438
+ x = torch.tanh(x)
439
+
440
+ return x
441
+
442
+ def remove_weight_norm(self):
443
+ remove_weight_norm(self.pre_conv)
444
+ remove_weight_norm(self.mel_conv)
445
+ remove_weight_norm(self.conv_post)
446
+
447
+ for block in self.downsample_blocks:
448
+ block.remove_weight_norm()
449
+
450
+ for block in self.upsample_conv_blocks:
451
+ block.remove_weight_norm()