kathiasi commited on
Commit
0ddd845
·
verified ·
1 Parent(s): 2a44583

Uploaded hiftnet vocoder files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ hiftnet/libritts/g_00650000 filter=lfs diff=lfs merge=lfs -text
hiftnet/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Aaron (Yinghao) Li
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
hiftnet/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiFTNet: A Fast High-Quality Neural Vocoder with Harmonic-plus-Noise Filter and Inverse Short Time Fourier Transform
2
+
3
+ ### Yinghao Aaron Li, Cong Han, Xilin Jiang, Nima Mesgarani
4
+
5
+ > Recent advancements in speech synthesis have leveraged GAN-based networks like HiFi-GAN and BigVGAN to produce high-fidelity waveforms from mel-spectrograms. However, these networks are computationally expensive and parameter-heavy. iSTFTNet addresses these limitations by integrating inverse short-time Fourier transform (iSTFT) into the network, achieving both speed and parameter efficiency. In this paper, we introduce an extension to iSTFTNet, termed HiFTNet, which incorporates a harmonic-plus-noise source filter in the time-frequency domain that uses a sinusoidal source from the fundamental frequency (F0) inferred via a pre-trained F0 estimation network for fast inference speed. Subjective evaluations on LJSpeech show that our model significantly outperforms both iSTFTNet and HiFi-GAN, achieving ground-truth-level performance. HiFTNet also outperforms BigVGAN-base on LibriTTS for unseen speakers and achieves comparable performance to BigVGAN while being four times faster with only 1/6 of the parameters. Our work sets a new benchmark for efficient, high-quality neural vocoding, paving the way for real-time applications that demand high quality speech synthesis.
6
+
7
+ Paper: [https://arxiv.org/abs/2309.09493](https://arxiv.org/abs/2309.09493)
8
+
9
+ Audio samples: [https://hiftnet.github.io/](https://hiftnet.github.io/)
10
+
11
+ **Check our TTS work that uses HiFTNet as speech decoder for human-level speech synthesis here: https://github.com/yl4579/StyleTTS2**
12
+
13
+ ## Pre-requisites
14
+ 1. Python >= 3.7
15
+ 2. Clone this repository:
16
+ ```bash
17
+ git clone https://github.com/yl4579/HiFTNet.git
18
+ cd HiFTNet
19
+ ```
20
+ 3. Install python requirements:
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ ## Training
26
+ ```bash
27
+ python train.py --config config_v1.json --[args]
28
+ ```
29
+ For the F0 model training, please refer to [yl4579/PitchExtractor](https://github.com/yl4579/PitchExtractor). This repo includes a pre-trained F0 model on LibriTTS. Still, you may want to train your own F0 model for the best performance, particularly for noisy or non-speech data, as we found that F0 estimation accuracy is essential for the vocoder performance.
30
+
31
+ ## Inference
32
+ Please refer to the notebook [inference.ipynb](https://github.com/yl4579/HiFTNet/blob/main/inference.ipynb) for details.
33
+ ### Pre-Trained Models
34
+ You can download the pre-trained LJSpeech model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LJSpeech/cp_hifigan.zip) and the pre-trained LibriTTS model [here](https://huggingface.co/yl4579/HiFTNet/blob/main/LibriTTS/cp_hifigan.zip). The pre-trained models contain parameters of the optimizers and discriminators that can be used for fine-tuning.
35
+
36
+ ## References
37
+ - [rishikksh20/iSTFTNet-pytorch](https://github.com/rishikksh20/iSTFTNet-pytorch)
38
+ - [nii-yamagishilab/project-NN-Pytorch-scripts/project/01-nsf](https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts/tree/master/project/01-nsf)
hiftnet/Utils/JDC/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
hiftnet/Utils/JDC/model.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of model from:
3
+ Kum et al. - "Joint Detection and Classification of Singing Voice Melody Using
4
+ Convolutional Recurrent Neural Networks" (2019)
5
+ Link: https://www.semanticscholar.org/paper/Joint-Detection-and-Classification-of-Singing-Voice-Kum-Nam/60a2ad4c7db43bace75805054603747fcd062c0d
6
+ """
7
+ import torch
8
+ from torch import nn
9
+
10
+ class JDCNet(nn.Module):
11
+ """
12
+ Joint Detection and Classification Network model for singing voice melody.
13
+ """
14
+ def __init__(self, num_class=722, seq_len=31, leaky_relu_slope=0.01):
15
+ super().__init__()
16
+ self.num_class = num_class
17
+
18
+ # input = (b, 1, 31, 513), b = batch size
19
+ self.conv_block = nn.Sequential(
20
+ nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1, bias=False), # out: (b, 64, 31, 513)
21
+ nn.BatchNorm2d(num_features=64),
22
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
23
+ nn.Conv2d(64, 64, 3, padding=1, bias=False), # (b, 64, 31, 513)
24
+ )
25
+
26
+ # res blocks
27
+ self.res_block1 = ResBlock(in_channels=64, out_channels=128) # (b, 128, 31, 128)
28
+ self.res_block2 = ResBlock(in_channels=128, out_channels=192) # (b, 192, 31, 32)
29
+ self.res_block3 = ResBlock(in_channels=192, out_channels=256) # (b, 256, 31, 8)
30
+
31
+ # pool block
32
+ self.pool_block = nn.Sequential(
33
+ nn.BatchNorm2d(num_features=256),
34
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
35
+ nn.MaxPool2d(kernel_size=(1, 4)), # (b, 256, 31, 2)
36
+ nn.Dropout(p=0.2),
37
+ )
38
+
39
+ # maxpool layers (for auxiliary network inputs)
40
+ # in = (b, 128, 31, 513) from conv_block, out = (b, 128, 31, 2)
41
+ self.maxpool1 = nn.MaxPool2d(kernel_size=(1, 40))
42
+ # in = (b, 128, 31, 128) from res_block1, out = (b, 128, 31, 2)
43
+ self.maxpool2 = nn.MaxPool2d(kernel_size=(1, 20))
44
+ # in = (b, 128, 31, 32) from res_block2, out = (b, 128, 31, 2)
45
+ self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 10))
46
+
47
+ # in = (b, 640, 31, 2), out = (b, 256, 31, 2)
48
+ self.detector_conv = nn.Sequential(
49
+ nn.Conv2d(640, 256, 1, bias=False),
50
+ nn.BatchNorm2d(256),
51
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
52
+ nn.Dropout(p=0.2),
53
+ )
54
+
55
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
56
+ self.bilstm_classifier = nn.LSTM(
57
+ input_size=512, hidden_size=256,
58
+ batch_first=True, bidirectional=True) # (b, 31, 512)
59
+
60
+ # input: (b, 31, 512) - resized from (b, 256, 31, 2)
61
+ self.bilstm_detector = nn.LSTM(
62
+ input_size=512, hidden_size=256,
63
+ batch_first=True, bidirectional=True) # (b, 31, 512)
64
+
65
+ # input: (b * 31, 512)
66
+ self.classifier = nn.Linear(in_features=512, out_features=self.num_class) # (b * 31, num_class)
67
+
68
+ # input: (b * 31, 512)
69
+ self.detector = nn.Linear(in_features=512, out_features=2) # (b * 31, 2) - binary classifier
70
+
71
+ # initialize weights
72
+ self.apply(self.init_weights)
73
+
74
+ def get_feature_GAN(self, x):
75
+ seq_len = x.shape[-2]
76
+ x = x.float().transpose(-1, -2)
77
+
78
+ convblock_out = self.conv_block(x)
79
+
80
+ resblock1_out = self.res_block1(convblock_out)
81
+ resblock2_out = self.res_block2(resblock1_out)
82
+ resblock3_out = self.res_block3(resblock2_out)
83
+ poolblock_out = self.pool_block[0](resblock3_out)
84
+ poolblock_out = self.pool_block[1](poolblock_out)
85
+
86
+ return poolblock_out.transpose(-1, -2)
87
+
88
+ def get_feature(self, x):
89
+ seq_len = x.shape[-2]
90
+ x = x.float().transpose(-1, -2)
91
+
92
+ convblock_out = self.conv_block(x)
93
+
94
+ resblock1_out = self.res_block1(convblock_out)
95
+ resblock2_out = self.res_block2(resblock1_out)
96
+ resblock3_out = self.res_block3(resblock2_out)
97
+ poolblock_out = self.pool_block[0](resblock3_out)
98
+ poolblock_out = self.pool_block[1](poolblock_out)
99
+
100
+ return self.pool_block[2](poolblock_out)
101
+
102
+ def forward(self, x):
103
+ """
104
+ Returns:
105
+ classification_prediction, detection_prediction
106
+ sizes: (b, 31, 722), (b, 31, 2)
107
+ """
108
+ ###############################
109
+ # forward pass for classifier #
110
+ ###############################
111
+ seq_len = x.shape[-1]
112
+ x = x.float().transpose(-1, -2)
113
+
114
+ convblock_out = self.conv_block(x)
115
+
116
+ resblock1_out = self.res_block1(convblock_out)
117
+ resblock2_out = self.res_block2(resblock1_out)
118
+ resblock3_out = self.res_block3(resblock2_out)
119
+
120
+
121
+ poolblock_out = self.pool_block[0](resblock3_out)
122
+ poolblock_out = self.pool_block[1](poolblock_out)
123
+ GAN_feature = poolblock_out.transpose(-1, -2)
124
+ poolblock_out = self.pool_block[2](poolblock_out)
125
+
126
+ # (b, 256, 31, 2) => (b, 31, 256, 2) => (b, 31, 512)
127
+ classifier_out = poolblock_out.permute(0, 2, 1, 3).contiguous().view((-1, seq_len, 512))
128
+ classifier_out, _ = self.bilstm_classifier(classifier_out) # ignore the hidden states
129
+
130
+ classifier_out = classifier_out.contiguous().view((-1, 512)) # (b * 31, 512)
131
+ classifier_out = self.classifier(classifier_out)
132
+ classifier_out = classifier_out.view((-1, seq_len, self.num_class)) # (b, 31, num_class)
133
+
134
+ # sizes: (b, 31, 722), (b, 31, 2)
135
+ # classifier output consists of predicted pitch classes per frame
136
+ # detector output consists of: (isvoice, notvoice) estimates per frame
137
+ return torch.abs(classifier_out.squeeze()), GAN_feature, poolblock_out
138
+
139
+ @staticmethod
140
+ def init_weights(m):
141
+ if isinstance(m, nn.Linear):
142
+ nn.init.kaiming_uniform_(m.weight)
143
+ if m.bias is not None:
144
+ nn.init.constant_(m.bias, 0)
145
+ elif isinstance(m, nn.Conv2d):
146
+ nn.init.xavier_normal_(m.weight)
147
+ elif isinstance(m, nn.LSTM) or isinstance(m, nn.LSTMCell):
148
+ for p in m.parameters():
149
+ if p.data is None:
150
+ continue
151
+
152
+ if len(p.shape) >= 2:
153
+ nn.init.orthogonal_(p.data)
154
+ else:
155
+ nn.init.normal_(p.data)
156
+
157
+
158
+ class ResBlock(nn.Module):
159
+ def __init__(self, in_channels: int, out_channels: int, leaky_relu_slope=0.01):
160
+ super().__init__()
161
+ self.downsample = in_channels != out_channels
162
+
163
+ # BN / LReLU / MaxPool layer before the conv layer - see Figure 1b in the paper
164
+ self.pre_conv = nn.Sequential(
165
+ nn.BatchNorm2d(num_features=in_channels),
166
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
167
+ nn.MaxPool2d(kernel_size=(1, 2)), # apply downsampling on the y axis only
168
+ )
169
+
170
+ # conv layers
171
+ self.conv = nn.Sequential(
172
+ nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
173
+ kernel_size=3, padding=1, bias=False),
174
+ nn.BatchNorm2d(out_channels),
175
+ nn.LeakyReLU(leaky_relu_slope, inplace=True),
176
+ nn.Conv2d(out_channels, out_channels, 3, padding=1, bias=False),
177
+ )
178
+
179
+ # 1 x 1 convolution layer to match the feature dimensions
180
+ self.conv1by1 = None
181
+ if self.downsample:
182
+ self.conv1by1 = nn.Conv2d(in_channels, out_channels, 1, bias=False)
183
+
184
+ def forward(self, x):
185
+ x = self.pre_conv(x)
186
+ if self.downsample:
187
+ x = self.conv(x) + self.conv1by1(x)
188
+ else:
189
+ x = self.conv(x) + x
190
+ return x
hiftnet/Utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
hiftnet/__init__.py ADDED
File without changes
hiftnet/config_v1.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "F0_path": "Utils/JDC/bst.t7",
3
+
4
+ "resblock": "1",
5
+ "num_gpus": 1,
6
+ "batch_size": 2,
7
+ "learning_rate": 0.0002,
8
+ "adam_b1": 0.8,
9
+ "adam_b2": 0.99,
10
+ "lr_decay": 0.999,
11
+ "seed": 1234,
12
+
13
+ "upsample_rates": [8,8],
14
+ "upsample_kernel_sizes": [16,16],
15
+ "upsample_initial_channel": 512,
16
+ "resblock_kernel_sizes": [3,7,11],
17
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
18
+ "gen_istft_n_fft": 16,
19
+ "gen_istft_hop_size": 4,
20
+
21
+ "segment_size": 24576,
22
+ "num_mels": 80,
23
+ "n_fft": 1024,
24
+ "hop_size": 256,
25
+ "win_size": 1024,
26
+
27
+ "sampling_rate": 22050,
28
+
29
+ "fmin": 0,
30
+ "fmax": 8000,
31
+ "fmax_for_loss": null,
32
+
33
+ "num_workers": 4,
34
+
35
+ "dist_config": {
36
+ "dist_backend": "nccl",
37
+ "dist_url": "tcp://localhost:54321",
38
+ "world_size": 1
39
+ }
40
+ }
hiftnet/env.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+
4
+
5
+ class AttrDict(dict):
6
+ def __init__(self, *args, **kwargs):
7
+ super(AttrDict, self).__init__(*args, **kwargs)
8
+ self.__dict__ = self
9
+
10
+
11
+ def build_env(config, config_name, path):
12
+ t_path = os.path.join(path, config_name)
13
+ if config != t_path:
14
+ os.makedirs(path, exist_ok=True)
15
+ shutil.copyfile(config, os.path.join(path, config_name))
hiftnet/hiftnet.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import json
4
+ import torch
5
+ import sys
6
+ from .env import AttrDict
7
+ from .meldataset import mel_spectrogram, MAX_WAV_VALUE, load_wav
8
+ from .models import Generator
9
+ from .stft import TorchSTFT
10
+ from .Utils.JDC.model import JDCNet
11
+ from scipy.io.wavfile import write
12
+
13
+ class HiFTNet:
14
+ """A class for HiFTNet inference."""
15
+ def __init__(self, device="cpu"):
16
+ self.device = device
17
+
18
+ my_dir = os.path.dirname(os.path.abspath(__file__))
19
+
20
+ checkpoint_path = os.path.join(my_dir, "libritts")
21
+
22
+
23
+ # Load configuration
24
+ config_file = os.path.join(checkpoint_path, 'config.json')
25
+ with open(config_file) as f:
26
+ data = f.read()
27
+ json_config = json.loads(data)
28
+ self.h = AttrDict(json_config)
29
+
30
+ # Load models
31
+ F0_model = JDCNet(num_class=1, seq_len=192)
32
+ self.generator = Generator(self.h, F0_model).to(self.device)
33
+ self.stft = TorchSTFT(filter_length=self.h.gen_istft_n_fft,
34
+ hop_length=self.h.gen_istft_hop_size,
35
+ win_length=self.h.gen_istft_n_fft).to(self.device)
36
+
37
+ # Load checkpoint
38
+
39
+ state_dict_g = self._load_checkpoint(checkpoint_path+"/g_00650000", self.device)
40
+ self.generator.load_state_dict(state_dict_g['generator'])
41
+
42
+ # Set to evaluation mode
43
+ self.generator.remove_weight_norm()
44
+ self.generator.eval()
45
+
46
+
47
+ def _load_checkpoint(self, filepath, device):
48
+ """Loads a checkpoint file."""
49
+ assert os.path.isfile(filepath)
50
+ print(f"Loading '{filepath}'")
51
+ checkpoint_dict = torch.load(filepath, map_location=device)
52
+ print("Complete.")
53
+ return checkpoint_dict
54
+
55
+ def _get_mel(self, x):
56
+ """Computes a mel-spectrogram from a raw waveform."""
57
+ return mel_spectrogram(x, self.h.n_fft, self.h.num_mels, self.h.sampling_rate,
58
+ self.h.hop_size, self.h.win_size, self.h.fmin, self.h.fmax)
59
+
60
+ def _infer_waveform(self, mel):
61
+ """Private helper to run inference from a mel-spectrogram."""
62
+ with torch.no_grad():
63
+ # Run inference
64
+ spec, phase = self.generator(mel)
65
+ y_g_hat = self.stft.inverse(spec, phase)
66
+ return y_g_hat
67
+
68
+ audio = y_g_hat.squeeze()
69
+
70
+ # Post-processing
71
+ audio = audio * MAX_WAV_VALUE
72
+ audio = audio.cpu().numpy().astype('int16')
73
+
74
+ return audio
75
+
76
+ def analysis_synthesis(self, wav_path):
77
+ """
78
+ Synthesizes audio from a WAV file path.
79
+
80
+ Args:
81
+ wav_path (str): Path to the input WAV file.
82
+
83
+ Returns:
84
+ numpy.ndarray: The synthesized audio waveform as a 16-bit integer array.
85
+ """
86
+ # Load and pre-process audio
87
+ wav, sr = load_wav(wav_path)
88
+ print(f"Processing audio file: {wav_path}")
89
+ wav_tensor = torch.FloatTensor(wav / MAX_WAV_VALUE).to(self.device)
90
+
91
+ # Get mel-spectrogram
92
+ mel_tensor = self._get_mel(wav_tensor.unsqueeze(0))
93
+ print(mel_tensor.shape)
94
+ # Synthesize and return audio
95
+ return self._infer_waveform(mel_tensor)
96
+
97
+ def synthesize_from_mel(self, mel_tensor):
98
+ """
99
+ Synthesizes audio from a pre-computed mel-spectrogram.
100
+
101
+ Args:
102
+ mel_tensor (torch.FloatTensor): A mel-spectrogram tensor of shape
103
+ [batch_size, num_mels, num_frames].
104
+ Typically batch_size is 1.
105
+
106
+ Returns:
107
+ numpy.ndarray: The synthesized audio waveform as a 16-bit integer array.
108
+ """
109
+ print("Synthesizing from mel-spectrogram...")
110
+ # Ensure tensor is on the correct device
111
+ mel_tensor = mel_tensor.to(self.device)
112
+
113
+ # Handle 2D input [num_mels, num_frames] by adding a batch dimension
114
+ if mel_tensor.dim() == 2:
115
+ mel_tensor = mel_tensor.unsqueeze(0)
116
+
117
+ # Synthesize and return audio
118
+ return self._infer_waveform(mel_tensor)
119
+
120
+
121
+ if __name__ == '__main__':
122
+ # Instantiate the vocoder. It loads the model automatically.
123
+ vocoder = HiFTNet()
124
+
125
+ # Get the input file path from the command line
126
+ input_wav_path = sys.argv[1]
127
+
128
+ # Synthesize the audio from the file
129
+ audio_out = vocoder.analysis_synthesis(input_wav_path)
130
+
131
+ # Define the output path
132
+ output_wav_path = "/tmp/tmp_hift.wav"
133
+
134
+ # Save the synthesized audio
135
+ write(output_wav_path, vocoder.h.sampling_rate, audio_out)
136
+
137
+ # Play the synthesized audio
138
+ os.system(f"play -q {output_wav_path}")
hiftnet/libritts/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "resblock": "1",
3
+ "num_gpus": 1,
4
+ "batch_size": 16,
5
+ "learning_rate": 0.0002,
6
+ "adam_b1": 0.8,
7
+ "adam_b2": 0.99,
8
+ "lr_decay": 0.999,
9
+ "seed": 1234,
10
+
11
+ "upsample_rates": [8,8],
12
+ "upsample_kernel_sizes": [16,16],
13
+ "upsample_initial_channel": 512,
14
+ "resblock_kernel_sizes": [3,7,11],
15
+ "resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
16
+ "gen_istft_n_fft": 16,
17
+ "gen_istft_hop_size": 4,
18
+
19
+ "segment_size": 24576,
20
+ "num_mels": 80,
21
+ "n_fft": 1024,
22
+ "hop_size": 256,
23
+ "win_size": 1024,
24
+
25
+ "sampling_rate": 22050,
26
+
27
+ "fmin": 0,
28
+ "fmax": 8000,
29
+ "fmax_for_loss": null,
30
+
31
+ "num_workers": 4,
32
+
33
+ "dist_config": {
34
+ "dist_backend": "nccl",
35
+ "dist_url": "tcp://localhost:54321",
36
+ "world_size": 1
37
+ }
38
+ }
hiftnet/libritts/g_00650000 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25ce532d3658397d7dfd5206ef63fda1a9aa8b91aea68f653c84be5422451f54
3
+ size 89846680
hiftnet/meldataset.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ import torch.utils.data
6
+ import numpy as np
7
+ from librosa.util import normalize
8
+ from scipy.io.wavfile import read
9
+ from librosa.filters import mel as librosa_mel_fn
10
+
11
+ MAX_WAV_VALUE = 32768.0
12
+
13
+
14
+ def load_wav(full_path):
15
+ sampling_rate, data = read(full_path)
16
+ return data, sampling_rate
17
+
18
+
19
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
20
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
21
+
22
+
23
+ def dynamic_range_decompression(x, C=1):
24
+ return np.exp(x) / C
25
+
26
+
27
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
28
+ return torch.log(torch.clamp(x, min=clip_val) * C)
29
+
30
+
31
+ def dynamic_range_decompression_torch(x, C=1):
32
+ return torch.exp(x) / C
33
+
34
+
35
+ def spectral_normalize_torch(magnitudes):
36
+ output = dynamic_range_compression_torch(magnitudes)
37
+ return output
38
+
39
+
40
+ def spectral_de_normalize_torch(magnitudes):
41
+ output = dynamic_range_decompression_torch(magnitudes)
42
+ return output
43
+
44
+
45
+ mel_basis = {}
46
+ hann_window = {}
47
+
48
+
49
+ def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
50
+ if torch.min(y) < -1.:
51
+ print('min value is ', torch.min(y))
52
+ if torch.max(y) > 1.:
53
+ print('max value is ', torch.max(y))
54
+
55
+ global mel_basis, hann_window
56
+ if fmax not in mel_basis:
57
+ mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
58
+ mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
59
+ hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)
60
+
61
+ y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')
62
+ y = y.squeeze(1)
63
+
64
+ # complex tensor as default, then use view_as_real for future pytorch compatibility
65
+ spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],
66
+ center=center, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
67
+ spec = torch.view_as_real(spec)
68
+ spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
69
+
70
+ spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)
71
+ spec = spectral_normalize_torch(spec)
72
+
73
+ return spec
74
+
75
+
76
+
77
+ def get_dataset_filelist(a):
78
+ with open(a.input_training_file, 'r', encoding='utf-8') as fi:
79
+ training_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('.wav' if '.wav' not in x else ''))
80
+ for x in fi.read().split('\n') if len(x) > 0]
81
+
82
+ with open(a.input_validation_file, 'r', encoding='utf-8') as fi:
83
+ validation_files = [os.path.join(a.input_wavs_dir, x.split('|')[0] + ('.wav' if '.wav' not in x else ''))
84
+ for x in fi.read().split('\n') if len(x) > 0]
85
+ return training_files, validation_files
86
+
87
+
88
+ class MelDataset(torch.utils.data.Dataset):
89
+ def __init__(self, training_files, segment_size, n_fft, num_mels,
90
+ hop_size, win_size, sampling_rate, fmin, fmax, split=True, shuffle=True, n_cache_reuse=1,
91
+ device=None, fmax_loss=None, fine_tuning=False, base_mels_path=None):
92
+ self.audio_files = training_files
93
+ random.seed(1234)
94
+ if shuffle:
95
+ random.shuffle(self.audio_files)
96
+ self.segment_size = segment_size
97
+ self.sampling_rate = sampling_rate
98
+ self.split = split
99
+ self.n_fft = n_fft
100
+ self.num_mels = num_mels
101
+ self.hop_size = hop_size
102
+ self.win_size = win_size
103
+ self.fmin = fmin
104
+ self.fmax = fmax
105
+ self.fmax_loss = fmax_loss
106
+ self.cached_wav = None
107
+ self.n_cache_reuse = n_cache_reuse
108
+ self._cache_ref_count = 0
109
+ self.device = device
110
+ self.fine_tuning = fine_tuning
111
+ self.base_mels_path = base_mels_path
112
+
113
+ def __getitem__(self, index):
114
+ filename = self.audio_files[index]
115
+ if self._cache_ref_count == 0:
116
+ audio, sampling_rate = load_wav(filename)
117
+ audio = audio / MAX_WAV_VALUE
118
+ if not self.fine_tuning:
119
+ audio = normalize(audio) * 0.95
120
+ self.cached_wav = audio
121
+ if sampling_rate != self.sampling_rate:
122
+ raise ValueError("{} SR doesn't match target {} SR".format(
123
+ sampling_rate, self.sampling_rate))
124
+ self._cache_ref_count = self.n_cache_reuse
125
+ else:
126
+ audio = self.cached_wav
127
+ self._cache_ref_count -= 1
128
+
129
+ audio = torch.FloatTensor(audio)
130
+ audio = audio.unsqueeze(0)
131
+
132
+ if not self.fine_tuning:
133
+ if self.split:
134
+ if audio.size(1) >= self.segment_size:
135
+ max_audio_start = audio.size(1) - self.segment_size
136
+ audio_start = random.randint(0, max_audio_start)
137
+ audio = audio[:, audio_start:audio_start+self.segment_size]
138
+ else:
139
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
140
+
141
+ mel = mel_spectrogram(audio, self.n_fft, self.num_mels,
142
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax,
143
+ center=False)
144
+ else:
145
+ mel = np.load(
146
+ os.path.join(self.base_mels_path, os.path.splitext(os.path.split(filename)[-1])[0] + '.npy'))
147
+ mel = torch.from_numpy(mel)
148
+
149
+ if len(mel.shape) < 3:
150
+ mel = mel.unsqueeze(0)
151
+
152
+ if self.split:
153
+ frames_per_seg = math.ceil(self.segment_size / self.hop_size)
154
+
155
+ if audio.size(1) >= self.segment_size:
156
+ mel_start = random.randint(0, mel.size(2) - frames_per_seg - 1)
157
+ mel = mel[:, :, mel_start:mel_start + frames_per_seg]
158
+ audio = audio[:, mel_start * self.hop_size:(mel_start + frames_per_seg) * self.hop_size]
159
+ else:
160
+ mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), 'constant')
161
+ audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), 'constant')
162
+
163
+ mel_loss = mel_spectrogram(audio, self.n_fft, self.num_mels,
164
+ self.sampling_rate, self.hop_size, self.win_size, self.fmin, self.fmax_loss,
165
+ center=False)
166
+
167
+ return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
168
+
169
+ def __len__(self):
170
+ return len(self.audio_files)
hiftnet/models.py ADDED
@@ -0,0 +1,664 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
5
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
6
+ from .utils import init_weights, get_padding
7
+ import numpy as np
8
+ from .stft import TorchSTFT
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+ class ResBlock1(torch.nn.Module):
13
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
14
+ super(ResBlock1, self).__init__()
15
+ self.h = h
16
+ self.convs1 = nn.ModuleList([
17
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
18
+ padding=get_padding(kernel_size, dilation[0]))),
19
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
20
+ padding=get_padding(kernel_size, dilation[1]))),
21
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
22
+ padding=get_padding(kernel_size, dilation[2])))
23
+ ])
24
+ self.convs1.apply(init_weights)
25
+
26
+ self.convs2 = nn.ModuleList([
27
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
28
+ padding=get_padding(kernel_size, 1))),
29
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
30
+ padding=get_padding(kernel_size, 1))),
31
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
32
+ padding=get_padding(kernel_size, 1)))
33
+ ])
34
+ self.convs2.apply(init_weights)
35
+
36
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
37
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
38
+
39
+
40
+ def forward(self, x):
41
+ for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.alpha1, self.alpha2):
42
+ xt = x + (1 / a1) * (torch.sin(a1 * x) ** 2) # Snake1D
43
+ xt = c1(xt)
44
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
45
+ xt = c2(xt)
46
+ x = xt + x
47
+ return x
48
+
49
+ def remove_weight_norm(self):
50
+ for l in self.convs1:
51
+ remove_weight_norm(l)
52
+ for l in self.convs2:
53
+ remove_weight_norm(l)
54
+
55
+ class ResBlock1_old(torch.nn.Module):
56
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
57
+ super(ResBlock1, self).__init__()
58
+ self.h = h
59
+ self.convs1 = nn.ModuleList([
60
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
61
+ padding=get_padding(kernel_size, dilation[0]))),
62
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
63
+ padding=get_padding(kernel_size, dilation[1]))),
64
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
65
+ padding=get_padding(kernel_size, dilation[2])))
66
+ ])
67
+ self.convs1.apply(init_weights)
68
+
69
+ self.convs2 = nn.ModuleList([
70
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
71
+ padding=get_padding(kernel_size, 1))),
72
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
73
+ padding=get_padding(kernel_size, 1))),
74
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
75
+ padding=get_padding(kernel_size, 1)))
76
+ ])
77
+ self.convs2.apply(init_weights)
78
+
79
+ def forward(self, x):
80
+ for c1, c2 in zip(self.convs1, self.convs2):
81
+ xt = F.leaky_relu(x, LRELU_SLOPE)
82
+ xt = c1(xt)
83
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
84
+ xt = c2(xt)
85
+ x = xt + x
86
+ return x
87
+
88
+ def remove_weight_norm(self):
89
+ for l in self.convs1:
90
+ remove_weight_norm(l)
91
+ for l in self.convs2:
92
+ remove_weight_norm(l)
93
+
94
+
95
+ class ResBlock2(torch.nn.Module):
96
+ def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
97
+ super(ResBlock2, self).__init__()
98
+ self.h = h
99
+ self.convs = nn.ModuleList([
100
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
101
+ padding=get_padding(kernel_size, dilation[0]))),
102
+ weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
103
+ padding=get_padding(kernel_size, dilation[1])))
104
+ ])
105
+ self.convs.apply(init_weights)
106
+
107
+ def forward(self, x):
108
+ for c in self.convs:
109
+ xt = F.leaky_relu(x, LRELU_SLOPE)
110
+ xt = c(xt)
111
+ x = xt + x
112
+ return x
113
+
114
+ def remove_weight_norm(self):
115
+ for l in self.convs:
116
+ remove_weight_norm(l)
117
+
118
+
119
+ class SineGen(torch.nn.Module):
120
+ """ Definition of sine generator
121
+ SineGen(samp_rate, harmonic_num = 0,
122
+ sine_amp = 0.1, noise_std = 0.003,
123
+ voiced_threshold = 0,
124
+ flag_for_pulse=False)
125
+ samp_rate: sampling rate in Hz
126
+ harmonic_num: number of harmonic overtones (default 0)
127
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
128
+ noise_std: std of Gaussian noise (default 0.003)
129
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
130
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
131
+ Note: when flag_for_pulse is True, the first time step of a voiced
132
+ segment is always sin(np.pi) or cos(0)
133
+ """
134
+
135
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
136
+ sine_amp=0.1, noise_std=0.003,
137
+ voiced_threshold=0,
138
+ flag_for_pulse=False):
139
+ super(SineGen, self).__init__()
140
+ self.sine_amp = sine_amp
141
+ self.noise_std = noise_std
142
+ self.harmonic_num = harmonic_num
143
+ self.dim = self.harmonic_num + 1
144
+ self.sampling_rate = samp_rate
145
+ self.voiced_threshold = voiced_threshold
146
+ self.flag_for_pulse = flag_for_pulse
147
+ self.upsample_scale = upsample_scale
148
+
149
+ def _f02uv(self, f0):
150
+ # generate uv signal
151
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
152
+ return uv
153
+
154
+ def _f02sine(self, f0_values):
155
+ """ f0_values: (batchsize, length, dim)
156
+ where dim indicates fundamental tone and overtones
157
+ """
158
+ # convert to F0 in rad. The interger part n can be ignored
159
+ # because 2 * np.pi * n doesn't affect phase
160
+ rad_values = (f0_values / self.sampling_rate) % 1
161
+
162
+ # initial phase noise (no noise for fundamental component)
163
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], \
164
+ device=f0_values.device)
165
+ rand_ini[:, 0] = 0
166
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
167
+
168
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
169
+ if not self.flag_for_pulse:
170
+ # # for normal case
171
+
172
+ # # To prevent torch.cumsum numerical overflow,
173
+ # # it is necessary to add -1 whenever \sum_k=1^n rad_value_k > 1.
174
+ # # Buffer tmp_over_one_idx indicates the time step to add -1.
175
+ # # This will not change F0 of sine because (x-1) * 2*pi = x * 2*pi
176
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
177
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
178
+ # cumsum_shift = torch.zeros_like(rad_values)
179
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
180
+
181
+ # phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
182
+ rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
183
+ scale_factor=1/self.upsample_scale,
184
+ mode="linear").transpose(1, 2)
185
+
186
+ # tmp_over_one = torch.cumsum(rad_values, 1) % 1
187
+ # tmp_over_one_idx = (padDiff(tmp_over_one)) < 0
188
+ # cumsum_shift = torch.zeros_like(rad_values)
189
+ # cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
190
+
191
+ phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
192
+ phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
193
+ scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
194
+ sines = torch.sin(phase)
195
+
196
+ else:
197
+ # If necessary, make sure that the first time step of every
198
+ # voiced segments is sin(pi) or cos(0)
199
+ # This is used for pulse-train generation
200
+
201
+ # identify the last time step in unvoiced segments
202
+ uv = self._f02uv(f0_values)
203
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
204
+ uv_1[:, -1, :] = 1
205
+ u_loc = (uv < 1) * (uv_1 > 0)
206
+
207
+ # get the instantanouse phase
208
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
209
+ # different batch needs to be processed differently
210
+ for idx in range(f0_values.shape[0]):
211
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
212
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
213
+ # stores the accumulation of i.phase within
214
+ # each voiced segments
215
+ tmp_cumsum[idx, :, :] = 0
216
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
217
+
218
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
219
+ # within the previous voiced segment.
220
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
221
+
222
+ # get the sines
223
+ sines = torch.cos(i_phase * 2 * np.pi)
224
+ return sines
225
+
226
+ def forward(self, f0):
227
+ """ sine_tensor, uv = forward(f0)
228
+ input F0: tensor(batchsize=1, length, dim=1)
229
+ f0 for unvoiced steps should be 0
230
+ output sine_tensor: tensor(batchsize=1, length, dim)
231
+ output uv: tensor(batchsize=1, length, 1)
232
+ """
233
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim,
234
+ device=f0.device)
235
+ # fundamental component
236
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
237
+
238
+ # generate sine waveforms
239
+ sine_waves = self._f02sine(fn) * self.sine_amp
240
+
241
+ # generate uv signal
242
+ # uv = torch.ones(f0.shape)
243
+ # uv = uv * (f0 > self.voiced_threshold)
244
+ uv = self._f02uv(f0)
245
+
246
+ # noise: for unvoiced should be similar to sine_amp
247
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
248
+ # . for voiced regions is self.noise_std
249
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
250
+ noise = noise_amp * torch.randn_like(sine_waves)
251
+
252
+ # first: set the unvoiced part to 0 by uv
253
+ # then: additive noise
254
+ sine_waves = sine_waves * uv + noise
255
+ return sine_waves, uv, noise
256
+
257
+
258
+ class SourceModuleHnNSF(torch.nn.Module):
259
+ """ SourceModule for hn-nsf
260
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
261
+ add_noise_std=0.003, voiced_threshod=0)
262
+ sampling_rate: sampling_rate in Hz
263
+ harmonic_num: number of harmonic above F0 (default: 0)
264
+ sine_amp: amplitude of sine source signal (default: 0.1)
265
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
266
+ note that amplitude of noise in unvoiced is decided
267
+ by sine_amp
268
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
269
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
270
+ F0_sampled (batchsize, length, 1)
271
+ Sine_source (batchsize, length, 1)
272
+ noise_source (batchsize, length 1)
273
+ uv (batchsize, length, 1)
274
+ """
275
+
276
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
277
+ add_noise_std=0.003, voiced_threshod=0):
278
+ super(SourceModuleHnNSF, self).__init__()
279
+
280
+ self.sine_amp = sine_amp
281
+ self.noise_std = add_noise_std
282
+
283
+ # to produce sine waveforms
284
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
285
+ sine_amp, add_noise_std, voiced_threshod)
286
+
287
+ # to merge source harmonics into a single excitation
288
+ self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
289
+ self.l_tanh = torch.nn.Tanh()
290
+
291
+ def forward(self, x):
292
+ """
293
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
294
+ F0_sampled (batchsize, length, 1)
295
+ Sine_source (batchsize, length, 1)
296
+ noise_source (batchsize, length 1)
297
+ """
298
+ # source for harmonic branch
299
+ with torch.no_grad():
300
+ sine_wavs, uv, _ = self.l_sin_gen(x)
301
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
302
+
303
+ # source for noise branch, in the same shape as uv
304
+ noise = torch.randn_like(uv) * self.sine_amp / 3
305
+ return sine_merge, noise, uv
306
+ def padDiff(x):
307
+ return F.pad(F.pad(x, (0,0,-1,1), 'constant', 0) - x, (0,0,0,-1), 'constant', 0)
308
+
309
+
310
+
311
+ class Generator(torch.nn.Module):
312
+ def __init__(self, h, F0_model):
313
+ super(Generator, self).__init__()
314
+ self.h = h
315
+ self.num_kernels = len(h.resblock_kernel_sizes)
316
+ self.num_upsamples = len(h.upsample_rates)
317
+ self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
318
+ resblock = ResBlock1 if h.resblock == '1' else ResBlock2
319
+
320
+ self.m_source = SourceModuleHnNSF(
321
+ sampling_rate=h.sampling_rate,
322
+ upsample_scale=np.prod(h.upsample_rates) * h.gen_istft_hop_size,
323
+ harmonic_num=8, voiced_threshod=10)
324
+ self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(h.upsample_rates) * h.gen_istft_hop_size)
325
+ self.noise_convs = nn.ModuleList()
326
+ self.noise_res = nn.ModuleList()
327
+
328
+ self.F0_model = F0_model
329
+
330
+ self.ups = nn.ModuleList()
331
+ for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
332
+ self.ups.append(weight_norm(
333
+ ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
334
+ k, u, padding=(k-u)//2)))
335
+
336
+ c_cur = h.upsample_initial_channel // (2 ** (i + 1))
337
+
338
+ if i + 1 < len(h.upsample_rates): #
339
+ stride_f0 = np.prod(h.upsample_rates[i + 1:])
340
+ self.noise_convs.append(Conv1d(
341
+ h.gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
342
+ self.noise_res.append(resblock(h, c_cur, 7, [1,3,5]))
343
+ else:
344
+ self.noise_convs.append(Conv1d(h.gen_istft_n_fft + 2, c_cur, kernel_size=1))
345
+ self.noise_res.append(resblock(h, c_cur, 11, [1,3,5]))
346
+
347
+ self.resblocks = nn.ModuleList()
348
+ for i in range(len(self.ups)):
349
+ ch = h.upsample_initial_channel//(2**(i+1))
350
+ for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
351
+ self.resblocks.append(resblock(h, ch, k, d))
352
+
353
+ self.post_n_fft = h.gen_istft_n_fft
354
+ self.conv_post = weight_norm(Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
355
+ self.ups.apply(init_weights)
356
+ self.conv_post.apply(init_weights)
357
+ self.reflection_pad = torch.nn.ReflectionPad1d((1, 0))
358
+ self.stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft)
359
+
360
+ def forward(self, x):
361
+ f0, _, _ = self.F0_model(x.unsqueeze(1))
362
+ if len(f0.shape) == 1:
363
+ f0 = f0.unsqueeze(0)
364
+
365
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
366
+
367
+ har_source, _, _ = self.m_source(f0)
368
+ har_source = har_source.transpose(1, 2).squeeze(1)
369
+ har_spec, har_phase = self.stft.transform(har_source)
370
+ har = torch.cat([har_spec, har_phase], dim=1)
371
+
372
+ x = self.conv_pre(x)
373
+ for i in range(self.num_upsamples):
374
+ x = F.leaky_relu(x, LRELU_SLOPE)
375
+ x_source = self.noise_convs[i](har)
376
+ x_source = self.noise_res[i](x_source)
377
+
378
+ x = self.ups[i](x)
379
+ if i == self.num_upsamples - 1:
380
+ x = self.reflection_pad(x)
381
+
382
+ x = x + x_source
383
+ xs = None
384
+ for j in range(self.num_kernels):
385
+ if xs is None:
386
+ xs = self.resblocks[i*self.num_kernels+j](x)
387
+ else:
388
+ xs += self.resblocks[i*self.num_kernels+j](x)
389
+ x = xs / self.num_kernels
390
+ x = F.leaky_relu(x)
391
+ x = self.conv_post(x)
392
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
393
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
394
+
395
+ return spec, phase
396
+
397
+ def remove_weight_norm(self):
398
+ print('Removing weight norm...')
399
+ for l in self.ups:
400
+ remove_weight_norm(l)
401
+ for l in self.resblocks:
402
+ l.remove_weight_norm()
403
+ remove_weight_norm(self.conv_pre)
404
+ remove_weight_norm(self.conv_post)
405
+
406
+ def stft(x, fft_size, hop_size, win_length, window):
407
+ """Perform STFT and convert to magnitude spectrogram.
408
+ Args:
409
+ x (Tensor): Input signal tensor (B, T).
410
+ fft_size (int): FFT size.
411
+ hop_size (int): Hop size.
412
+ win_length (int): Window length.
413
+ window (str): Window function type.
414
+ Returns:
415
+ Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
416
+ """
417
+ x_stft = torch.stft(x, fft_size, hop_size, win_length, window,
418
+ return_complex=True)
419
+ real = x_stft[..., 0]
420
+ imag = x_stft[..., 1]
421
+
422
+ # NOTE(kan-bayashi): clamp is needed to avoid nan or inf
423
+ return torch.abs(x_stft).transpose(2, 1)
424
+
425
+ class SpecDiscriminator(nn.Module):
426
+ """docstring for Discriminator."""
427
+
428
+ def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
429
+ super(SpecDiscriminator, self).__init__()
430
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
431
+ self.fft_size = fft_size
432
+ self.shift_size = shift_size
433
+ self.win_length = win_length
434
+ self.window = getattr(torch, window)(win_length)
435
+ self.discriminators = nn.ModuleList([
436
+ norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
437
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
438
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
439
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1,2), padding=(1, 4))),
440
+ norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1,1), padding=(1, 1))),
441
+ ])
442
+
443
+ self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
444
+
445
+ def forward(self, y):
446
+
447
+ fmap = []
448
+ y = y.squeeze(1)
449
+ y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.get_device()))
450
+ y = y.unsqueeze(1)
451
+ for i, d in enumerate(self.discriminators):
452
+ y = d(y)
453
+ y = F.leaky_relu(y, LRELU_SLOPE)
454
+ fmap.append(y)
455
+
456
+ y = self.out(y)
457
+ fmap.append(y)
458
+
459
+ return torch.flatten(y, 1, -1), fmap
460
+
461
+ class MultiResSpecDiscriminator(torch.nn.Module):
462
+
463
+ def __init__(self,
464
+ fft_sizes=[1024, 2048, 512],
465
+ hop_sizes=[120, 240, 50],
466
+ win_lengths=[600, 1200, 240],
467
+ window="hann_window"):
468
+
469
+ super(MultiResSpecDiscriminator, self).__init__()
470
+ self.discriminators = nn.ModuleList([
471
+ SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
472
+ SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
473
+ SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)
474
+ ])
475
+
476
+ def forward(self, y, y_hat):
477
+ y_d_rs = []
478
+ y_d_gs = []
479
+ fmap_rs = []
480
+ fmap_gs = []
481
+ for i, d in enumerate(self.discriminators):
482
+ y_d_r, fmap_r = d(y)
483
+ y_d_g, fmap_g = d(y_hat)
484
+ y_d_rs.append(y_d_r)
485
+ fmap_rs.append(fmap_r)
486
+ y_d_gs.append(y_d_g)
487
+ fmap_gs.append(fmap_g)
488
+
489
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
490
+
491
+
492
+ class DiscriminatorP(torch.nn.Module):
493
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
494
+ super(DiscriminatorP, self).__init__()
495
+ self.period = period
496
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
497
+ self.convs = nn.ModuleList([
498
+ norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
499
+ norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
500
+ norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
501
+ norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
502
+ norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
503
+ ])
504
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
505
+
506
+ def forward(self, x):
507
+ fmap = []
508
+
509
+ # 1d to 2d
510
+ b, c, t = x.shape
511
+ if t % self.period != 0: # pad first
512
+ n_pad = self.period - (t % self.period)
513
+ x = F.pad(x, (0, n_pad), "reflect")
514
+ t = t + n_pad
515
+ x = x.view(b, c, t // self.period, self.period)
516
+
517
+ for l in self.convs:
518
+ x = l(x)
519
+ x = F.leaky_relu(x, LRELU_SLOPE)
520
+ fmap.append(x)
521
+ x = self.conv_post(x)
522
+ fmap.append(x)
523
+ x = torch.flatten(x, 1, -1)
524
+
525
+ return x, fmap
526
+
527
+
528
+ class MultiPeriodDiscriminator(torch.nn.Module):
529
+ def __init__(self):
530
+ super(MultiPeriodDiscriminator, self).__init__()
531
+ self.discriminators = nn.ModuleList([
532
+ DiscriminatorP(2),
533
+ DiscriminatorP(3),
534
+ DiscriminatorP(5),
535
+ DiscriminatorP(7),
536
+ DiscriminatorP(11),
537
+ ])
538
+
539
+ def forward(self, y, y_hat):
540
+ y_d_rs = []
541
+ y_d_gs = []
542
+ fmap_rs = []
543
+ fmap_gs = []
544
+ for i, d in enumerate(self.discriminators):
545
+ y_d_r, fmap_r = d(y)
546
+ y_d_g, fmap_g = d(y_hat)
547
+ y_d_rs.append(y_d_r)
548
+ fmap_rs.append(fmap_r)
549
+ y_d_gs.append(y_d_g)
550
+ fmap_gs.append(fmap_g)
551
+
552
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
553
+
554
+
555
+ class DiscriminatorS(torch.nn.Module):
556
+ def __init__(self, use_spectral_norm=False):
557
+ super(DiscriminatorS, self).__init__()
558
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
559
+ self.convs = nn.ModuleList([
560
+ norm_f(Conv1d(1, 128, 15, 1, padding=7)),
561
+ norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
562
+ norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
563
+ norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
564
+ norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
565
+ norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
566
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
567
+ ])
568
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
569
+
570
+ def forward(self, x):
571
+ fmap = []
572
+ for l in self.convs:
573
+ x = l(x)
574
+ x = F.leaky_relu(x, LRELU_SLOPE)
575
+ fmap.append(x)
576
+ x = self.conv_post(x)
577
+ fmap.append(x)
578
+ x = torch.flatten(x, 1, -1)
579
+
580
+ return x, fmap
581
+
582
+
583
+ class MultiScaleDiscriminator(torch.nn.Module):
584
+ def __init__(self):
585
+ super(MultiScaleDiscriminator, self).__init__()
586
+ self.discriminators = nn.ModuleList([
587
+ DiscriminatorS(use_spectral_norm=True),
588
+ DiscriminatorS(),
589
+ DiscriminatorS(),
590
+ ])
591
+ self.meanpools = nn.ModuleList([
592
+ AvgPool1d(4, 2, padding=2),
593
+ AvgPool1d(4, 2, padding=2)
594
+ ])
595
+
596
+ def forward(self, y, y_hat):
597
+ y_d_rs = []
598
+ y_d_gs = []
599
+ fmap_rs = []
600
+ fmap_gs = []
601
+ for i, d in enumerate(self.discriminators):
602
+ if i != 0:
603
+ y = self.meanpools[i-1](y)
604
+ y_hat = self.meanpools[i-1](y_hat)
605
+ y_d_r, fmap_r = d(y)
606
+ y_d_g, fmap_g = d(y_hat)
607
+ y_d_rs.append(y_d_r)
608
+ fmap_rs.append(fmap_r)
609
+ y_d_gs.append(y_d_g)
610
+ fmap_gs.append(fmap_g)
611
+
612
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
613
+
614
+
615
+ def feature_loss(fmap_r, fmap_g):
616
+ loss = 0
617
+ for dr, dg in zip(fmap_r, fmap_g):
618
+ for rl, gl in zip(dr, dg):
619
+ loss += torch.mean(torch.abs(rl - gl))
620
+
621
+ return loss*2
622
+
623
+
624
+ def discriminator_loss(disc_real_outputs, disc_generated_outputs):
625
+ loss = 0
626
+ r_losses = []
627
+ g_losses = []
628
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
629
+ r_loss = torch.mean((1-dr)**2)
630
+ g_loss = torch.mean(dg**2)
631
+ loss += (r_loss + g_loss)
632
+ r_losses.append(r_loss.item())
633
+ g_losses.append(g_loss.item())
634
+
635
+ return loss, r_losses, g_losses
636
+
637
+
638
+ def generator_loss(disc_outputs):
639
+ loss = 0
640
+ gen_losses = []
641
+ for dg in disc_outputs:
642
+ l = torch.mean((1-dg)**2)
643
+ gen_losses.append(l)
644
+ loss += l
645
+
646
+ return loss, gen_losses
647
+
648
+ def discriminator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
649
+ loss = 0
650
+ for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
651
+ tau = 0.04
652
+ m_DG = torch.median((dr-dg))
653
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
654
+ loss += tau - F.relu(tau - L_rel)
655
+ return loss
656
+
657
+ def generator_TPRLS_loss(disc_real_outputs, disc_generated_outputs):
658
+ loss = 0
659
+ for dg, dr in zip(disc_real_outputs, disc_generated_outputs):
660
+ tau = 0.04
661
+ m_DG = torch.median((dr-dg))
662
+ L_rel = torch.mean((((dr - dg) - m_DG)**2)[dr < dg + m_DG])
663
+ loss += tau - F.relu(tau - L_rel)
664
+ return loss
hiftnet/requirements.txt ADDED
File without changes
hiftnet/stft.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BSD 3-Clause License
3
+ Copyright (c) 2017, Prem Seetharaman
4
+ All rights reserved.
5
+ * Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+ * Redistributions of source code must retain the above copyright notice,
8
+ this list of conditions and the following disclaimer.
9
+ * Redistributions in binary form must reproduce the above copyright notice, this
10
+ list of conditions and the following disclaimer in the
11
+ documentation and/or other materials provided with the distribution.
12
+ * Neither the name of the copyright holder nor the names of its
13
+ contributors may be used to endorse or promote products derived from this
14
+ software without specific prior written permission.
15
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
16
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
17
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
18
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
19
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
20
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
21
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
22
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
23
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
24
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
25
+ """
26
+
27
+ import torch
28
+ import numpy as np
29
+ import torch.nn.functional as F
30
+ from torch.autograd import Variable
31
+ from scipy.signal import get_window
32
+ from librosa.util import pad_center, tiny
33
+ import librosa.util as librosa_util
34
+
35
+ def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
36
+ n_fft=800, dtype=np.float32, norm=None):
37
+ """
38
+ # from librosa 0.6
39
+ Compute the sum-square envelope of a window function at a given hop length.
40
+ This is used to estimate modulation effects induced by windowing
41
+ observations in short-time fourier transforms.
42
+ Parameters
43
+ ----------
44
+ window : string, tuple, number, callable, or list-like
45
+ Window specification, as in `get_window`
46
+ n_frames : int > 0
47
+ The number of analysis frames
48
+ hop_length : int > 0
49
+ The number of samples to advance between frames
50
+ win_length : [optional]
51
+ The length of the window function. By default, this matches `n_fft`.
52
+ n_fft : int > 0
53
+ The length of each analysis frame.
54
+ dtype : np.dtype
55
+ The data type of the output
56
+ Returns
57
+ -------
58
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
59
+ The sum-squared envelope of the window function
60
+ """
61
+ if win_length is None:
62
+ win_length = n_fft
63
+
64
+ n = n_fft + hop_length * (n_frames - 1)
65
+ x = np.zeros(n, dtype=dtype)
66
+
67
+ # Compute the squared window at the desired length
68
+ win_sq = get_window(window, win_length, fftbins=True)
69
+ win_sq = librosa_util.normalize(win_sq, norm=norm)**2
70
+ win_sq = librosa_util.pad_center(win_sq, n_fft)
71
+
72
+ # Fill the envelope
73
+ for i in range(n_frames):
74
+ sample = i * hop_length
75
+ x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
76
+ return x
77
+
78
+
79
+ class STFT(torch.nn.Module):
80
+ """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
81
+ def __init__(self, filter_length=800, hop_length=200, win_length=800,
82
+ window='hann'):
83
+ super(STFT, self).__init__()
84
+ self.filter_length = filter_length
85
+ self.hop_length = hop_length
86
+ self.win_length = win_length
87
+ self.window = window
88
+ self.forward_transform = None
89
+ scale = self.filter_length / self.hop_length
90
+ fourier_basis = np.fft.fft(np.eye(self.filter_length))
91
+
92
+ cutoff = int((self.filter_length / 2 + 1))
93
+ fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
94
+ np.imag(fourier_basis[:cutoff, :])])
95
+
96
+ forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
97
+ inverse_basis = torch.FloatTensor(
98
+ np.linalg.pinv(scale * fourier_basis).T[:, None, :])
99
+
100
+ if window is not None:
101
+ assert(filter_length >= win_length)
102
+ # get window and zero center pad it to filter_length
103
+ fft_window = get_window(window, win_length, fftbins=True)
104
+ fft_window = pad_center(fft_window, filter_length)
105
+ fft_window = torch.from_numpy(fft_window).float()
106
+
107
+ # window the bases
108
+ forward_basis *= fft_window
109
+ inverse_basis *= fft_window
110
+
111
+ self.register_buffer('forward_basis', forward_basis.float())
112
+ self.register_buffer('inverse_basis', inverse_basis.float())
113
+
114
+ def transform(self, input_data):
115
+ num_batches = input_data.size(0)
116
+ num_samples = input_data.size(1)
117
+
118
+ self.num_samples = num_samples
119
+
120
+ # similar to librosa, reflect-pad the input
121
+ input_data = input_data.view(num_batches, 1, num_samples)
122
+ input_data = F.pad(
123
+ input_data.unsqueeze(1),
124
+ (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
125
+ mode='reflect')
126
+ input_data = input_data.squeeze(1)
127
+
128
+ forward_transform = F.conv1d(
129
+ input_data,
130
+ Variable(self.forward_basis, requires_grad=False),
131
+ stride=self.hop_length,
132
+ padding=0)
133
+
134
+ cutoff = int((self.filter_length / 2) + 1)
135
+ real_part = forward_transform[:, :cutoff, :]
136
+ imag_part = forward_transform[:, cutoff:, :]
137
+
138
+ magnitude = torch.sqrt(real_part**2 + imag_part**2)
139
+ phase = torch.autograd.Variable(
140
+ torch.atan2(imag_part.data, real_part.data))
141
+
142
+ return magnitude, phase
143
+
144
+ def inverse(self, magnitude, phase):
145
+ recombine_magnitude_phase = torch.cat(
146
+ [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
147
+
148
+ inverse_transform = F.conv_transpose1d(
149
+ recombine_magnitude_phase,
150
+ Variable(self.inverse_basis, requires_grad=False),
151
+ stride=self.hop_length,
152
+ padding=0)
153
+
154
+ if self.window is not None:
155
+ window_sum = window_sumsquare(
156
+ self.window, magnitude.size(-1), hop_length=self.hop_length,
157
+ win_length=self.win_length, n_fft=self.filter_length,
158
+ dtype=np.float32)
159
+ # remove modulation effects
160
+ approx_nonzero_indices = torch.from_numpy(
161
+ np.where(window_sum > tiny(window_sum))[0])
162
+ window_sum = torch.autograd.Variable(
163
+ torch.from_numpy(window_sum), requires_grad=False)
164
+ window_sum = window_sum.to(inverse_transform.device()) if magnitude.is_cuda else window_sum
165
+ inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
166
+
167
+ # scale by hop ratio
168
+ inverse_transform *= float(self.filter_length) / self.hop_length
169
+
170
+ inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
171
+ inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
172
+
173
+ return inverse_transform
174
+
175
+ def forward(self, input_data):
176
+ self.magnitude, self.phase = self.transform(input_data)
177
+ reconstruction = self.inverse(self.magnitude, self.phase)
178
+ return reconstruction
179
+
180
+
181
+ class TorchSTFT(torch.nn.Module):
182
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
183
+ super().__init__()
184
+ self.filter_length = filter_length
185
+ self.hop_length = hop_length
186
+ self.win_length = win_length
187
+ self.window = torch.from_numpy(get_window(window, win_length, fftbins=True).astype(np.float32))
188
+
189
+ def transform(self, input_data):
190
+ forward_transform = torch.stft(
191
+ input_data,
192
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
193
+ return_complex=True)
194
+
195
+ return torch.abs(forward_transform), torch.angle(forward_transform)
196
+
197
+ def inverse(self, magnitude, phase):
198
+ inverse_transform = torch.istft(
199
+ magnitude * torch.exp(phase * 1j),
200
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
201
+
202
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
203
+
204
+ def forward(self, input_data):
205
+ self.magnitude, self.phase = self.transform(input_data)
206
+ reconstruction = self.inverse(self.magnitude, self.phase)
207
+ return reconstruction
208
+
209
+
hiftnet/train.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.simplefilter(action='ignore', category=FutureWarning)
3
+ import itertools
4
+ import os
5
+ import time
6
+ import argparse
7
+ import json
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from torch.utils.tensorboard import SummaryWriter
11
+ from torch.utils.data import DistributedSampler, DataLoader
12
+ import torch.multiprocessing as mp
13
+ from torch.distributed import init_process_group
14
+ from torch.nn.parallel import DistributedDataParallel
15
+ from env import AttrDict, build_env
16
+ from meldataset import MelDataset, mel_spectrogram, get_dataset_filelist
17
+ from models import Generator, MultiPeriodDiscriminator, MultiResSpecDiscriminator, feature_loss, generator_loss,\
18
+ discriminator_loss, discriminator_TPRLS_loss, generator_TPRLS_loss
19
+ from utils import plot_spectrogram, scan_checkpoint, load_checkpoint, save_checkpoint
20
+ from stft import TorchSTFT
21
+ from Utils.JDC.model import JDCNet
22
+
23
+ torch.backends.cudnn.benchmark = True
24
+
25
+
26
+ def train(rank, a, h):
27
+ if h.num_gpus > 1:
28
+ init_process_group(backend=h.dist_config['dist_backend'], init_method=h.dist_config['dist_url'],
29
+ world_size=h.dist_config['world_size'] * h.num_gpus, rank=rank)
30
+
31
+ torch.cuda.manual_seed(h.seed)
32
+ device = torch.device('cuda:{:d}'.format(rank))
33
+
34
+ F0_model = JDCNet(num_class=1, seq_len=192)
35
+ params = torch.load(h.F0_path)['model']
36
+ F0_model.load_state_dict(params)
37
+
38
+ generator = Generator(h, F0_model).to(device)
39
+ mpd = MultiPeriodDiscriminator().to(device)
40
+ msd = MultiResSpecDiscriminator().to(device)
41
+ stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)
42
+
43
+ if rank == 0:
44
+ print(generator)
45
+ os.makedirs(a.checkpoint_path, exist_ok=True)
46
+ print("checkpoints directory : ", a.checkpoint_path)
47
+
48
+ if os.path.isdir(a.checkpoint_path):
49
+ cp_g = scan_checkpoint(a.checkpoint_path, 'g_')
50
+ cp_do = scan_checkpoint(a.checkpoint_path, 'do_')
51
+
52
+ steps = 0
53
+ if cp_g is None or cp_do is None:
54
+ state_dict_do = None
55
+ last_epoch = -1
56
+ else:
57
+ state_dict_g = load_checkpoint(cp_g, device)
58
+ state_dict_do = load_checkpoint(cp_do, device)
59
+ generator.load_state_dict(state_dict_g['generator'])
60
+ mpd.load_state_dict(state_dict_do['mpd'])
61
+ msd.load_state_dict(state_dict_do['msd'])
62
+ steps = state_dict_do['steps'] + 1
63
+ last_epoch = state_dict_do['epoch']
64
+
65
+ if h.num_gpus > 1:
66
+ generator = DistributedDataParallel(generator, device_ids=[rank], find_unused_parameters=True).to(device)
67
+ mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
68
+ msd = DistributedDataParallel(msd, device_ids=[rank]).to(device)
69
+
70
+ optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
71
+ optim_d = torch.optim.AdamW(itertools.chain(msd.parameters(), mpd.parameters()),
72
+ h.learning_rate, betas=[h.adam_b1, h.adam_b2])
73
+
74
+ if state_dict_do is not None:
75
+ optim_g.load_state_dict(state_dict_do['optim_g'])
76
+ optim_d.load_state_dict(state_dict_do['optim_d'])
77
+
78
+ scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
79
+ scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
80
+
81
+ training_filelist, validation_filelist = get_dataset_filelist(a)
82
+
83
+ trainset = MelDataset(training_filelist, h.segment_size, h.n_fft, h.num_mels,
84
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, n_cache_reuse=0,
85
+ shuffle=False if h.num_gpus > 1 else True, fmax_loss=h.fmax_for_loss, device=device,
86
+ fine_tuning=a.fine_tuning, base_mels_path=a.input_mels_dir)
87
+
88
+ train_sampler = DistributedSampler(trainset) if h.num_gpus > 1 else None
89
+
90
+ train_loader = DataLoader(trainset, num_workers=h.num_workers, shuffle=False,
91
+ sampler=train_sampler,
92
+ batch_size=h.batch_size,
93
+ pin_memory=True,
94
+ drop_last=True)
95
+
96
+ if rank == 0:
97
+ validset = MelDataset(validation_filelist, h.segment_size, h.n_fft, h.num_mels,
98
+ h.hop_size, h.win_size, h.sampling_rate, h.fmin, h.fmax, False, False, n_cache_reuse=0,
99
+ fmax_loss=h.fmax_for_loss, device=device, fine_tuning=a.fine_tuning,
100
+ base_mels_path=a.input_mels_dir)
101
+ validation_loader = DataLoader(validset, num_workers=1, shuffle=False,
102
+ sampler=None,
103
+ batch_size=1,
104
+ pin_memory=True,
105
+ drop_last=True)
106
+
107
+ sw = SummaryWriter(os.path.join(a.checkpoint_path, 'logs'))
108
+
109
+ generator.train()
110
+ mpd.train()
111
+ msd.train()
112
+ for epoch in range(max(0, last_epoch), a.training_epochs):
113
+ if rank == 0:
114
+ start = time.time()
115
+ print("Epoch: {}".format(epoch+1))
116
+
117
+ if h.num_gpus > 1:
118
+ train_sampler.set_epoch(epoch)
119
+
120
+ for i, batch in enumerate(train_loader):
121
+ if rank == 0:
122
+ start_b = time.time()
123
+ x, y, _, y_mel = batch
124
+ x = torch.autograd.Variable(x.to(device, non_blocking=True))
125
+ y = torch.autograd.Variable(y.to(device, non_blocking=True))
126
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
127
+ y = y.unsqueeze(1)
128
+ # y_g_hat = generator(x)
129
+ spec, phase = generator(x)
130
+
131
+ y_g_hat = stft.inverse(spec, phase)
132
+
133
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size,
134
+ h.fmin, h.fmax_for_loss)
135
+
136
+ optim_d.zero_grad()
137
+
138
+ # MPD
139
+ y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
140
+ loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
141
+ loss_disc_f += discriminator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
142
+
143
+ # MSD
144
+ y_ds_hat_r, y_ds_hat_g, _, _ = msd(y, y_g_hat.detach())
145
+ loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
146
+ loss_disc_s += discriminator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
147
+
148
+ loss_disc_all = loss_disc_s + loss_disc_f
149
+
150
+ loss_disc_all.backward()
151
+ optim_d.step()
152
+
153
+ # Generator
154
+ optim_g.zero_grad()
155
+
156
+ # L1 Mel-Spectrogram Loss
157
+ loss_mel = F.l1_loss(y_mel, y_g_hat_mel) * 45
158
+
159
+ y_df_hat_r, y_df_hat_g, fmap_f_r, fmap_f_g = mpd(y, y_g_hat)
160
+ y_ds_hat_r, y_ds_hat_g, fmap_s_r, fmap_s_g = msd(y, y_g_hat)
161
+ loss_fm_f = feature_loss(fmap_f_r, fmap_f_g)
162
+ loss_fm_s = feature_loss(fmap_s_r, fmap_s_g)
163
+ loss_gen_f, losses_gen_f = generator_loss(y_df_hat_g)
164
+ loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
165
+
166
+ loss_gen_f += generator_TPRLS_loss(y_df_hat_r, y_df_hat_g)
167
+ loss_gen_s += generator_TPRLS_loss(y_ds_hat_r, y_ds_hat_g)
168
+
169
+ loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
170
+
171
+ loss_gen_all.backward()
172
+ optim_g.step()
173
+
174
+ if rank == 0:
175
+ # STDOUT logging
176
+ if steps % a.stdout_interval == 0:
177
+ with torch.no_grad():
178
+ mel_error = F.l1_loss(y_mel, y_g_hat_mel).item()
179
+
180
+ print('Steps : {:d}, Gen Loss Total : {:4.3f}, Mel-Spec. Error : {:4.3f}, s/b : {:4.3f}'.
181
+ format(steps, loss_gen_all, mel_error, time.time() - start_b))
182
+
183
+ # checkpointing
184
+ if steps % a.checkpoint_interval == 0 and steps != 0:
185
+ checkpoint_path = "{}/g_{:08d}".format(a.checkpoint_path, steps)
186
+ save_checkpoint(checkpoint_path,
187
+ {'generator': (generator.module if h.num_gpus > 1 else generator).state_dict()})
188
+ checkpoint_path = "{}/do_{:08d}".format(a.checkpoint_path, steps)
189
+ save_checkpoint(checkpoint_path,
190
+ {'mpd': (mpd.module if h.num_gpus > 1
191
+ else mpd).state_dict(),
192
+ 'msd': (msd.module if h.num_gpus > 1
193
+ else msd).state_dict(),
194
+ 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps,
195
+ 'epoch': epoch})
196
+
197
+ # Tensorboard summary logging
198
+ if steps % a.summary_interval == 0:
199
+ sw.add_scalar("training/gen_loss_total", loss_gen_all, steps)
200
+ sw.add_scalar("training/mel_spec_error", mel_error, steps)
201
+
202
+ # Validation
203
+ if steps % a.validation_interval == 0: # and steps != 0:
204
+ generator.eval()
205
+ torch.cuda.empty_cache()
206
+ val_err_tot = 0
207
+ with torch.no_grad():
208
+ for j, batch in enumerate(validation_loader):
209
+ x, y, _, y_mel = batch
210
+ # y_g_hat = generator(x.to(device))
211
+ spec, phase = generator(x.to(device))
212
+
213
+ y_g_hat = stft.inverse(spec, phase)
214
+
215
+ y_mel = torch.autograd.Variable(y_mel.to(device, non_blocking=True))
216
+ y_g_hat_mel = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels, h.sampling_rate,
217
+ h.hop_size, h.win_size,
218
+ h.fmin, h.fmax_for_loss)
219
+ val_err_tot += F.l1_loss(y_mel, y_g_hat_mel).item()
220
+
221
+ if j <= 4:
222
+ if steps == 0:
223
+ sw.add_audio('gt/y_{}'.format(j), y[0], steps, h.sampling_rate)
224
+ sw.add_figure('gt/y_spec_{}'.format(j), plot_spectrogram(x[0]), steps)
225
+
226
+ sw.add_audio('generated/y_hat_{}'.format(j), y_g_hat[0], steps, h.sampling_rate)
227
+ y_hat_spec = mel_spectrogram(y_g_hat.squeeze(1), h.n_fft, h.num_mels,
228
+ h.sampling_rate, h.hop_size, h.win_size,
229
+ h.fmin, h.fmax)
230
+ sw.add_figure('generated/y_hat_spec_{}'.format(j),
231
+ plot_spectrogram(y_hat_spec.squeeze(0).cpu().numpy()), steps)
232
+
233
+ val_err = val_err_tot / (j+1)
234
+ sw.add_scalar("validation/mel_spec_error", val_err, steps)
235
+
236
+ generator.train()
237
+
238
+ steps += 1
239
+
240
+ scheduler_g.step()
241
+ scheduler_d.step()
242
+
243
+ if rank == 0:
244
+ print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, int(time.time() - start)))
245
+
246
+
247
+ def main():
248
+ print('Initializing Training Process..')
249
+
250
+ parser = argparse.ArgumentParser()
251
+
252
+ parser.add_argument('--group_name', default=None)
253
+ parser.add_argument('--input_wavs_dir', default='')
254
+ parser.add_argument('--input_mels_dir', default='ft_dataset')
255
+ parser.add_argument('--input_training_file', default='LJSpeech-1.1/training.txt')
256
+ parser.add_argument('--input_validation_file', default='LJSpeech-1.1/validation.txt')
257
+ parser.add_argument('--checkpoint_path', default='cp_hifigan')
258
+ parser.add_argument('--config', default='config_v1.json')
259
+ parser.add_argument('--training_epochs', default=3100, type=int)
260
+ parser.add_argument('--stdout_interval', default=5, type=int)
261
+ parser.add_argument('--checkpoint_interval', default=5000, type=int)
262
+ parser.add_argument('--summary_interval', default=100, type=int)
263
+ parser.add_argument('--validation_interval', default=1000, type=int)
264
+ parser.add_argument('--fine_tuning', default=False, type=bool)
265
+
266
+ a = parser.parse_args()
267
+
268
+ with open(a.config) as f:
269
+ data = f.read()
270
+
271
+ json_config = json.loads(data)
272
+ h = AttrDict(json_config)
273
+ build_env(a.config, 'config.json', a.checkpoint_path)
274
+
275
+ torch.manual_seed(h.seed)
276
+ if torch.cuda.is_available():
277
+ torch.cuda.manual_seed(h.seed)
278
+ h.num_gpus = torch.cuda.device_count()
279
+ h.batch_size = int(h.batch_size / h.num_gpus)
280
+ print('Batch size per GPU :', h.batch_size)
281
+ else:
282
+ pass
283
+
284
+ if h.num_gpus > 1:
285
+ mp.spawn(train, nprocs=h.num_gpus, args=(a, h,))
286
+ else:
287
+ train(0, a, h)
288
+
289
+
290
+ if __name__ == '__main__':
291
+ main()
hiftnet/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import matplotlib
4
+ import torch
5
+ from torch.nn.utils import weight_norm
6
+ matplotlib.use("Agg")
7
+ import matplotlib.pylab as plt
8
+
9
+
10
+ def plot_spectrogram(spectrogram):
11
+ fig, ax = plt.subplots(figsize=(10, 2))
12
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower",
13
+ interpolation='none')
14
+ plt.colorbar(im, ax=ax)
15
+
16
+ fig.canvas.draw()
17
+ plt.close()
18
+
19
+ return fig
20
+
21
+
22
+ def init_weights(m, mean=0.0, std=0.01):
23
+ classname = m.__class__.__name__
24
+ if classname.find("Conv") != -1:
25
+ m.weight.data.normal_(mean, std)
26
+
27
+
28
+ def apply_weight_norm(m):
29
+ classname = m.__class__.__name__
30
+ if classname.find("Conv") != -1:
31
+ weight_norm(m)
32
+
33
+
34
+ def get_padding(kernel_size, dilation=1):
35
+ return int((kernel_size*dilation - dilation)/2)
36
+
37
+
38
+ def load_checkpoint(filepath, device):
39
+ assert os.path.isfile(filepath)
40
+ print("Loading '{}'".format(filepath))
41
+ checkpoint_dict = torch.load(filepath, map_location=device)
42
+ print("Complete.")
43
+ return checkpoint_dict
44
+
45
+
46
+ def save_checkpoint(filepath, obj):
47
+ print("Saving checkpoint to {}".format(filepath))
48
+ torch.save(obj, filepath)
49
+ print("Complete.")
50
+
51
+
52
+ def scan_checkpoint(cp_dir, prefix):
53
+ pattern = os.path.join(cp_dir, prefix + '????????')
54
+ cp_list = glob.glob(pattern)
55
+ if len(cp_list) == 0:
56
+ return None
57
+ return sorted(cp_list)[-1]
58
+