cjayic commited on
Commit
f4b9544
·
1 Parent(s): b89d7de
.gitattributes CHANGED
@@ -28,6 +28,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
 
31
  *.xz filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
 
28
  *.tflite filter=lfs diff=lfs merge=lfs -text
29
  *.tgz filter=lfs diff=lfs merge=lfs -text
30
  *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.wav filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2021 Benjamin van Niekerk
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
- title: Soft Vc Widowmaker
3
- emoji: 👁
4
- colorFrom: green
5
- colorTo: blue
6
  sdk: gradio
7
  sdk_version: 3.15.0
8
  app_file: app.py
 
1
  ---
2
+ title: Soft-VC Widowmaker
3
+ emoji: 🕷️
4
+ colorFrom: black
5
+ colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.15.0
8
  app_file: app.py
acoustic/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import AcousticModel, hubert_discrete, hubert_soft
acoustic/dataset.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import numpy as np
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import Dataset
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ class MelDataset(Dataset):
11
+ def __init__(self, root: Path, train: bool = True, discrete: bool = False):
12
+ self.discrete = discrete
13
+ self.mels_dir = root / "mels"
14
+ self.units_dir = root / "discrete" if discrete else root / "soft"
15
+
16
+ pattern = "train/**/*.npy" if train else "dev/**/*.npy"
17
+ self.metadata = [
18
+ path.relative_to(self.mels_dir).with_suffix("")
19
+ for path in self.mels_dir.rglob(pattern)
20
+ ]
21
+
22
+ def __len__(self):
23
+ return len(self.metadata)
24
+
25
+ def __getitem__(self, index):
26
+ path = self.metadata[index]
27
+ mel_path = self.mels_dir / path
28
+ units_path = self.units_dir / path
29
+
30
+ mel = np.load(mel_path.with_suffix(".npy")).T
31
+ units = np.load(units_path.with_suffix(".npy"))
32
+
33
+ length = 2 * units.shape[0]
34
+
35
+ mel = torch.from_numpy(mel[:length, :])
36
+ mel = F.pad(mel, (0, 0, 1, 0))
37
+ units = torch.from_numpy(units)
38
+ if self.discrete:
39
+ units = units.long()
40
+ return mel, units
41
+
42
+ def pad_collate(self, batch):
43
+ mels, units = zip(*batch)
44
+
45
+ mels, units = list(mels), list(units)
46
+
47
+ mels_lengths = torch.tensor([x.size(0) - 1 for x in mels])
48
+ units_lengths = torch.tensor([x.size(0) for x in units])
49
+
50
+ mels = pad_sequence(mels, batch_first=True)
51
+ units = pad_sequence(
52
+ units, batch_first=True, padding_value=100 if self.discrete else 0
53
+ )
54
+
55
+ return mels, mels_lengths, units, units_lengths
acoustic/model.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
4
+
5
+ URLS = {
6
+ "hubert-discrete": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-discrete-d49e1c77.pt",
7
+ "hubert-soft": "https://github.com/bshall/acoustic-model/releases/download/v0.1/hubert-soft-0321fd7e.pt",
8
+ }
9
+
10
+
11
+ class AcousticModel(nn.Module):
12
+ def __init__(self, discrete: bool = False, upsample: bool = True):
13
+ super().__init__()
14
+ self.encoder = Encoder(discrete, upsample)
15
+ self.decoder = Decoder()
16
+
17
+ def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
18
+ x = self.encoder(x)
19
+ return self.decoder(x, mels)
20
+
21
+ @torch.inference_mode()
22
+ def generate(self, x: torch.Tensor) -> torch.Tensor:
23
+ x = self.encoder(x)
24
+ return self.decoder.generate(x)
25
+
26
+
27
+ class Encoder(nn.Module):
28
+ def __init__(self, discrete: bool = False, upsample: bool = True):
29
+ super().__init__()
30
+ self.embedding = nn.Embedding(100 + 1, 256) if discrete else None
31
+ self.prenet = PreNet(256, 256, 256)
32
+ self.convs = nn.Sequential(
33
+ nn.Conv1d(256, 512, 5, 1, 2),
34
+ nn.ReLU(),
35
+ nn.InstanceNorm1d(512),
36
+ nn.ConvTranspose1d(512, 512, 4, 2, 1) if upsample else nn.Identity(),
37
+ nn.Conv1d(512, 512, 5, 1, 2),
38
+ nn.ReLU(),
39
+ nn.InstanceNorm1d(512),
40
+ nn.Conv1d(512, 512, 5, 1, 2),
41
+ nn.ReLU(),
42
+ nn.InstanceNorm1d(512),
43
+ )
44
+
45
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
46
+ if self.embedding is not None:
47
+ x = self.embedding(x)
48
+ x = self.prenet(x)
49
+ x = self.convs(x.transpose(1, 2))
50
+ return x.transpose(1, 2)
51
+
52
+
53
+ class Decoder(nn.Module):
54
+ def __init__(self):
55
+ super().__init__()
56
+ self.prenet = PreNet(128, 256, 256)
57
+ self.lstm1 = nn.LSTM(512 + 256, 768, batch_first=True)
58
+ self.lstm2 = nn.LSTM(768, 768, batch_first=True)
59
+ self.lstm3 = nn.LSTM(768, 768, batch_first=True)
60
+ self.proj = nn.Linear(768, 128, bias=False)
61
+
62
+ def forward(self, x: torch.Tensor, mels: torch.Tensor) -> torch.Tensor:
63
+ mels = self.prenet(mels)
64
+ x, _ = self.lstm1(torch.cat((x, mels), dim=-1))
65
+ res = x
66
+ x, _ = self.lstm2(x)
67
+ x = res + x
68
+ res = x
69
+ x, _ = self.lstm3(x)
70
+ x = res + x
71
+ return self.proj(x)
72
+
73
+ @torch.inference_mode()
74
+ def generate(self, xs: torch.Tensor) -> torch.Tensor:
75
+ m = torch.zeros(xs.size(0), 128, device=xs.device)
76
+ h1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
77
+ c1 = torch.zeros(1, xs.size(0), 768, device=xs.device)
78
+ h2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
79
+ c2 = torch.zeros(1, xs.size(0), 768, device=xs.device)
80
+ h3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
81
+ c3 = torch.zeros(1, xs.size(0), 768, device=xs.device)
82
+
83
+ mel = []
84
+ for x in torch.unbind(xs, dim=1):
85
+ m = self.prenet(m)
86
+ x = torch.cat((x, m), dim=1).unsqueeze(1)
87
+ x1, (h1, c1) = self.lstm1(x, (h1, c1))
88
+ x2, (h2, c2) = self.lstm2(x1, (h2, c2))
89
+ x = x1 + x2
90
+ x3, (h3, c3) = self.lstm3(x, (h3, c3))
91
+ x = x + x3
92
+ m = self.proj(x).squeeze(1)
93
+ mel.append(m)
94
+ return torch.stack(mel, dim=1)
95
+
96
+
97
+ class PreNet(nn.Module):
98
+ def __init__(
99
+ self,
100
+ input_size: int,
101
+ hidden_size: int,
102
+ output_size: int,
103
+ dropout: float = 0.5,
104
+ ):
105
+ super().__init__()
106
+ self.net = nn.Sequential(
107
+ nn.Linear(input_size, hidden_size),
108
+ nn.ReLU(),
109
+ nn.Dropout(dropout),
110
+ nn.Linear(hidden_size, output_size),
111
+ nn.ReLU(),
112
+ nn.Dropout(dropout),
113
+ )
114
+
115
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
116
+ return self.net(x)
117
+
118
+
119
+ def _acoustic(
120
+ name: str,
121
+ discrete: bool,
122
+ upsample: bool,
123
+ pretrained: bool = True,
124
+ progress: bool = True,
125
+ ) -> AcousticModel:
126
+ acoustic = AcousticModel(discrete, upsample)
127
+ if pretrained:
128
+ checkpoint = torch.hub.load_state_dict_from_url(URLS[name], progress=progress)
129
+ consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
130
+ acoustic.load_state_dict(checkpoint["acoustic-model"])
131
+ acoustic.eval()
132
+ return acoustic
133
+
134
+
135
+ def hubert_discrete(
136
+ pretrained: bool = True,
137
+ progress: bool = True,
138
+ ) -> AcousticModel:
139
+ r"""HuBERT-Discrete acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
140
+ Args:
141
+ pretrained (bool): load pretrained weights into the model
142
+ progress (bool): show progress bar when downloading model
143
+ """
144
+ return _acoustic(
145
+ "hubert-discrete",
146
+ discrete=True,
147
+ upsample=True,
148
+ pretrained=pretrained,
149
+ progress=progress,
150
+ )
151
+
152
+
153
+ def hubert_soft(
154
+ pretrained: bool = True,
155
+ progress: bool = True,
156
+ ) -> AcousticModel:
157
+ r"""HuBERT-Soft acoustic model from `"A Comparison of Discrete and Soft Speech Units for Improved Voice Conversion"`.
158
+ Args:
159
+ pretrained (bool): load pretrained weights into the model
160
+ progress (bool): show progress bar when downloading model
161
+ """
162
+ return _acoustic(
163
+ "hubert-soft",
164
+ discrete=False,
165
+ upsample=True,
166
+ pretrained=pretrained,
167
+ progress=progress,
168
+ )
acoustic/utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import matplotlib
4
+
5
+ import torchaudio.transforms as transforms
6
+
7
+ matplotlib.use("Agg")
8
+ import matplotlib.pylab as plt
9
+
10
+
11
+ class Metric:
12
+ def __init__(self):
13
+ self.steps = 0
14
+ self.value = 0
15
+
16
+ def update(self, value):
17
+ self.steps += 1
18
+ self.value += (value - self.value) / self.steps
19
+ return self.value
20
+
21
+ def reset(self):
22
+ self.steps = 0
23
+ self.value = 0
24
+
25
+
26
+ class LogMelSpectrogram(torch.nn.Module):
27
+ def __init__(self):
28
+ super().__init__()
29
+ self.melspctrogram = transforms.MelSpectrogram(
30
+ sample_rate=16000,
31
+ n_fft=1024,
32
+ win_length=1024,
33
+ hop_length=160,
34
+ center=False,
35
+ power=1.0,
36
+ norm="slaney",
37
+ onesided=True,
38
+ n_mels=128,
39
+ mel_scale="slaney",
40
+ )
41
+
42
+ def forward(self, wav):
43
+ padding = (1024 - 160) // 2
44
+ wav = F.pad(wav, (padding, padding), "reflect")
45
+ mel = self.melspctrogram(wav)
46
+ logmel = torch.log(torch.clamp(mel, min=1e-5))
47
+ return logmel
48
+
49
+
50
+ def save_checkpoint(
51
+ checkpoint_dir,
52
+ acoustic,
53
+ optimizer,
54
+ step,
55
+ loss,
56
+ best,
57
+ logger,
58
+ ):
59
+ state = {
60
+ "acoustic-model": acoustic.state_dict(),
61
+ "optimizer": optimizer.state_dict(),
62
+ "step": step,
63
+ "loss": loss,
64
+ }
65
+ checkpoint_dir.mkdir(exist_ok=True, parents=True)
66
+ checkpoint_path = checkpoint_dir / f"model-{step}.pt"
67
+ torch.save(state, checkpoint_path)
68
+ if best:
69
+ best_path = checkpoint_dir / "model-best.pt"
70
+ torch.save(state, best_path)
71
+ logger.info(f"Saved checkpoint: {checkpoint_path.stem}")
72
+
73
+
74
+ def load_checkpoint(
75
+ load_path,
76
+ acoustic,
77
+ optimizer,
78
+ rank,
79
+ logger,
80
+ ):
81
+ logger.info(f"Loading checkpoint from {load_path}")
82
+ checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"})
83
+ acoustic.load_state_dict(checkpoint["acoustic-model"])
84
+ if "optimizer" in checkpoint:
85
+ optimizer.load_state_dict(checkpoint["optimizer"])
86
+ step = checkpoint.get("step", 0)
87
+ loss = checkpoint.get("loss", float("inf"))
88
+ return step, loss
89
+
90
+
91
+ def plot_spectrogram(spectrogram):
92
+ fig, ax = plt.subplots(figsize=(10, 2))
93
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
94
+ plt.colorbar(im, ax=ax)
95
+
96
+ fig.canvas.draw()
97
+ plt.close()
98
+
99
+ return fig
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, torchaudio
2
+ import gradio as gr
3
+ from hifigan.generator import HifiganGenerator
4
+
5
+ from acoustic import AcousticModel
6
+
7
+ #from hifigan.generator import HifiganGenerator
8
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
9
+
10
+ hubert = torch.hub.load("bshall/hubert:main", "hubert_soft").cpu()
11
+
12
+ acoustic = AcousticModel(False, True)
13
+
14
+ checkpoint = torch.load("models/acoustic-model-100000.pt", map_location=torch.device('cpu'))
15
+
16
+ consume_prefix_in_state_dict_if_present(checkpoint["acoustic-model"], "module.")
17
+ acoustic.load_state_dict(checkpoint["acoustic-model"])
18
+ acoustic.eval()
19
+
20
+ #hifigan = torch.hub.load("bshall/hifigan:main", "hifigan_hubert_soft").cpu()#.cuda()
21
+
22
+ hifigan = HifiganGenerator()
23
+ checkpoint = torch.load("models/hifigan-model-best.pt", map_location=torch.device('cpu'))
24
+ consume_prefix_in_state_dict_if_present(checkpoint["generator"]["model"], "module.")
25
+ hifigan.load_state_dict(checkpoint["generator"]["model"])
26
+ hifigan.eval()
27
+
28
+ def run_conversion(audio_in):
29
+ sr, source = audio_in
30
+
31
+ source = torch.Tensor(source)
32
+
33
+ if source.dim() == 1:
34
+ source = source.unsqueeze(1)
35
+
36
+ source = source.T
37
+
38
+ #resample to 16khz
39
+ source = torchaudio.functional.resample(source, sr, 16000)
40
+
41
+ #convert to mono
42
+ source = torch.mean(source, dim=0).unsqueeze(0)
43
+ source = source.unsqueeze(0)
44
+
45
+ with torch.inference_mode():
46
+ # Extract speech units
47
+ units = hubert.units(source)
48
+ # Generate target spectrogram
49
+ mel = acoustic.generate(units).transpose(1, 2)
50
+ # Generate audio waveform
51
+ target = hifigan(mel)
52
+
53
+ result = target.squeeze().cpu().multiply(32767).to(torch.int16).numpy()
54
+
55
+ return (16000, result)
56
+
57
+
58
+ with gr.Blocks() as demo:
59
+ with gr.Column(variant="panel"):
60
+ with gr.Row(variant="compact"):
61
+ input_audio = gr.Audio(
62
+ label="Audio to be converted",
63
+ ).style(
64
+ container=False,
65
+ )
66
+ btn = gr.Button("Widowify").style(full_width=False)
67
+ output_audio = gr.Audio(
68
+ label="Converted Audio",
69
+ elem_id="output_audio",
70
+ interactive=False
71
+ ).style(height="auto")
72
+
73
+ btn.click(run_conversion, input_audio, output_audio)
74
+ gr.Examples(["examples/jermacraft.wav","examples/meatgrinder.wav"], inputs=[input_audio])
75
+
76
+ demo.launch()
examples/jermacraft.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4a71412c1b685bf3e1e5bab0685e08fd88a51b18e682613a548ab9e8ca68835c
3
+ size 450510
examples/meatgrinder.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29324c4e5909f7eff663b3f3a17100fdf36fef1c6707ba16b4175bb21b3cb84
3
+ size 1460740
hifigan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .generator import hifigan, hifigan_hubert_discrete, hifigan_hubert_soft
hifigan/dataset.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import math
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from torch.utils.data import Dataset
9
+
10
+ import torchaudio
11
+ import torchaudio.transforms as transforms
12
+
13
+
14
+ class LogMelSpectrogram(torch.nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+ self.melspctrogram = transforms.MelSpectrogram(
18
+ sample_rate=16000,
19
+ n_fft=1024,
20
+ win_length=1024,
21
+ hop_length=160,
22
+ center=False,
23
+ power=1.0,
24
+ norm="slaney",
25
+ onesided=True,
26
+ n_mels=128,
27
+ mel_scale="slaney",
28
+ )
29
+
30
+ def forward(self, wav):
31
+ wav = F.pad(wav, ((1024 - 160) // 2, (1024 - 160) // 2), "reflect")
32
+ mel = self.melspctrogram(wav)
33
+ logmel = torch.log(torch.clamp(mel, min=1e-5))
34
+ return logmel
35
+
36
+
37
+ class MelDataset(Dataset):
38
+ def __init__(
39
+ self,
40
+ root: Path,
41
+ segment_length: int,
42
+ sample_rate: int,
43
+ hop_length: int,
44
+ train: bool = True,
45
+ finetune: bool = False,
46
+ ):
47
+ self.wavs_dir = root / "wavs"
48
+ self.mels_dir = root / "mels"
49
+ self.data_dir = self.wavs_dir if not finetune else self.mels_dir
50
+
51
+ self.segment_length = segment_length
52
+ self.sample_rate = sample_rate
53
+ self.hop_length = hop_length
54
+ self.train = train
55
+ self.finetune = finetune
56
+
57
+ suffix = ".wav" if not finetune else ".npy"
58
+ pattern = f"train/**/*{suffix}" if train else "dev/**/*{suffix}"
59
+
60
+ self.metadata = [
61
+ path.relative_to(self.data_dir).with_suffix("")
62
+ for path in self.data_dir.rglob(pattern)
63
+ ]
64
+
65
+ self.logmel = LogMelSpectrogram()
66
+
67
+ def __len__(self):
68
+ return len(self.metadata)
69
+
70
+ def __getitem__(self, index):
71
+ path = self.metadata[index]
72
+ wav_path = self.wavs_dir / path
73
+
74
+ info = torchaudio.info(wav_path.with_suffix(".wav"))
75
+ if info.sample_rate != self.sample_rate:
76
+ raise ValueError(
77
+ f"Sample rate {info.sample_rate} doesn't match target of {self.sample_rate}"
78
+ )
79
+
80
+ if self.finetune:
81
+ mel_path = self.mels_dir / path
82
+ src_logmel = torch.from_numpy(np.load(mel_path.with_suffix(".npy")))
83
+ src_logmel = src_logmel.unsqueeze(0)
84
+
85
+ mel_frames_per_segment = math.ceil(self.segment_length / self.hop_length)
86
+ mel_diff = src_logmel.size(-1) - mel_frames_per_segment if self.train else 0
87
+ mel_offset = random.randint(0, max(mel_diff, 0))
88
+
89
+ frame_offset = self.hop_length * mel_offset
90
+ else:
91
+ frame_diff = info.num_frames - self.segment_length
92
+ frame_offset = random.randint(0, max(frame_diff, 0))
93
+
94
+ wav, _ = torchaudio.load(
95
+ filepath=wav_path.with_suffix(".wav"),
96
+ frame_offset=frame_offset if self.train else 0,
97
+ num_frames=self.segment_length if self.train else -1,
98
+ )
99
+
100
+ if wav.size(-1) < self.segment_length:
101
+ wav = F.pad(wav, (0, self.segment_length - wav.size(-1)))
102
+
103
+ if not self.finetune and self.train:
104
+ gain = random.random() * (0.99 - 0.4) + 0.4
105
+ flip = -1 if random.random() > 0.5 else 1
106
+ wav = flip * gain * wav / wav.abs().max()
107
+
108
+ tgt_logmel = self.logmel(wav.unsqueeze(0)).squeeze(0)
109
+
110
+ if self.finetune:
111
+ if self.train:
112
+ src_logmel = src_logmel[
113
+ :, :, mel_offset : mel_offset + mel_frames_per_segment
114
+ ]
115
+
116
+ if src_logmel.size(-1) < mel_frames_per_segment:
117
+ src_logmel = F.pad(
118
+ src_logmel,
119
+ (0, mel_frames_per_segment - src_logmel.size(-1)),
120
+ "constant",
121
+ src_logmel.min(),
122
+ )
123
+ else:
124
+ src_logmel = tgt_logmel.clone()
125
+
126
+ return wav, src_logmel, tgt_logmel
hifigan/discriminator.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from https://github.com/jik876/hifi-gan/blob/master/models.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Tuple, List
6
+
7
+ from hifigan.utils import get_padding
8
+
9
+
10
+ LRELU_SLOPE = 0.1
11
+
12
+
13
+ class PeriodDiscriminator(torch.nn.Module):
14
+ """HiFiGAN Period Discriminator"""
15
+
16
+ def __init__(
17
+ self,
18
+ period: int,
19
+ kernel_size: int = 5,
20
+ stride: int = 3,
21
+ use_spectral_norm: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.period = period
25
+ norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
26
+ self.convs = nn.ModuleList(
27
+ [
28
+ norm_f(
29
+ nn.Conv2d(
30
+ 1,
31
+ 32,
32
+ (kernel_size, 1),
33
+ (stride, 1),
34
+ padding=(get_padding(5, 1), 0),
35
+ )
36
+ ),
37
+ norm_f(
38
+ nn.Conv2d(
39
+ 32,
40
+ 128,
41
+ (kernel_size, 1),
42
+ (stride, 1),
43
+ padding=(get_padding(5, 1), 0),
44
+ )
45
+ ),
46
+ norm_f(
47
+ nn.Conv2d(
48
+ 128,
49
+ 512,
50
+ (kernel_size, 1),
51
+ (stride, 1),
52
+ padding=(get_padding(5, 1), 0),
53
+ )
54
+ ),
55
+ norm_f(
56
+ nn.Conv2d(
57
+ 512,
58
+ 1024,
59
+ (kernel_size, 1),
60
+ (stride, 1),
61
+ padding=(get_padding(5, 1), 0),
62
+ )
63
+ ),
64
+ norm_f(nn.Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
65
+ ]
66
+ )
67
+ self.conv_post = norm_f(nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
68
+
69
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
70
+ """
71
+ Args:
72
+ x (Tensor): input waveform.
73
+ Returns:
74
+ [Tensor]: discriminator scores per sample in the batch.
75
+ [List[Tensor]]: list of features from each convolutional layer.
76
+ """
77
+ feat = []
78
+
79
+ # 1d to 2d
80
+ b, c, t = x.shape
81
+ if t % self.period != 0: # pad first
82
+ n_pad = self.period - (t % self.period)
83
+ x = F.pad(x, (0, n_pad), "reflect")
84
+ t = t + n_pad
85
+ x = x.view(b, c, t // self.period, self.period)
86
+
87
+ for l in self.convs:
88
+ x = l(x)
89
+ x = F.leaky_relu(x, LRELU_SLOPE)
90
+ feat.append(x)
91
+ x = self.conv_post(x)
92
+ feat.append(x)
93
+ x = torch.flatten(x, 1, -1)
94
+
95
+ return x, feat
96
+
97
+
98
+ class MultiPeriodDiscriminator(torch.nn.Module):
99
+ """HiFiGAN Multi-Period Discriminator (MPD)"""
100
+
101
+ def __init__(self):
102
+ super().__init__()
103
+ self.discriminators = nn.ModuleList(
104
+ [
105
+ PeriodDiscriminator(2),
106
+ PeriodDiscriminator(3),
107
+ PeriodDiscriminator(5),
108
+ PeriodDiscriminator(7),
109
+ PeriodDiscriminator(11),
110
+ ]
111
+ )
112
+
113
+ def forward(
114
+ self, x: torch.Tensor
115
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
116
+ """
117
+ Args:
118
+ x (Tensor): input waveform.
119
+ Returns:
120
+ [List[Tensor]]: list of scores from each discriminator.
121
+ [List[List[Tensor]]]: list of features from each discriminator's convolutional layers.
122
+ """
123
+ scores = []
124
+ feats = []
125
+ for _, d in enumerate(self.discriminators):
126
+ score, feat = d(x)
127
+ scores.append(score)
128
+ feats.append(feat)
129
+ return scores, feats
130
+
131
+
132
+ class ScaleDiscriminator(torch.nn.Module):
133
+ """HiFiGAN Scale Discriminator."""
134
+
135
+ def __init__(self, use_spectral_norm: bool = False) -> None:
136
+ super().__init__()
137
+ norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.weight_norm
138
+ self.convs = nn.ModuleList(
139
+ [
140
+ norm_f(nn.Conv1d(1, 128, 15, 1, padding=7)),
141
+ norm_f(nn.Conv1d(128, 128, 41, 2, groups=4, padding=20)),
142
+ norm_f(nn.Conv1d(128, 256, 41, 2, groups=16, padding=20)),
143
+ norm_f(nn.Conv1d(256, 512, 41, 4, groups=16, padding=20)),
144
+ norm_f(nn.Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
145
+ norm_f(nn.Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
146
+ norm_f(nn.Conv1d(1024, 1024, 5, 1, padding=2)),
147
+ ]
148
+ )
149
+ self.conv_post = norm_f(nn.Conv1d(1024, 1, 3, 1, padding=1))
150
+
151
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
152
+ """
153
+ Args:
154
+ x (Tensor): input waveform.
155
+ Returns:
156
+ Tensor: discriminator scores.
157
+ List[Tensor]: list of features from the convolutional layers.
158
+ """
159
+ feat = []
160
+ for l in self.convs:
161
+ x = l(x)
162
+ x = F.leaky_relu(x, LRELU_SLOPE)
163
+ feat.append(x)
164
+ x = self.conv_post(x)
165
+ feat.append(x)
166
+ x = torch.flatten(x, 1, -1)
167
+ return x, feat
168
+
169
+
170
+ class MultiScaleDiscriminator(torch.nn.Module):
171
+ """HiFiGAN Multi-Scale Discriminator."""
172
+
173
+ def __init__(self):
174
+ super().__init__()
175
+ self.discriminators = nn.ModuleList(
176
+ [
177
+ ScaleDiscriminator(use_spectral_norm=True),
178
+ ScaleDiscriminator(),
179
+ ScaleDiscriminator(),
180
+ ]
181
+ )
182
+ self.meanpools = nn.ModuleList(
183
+ [nn.AvgPool1d(4, 2, padding=2), nn.AvgPool1d(4, 2, padding=2)]
184
+ )
185
+
186
+ def forward(
187
+ self, x: torch.Tensor
188
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
189
+ """
190
+ Args:
191
+ x (Tensor): input waveform.
192
+ Returns:
193
+ List[Tensor]: discriminator scores.
194
+ List[List[Tensor]]: list of features from each discriminator's convolutional layers.
195
+ """
196
+ scores = []
197
+ feats = []
198
+ for i, d in enumerate(self.discriminators):
199
+ if i != 0:
200
+ x = self.meanpools[i - 1](x)
201
+ score, feat = d(x)
202
+ scores.append(score)
203
+ feats.append(feat)
204
+ return scores, feats
205
+
206
+
207
+ class HifiganDiscriminator(nn.Module):
208
+ """HiFiGAN discriminator"""
209
+
210
+ def __init__(self):
211
+ super().__init__()
212
+ self.mpd = MultiPeriodDiscriminator()
213
+ self.msd = MultiScaleDiscriminator()
214
+
215
+ def forward(
216
+ self, x: torch.Tensor
217
+ ) -> Tuple[List[torch.Tensor], List[List[torch.Tensor]]]:
218
+ """
219
+ Args:
220
+ x (Tensor): input waveform.
221
+ Returns:
222
+ List[Tensor]: discriminator scores.
223
+ List[List[Tensor]]: list of features from from each discriminator's convolutional layers.
224
+ """
225
+ scores, feats = self.mpd(x)
226
+ scores_, feats_ = self.msd(x)
227
+ return scores + scores_, feats + feats_
228
+
229
+
230
+ def feature_loss(
231
+ features_real: List[List[torch.Tensor]], features_generate: List[List[torch.Tensor]]
232
+ ) -> float:
233
+ loss = 0
234
+ for r, g in zip(features_real, features_generate):
235
+ for rl, gl in zip(r, g):
236
+ loss += torch.mean(torch.abs(rl - gl))
237
+ return loss * 2
238
+
239
+
240
+ def discriminator_loss(real, generated):
241
+ loss = 0
242
+ real_losses = []
243
+ generated_losses = []
244
+ for r, g in zip(real, generated):
245
+ r_loss = torch.mean((1 - r) ** 2)
246
+ g_loss = torch.mean(g ** 2)
247
+ loss += r_loss + g_loss
248
+ real_losses.append(r_loss.item())
249
+ generated_losses.append(g_loss.item())
250
+
251
+ return loss, real_losses, generated_losses
252
+
253
+
254
+ def generator_loss(discriminator_outputs):
255
+ loss = 0
256
+ generator_losses = []
257
+ for x in discriminator_outputs:
258
+ l = torch.mean((1 - x) ** 2)
259
+ generator_losses.append(l)
260
+ loss += l
261
+
262
+ return loss, generator_losses
hifigan/generator.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from https://github.com/jik876/hifi-gan/blob/master/models.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn.utils import remove_weight_norm, weight_norm
6
+ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
7
+ from typing import Tuple
8
+
9
+ from hifigan.utils import get_padding
10
+
11
+
12
+ URLS = {
13
+ "hifigan": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-67926ec6.pt",
14
+ "hifigan-hubert-soft": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-discrete-bbad3043.pt",
15
+ "hifigan-hubert-discrete": "https://github.com/bshall/hifigan/releases/download/v0.1/hifigan-hubert-soft-65f03469.pt",
16
+ }
17
+
18
+ LRELU_SLOPE = 0.1
19
+
20
+
21
+ class HifiganGenerator(torch.nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels: int = 128,
25
+ resblock_dilation_sizes: Tuple[Tuple[int, ...], ...] = (
26
+ (1, 3, 5),
27
+ (1, 3, 5),
28
+ (1, 3, 5),
29
+ ),
30
+ resblock_kernel_sizes: Tuple[int, ...] = (3, 7, 11),
31
+ upsample_kernel_sizes: Tuple[int, ...] = (20, 8, 4, 4),
32
+ upsample_initial_channel: int = 512,
33
+ upsample_factors: int = (10, 4, 2, 2),
34
+ inference_padding: int = 5,
35
+ sample_rate: int = 16000,
36
+ ) -> None:
37
+ r"""HiFiGAN Generator
38
+ Args:
39
+ in_channels (int): number of input channels.
40
+ resblock_dilation_sizes (Tuple[Tuple[int, ...], ...]): list of dilation values in each layer of a `ResBlock`.
41
+ resblock_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each `ResBlock`.
42
+ upsample_kernel_sizes (Tuple[int, ...]): list of kernel sizes for each transposed convolution.
43
+ upsample_initial_channel (int): number of channels for the first upsampling layer. This is divided by 2
44
+ for each consecutive upsampling layer.
45
+ upsample_factors (Tuple[int, ...]): upsampling factors (stride) for each upsampling layer.
46
+ inference_padding (int): constant padding applied to the input at inference time.
47
+ sample_rate (int): sample rate of the generated audio.
48
+ """
49
+ super().__init__()
50
+ self.inference_padding = inference_padding
51
+ self.num_kernels = len(resblock_kernel_sizes)
52
+ self.num_upsamples = len(upsample_factors)
53
+ self.sample_rate = sample_rate
54
+ # initial upsampling layers
55
+ self.conv_pre = weight_norm(
56
+ nn.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
57
+ )
58
+
59
+ # upsampling layers
60
+ self.ups = nn.ModuleList()
61
+ for i, (u, k) in enumerate(zip(upsample_factors, upsample_kernel_sizes)):
62
+ self.ups.append(
63
+ weight_norm(
64
+ nn.ConvTranspose1d(
65
+ upsample_initial_channel // (2 ** i),
66
+ upsample_initial_channel // (2 ** (i + 1)),
67
+ k,
68
+ u,
69
+ padding=(k - u) // 2,
70
+ )
71
+ )
72
+ )
73
+ # MRF blocks
74
+ self.resblocks = nn.ModuleList()
75
+ for i in range(len(self.ups)):
76
+ ch = upsample_initial_channel // (2 ** (i + 1))
77
+ for _, (k, d) in enumerate(
78
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
79
+ ):
80
+ self.resblocks.append(ResBlock1(ch, k, d))
81
+ # post convolution layer
82
+ self.conv_post = weight_norm(nn.Conv1d(ch, 1, 7, 1, padding=3))
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ o = self.conv_pre(x)
86
+ for i in range(self.num_upsamples):
87
+ o = F.leaky_relu(o, LRELU_SLOPE)
88
+ o = self.ups[i](o)
89
+ z_sum = None
90
+ for j in range(self.num_kernels):
91
+ if z_sum is None:
92
+ z_sum = self.resblocks[i * self.num_kernels + j](o)
93
+ else:
94
+ z_sum += self.resblocks[i * self.num_kernels + j](o)
95
+ o = z_sum / self.num_kernels
96
+ o = F.leaky_relu(o)
97
+ o = self.conv_post(o)
98
+ o = torch.tanh(o)
99
+ return o
100
+
101
+ @torch.no_grad()
102
+ def generate(self, x: torch.Tensor) -> torch.Tensor:
103
+ x = F.pad(x, (self.inference_padding, self.inference_padding), "replicate")
104
+ return self(x), self.sample_rate
105
+
106
+ def remove_weight_norm(self):
107
+ print("Removing weight norm...")
108
+ for l in self.ups:
109
+ remove_weight_norm(l)
110
+ for l in self.resblocks:
111
+ l.remove_weight_norm()
112
+ remove_weight_norm(self.conv_pre)
113
+ remove_weight_norm(self.conv_post)
114
+
115
+
116
+ class ResBlock1(torch.nn.Module):
117
+ def __init__(
118
+ self, channels: int, kernel_size: int = 3, dilation: Tuple[int, ...] = (1, 3, 5)
119
+ ) -> None:
120
+ super().__init__()
121
+ self.convs1 = nn.ModuleList(
122
+ [
123
+ weight_norm(
124
+ nn.Conv1d(
125
+ channels,
126
+ channels,
127
+ kernel_size,
128
+ 1,
129
+ dilation=dilation[0],
130
+ padding=get_padding(kernel_size, dilation[0]),
131
+ )
132
+ ),
133
+ weight_norm(
134
+ nn.Conv1d(
135
+ channels,
136
+ channels,
137
+ kernel_size,
138
+ 1,
139
+ dilation=dilation[1],
140
+ padding=get_padding(kernel_size, dilation[1]),
141
+ )
142
+ ),
143
+ weight_norm(
144
+ nn.Conv1d(
145
+ channels,
146
+ channels,
147
+ kernel_size,
148
+ 1,
149
+ dilation=dilation[2],
150
+ padding=get_padding(kernel_size, dilation[2]),
151
+ )
152
+ ),
153
+ ]
154
+ )
155
+
156
+ self.convs2 = nn.ModuleList(
157
+ [
158
+ weight_norm(
159
+ nn.Conv1d(
160
+ channels,
161
+ channels,
162
+ kernel_size,
163
+ 1,
164
+ dilation=1,
165
+ padding=get_padding(kernel_size, 1),
166
+ )
167
+ ),
168
+ weight_norm(
169
+ nn.Conv1d(
170
+ channels,
171
+ channels,
172
+ kernel_size,
173
+ 1,
174
+ dilation=1,
175
+ padding=get_padding(kernel_size, 1),
176
+ )
177
+ ),
178
+ weight_norm(
179
+ nn.Conv1d(
180
+ channels,
181
+ channels,
182
+ kernel_size,
183
+ 1,
184
+ dilation=1,
185
+ padding=get_padding(kernel_size, 1),
186
+ )
187
+ ),
188
+ ]
189
+ )
190
+
191
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
192
+ for c1, c2 in zip(self.convs1, self.convs2):
193
+ xt = F.leaky_relu(x, LRELU_SLOPE)
194
+ xt = c1(xt)
195
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
196
+ xt = c2(xt)
197
+ x = xt + x
198
+ return x
199
+
200
+ def remove_weight_norm(self):
201
+ for l in self.convs1:
202
+ remove_weight_norm(l)
203
+ for l in self.convs2:
204
+ remove_weight_norm(l)
205
+
206
+
207
+ class ResBlock2(torch.nn.Module):
208
+ def __init__(
209
+ self, channels: int, kernel_size: int = 3, dilation: Tuple[int, ...] = (1, 3)
210
+ ) -> None:
211
+ super().__init__()
212
+ self.convs = nn.ModuleList(
213
+ [
214
+ weight_norm(
215
+ nn.Conv1d(
216
+ channels,
217
+ channels,
218
+ kernel_size,
219
+ 1,
220
+ dilation=dilation[0],
221
+ padding=get_padding(kernel_size, dilation[0]),
222
+ )
223
+ ),
224
+ weight_norm(
225
+ nn.Conv1d(
226
+ channels,
227
+ channels,
228
+ kernel_size,
229
+ 1,
230
+ dilation=dilation[1],
231
+ padding=get_padding(kernel_size, dilation[1]),
232
+ )
233
+ ),
234
+ ]
235
+ )
236
+
237
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
238
+ for c in self.convs:
239
+ xt = F.leaky_relu(x, LRELU_SLOPE)
240
+ xt = c(xt)
241
+ x = xt + x
242
+ return x
243
+
244
+ def remove_weight_norm(self):
245
+ for l in self.convs:
246
+ remove_weight_norm(l)
247
+
248
+
249
+ def _hifigan(
250
+ name: str,
251
+ pretrained: bool = True,
252
+ progress: bool = True,
253
+ map_location=None,
254
+ ) -> HifiganGenerator:
255
+ hifigan = HifiganGenerator()
256
+ if pretrained:
257
+ checkpoint = torch.hub.load_state_dict_from_url(
258
+ URLS[name], map_location=map_location, progress=progress
259
+ )
260
+ consume_prefix_in_state_dict_if_present(checkpoint, "module.")
261
+ hifigan.load_state_dict(checkpoint)
262
+ hifigan.eval()
263
+ hifigan.remove_weight_norm()
264
+ return hifigan
265
+
266
+
267
+ def hifigan(
268
+ pretrained: bool = True, progress: bool = True, map_location=None
269
+ ) -> HifiganGenerator:
270
+ return _hifigan("hifigan", pretrained, progress, map_location)
271
+
272
+
273
+ def hifigan_hubert_soft(
274
+ pretrained: bool = True, progress: bool = True, map_location=None
275
+ ) -> HifiganGenerator:
276
+ return _hifigan("hifigan-hubert-soft", pretrained, progress, map_location=None)
277
+
278
+
279
+ def hifigan_hubert_discrete(
280
+ pretrained: bool = True, progress: bool = True, map_location=None
281
+ ) -> HifiganGenerator:
282
+ return _hifigan("hifigan-hubert-discrete", pretrained, progress, map_location=None)
hifigan/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib
3
+
4
+ matplotlib.use("Agg")
5
+ import matplotlib.pylab as plt
6
+
7
+
8
+ def get_padding(k, d):
9
+ return int((k * d - d) / 2)
10
+
11
+
12
+ def plot_spectrogram(spectrogram):
13
+ fig, ax = plt.subplots(figsize=(10, 2))
14
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
15
+ plt.colorbar(im, ax=ax)
16
+
17
+ fig.canvas.draw()
18
+ plt.close()
19
+
20
+ return fig
21
+
22
+
23
+ def save_checkpoint(
24
+ checkpoint_dir,
25
+ generator,
26
+ discriminator,
27
+ optimizer_generator,
28
+ optimizer_discriminator,
29
+ scheduler_generator,
30
+ scheduler_discriminator,
31
+ step,
32
+ loss,
33
+ best,
34
+ logger,
35
+ ):
36
+ state = {
37
+ "generator": {
38
+ "model": generator.state_dict(),
39
+ "optimizer": optimizer_generator.state_dict(),
40
+ "scheduler": scheduler_generator.state_dict(),
41
+ },
42
+ "discriminator": {
43
+ "model": discriminator.state_dict(),
44
+ "optimizer": optimizer_discriminator.state_dict(),
45
+ "scheduler": scheduler_discriminator.state_dict(),
46
+ },
47
+ "step": step,
48
+ "loss": loss,
49
+ }
50
+ checkpoint_dir.mkdir(exist_ok=True, parents=True)
51
+ checkpoint_path = checkpoint_dir / f"model-{step}.pt"
52
+ torch.save(state, checkpoint_path)
53
+ if best:
54
+ best_path = checkpoint_dir / "model-best.pt"
55
+ torch.save(state, best_path)
56
+ logger.info(f"Saved checkpoint: {checkpoint_path.stem}")
57
+
58
+
59
+ def load_checkpoint(
60
+ load_path,
61
+ generator,
62
+ discriminator,
63
+ optimizer_generator,
64
+ optimizer_discriminator,
65
+ scheduler_generator,
66
+ scheduler_discriminator,
67
+ rank,
68
+ logger,
69
+ finetune=False,
70
+ ):
71
+ logger.info(f"Loading checkpoint from {load_path}")
72
+ checkpoint = torch.load(load_path, map_location={"cuda:0": f"cuda:{rank}"})
73
+ generator.load_state_dict(checkpoint["generator"]["model"])
74
+ discriminator.load_state_dict(checkpoint["discriminator"]["model"])
75
+ if not finetune:
76
+ optimizer_generator.load_state_dict(checkpoint["generator"]["optimizer"])
77
+ scheduler_generator.load_state_dict(checkpoint["generator"]["scheduler"])
78
+ optimizer_discriminator.load_state_dict(
79
+ checkpoint["discriminator"]["optimizer"]
80
+ )
81
+ scheduler_discriminator.load_state_dict(
82
+ checkpoint["discriminator"]["scheduler"]
83
+ )
84
+ return checkpoint["step"], checkpoint["loss"]
models/acoustic-model-100000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bab1ca079f6d3454cbe20be736c2fea003ddb8425acf5a451bc0b8e8975d6d99
3
+ size 225997291
models/hifigan-model-best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e2c4c04b6a829854ccd9eb5eac3b0f7a434fc1e94809e6662e2be79e6f930c49
3
+ size 1021686329
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ gradio