ZLSCompLing commited on
Commit
b9c74dd
·
verified ·
1 Parent(s): 9c3132d

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ monotonic_align/build/temp.linux-x86_64-3.10/core.o filter=lfs diff=lfs merge=lfs -text
37
+ monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,3 +1,121 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - lb
5
+ tags:
6
+ - text-to-speech
7
+ - tts
8
+ - vits2
9
+ - luxembourgish
10
+ pipeline_tag: text-to-speech
11
+ ---
12
+
13
+ # VITS2 - Claude (Luxembourgish Gender-Neutral Voice)
14
+
15
+ A VITS2-based text-to-speech model for Luxembourgish, featuring a synthetic gender-neutral voice.
16
+
17
+ ## Model Description
18
+
19
+ This model was trained using the VITS2 architecture on Luxembourgish speech data from the [Lëtzebuerger Online Dictionnaire (LOD)](https://lod.lu) example sentences.
20
+
21
+ "Claude" is a synthetic gender-neutral Luxembourgish voice created by modulating the original LOD recordings.
22
+
23
+ ### Model Details
24
+
25
+ - **Architecture:** VITS2 with duration discriminator and transformer flows
26
+ - **Language:** Luxembourgish (lb)
27
+ - **Speaker:** Single speaker (gender-neutral, synthetic)
28
+ - **Sample Rate:** 24000 Hz
29
+ - **Checkpoint:** G_57000 (57,000 steps)
30
+ - **License:** MIT
31
+
32
+ ## Usage
33
+
34
+ This model requires the included Python source files for inference.
35
+
36
+ ### Basic Usage
37
+
38
+ ```python
39
+ import torch
40
+ import scipy.io.wavfile as wavfile
41
+ from vits2_engine import VITS2Engine
42
+
43
+ # Load the model
44
+ engine = VITS2Engine(model_dir="path/to/vits2-claude")
45
+
46
+ # Generate speech
47
+ wav = engine.tts("Moien, wéi geet et dir?")
48
+
49
+ # Save to file
50
+ wavfile.write("output.wav", engine.sample_rate, wav)
51
+ ```
52
+
53
+ ### Command Line
54
+
55
+ ```bash
56
+ python inference.py "Moien, wéi geet et dir?"
57
+
58
+ # With custom parameters
59
+ python inference.py "Text" --noise_scale 0.5 --length_scale 1.1 -o output.wav
60
+ ```
61
+
62
+ ### Parameters
63
+
64
+ - `noise_scale`: Controls voice variation (default: 0.667, lower = more consistent)
65
+ - `noise_scale_w`: Controls duration variation (default: 0.8)
66
+ - `length_scale`: Controls speech speed (default: 1.0, higher = slower)
67
+
68
+ ## Technical Specifications
69
+
70
+ | Parameter | Value |
71
+ |-----------|-------|
72
+ | Hidden Channels | 192 |
73
+ | Filter Channels | 768 |
74
+ | Attention Heads | 2 |
75
+ | Encoder Layers | 6 |
76
+ | Mel Channels | 80 |
77
+ | FFT Size | 1024 |
78
+ | Hop Length | 256 |
79
+
80
+ ### Architecture Features
81
+
82
+ - Mel posterior encoder
83
+ - Transformer flows (pre-conv type)
84
+ - Duration discriminator
85
+ - Noise-scaled MAS (Monotonic Alignment Search)
86
+
87
+ ## Requirements
88
+
89
+ - Python 3.8+
90
+ - PyTorch
91
+ - scipy
92
+ - numpy
93
+ - Cython (for monotonic_align)
94
+
95
+ ## Citation
96
+
97
+ If you use this model, please cite:
98
+
99
+ ```bibtex
100
+ @misc{zls2025vits2claude,
101
+ title={VITS2 Claude - Luxembourgish Gender-Neutral Voice},
102
+ author={Zenter fir d'Lëtzebuerger Sprooch},
103
+ year={2025},
104
+ publisher={Hugging Face},
105
+ url={https://huggingface.co/ZLSCompLing/VITS2-Claude}
106
+ }
107
+ ```
108
+
109
+ ## Acknowledgments
110
+
111
+ Developed by [Zenter fir d'Lëtzebuerger Sprooch](https://zls.lu).
112
+
113
+ Voice data sourced from the [Lëtzebuerger Online Dictionnaire (LOD)](https://lod.lu). The original audio files are available via the [LOD linguistic data on data.public.lu](https://data.public.lu/en/datasets/letzebuerger-online-dictionnaire-lod-linguistesch-daten/), which provides an XML file containing example sentence IDs. Audio files can be accessed at:
114
+
115
+ ```
116
+ https://lod.lu/uploads/examples/AAC/{folder}/{id}.m4a
117
+ ```
118
+
119
+ where `{folder}` is the first 2 characters of `{id}`.
120
+
121
+ This model is used in [Sproochmaschinn](https://sproochmaschinn.lu), a Luxembourgish speech processing platform.
__pycache__/attentions.cpython-310.pyc ADDED
Binary file (14.6 kB). View file
 
__pycache__/commons.cpython-310.pyc ADDED
Binary file (5.8 kB). View file
 
__pycache__/models.cpython-310.pyc ADDED
Binary file (26.2 kB). View file
 
__pycache__/modules.cpython-310.pyc ADDED
Binary file (11.7 kB). View file
 
__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.96 kB). View file
 
__pycache__/utils.cpython-310.pyc ADDED
Binary file (9.02 kB). View file
 
attentions.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torch.nn.utils import remove_weight_norm, weight_norm
8
+
9
+ import commons
10
+ import modules
11
+ from modules import LayerNorm
12
+
13
+
14
+ class Encoder(nn.Module): # backward compatible vits2 encoder
15
+ def __init__(
16
+ self,
17
+ hidden_channels,
18
+ filter_channels,
19
+ n_heads,
20
+ n_layers,
21
+ kernel_size=1,
22
+ p_dropout=0.0,
23
+ window_size=4,
24
+ **kwargs
25
+ ):
26
+ super().__init__()
27
+ self.hidden_channels = hidden_channels
28
+ self.filter_channels = filter_channels
29
+ self.n_heads = n_heads
30
+ self.n_layers = n_layers
31
+ self.kernel_size = kernel_size
32
+ self.p_dropout = p_dropout
33
+ self.window_size = window_size
34
+
35
+ self.drop = nn.Dropout(p_dropout)
36
+ self.attn_layers = nn.ModuleList()
37
+ self.norm_layers_1 = nn.ModuleList()
38
+ self.ffn_layers = nn.ModuleList()
39
+ self.norm_layers_2 = nn.ModuleList()
40
+ # if kwargs has spk_emb_dim, then add a linear layer to project spk_emb_dim to hidden_channels
41
+ self.cond_layer_idx = self.n_layers
42
+ if "gin_channels" in kwargs:
43
+ self.gin_channels = kwargs["gin_channels"]
44
+ if self.gin_channels != 0:
45
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
46
+ # vits2 says 3rd block, so idx is 2 by default
47
+ self.cond_layer_idx = (
48
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
49
+ )
50
+ assert (
51
+ self.cond_layer_idx < self.n_layers
52
+ ), "cond_layer_idx should be less than n_layers"
53
+
54
+ for i in range(self.n_layers):
55
+ self.attn_layers.append(
56
+ MultiHeadAttention(
57
+ hidden_channels,
58
+ hidden_channels,
59
+ n_heads,
60
+ p_dropout=p_dropout,
61
+ window_size=window_size,
62
+ )
63
+ )
64
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
65
+ self.ffn_layers.append(
66
+ FFN(
67
+ hidden_channels,
68
+ hidden_channels,
69
+ filter_channels,
70
+ kernel_size,
71
+ p_dropout=p_dropout,
72
+ )
73
+ )
74
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
75
+
76
+ def forward(self, x, x_mask, g=None):
77
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
78
+ x = x * x_mask
79
+ for i in range(self.n_layers):
80
+ if i == self.cond_layer_idx and g is not None:
81
+ g = self.spk_emb_linear(g.transpose(1, 2))
82
+ g = g.transpose(1, 2)
83
+ x = x + g
84
+ x = x * x_mask
85
+ y = self.attn_layers[i](x, x, attn_mask)
86
+ y = self.drop(y)
87
+ x = self.norm_layers_1[i](x + y)
88
+
89
+ y = self.ffn_layers[i](x, x_mask)
90
+ y = self.drop(y)
91
+ x = self.norm_layers_2[i](x + y)
92
+ x = x * x_mask
93
+ return x
94
+
95
+
96
+ class Decoder(nn.Module):
97
+ def __init__(
98
+ self,
99
+ hidden_channels,
100
+ filter_channels,
101
+ n_heads,
102
+ n_layers,
103
+ kernel_size=1,
104
+ p_dropout=0.0,
105
+ proximal_bias=False,
106
+ proximal_init=True,
107
+ **kwargs
108
+ ):
109
+ super().__init__()
110
+ self.hidden_channels = hidden_channels
111
+ self.filter_channels = filter_channels
112
+ self.n_heads = n_heads
113
+ self.n_layers = n_layers
114
+ self.kernel_size = kernel_size
115
+ self.p_dropout = p_dropout
116
+ self.proximal_bias = proximal_bias
117
+ self.proximal_init = proximal_init
118
+
119
+ self.drop = nn.Dropout(p_dropout)
120
+ self.self_attn_layers = nn.ModuleList()
121
+ self.norm_layers_0 = nn.ModuleList()
122
+ self.encdec_attn_layers = nn.ModuleList()
123
+ self.norm_layers_1 = nn.ModuleList()
124
+ self.ffn_layers = nn.ModuleList()
125
+ self.norm_layers_2 = nn.ModuleList()
126
+ for i in range(self.n_layers):
127
+ self.self_attn_layers.append(
128
+ MultiHeadAttention(
129
+ hidden_channels,
130
+ hidden_channels,
131
+ n_heads,
132
+ p_dropout=p_dropout,
133
+ proximal_bias=proximal_bias,
134
+ proximal_init=proximal_init,
135
+ )
136
+ )
137
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
138
+ self.encdec_attn_layers.append(
139
+ MultiHeadAttention(
140
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
141
+ )
142
+ )
143
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
144
+ self.ffn_layers.append(
145
+ FFN(
146
+ hidden_channels,
147
+ hidden_channels,
148
+ filter_channels,
149
+ kernel_size,
150
+ p_dropout=p_dropout,
151
+ causal=True,
152
+ )
153
+ )
154
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
155
+
156
+ def forward(self, x, x_mask, h, h_mask):
157
+ """
158
+ x: decoder input
159
+ h: encoder output
160
+ """
161
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
162
+ device=x.device, dtype=x.dtype
163
+ )
164
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
165
+ x = x * x_mask
166
+ for i in range(self.n_layers):
167
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
168
+ y = self.drop(y)
169
+ x = self.norm_layers_0[i](x + y)
170
+
171
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
172
+ y = self.drop(y)
173
+ x = self.norm_layers_1[i](x + y)
174
+
175
+ y = self.ffn_layers[i](x, x_mask)
176
+ y = self.drop(y)
177
+ x = self.norm_layers_2[i](x + y)
178
+ x = x * x_mask
179
+ return x
180
+
181
+
182
+ class MultiHeadAttention(nn.Module):
183
+ def __init__(
184
+ self,
185
+ channels,
186
+ out_channels,
187
+ n_heads,
188
+ p_dropout=0.0,
189
+ window_size=None,
190
+ heads_share=True,
191
+ block_length=None,
192
+ proximal_bias=False,
193
+ proximal_init=False,
194
+ ):
195
+ super().__init__()
196
+ assert channels % n_heads == 0
197
+
198
+ self.channels = channels
199
+ self.out_channels = out_channels
200
+ self.n_heads = n_heads
201
+ self.p_dropout = p_dropout
202
+ self.window_size = window_size
203
+ self.heads_share = heads_share
204
+ self.block_length = block_length
205
+ self.proximal_bias = proximal_bias
206
+ self.proximal_init = proximal_init
207
+ self.attn = None
208
+
209
+ self.k_channels = channels // n_heads
210
+ self.conv_q = nn.Conv1d(channels, channels, 1)
211
+ self.conv_k = nn.Conv1d(channels, channels, 1)
212
+ self.conv_v = nn.Conv1d(channels, channels, 1)
213
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
214
+ self.drop = nn.Dropout(p_dropout)
215
+
216
+ if window_size is not None:
217
+ n_heads_rel = 1 if heads_share else n_heads
218
+ rel_stddev = self.k_channels**-0.5
219
+ self.emb_rel_k = nn.Parameter(
220
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
221
+ * rel_stddev
222
+ )
223
+ self.emb_rel_v = nn.Parameter(
224
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
225
+ * rel_stddev
226
+ )
227
+
228
+ nn.init.xavier_uniform_(self.conv_q.weight)
229
+ nn.init.xavier_uniform_(self.conv_k.weight)
230
+ nn.init.xavier_uniform_(self.conv_v.weight)
231
+ if proximal_init:
232
+ with torch.no_grad():
233
+ self.conv_k.weight.copy_(self.conv_q.weight)
234
+ self.conv_k.bias.copy_(self.conv_q.bias)
235
+
236
+ def forward(self, x, c, attn_mask=None):
237
+ q = self.conv_q(x)
238
+ k = self.conv_k(c)
239
+ v = self.conv_v(c)
240
+
241
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
242
+
243
+ x = self.conv_o(x)
244
+ return x
245
+
246
+ def attention(self, query, key, value, mask=None):
247
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
248
+ b, d, t_s, t_t = (*key.size(), query.size(2))
249
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
250
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
251
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
252
+
253
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
254
+ if self.window_size is not None:
255
+ assert (
256
+ t_s == t_t
257
+ ), "Relative attention is only available for self-attention."
258
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
259
+ rel_logits = self._matmul_with_relative_keys(
260
+ query / math.sqrt(self.k_channels), key_relative_embeddings
261
+ )
262
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
263
+ scores = scores + scores_local
264
+ if self.proximal_bias:
265
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
266
+ scores = scores + self._attention_bias_proximal(t_s).to(
267
+ device=scores.device, dtype=scores.dtype
268
+ )
269
+ if mask is not None:
270
+ scores = scores.masked_fill(mask == 0, -1e4)
271
+ if self.block_length is not None:
272
+ assert (
273
+ t_s == t_t
274
+ ), "Local attention is only available for self-attention."
275
+ block_mask = (
276
+ torch.ones_like(scores)
277
+ .triu(-self.block_length)
278
+ .tril(self.block_length)
279
+ )
280
+ scores = scores.masked_fill(block_mask == 0, -1e4)
281
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
282
+ p_attn = self.drop(p_attn)
283
+ output = torch.matmul(p_attn, value)
284
+ if self.window_size is not None:
285
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
286
+ value_relative_embeddings = self._get_relative_embeddings(
287
+ self.emb_rel_v, t_s
288
+ )
289
+ output = output + self._matmul_with_relative_values(
290
+ relative_weights, value_relative_embeddings
291
+ )
292
+ output = (
293
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
294
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
295
+ return output, p_attn
296
+
297
+ def _matmul_with_relative_values(self, x, y):
298
+ """
299
+ x: [b, h, l, m]
300
+ y: [h or 1, m, d]
301
+ ret: [b, h, l, d]
302
+ """
303
+ ret = torch.matmul(x, y.unsqueeze(0))
304
+ return ret
305
+
306
+ def _matmul_with_relative_keys(self, x, y):
307
+ """
308
+ x: [b, h, l, d]
309
+ y: [h or 1, m, d]
310
+ ret: [b, h, l, m]
311
+ """
312
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
313
+ return ret
314
+
315
+ def _get_relative_embeddings(self, relative_embeddings, length):
316
+ max_relative_position = 2 * self.window_size + 1
317
+ # Pad first before slice to avoid using cond ops.
318
+ pad_length = max(length - (self.window_size + 1), 0)
319
+ slice_start_position = max((self.window_size + 1) - length, 0)
320
+ slice_end_position = slice_start_position + 2 * length - 1
321
+ if pad_length > 0:
322
+ padded_relative_embeddings = F.pad(
323
+ relative_embeddings,
324
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
325
+ )
326
+ else:
327
+ padded_relative_embeddings = relative_embeddings
328
+ used_relative_embeddings = padded_relative_embeddings[
329
+ :, slice_start_position:slice_end_position
330
+ ]
331
+ return used_relative_embeddings
332
+
333
+ def _relative_position_to_absolute_position(self, x):
334
+ """
335
+ x: [b, h, l, 2*l-1]
336
+ ret: [b, h, l, l]
337
+ """
338
+ batch, heads, length, _ = x.size()
339
+ # Concat columns of pad to shift from relative to absolute indexing.
340
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
341
+
342
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
343
+ x_flat = x.view([batch, heads, length * 2 * length])
344
+ x_flat = F.pad(
345
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
346
+ )
347
+
348
+ # Reshape and slice out the padded elements.
349
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
350
+ :, :, :length, length - 1 :
351
+ ]
352
+ return x_final
353
+
354
+ def _absolute_position_to_relative_position(self, x):
355
+ """
356
+ x: [b, h, l, l]
357
+ ret: [b, h, l, 2*l-1]
358
+ """
359
+ batch, heads, length, _ = x.size()
360
+ # padd along column
361
+ x = F.pad(
362
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
363
+ )
364
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
365
+ # add 0's in the beginning that will skew the elements after reshape
366
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
367
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
368
+ return x_final
369
+
370
+ def _attention_bias_proximal(self, length):
371
+ """Bias for self-attention to encourage attention to close positions.
372
+ Args:
373
+ length: an integer scalar.
374
+ Returns:
375
+ a Tensor with shape [1, 1, length, length]
376
+ """
377
+ r = torch.arange(length, dtype=torch.float32)
378
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
379
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
380
+
381
+
382
+ class FFN(nn.Module):
383
+ def __init__(
384
+ self,
385
+ in_channels,
386
+ out_channels,
387
+ filter_channels,
388
+ kernel_size,
389
+ p_dropout=0.0,
390
+ activation=None,
391
+ causal=False,
392
+ ):
393
+ super().__init__()
394
+ self.in_channels = in_channels
395
+ self.out_channels = out_channels
396
+ self.filter_channels = filter_channels
397
+ self.kernel_size = kernel_size
398
+ self.p_dropout = p_dropout
399
+ self.activation = activation
400
+ self.causal = causal
401
+
402
+ if causal:
403
+ self.padding = self._causal_padding
404
+ else:
405
+ self.padding = self._same_padding
406
+
407
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
408
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
409
+ self.drop = nn.Dropout(p_dropout)
410
+
411
+ def forward(self, x, x_mask):
412
+ x = self.conv_1(self.padding(x * x_mask))
413
+ if self.activation == "gelu":
414
+ x = x * torch.sigmoid(1.702 * x)
415
+ else:
416
+ x = torch.relu(x)
417
+ x = self.drop(x)
418
+ x = self.conv_2(self.padding(x * x_mask))
419
+ return x * x_mask
420
+
421
+ def _causal_padding(self, x):
422
+ if self.kernel_size == 1:
423
+ return x
424
+ pad_l = self.kernel_size - 1
425
+ pad_r = 0
426
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
427
+ x = F.pad(x, commons.convert_pad_shape(padding))
428
+ return x
429
+
430
+ def _same_padding(self, x):
431
+ if self.kernel_size == 1:
432
+ return x
433
+ pad_l = (self.kernel_size - 1) // 2
434
+ pad_r = self.kernel_size // 2
435
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
436
+ x = F.pad(x, commons.convert_pad_shape(padding))
437
+ return x
438
+
439
+
440
+ class Depthwise_Separable_Conv1D(nn.Module):
441
+ def __init__(
442
+ self,
443
+ in_channels,
444
+ out_channels,
445
+ kernel_size,
446
+ stride=1,
447
+ padding=0,
448
+ dilation=1,
449
+ bias=True,
450
+ padding_mode="zeros", # TODO: refine this type
451
+ device=None,
452
+ dtype=None,
453
+ ):
454
+ super().__init__()
455
+ self.depth_conv = nn.Conv1d(
456
+ in_channels=in_channels,
457
+ out_channels=in_channels,
458
+ kernel_size=kernel_size,
459
+ groups=in_channels,
460
+ stride=stride,
461
+ padding=padding,
462
+ dilation=dilation,
463
+ bias=bias,
464
+ padding_mode=padding_mode,
465
+ device=device,
466
+ dtype=dtype,
467
+ )
468
+ self.point_conv = nn.Conv1d(
469
+ in_channels=in_channels,
470
+ out_channels=out_channels,
471
+ kernel_size=1,
472
+ bias=bias,
473
+ device=device,
474
+ dtype=dtype,
475
+ )
476
+
477
+ def forward(self, input):
478
+ return self.point_conv(self.depth_conv(input))
479
+
480
+ def weight_norm(self):
481
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
482
+ self.point_conv = weight_norm(self.point_conv, name="weight")
483
+
484
+ def remove_weight_norm(self):
485
+ self.depth_conv = remove_weight_norm(self.depth_conv, name="weight")
486
+ self.point_conv = remove_weight_norm(self.point_conv, name="weight")
487
+
488
+
489
+ class Depthwise_Separable_TransposeConv1D(nn.Module):
490
+ def __init__(
491
+ self,
492
+ in_channels,
493
+ out_channels,
494
+ kernel_size,
495
+ stride=1,
496
+ padding=0,
497
+ output_padding=0,
498
+ bias=True,
499
+ dilation=1,
500
+ padding_mode="zeros", # TODO: refine this type
501
+ device=None,
502
+ dtype=None,
503
+ ):
504
+ super().__init__()
505
+ self.depth_conv = nn.ConvTranspose1d(
506
+ in_channels=in_channels,
507
+ out_channels=in_channels,
508
+ kernel_size=kernel_size,
509
+ groups=in_channels,
510
+ stride=stride,
511
+ output_padding=output_padding,
512
+ padding=padding,
513
+ dilation=dilation,
514
+ bias=bias,
515
+ padding_mode=padding_mode,
516
+ device=device,
517
+ dtype=dtype,
518
+ )
519
+ self.point_conv = nn.Conv1d(
520
+ in_channels=in_channels,
521
+ out_channels=out_channels,
522
+ kernel_size=1,
523
+ bias=bias,
524
+ device=device,
525
+ dtype=dtype,
526
+ )
527
+
528
+ def forward(self, input):
529
+ return self.point_conv(self.depth_conv(input))
530
+
531
+ def weight_norm(self):
532
+ self.depth_conv = weight_norm(self.depth_conv, name="weight")
533
+ self.point_conv = weight_norm(self.point_conv, name="weight")
534
+
535
+ def remove_weight_norm(self):
536
+ remove_weight_norm(self.depth_conv, name="weight")
537
+ remove_weight_norm(self.point_conv, name="weight")
538
+
539
+
540
+ def weight_norm_modules(module, name="weight", dim=0):
541
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
542
+ module, Depthwise_Separable_TransposeConv1D
543
+ ):
544
+ module.weight_norm()
545
+ return module
546
+ else:
547
+ return weight_norm(module, name, dim)
548
+
549
+
550
+ def remove_weight_norm_modules(module, name="weight"):
551
+ if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
552
+ module, Depthwise_Separable_TransposeConv1D
553
+ ):
554
+ module.remove_weight_norm()
555
+ else:
556
+ remove_weight_norm(module, name)
557
+
558
+
559
+ class FFT(nn.Module):
560
+ def __init__(
561
+ self,
562
+ hidden_channels,
563
+ filter_channels,
564
+ n_heads,
565
+ n_layers=1,
566
+ kernel_size=1,
567
+ p_dropout=0.0,
568
+ proximal_bias=False,
569
+ proximal_init=True,
570
+ isflow=False,
571
+ **kwargs
572
+ ):
573
+ super().__init__()
574
+ self.hidden_channels = hidden_channels
575
+ self.filter_channels = filter_channels
576
+ self.n_heads = n_heads
577
+ self.n_layers = n_layers
578
+ self.kernel_size = kernel_size
579
+ self.p_dropout = p_dropout
580
+ self.proximal_bias = proximal_bias
581
+ self.proximal_init = proximal_init
582
+ if isflow and "gin_channels" in kwargs and kwargs["gin_channels"] > 0:
583
+ cond_layer = torch.nn.Conv1d(
584
+ kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
585
+ )
586
+ self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
587
+ self.cond_layer = weight_norm_modules(cond_layer, name="weight")
588
+ self.gin_channels = kwargs["gin_channels"]
589
+ self.drop = nn.Dropout(p_dropout)
590
+ self.self_attn_layers = nn.ModuleList()
591
+ self.norm_layers_0 = nn.ModuleList()
592
+ self.ffn_layers = nn.ModuleList()
593
+ self.norm_layers_1 = nn.ModuleList()
594
+ for i in range(self.n_layers):
595
+ self.self_attn_layers.append(
596
+ MultiHeadAttention(
597
+ hidden_channels,
598
+ hidden_channels,
599
+ n_heads,
600
+ p_dropout=p_dropout,
601
+ proximal_bias=proximal_bias,
602
+ proximal_init=proximal_init,
603
+ )
604
+ )
605
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
606
+ self.ffn_layers.append(
607
+ FFN(
608
+ hidden_channels,
609
+ hidden_channels,
610
+ filter_channels,
611
+ kernel_size,
612
+ p_dropout=p_dropout,
613
+ causal=True,
614
+ )
615
+ )
616
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
617
+
618
+ def forward(self, x, x_mask, g=None):
619
+ """
620
+ x: decoder input
621
+ h: encoder output
622
+ """
623
+ if g is not None:
624
+ g = self.cond_layer(g)
625
+
626
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
627
+ device=x.device, dtype=x.dtype
628
+ )
629
+ x = x * x_mask
630
+ for i in range(self.n_layers):
631
+ if g is not None:
632
+ x = self.cond_pre(x)
633
+ cond_offset = i * 2 * self.hidden_channels
634
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
635
+ x = commons.fused_add_tanh_sigmoid_multiply(
636
+ x, g_l, torch.IntTensor([self.hidden_channels])
637
+ )
638
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
639
+ y = self.drop(y)
640
+ x = self.norm_layers_0[i](x + y)
641
+
642
+ y = self.ffn_layers[i](x, x_mask)
643
+ y = self.drop(y)
644
+ x = self.norm_layers_1[i](x + y)
645
+ x = x * x_mask
646
+ return x
commons.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+
8
+ def init_weights(m, mean=0.0, std=0.01):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ m.weight.data.normal_(mean, std)
12
+
13
+
14
+ def get_padding(kernel_size, dilation=1):
15
+ return int((kernel_size * dilation - dilation) / 2)
16
+
17
+
18
+ def convert_pad_shape(pad_shape):
19
+ l = pad_shape[::-1]
20
+ pad_shape = [item for sublist in l for item in sublist]
21
+ return pad_shape
22
+
23
+
24
+ def intersperse(lst, item):
25
+ result = [item] * (len(lst) * 2 + 1)
26
+ result[1::2] = lst
27
+ return result
28
+
29
+
30
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
31
+ """KL(P||Q)"""
32
+ kl = (logs_q - logs_p) - 0.5
33
+ kl += (
34
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
35
+ )
36
+ return kl
37
+
38
+
39
+ def rand_gumbel(shape):
40
+ """Sample from the Gumbel distribution, protect from overflows."""
41
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
42
+ return -torch.log(-torch.log(uniform_samples))
43
+
44
+
45
+ def rand_gumbel_like(x):
46
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
47
+ return g
48
+
49
+
50
+ def slice_segments(x, ids_str, segment_size=4):
51
+ ret = torch.zeros_like(x[:, :, :segment_size])
52
+ for i in range(x.size(0)):
53
+ idx_str = ids_str[i]
54
+ idx_end = idx_str + segment_size
55
+ ret[i] = x[i, :, idx_str:idx_end]
56
+ return ret
57
+
58
+
59
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
60
+ b, d, t = x.size()
61
+ if x_lengths is None:
62
+ x_lengths = t
63
+ ids_str_max = x_lengths - segment_size + 1
64
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
65
+ ret = slice_segments(x, ids_str, segment_size)
66
+ return ret, ids_str
67
+
68
+
69
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
70
+ position = torch.arange(length, dtype=torch.float)
71
+ num_timescales = channels // 2
72
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
73
+ num_timescales - 1
74
+ )
75
+ inv_timescales = min_timescale * torch.exp(
76
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
77
+ )
78
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
79
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
80
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
81
+ signal = signal.view(1, channels, length)
82
+ return signal
83
+
84
+
85
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
86
+ b, channels, length = x.size()
87
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88
+ return x + signal.to(dtype=x.dtype, device=x.device)
89
+
90
+
91
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
92
+ b, channels, length = x.size()
93
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
94
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
95
+
96
+
97
+ def subsequent_mask(length):
98
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
99
+ return mask
100
+
101
+
102
+ @torch.jit.script
103
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
104
+ n_channels_int = n_channels[0]
105
+ in_act = input_a + input_b
106
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
107
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
108
+ acts = t_act * s_act
109
+ return acts
110
+
111
+
112
+ def convert_pad_shape(pad_shape):
113
+ l = pad_shape[::-1]
114
+ pad_shape = [item for sublist in l for item in sublist]
115
+ return pad_shape
116
+
117
+
118
+ def shift_1d(x):
119
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
120
+ return x
121
+
122
+
123
+ def sequence_mask(length, max_length=None):
124
+ if max_length is None:
125
+ max_length = length.max()
126
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
127
+ return x.unsqueeze(0) < length.unsqueeze(1)
128
+
129
+
130
+ def generate_path(duration, mask):
131
+ """
132
+ duration: [b, 1, t_x]
133
+ mask: [b, 1, t_y, t_x]
134
+ """
135
+ device = duration.device
136
+
137
+ b, _, t_y, t_x = mask.shape
138
+ cum_duration = torch.cumsum(duration, -1)
139
+
140
+ cum_duration_flat = cum_duration.view(b * t_x)
141
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
142
+ path = path.view(b, t_x, t_y)
143
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
144
+ path = path.unsqueeze(1).transpose(2, 3) * mask
145
+ return path
146
+
147
+
148
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
149
+ if isinstance(parameters, torch.Tensor):
150
+ parameters = [parameters]
151
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
152
+ norm_type = float(norm_type)
153
+ if clip_value is not None:
154
+ clip_value = float(clip_value)
155
+
156
+ total_norm = 0
157
+ for p in parameters:
158
+ param_norm = p.grad.data.norm(norm_type)
159
+ total_norm += param_norm.item() ** norm_type
160
+ if clip_value is not None:
161
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
162
+ total_norm = total_norm ** (1.0 / norm_type)
163
+ return total_norm
config.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "train": {
3
+ "log_interval": 200,
4
+ "eval_interval": 1000,
5
+ "seed": 1234,
6
+ "epochs": 150,
7
+ "learning_rate": 2e-4,
8
+ "betas": [0.8, 0.99],
9
+ "eps": 1e-9,
10
+ "batch_size": 32,
11
+ "fp16_run": true,
12
+ "lr_decay": 0.999875,
13
+ "segment_size": 8192,
14
+ "init_lr_ratio": 1,
15
+ "warmup_epochs": 0,
16
+ "c_mel": 45,
17
+ "c_kl": 1.0
18
+ },
19
+ "data": {
20
+ "use_mel_posterior_encoder": true,
21
+ "training_files": "/home/lod/liesmaschinn_v2/vits2_remy/filelists/remy_train.txt",
22
+ "validation_files": "/home/lod/liesmaschinn_v2/vits2_remy/filelists/remy_val.txt",
23
+ "text_cleaners": ["luxembourgish_cleaners"],
24
+ "max_wav_value": 32768.0,
25
+ "sampling_rate": 24000,
26
+ "filter_length": 1024,
27
+ "hop_length": 256,
28
+ "win_length": 1024,
29
+ "n_mel_channels": 80,
30
+ "mel_fmin": 0.0,
31
+ "mel_fmax": null,
32
+ "add_blank": true,
33
+ "n_speakers": 0,
34
+ "cleaned_text": false
35
+ },
36
+ "model": {
37
+ "use_mel_posterior_encoder": true,
38
+ "use_transformer_flows": true,
39
+ "transformer_flow_type": "pre_conv",
40
+ "use_spk_conditioned_encoder": false,
41
+ "use_noise_scaled_mas": true,
42
+ "use_sdp": false,
43
+ "use_duration_discriminator": true,
44
+ "duration_discriminator_type": "dur_disc_2",
45
+ "inter_channels": 192,
46
+ "hidden_channels": 192,
47
+ "filter_channels": 768,
48
+ "n_heads": 2,
49
+ "n_layers": 6,
50
+ "kernel_size": 3,
51
+ "p_dropout": 0.5,
52
+ "resblock": "1",
53
+ "resblock_kernel_sizes": [3, 7, 11],
54
+ "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
55
+ "upsample_rates": [8, 8, 2, 2],
56
+ "upsample_initial_channel": 512,
57
+ "upsample_kernel_sizes": [16, 16, 4, 4],
58
+ "n_layers_q": 3,
59
+ "use_spectral_norm": false
60
+ },
61
+ "model_dir": "./logs/luxembourgish_remy"
62
+ }
inference.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ VITS2 Remy - Luxembourgish TTS Inference Script
4
+
5
+ Usage:
6
+ python inference.py "Moien, wéi geet et dir?"
7
+ python inference.py "Moien, wéi geet et dir?" -o output.wav
8
+ python inference.py "Moien, wéi geet et dir?" --noise_scale 0.5
9
+ """
10
+ import argparse
11
+ import torch
12
+ import scipy.io.wavfile as wavfile
13
+
14
+ import utils
15
+ import commons
16
+ from models import SynthesizerTrn
17
+ from text.symbols import symbols
18
+ from text import text_to_sequence
19
+
20
+
21
+ def get_text(text, hps):
22
+ text_norm = text_to_sequence(text, hps.data.text_cleaners)
23
+ if hps.data.add_blank:
24
+ text_norm = commons.intersperse(text_norm, 0)
25
+ text_norm = torch.LongTensor(text_norm)
26
+ return text_norm
27
+
28
+
29
+ def main():
30
+ parser = argparse.ArgumentParser(description="VITS2 Remy TTS")
31
+ parser.add_argument("text", type=str, help="Text to synthesize")
32
+ parser.add_argument("-o", "--output", type=str, default="output.wav", help="Output WAV file")
33
+ parser.add_argument("--noise_scale", type=float, default=0.667, help="Noise scale (default: 0.667)")
34
+ parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Noise scale W (default: 0.8)")
35
+ parser.add_argument("--length_scale", type=float, default=1.0, help="Length scale (default: 1.0)")
36
+ parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU")
37
+ args = parser.parse_args()
38
+
39
+ device = "cpu" if args.cpu else "cuda"
40
+
41
+ # Load config
42
+ hps = utils.get_hparams_from_file("config.json")
43
+
44
+ # Load model
45
+ if getattr(hps.model, 'use_mel_posterior_encoder', False):
46
+ posterior_channels = hps.data.n_mel_channels
47
+ else:
48
+ posterior_channels = hps.data.filter_length // 2 + 1
49
+
50
+ net_g = SynthesizerTrn(
51
+ len(symbols),
52
+ posterior_channels,
53
+ hps.train.segment_size // hps.data.hop_length,
54
+ n_speakers=hps.data.n_speakers,
55
+ **hps.model
56
+ ).to(device)
57
+
58
+ _ = utils.load_checkpoint("model.pth", net_g, None)
59
+ net_g.eval()
60
+
61
+ # Synthesize
62
+ text = args.text.lower()
63
+ print(f"Synthesizing: {text}")
64
+
65
+ with torch.no_grad():
66
+ stn_tst = get_text(text, hps)
67
+ x_tst = stn_tst.to(device).unsqueeze(0)
68
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
69
+
70
+ audio = net_g.infer(
71
+ x_tst, x_tst_lengths,
72
+ noise_scale=args.noise_scale,
73
+ noise_scale_w=args.noise_scale_w,
74
+ length_scale=args.length_scale
75
+ )[0][0, 0].data.cpu().float().numpy()
76
+
77
+ wavfile.write(args.output, hps.data.sampling_rate, audio)
78
+ print(f"Saved to: {args.output}")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ main()
mel_processing.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ # warnings.simplefilter(action='ignore', category=FutureWarning)
4
+ warnings.filterwarnings(action="ignore")
5
+
6
+ import math
7
+ import os
8
+ import random
9
+
10
+ import librosa
11
+ import librosa.util as librosa_util
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as F
15
+ import torch.utils.data
16
+ from librosa.filters import mel as librosa_mel_fn
17
+ from librosa.util import normalize, pad_center, tiny
18
+ from packaging import version
19
+ from scipy.io.wavfile import read
20
+ from scipy.signal import get_window
21
+ from torch import nn
22
+
23
+ MAX_WAV_VALUE = 32768.0
24
+
25
+
26
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
27
+ """
28
+ PARAMS
29
+ ------
30
+ C: compression factor
31
+ """
32
+ return torch.log(torch.clamp(x, min=clip_val) * C)
33
+
34
+
35
+ def dynamic_range_decompression_torch(x, C=1):
36
+ """
37
+ PARAMS
38
+ ------
39
+ C: compression factor used to compress
40
+ """
41
+ return torch.exp(x) / C
42
+
43
+
44
+ def spectral_normalize_torch(magnitudes):
45
+ output = dynamic_range_compression_torch(magnitudes)
46
+ return output
47
+
48
+
49
+ def spectral_de_normalize_torch(magnitudes):
50
+ output = dynamic_range_decompression_torch(magnitudes)
51
+ return output
52
+
53
+
54
+ mel_basis = {}
55
+ hann_window = {}
56
+
57
+
58
+ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
59
+ if torch.min(y) < -1.0:
60
+ print("min value is ", torch.min(y))
61
+ if torch.max(y) > 1.0:
62
+ print("max value is ", torch.max(y))
63
+
64
+ global hann_window
65
+ dtype_device = str(y.dtype) + "_" + str(y.device)
66
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
67
+ if wnsize_dtype_device not in hann_window:
68
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
69
+ dtype=y.dtype, device=y.device
70
+ )
71
+
72
+ y = torch.nn.functional.pad(
73
+ y.unsqueeze(1),
74
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
75
+ mode="reflect",
76
+ )
77
+ y = y.squeeze(1)
78
+
79
+ if version.parse(torch.__version__) >= version.parse("2"):
80
+ spec = torch.stft(
81
+ y,
82
+ n_fft,
83
+ hop_length=hop_size,
84
+ win_length=win_size,
85
+ window=hann_window[wnsize_dtype_device],
86
+ center=center,
87
+ pad_mode="reflect",
88
+ normalized=False,
89
+ onesided=True,
90
+ return_complex=False,
91
+ )
92
+ else:
93
+ spec = torch.stft(
94
+ y,
95
+ n_fft,
96
+ hop_length=hop_size,
97
+ win_length=win_size,
98
+ window=hann_window[wnsize_dtype_device],
99
+ center=center,
100
+ pad_mode="reflect",
101
+ normalized=False,
102
+ onesided=True,
103
+ )
104
+
105
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
106
+ return spec
107
+
108
+
109
+ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
110
+ global mel_basis
111
+ dtype_device = str(spec.dtype) + "_" + str(spec.device)
112
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
113
+ if fmax_dtype_device not in mel_basis:
114
+ mel = librosa_mel_fn(
115
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
116
+ )
117
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
118
+ dtype=spec.dtype, device=spec.device
119
+ )
120
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
121
+ spec = spectral_normalize_torch(spec)
122
+ return spec
123
+
124
+
125
+ def mel_spectrogram_torch(
126
+ y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
127
+ ):
128
+ if torch.min(y) < -1.0:
129
+ print("min value is ", torch.min(y))
130
+ if torch.max(y) > 1.0:
131
+ print("max value is ", torch.max(y))
132
+
133
+ global mel_basis, hann_window
134
+ dtype_device = str(y.dtype) + "_" + str(y.device)
135
+ fmax_dtype_device = str(fmax) + "_" + dtype_device
136
+ wnsize_dtype_device = str(win_size) + "_" + dtype_device
137
+ if fmax_dtype_device not in mel_basis:
138
+ mel = librosa_mel_fn(
139
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
140
+ )
141
+ mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
142
+ dtype=y.dtype, device=y.device
143
+ )
144
+ if wnsize_dtype_device not in hann_window:
145
+ hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
146
+ dtype=y.dtype, device=y.device
147
+ )
148
+
149
+ y = torch.nn.functional.pad(
150
+ y.unsqueeze(1),
151
+ (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
152
+ mode="reflect",
153
+ )
154
+ y = y.squeeze(1)
155
+
156
+ if version.parse(torch.__version__) >= version.parse("2"):
157
+ spec = torch.stft(
158
+ y,
159
+ n_fft,
160
+ hop_length=hop_size,
161
+ win_length=win_size,
162
+ window=hann_window[wnsize_dtype_device],
163
+ center=center,
164
+ pad_mode="reflect",
165
+ normalized=False,
166
+ onesided=True,
167
+ return_complex=False,
168
+ )
169
+ else:
170
+ spec = torch.stft(
171
+ y,
172
+ n_fft,
173
+ hop_length=hop_size,
174
+ win_length=win_size,
175
+ window=hann_window[wnsize_dtype_device],
176
+ center=center,
177
+ pad_mode="reflect",
178
+ normalized=False,
179
+ onesided=True,
180
+ )
181
+
182
+ spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
183
+
184
+ spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
185
+ spec = spectral_normalize_torch(spec)
186
+
187
+ return spec
models.py ADDED
@@ -0,0 +1,1363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
7
+ from torch.nn import functional as F
8
+ from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
9
+
10
+ import attentions
11
+ import commons
12
+ import modules
13
+ import monotonic_align
14
+ from commons import get_padding, init_weights
15
+
16
+ AVAILABLE_FLOW_TYPES = [
17
+ "pre_conv",
18
+ "pre_conv2",
19
+ "fft",
20
+ "mono_layer_inter_residual",
21
+ "mono_layer_post_residual",
22
+ ]
23
+
24
+ AVAILABLE_DURATION_DISCRIMINATOR_TYPES = [
25
+ "dur_disc_1",
26
+ "dur_disc_2",
27
+ ]
28
+
29
+
30
+ class StochasticDurationPredictor(nn.Module):
31
+ def __init__(
32
+ self,
33
+ in_channels,
34
+ filter_channels,
35
+ kernel_size,
36
+ p_dropout,
37
+ n_flows=4,
38
+ gin_channels=0,
39
+ ):
40
+ super().__init__()
41
+ filter_channels = in_channels # it needs to be removed from future version.
42
+ self.in_channels = in_channels
43
+ self.filter_channels = filter_channels
44
+ self.kernel_size = kernel_size
45
+ self.p_dropout = p_dropout
46
+ self.n_flows = n_flows
47
+ self.gin_channels = gin_channels
48
+
49
+ self.log_flow = modules.Log()
50
+ self.flows = nn.ModuleList()
51
+ self.flows.append(modules.ElementwiseAffine(2))
52
+ for i in range(n_flows):
53
+ self.flows.append(
54
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
55
+ )
56
+ self.flows.append(modules.Flip())
57
+
58
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
59
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
60
+ self.post_convs = modules.DDSConv(
61
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
62
+ )
63
+ self.post_flows = nn.ModuleList()
64
+ self.post_flows.append(modules.ElementwiseAffine(2))
65
+ for i in range(4):
66
+ self.post_flows.append(
67
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
68
+ )
69
+ self.post_flows.append(modules.Flip())
70
+
71
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
72
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
73
+ self.convs = modules.DDSConv(
74
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
75
+ )
76
+ if gin_channels != 0:
77
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
78
+
79
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
80
+ x = torch.detach(x)
81
+ x = self.pre(x)
82
+ if g is not None:
83
+ g = torch.detach(g)
84
+ x = x + self.cond(g)
85
+ x = self.convs(x, x_mask)
86
+ x = self.proj(x) * x_mask
87
+
88
+ if not reverse:
89
+ flows = self.flows
90
+ assert w is not None
91
+
92
+ logdet_tot_q = 0
93
+ h_w = self.post_pre(w)
94
+ h_w = self.post_convs(h_w, x_mask)
95
+ h_w = self.post_proj(h_w) * x_mask
96
+ e_q = (
97
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
98
+ * x_mask
99
+ )
100
+ z_q = e_q
101
+ for flow in self.post_flows:
102
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
103
+ logdet_tot_q += logdet_q
104
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
105
+ u = torch.sigmoid(z_u) * x_mask
106
+ z0 = (w - u) * x_mask
107
+ logdet_tot_q += torch.sum(
108
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
109
+ )
110
+ logq = (
111
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
112
+ - logdet_tot_q
113
+ )
114
+
115
+ logdet_tot = 0
116
+ z0, logdet = self.log_flow(z0, x_mask)
117
+ logdet_tot += logdet
118
+ z = torch.cat([z0, z1], 1)
119
+ for flow in flows:
120
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
121
+ logdet_tot = logdet_tot + logdet
122
+ nll = (
123
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
124
+ - logdet_tot
125
+ )
126
+ return nll + logq # [b]
127
+ else:
128
+ flows = list(reversed(self.flows))
129
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
130
+ z = (
131
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
132
+ * noise_scale
133
+ )
134
+ for flow in flows:
135
+ z = flow(z, x_mask, g=x, reverse=reverse)
136
+ z0, z1 = torch.split(z, [1, 1], 1)
137
+ logw = z0
138
+ return logw
139
+
140
+
141
+ class DurationPredictor(nn.Module):
142
+ def __init__(
143
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
144
+ ):
145
+ super().__init__()
146
+
147
+ self.in_channels = in_channels
148
+ self.filter_channels = filter_channels
149
+ self.kernel_size = kernel_size
150
+ self.p_dropout = p_dropout
151
+ self.gin_channels = gin_channels
152
+
153
+ self.drop = nn.Dropout(p_dropout)
154
+ self.conv_1 = nn.Conv1d(
155
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
156
+ )
157
+ self.norm_1 = modules.LayerNorm(filter_channels)
158
+ self.conv_2 = nn.Conv1d(
159
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
160
+ )
161
+ self.norm_2 = modules.LayerNorm(filter_channels)
162
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
163
+
164
+ if gin_channels != 0:
165
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
166
+
167
+ def forward(self, x, x_mask, g=None):
168
+ x = torch.detach(x)
169
+ if g is not None:
170
+ g = torch.detach(g)
171
+ x = x + self.cond(g)
172
+ x = self.conv_1(x * x_mask)
173
+ x = torch.relu(x)
174
+ x = self.norm_1(x)
175
+ x = self.drop(x)
176
+ x = self.conv_2(x * x_mask)
177
+ x = torch.relu(x)
178
+ x = self.norm_2(x)
179
+ x = self.drop(x)
180
+ x = self.proj(x * x_mask)
181
+ return x * x_mask
182
+
183
+
184
+ class DurationDiscriminatorV1(nn.Module): # vits2
185
+ # TODO : not using "spk conditioning" for now according to the paper.
186
+ # Can be a better discriminator if we use it.
187
+ def __init__(
188
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
189
+ ):
190
+ super().__init__()
191
+
192
+ self.in_channels = in_channels
193
+ self.filter_channels = filter_channels
194
+ self.kernel_size = kernel_size
195
+ self.p_dropout = p_dropout
196
+ self.gin_channels = gin_channels
197
+
198
+ self.drop = nn.Dropout(p_dropout)
199
+ self.conv_1 = nn.Conv1d(
200
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
201
+ )
202
+ # self.norm_1 = modules.LayerNorm(filter_channels)
203
+ self.conv_2 = nn.Conv1d(
204
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
205
+ )
206
+ # self.norm_2 = modules.LayerNorm(filter_channels)
207
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
208
+
209
+ self.pre_out_conv_1 = nn.Conv1d(
210
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
211
+ )
212
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
213
+ self.pre_out_conv_2 = nn.Conv1d(
214
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
215
+ )
216
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
217
+
218
+ # if gin_channels != 0:
219
+ # self.cond = nn.Conv1d(gin_channels, in_channels, 1)
220
+
221
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
222
+
223
+ def forward_probability(self, x, x_mask, dur, g=None):
224
+ dur = self.dur_proj(dur)
225
+ x = torch.cat([x, dur], dim=1)
226
+ x = self.pre_out_conv_1(x * x_mask)
227
+ # x = torch.relu(x)
228
+ # x = self.pre_out_norm_1(x)
229
+ # x = self.drop(x)
230
+ x = self.pre_out_conv_2(x * x_mask)
231
+ # x = torch.relu(x)
232
+ # x = self.pre_out_norm_2(x)
233
+ # x = self.drop(x)
234
+ x = x * x_mask
235
+ x = x.transpose(1, 2)
236
+ output_prob = self.output_layer(x)
237
+ return output_prob
238
+
239
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
240
+ x = torch.detach(x)
241
+ # if g is not None:
242
+ # g = torch.detach(g)
243
+ # x = x + self.cond(g)
244
+ x = self.conv_1(x * x_mask)
245
+ # x = torch.relu(x)
246
+ # x = self.norm_1(x)
247
+ # x = self.drop(x)
248
+ x = self.conv_2(x * x_mask)
249
+ # x = torch.relu(x)
250
+ # x = self.norm_2(x)
251
+ # x = self.drop(x)
252
+
253
+ output_probs = []
254
+ for dur in [dur_r, dur_hat]:
255
+ output_prob = self.forward_probability(x, x_mask, dur, g)
256
+ output_probs.append(output_prob)
257
+
258
+ return output_probs
259
+
260
+
261
+ class DurationDiscriminatorV2(nn.Module): # vits2
262
+ # TODO : not using "spk conditioning" for now according to the paper.
263
+ # Can be a better discriminator if we use it.
264
+ def __init__(
265
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
266
+ ):
267
+ super().__init__()
268
+
269
+ self.in_channels = in_channels
270
+ self.filter_channels = filter_channels
271
+ self.kernel_size = kernel_size
272
+ self.p_dropout = p_dropout
273
+ self.gin_channels = gin_channels
274
+
275
+ self.conv_1 = nn.Conv1d(
276
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
277
+ )
278
+ self.norm_1 = modules.LayerNorm(filter_channels)
279
+ self.conv_2 = nn.Conv1d(
280
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
281
+ )
282
+ self.norm_2 = modules.LayerNorm(filter_channels)
283
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
284
+
285
+ self.pre_out_conv_1 = nn.Conv1d(
286
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
287
+ )
288
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
289
+ self.pre_out_conv_2 = nn.Conv1d(
290
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
291
+ )
292
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
293
+
294
+ # if gin_channels != 0:
295
+ # self.cond = nn.Conv1d(gin_channels, in_channels, 1)
296
+
297
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
298
+
299
+ def forward_probability(self, x, x_mask, dur, g=None):
300
+ dur = self.dur_proj(dur)
301
+ x = torch.cat([x, dur], dim=1)
302
+ x = self.pre_out_conv_1(x * x_mask)
303
+ x = torch.relu(x)
304
+ x = self.pre_out_norm_1(x)
305
+ x = self.pre_out_conv_2(x * x_mask)
306
+ x = torch.relu(x)
307
+ x = self.pre_out_norm_2(x)
308
+ x = x * x_mask
309
+ x = x.transpose(1, 2)
310
+ output_prob = self.output_layer(x)
311
+ return output_prob
312
+
313
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
314
+ x = torch.detach(x)
315
+ # if g is not None:
316
+ # g = torch.detach(g)
317
+ # x = x + self.cond(g)
318
+ x = self.conv_1(x * x_mask)
319
+ x = torch.relu(x)
320
+ x = self.norm_1(x)
321
+ x = self.conv_2(x * x_mask)
322
+ x = torch.relu(x)
323
+ x = self.norm_2(x)
324
+
325
+ output_probs = []
326
+ for dur in [dur_r, dur_hat]:
327
+ output_prob = self.forward_probability(x, x_mask, dur, g)
328
+ output_probs.append([output_prob])
329
+
330
+ return output_probs
331
+
332
+
333
+ class TextEncoder(nn.Module):
334
+ def __init__(
335
+ self,
336
+ n_vocab,
337
+ out_channels,
338
+ hidden_channels,
339
+ filter_channels,
340
+ n_heads,
341
+ n_layers,
342
+ kernel_size,
343
+ p_dropout,
344
+ gin_channels=0,
345
+ ):
346
+ super().__init__()
347
+ self.n_vocab = n_vocab
348
+ self.out_channels = out_channels
349
+ self.hidden_channels = hidden_channels
350
+ self.filter_channels = filter_channels
351
+ self.n_heads = n_heads
352
+ self.n_layers = n_layers
353
+ self.kernel_size = kernel_size
354
+ self.p_dropout = p_dropout
355
+ self.gin_channels = gin_channels
356
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
357
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
358
+
359
+ self.encoder = attentions.Encoder(
360
+ hidden_channels,
361
+ filter_channels,
362
+ n_heads,
363
+ n_layers,
364
+ kernel_size,
365
+ p_dropout,
366
+ gin_channels=self.gin_channels,
367
+ )
368
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
369
+
370
+ def forward(self, x, x_lengths, g=None):
371
+ x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
374
+ x.dtype
375
+ )
376
+
377
+ x = self.encoder(x * x_mask, x_mask, g=g)
378
+ stats = self.proj(x) * x_mask
379
+
380
+ m, logs = torch.split(stats, self.out_channels, dim=1)
381
+ return x, m, logs, x_mask
382
+
383
+
384
+ class ResidualCouplingTransformersLayer2(nn.Module): # vits2
385
+ def __init__(
386
+ self,
387
+ channels,
388
+ hidden_channels,
389
+ kernel_size,
390
+ dilation_rate,
391
+ n_layers,
392
+ p_dropout=0,
393
+ gin_channels=0,
394
+ mean_only=False,
395
+ ):
396
+ assert channels % 2 == 0, "channels should be divisible by 2"
397
+ super().__init__()
398
+ self.channels = channels
399
+ self.hidden_channels = hidden_channels
400
+ self.kernel_size = kernel_size
401
+ self.dilation_rate = dilation_rate
402
+ self.n_layers = n_layers
403
+ self.half_channels = channels // 2
404
+ self.mean_only = mean_only
405
+
406
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
407
+ self.pre_transformer = attentions.Encoder(
408
+ hidden_channels,
409
+ hidden_channels,
410
+ n_heads=2,
411
+ n_layers=1,
412
+ kernel_size=kernel_size,
413
+ p_dropout=p_dropout,
414
+ # window_size=None,
415
+ )
416
+ self.enc = modules.WN(
417
+ hidden_channels,
418
+ kernel_size,
419
+ dilation_rate,
420
+ n_layers,
421
+ p_dropout=p_dropout,
422
+ gin_channels=gin_channels,
423
+ )
424
+
425
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
426
+ self.post.weight.data.zero_()
427
+ self.post.bias.data.zero_()
428
+
429
+ def forward(self, x, x_mask, g=None, reverse=False):
430
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
431
+ h = self.pre(x0) * x_mask
432
+ h = h + self.pre_transformer(h * x_mask, x_mask) # vits2 residual connection
433
+ h = self.enc(h, x_mask, g=g)
434
+ stats = self.post(h) * x_mask
435
+ if not self.mean_only:
436
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
437
+ else:
438
+ m = stats
439
+ logs = torch.zeros_like(m)
440
+ if not reverse:
441
+ x1 = m + x1 * torch.exp(logs) * x_mask
442
+ x = torch.cat([x0, x1], 1)
443
+ logdet = torch.sum(logs, [1, 2])
444
+ return x, logdet
445
+ else:
446
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
447
+ x = torch.cat([x0, x1], 1)
448
+ return x
449
+
450
+
451
+ class ResidualCouplingTransformersLayer(nn.Module): # vits2
452
+ def __init__(
453
+ self,
454
+ channels,
455
+ hidden_channels,
456
+ kernel_size,
457
+ dilation_rate,
458
+ n_layers,
459
+ p_dropout=0,
460
+ gin_channels=0,
461
+ mean_only=False,
462
+ ):
463
+ assert channels % 2 == 0, "channels should be divisible by 2"
464
+ super().__init__()
465
+ self.channels = channels
466
+ self.hidden_channels = hidden_channels
467
+ self.kernel_size = kernel_size
468
+ self.dilation_rate = dilation_rate
469
+ self.n_layers = n_layers
470
+ self.half_channels = channels // 2
471
+ self.mean_only = mean_only
472
+ # vits2
473
+ self.pre_transformer = attentions.Encoder(
474
+ self.half_channels,
475
+ self.half_channels,
476
+ n_heads=2,
477
+ n_layers=2,
478
+ kernel_size=3,
479
+ p_dropout=0.1,
480
+ window_size=None,
481
+ )
482
+
483
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
484
+ self.enc = modules.WN(
485
+ hidden_channels,
486
+ kernel_size,
487
+ dilation_rate,
488
+ n_layers,
489
+ p_dropout=p_dropout,
490
+ gin_channels=gin_channels,
491
+ )
492
+ # vits2
493
+ self.post_transformer = attentions.Encoder(
494
+ self.hidden_channels,
495
+ self.hidden_channels,
496
+ n_heads=2,
497
+ n_layers=2,
498
+ kernel_size=3,
499
+ p_dropout=0.1,
500
+ window_size=None,
501
+ )
502
+
503
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
504
+ self.post.weight.data.zero_()
505
+ self.post.bias.data.zero_()
506
+
507
+ def forward(self, x, x_mask, g=None, reverse=False):
508
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
509
+ x0_ = self.pre_transformer(x0 * x_mask, x_mask) # vits2
510
+ x0_ = x0_ + x0 # vits2 residual connection
511
+ h = self.pre(x0_) * x_mask # changed from x0 to x0_ to retain x0 for the flow
512
+ h = self.enc(h, x_mask, g=g)
513
+
514
+ # vits2 - (experimental;uncomment the following 2 line to use)
515
+ # h_ = self.post_transformer(h, x_mask)
516
+ # h = h + h_ #vits2 residual connection
517
+
518
+ stats = self.post(h) * x_mask
519
+ if not self.mean_only:
520
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
521
+ else:
522
+ m = stats
523
+ logs = torch.zeros_like(m)
524
+ if not reverse:
525
+ x1 = m + x1 * torch.exp(logs) * x_mask
526
+ x = torch.cat([x0, x1], 1)
527
+ logdet = torch.sum(logs, [1, 2])
528
+ return x, logdet
529
+ else:
530
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
531
+ x = torch.cat([x0, x1], 1)
532
+ return x
533
+
534
+
535
+ class FFTransformerCouplingLayer(nn.Module): # vits2
536
+ def __init__(
537
+ self,
538
+ channels,
539
+ hidden_channels,
540
+ kernel_size,
541
+ n_layers,
542
+ n_heads,
543
+ p_dropout=0,
544
+ filter_channels=768,
545
+ mean_only=False,
546
+ gin_channels=0,
547
+ ):
548
+ assert channels % 2 == 0, "channels should be divisible by 2"
549
+ super().__init__()
550
+ self.channels = channels
551
+ self.hidden_channels = hidden_channels
552
+ self.kernel_size = kernel_size
553
+ self.n_layers = n_layers
554
+ self.half_channels = channels // 2
555
+ self.mean_only = mean_only
556
+
557
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
558
+ self.enc = attentions.FFT(
559
+ hidden_channels,
560
+ filter_channels,
561
+ n_heads,
562
+ n_layers,
563
+ kernel_size,
564
+ p_dropout,
565
+ isflow=True,
566
+ gin_channels=gin_channels,
567
+ )
568
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
569
+ self.post.weight.data.zero_()
570
+ self.post.bias.data.zero_()
571
+
572
+ def forward(self, x, x_mask, g=None, reverse=False):
573
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
574
+ h = self.pre(x0) * x_mask
575
+ h_ = self.enc(h, x_mask, g=g)
576
+ h = h_ + h
577
+ stats = self.post(h) * x_mask
578
+ if not self.mean_only:
579
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
580
+ else:
581
+ m = stats
582
+ logs = torch.zeros_like(m)
583
+
584
+ if not reverse:
585
+ x1 = m + x1 * torch.exp(logs) * x_mask
586
+ x = torch.cat([x0, x1], 1)
587
+ logdet = torch.sum(logs, [1, 2])
588
+ return x, logdet
589
+ else:
590
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
591
+ x = torch.cat([x0, x1], 1)
592
+ return x
593
+
594
+
595
+ class MonoTransformerFlowLayer(nn.Module): # vits2
596
+ def __init__(
597
+ self,
598
+ channels,
599
+ hidden_channels,
600
+ mean_only=False,
601
+ residual_connection=False,
602
+ # according to VITS-2 paper fig 1B set residual_connection=True
603
+ ):
604
+ assert channels % 2 == 0, "channels should be divisible by 2"
605
+ super().__init__()
606
+ self.channels = channels
607
+ self.hidden_channels = hidden_channels
608
+ self.half_channels = channels // 2
609
+ self.mean_only = mean_only
610
+ self.residual_connection = residual_connection
611
+ # vits2
612
+ self.pre_transformer = attentions.Encoder(
613
+ self.half_channels,
614
+ self.half_channels,
615
+ n_heads=2,
616
+ n_layers=2,
617
+ kernel_size=3,
618
+ p_dropout=0.1,
619
+ window_size=None,
620
+ )
621
+
622
+ self.post = nn.Conv1d(
623
+ self.half_channels, self.half_channels * (2 - mean_only), 1
624
+ )
625
+ self.post.weight.data.zero_()
626
+ self.post.bias.data.zero_()
627
+
628
+ def forward(self, x, x_mask, g=None, reverse=False):
629
+ if self.residual_connection:
630
+ if not reverse:
631
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
632
+ x0_ = self.pre_transformer(x0, x_mask) # vits2
633
+ stats = self.post(x0_) * x_mask
634
+ if not self.mean_only:
635
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
636
+ else:
637
+ m = stats
638
+ logs = torch.zeros_like(m)
639
+ x1 = m + x1 * torch.exp(logs) * x_mask
640
+ x_ = torch.cat([x0, x1], 1)
641
+ x = x + x_
642
+ logdet = torch.sum(torch.log(torch.exp(logs) + 1), [1, 2])
643
+ logdet = logdet + torch.log(torch.tensor(2)) * (
644
+ x0.shape[1] * x0.shape[2]
645
+ )
646
+ return x, logdet
647
+ else:
648
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
649
+ x0 = x0 / 2
650
+ x0_ = x0 * x_mask
651
+ x0_ = self.pre_transformer(x0, x_mask) # vits2
652
+ stats = self.post(x0_) * x_mask
653
+ if not self.mean_only:
654
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
655
+ else:
656
+ m = stats
657
+ logs = torch.zeros_like(m)
658
+ x1_ = ((x1 - m) / (1 + torch.exp(-logs))) * x_mask
659
+ x = torch.cat([x0, x1_], 1)
660
+ return x
661
+ else:
662
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
663
+ x0_ = self.pre_transformer(x0 * x_mask, x_mask) # vits2
664
+ h = x0_ + x0 # vits2
665
+ stats = self.post(h) * x_mask
666
+ if not self.mean_only:
667
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
668
+ else:
669
+ m = stats
670
+ logs = torch.zeros_like(m)
671
+ if not reverse:
672
+ x1 = m + x1 * torch.exp(logs) * x_mask
673
+ x = torch.cat([x0, x1], 1)
674
+ logdet = torch.sum(logs, [1, 2])
675
+ return x, logdet
676
+ else:
677
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
678
+ x = torch.cat([x0, x1], 1)
679
+ return x
680
+
681
+
682
+ class ResidualCouplingTransformersBlock(nn.Module): # vits2
683
+ def __init__(
684
+ self,
685
+ channels,
686
+ hidden_channels,
687
+ kernel_size,
688
+ dilation_rate,
689
+ n_layers,
690
+ n_flows=4,
691
+ gin_channels=0,
692
+ use_transformer_flows=False,
693
+ transformer_flow_type="pre_conv",
694
+ ):
695
+ super().__init__()
696
+ self.channels = channels
697
+ self.hidden_channels = hidden_channels
698
+ self.kernel_size = kernel_size
699
+ self.dilation_rate = dilation_rate
700
+ self.n_layers = n_layers
701
+ self.n_flows = n_flows
702
+ self.gin_channels = gin_channels
703
+
704
+ self.flows = nn.ModuleList()
705
+ if use_transformer_flows:
706
+ if transformer_flow_type == "pre_conv":
707
+ for i in range(n_flows):
708
+ self.flows.append(
709
+ ResidualCouplingTransformersLayer(
710
+ channels,
711
+ hidden_channels,
712
+ kernel_size,
713
+ dilation_rate,
714
+ n_layers,
715
+ gin_channels=gin_channels,
716
+ mean_only=True,
717
+ )
718
+ )
719
+ self.flows.append(modules.Flip())
720
+ elif transformer_flow_type == "pre_conv2":
721
+ for i in range(n_flows):
722
+ self.flows.append(
723
+ ResidualCouplingTransformersLayer2(
724
+ channels,
725
+ hidden_channels,
726
+ kernel_size,
727
+ dilation_rate,
728
+ n_layers,
729
+ gin_channels=gin_channels,
730
+ mean_only=True,
731
+ )
732
+ )
733
+ self.flows.append(modules.Flip())
734
+ elif transformer_flow_type == "fft":
735
+ for i in range(n_flows):
736
+ self.flows.append(
737
+ FFTransformerCouplingLayer(
738
+ channels,
739
+ hidden_channels,
740
+ kernel_size,
741
+ dilation_rate,
742
+ n_layers,
743
+ gin_channels=gin_channels,
744
+ mean_only=True,
745
+ )
746
+ )
747
+ self.flows.append(modules.Flip())
748
+ elif transformer_flow_type == "mono_layer_inter_residual":
749
+ for i in range(n_flows):
750
+ self.flows.append(
751
+ modules.ResidualCouplingLayer(
752
+ channels,
753
+ hidden_channels,
754
+ kernel_size,
755
+ dilation_rate,
756
+ n_layers,
757
+ gin_channels=gin_channels,
758
+ mean_only=True,
759
+ )
760
+ )
761
+ self.flows.append(modules.Flip())
762
+ self.flows.append(
763
+ MonoTransformerFlowLayer(
764
+ channels, hidden_channels, mean_only=True
765
+ )
766
+ )
767
+ elif transformer_flow_type == "mono_layer_post_residual":
768
+ for i in range(n_flows):
769
+ self.flows.append(
770
+ modules.ResidualCouplingLayer(
771
+ channels,
772
+ hidden_channels,
773
+ kernel_size,
774
+ dilation_rate,
775
+ n_layers,
776
+ gin_channels=gin_channels,
777
+ mean_only=True,
778
+ )
779
+ )
780
+ self.flows.append(modules.Flip())
781
+ self.flows.append(
782
+ MonoTransformerFlowLayer(
783
+ channels,
784
+ hidden_channels,
785
+ mean_only=True,
786
+ residual_connection=True,
787
+ )
788
+ )
789
+ else:
790
+ for i in range(n_flows):
791
+ self.flows.append(
792
+ modules.ResidualCouplingLayer(
793
+ channels,
794
+ hidden_channels,
795
+ kernel_size,
796
+ dilation_rate,
797
+ n_layers,
798
+ gin_channels=gin_channels,
799
+ mean_only=True,
800
+ )
801
+ )
802
+ self.flows.append(modules.Flip())
803
+
804
+ def forward(self, x, x_mask, g=None, reverse=False):
805
+ if not reverse:
806
+ for flow in self.flows:
807
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
808
+ else:
809
+ for flow in reversed(self.flows):
810
+ x = flow(x, x_mask, g=g, reverse=reverse)
811
+ return x
812
+
813
+
814
+ class ResidualCouplingBlock(nn.Module):
815
+ def __init__(
816
+ self,
817
+ channels,
818
+ hidden_channels,
819
+ kernel_size,
820
+ dilation_rate,
821
+ n_layers,
822
+ n_flows=4,
823
+ gin_channels=0,
824
+ ):
825
+ super().__init__()
826
+ self.channels = channels
827
+ self.hidden_channels = hidden_channels
828
+ self.kernel_size = kernel_size
829
+ self.dilation_rate = dilation_rate
830
+ self.n_layers = n_layers
831
+ self.n_flows = n_flows
832
+ self.gin_channels = gin_channels
833
+
834
+ self.flows = nn.ModuleList()
835
+ for i in range(n_flows):
836
+ self.flows.append(
837
+ modules.ResidualCouplingLayer(
838
+ channels,
839
+ hidden_channels,
840
+ kernel_size,
841
+ dilation_rate,
842
+ n_layers,
843
+ gin_channels=gin_channels,
844
+ mean_only=True,
845
+ )
846
+ )
847
+ self.flows.append(modules.Flip())
848
+
849
+ def forward(self, x, x_mask, g=None, reverse=False):
850
+ if not reverse:
851
+ for flow in self.flows:
852
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
853
+ else:
854
+ for flow in reversed(self.flows):
855
+ x = flow(x, x_mask, g=g, reverse=reverse)
856
+ return x
857
+
858
+
859
+ class PosteriorEncoder(nn.Module):
860
+ def __init__(
861
+ self,
862
+ in_channels,
863
+ out_channels,
864
+ hidden_channels,
865
+ kernel_size,
866
+ dilation_rate,
867
+ n_layers,
868
+ gin_channels=0,
869
+ ):
870
+ super().__init__()
871
+ self.in_channels = in_channels
872
+ self.out_channels = out_channels
873
+ self.hidden_channels = hidden_channels
874
+ self.kernel_size = kernel_size
875
+ self.dilation_rate = dilation_rate
876
+ self.n_layers = n_layers
877
+ self.gin_channels = gin_channels
878
+
879
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
880
+ self.enc = modules.WN(
881
+ hidden_channels,
882
+ kernel_size,
883
+ dilation_rate,
884
+ n_layers,
885
+ gin_channels=gin_channels,
886
+ )
887
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
888
+
889
+ def forward(self, x, x_lengths, g=None):
890
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
891
+ x.dtype
892
+ )
893
+ x = self.pre(x) * x_mask
894
+ x = self.enc(x, x_mask, g=g)
895
+ stats = self.proj(x) * x_mask
896
+ m, logs = torch.split(stats, self.out_channels, dim=1)
897
+ z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
898
+ return z, m, logs, x_mask
899
+
900
+
901
+ class Generator(torch.nn.Module):
902
+ def __init__(
903
+ self,
904
+ initial_channel,
905
+ resblock,
906
+ resblock_kernel_sizes,
907
+ resblock_dilation_sizes,
908
+ upsample_rates,
909
+ upsample_initial_channel,
910
+ upsample_kernel_sizes,
911
+ gin_channels=0,
912
+ ):
913
+ super(Generator, self).__init__()
914
+ self.num_kernels = len(resblock_kernel_sizes)
915
+ self.num_upsamples = len(upsample_rates)
916
+ self.conv_pre = Conv1d(
917
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
918
+ )
919
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
920
+
921
+ self.ups = nn.ModuleList()
922
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
923
+ self.ups.append(
924
+ weight_norm(
925
+ ConvTranspose1d(
926
+ upsample_initial_channel // (2**i),
927
+ upsample_initial_channel // (2 ** (i + 1)),
928
+ k,
929
+ u,
930
+ padding=(k - u) // 2,
931
+ )
932
+ )
933
+ )
934
+
935
+ self.resblocks = nn.ModuleList()
936
+ for i in range(len(self.ups)):
937
+ ch = upsample_initial_channel // (2 ** (i + 1))
938
+ for j, (k, d) in enumerate(
939
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
940
+ ):
941
+ self.resblocks.append(resblock(ch, k, d))
942
+
943
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
944
+ self.ups.apply(init_weights)
945
+
946
+ if gin_channels != 0:
947
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
948
+
949
+ def forward(self, x, g=None):
950
+ x = self.conv_pre(x)
951
+ if g is not None:
952
+ x = x + self.cond(g)
953
+
954
+ for i in range(self.num_upsamples):
955
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
956
+ x = self.ups[i](x)
957
+ xs = None
958
+ for j in range(self.num_kernels):
959
+ if xs is None:
960
+ xs = self.resblocks[i * self.num_kernels + j](x)
961
+ else:
962
+ xs += self.resblocks[i * self.num_kernels + j](x)
963
+ x = xs / self.num_kernels
964
+ x = F.leaky_relu(x)
965
+ x = self.conv_post(x)
966
+ x = torch.tanh(x)
967
+
968
+ return x
969
+
970
+ def remove_weight_norm(self):
971
+ print("Removing weight norm...")
972
+ for l in self.ups:
973
+ remove_weight_norm(l)
974
+ for l in self.resblocks:
975
+ l.remove_weight_norm()
976
+
977
+
978
+ class DiscriminatorP(torch.nn.Module):
979
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
980
+ super(DiscriminatorP, self).__init__()
981
+ self.period = period
982
+ self.use_spectral_norm = use_spectral_norm
983
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
984
+ self.convs = nn.ModuleList(
985
+ [
986
+ norm_f(
987
+ Conv2d(
988
+ 1,
989
+ 32,
990
+ (kernel_size, 1),
991
+ (stride, 1),
992
+ padding=(get_padding(kernel_size, 1), 0),
993
+ )
994
+ ),
995
+ norm_f(
996
+ Conv2d(
997
+ 32,
998
+ 128,
999
+ (kernel_size, 1),
1000
+ (stride, 1),
1001
+ padding=(get_padding(kernel_size, 1), 0),
1002
+ )
1003
+ ),
1004
+ norm_f(
1005
+ Conv2d(
1006
+ 128,
1007
+ 512,
1008
+ (kernel_size, 1),
1009
+ (stride, 1),
1010
+ padding=(get_padding(kernel_size, 1), 0),
1011
+ )
1012
+ ),
1013
+ norm_f(
1014
+ Conv2d(
1015
+ 512,
1016
+ 1024,
1017
+ (kernel_size, 1),
1018
+ (stride, 1),
1019
+ padding=(get_padding(kernel_size, 1), 0),
1020
+ )
1021
+ ),
1022
+ norm_f(
1023
+ Conv2d(
1024
+ 1024,
1025
+ 1024,
1026
+ (kernel_size, 1),
1027
+ 1,
1028
+ padding=(get_padding(kernel_size, 1), 0),
1029
+ )
1030
+ ),
1031
+ ]
1032
+ )
1033
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1034
+
1035
+ def forward(self, x):
1036
+ fmap = []
1037
+
1038
+ # 1d to 2d
1039
+ b, c, t = x.shape
1040
+ if t % self.period != 0: # pad first
1041
+ n_pad = self.period - (t % self.period)
1042
+ x = F.pad(x, (0, n_pad), "reflect")
1043
+ t = t + n_pad
1044
+ x = x.view(b, c, t // self.period, self.period)
1045
+
1046
+ for l in self.convs:
1047
+ x = l(x)
1048
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1049
+ fmap.append(x)
1050
+ x = self.conv_post(x)
1051
+ fmap.append(x)
1052
+ x = torch.flatten(x, 1, -1)
1053
+
1054
+ return x, fmap
1055
+
1056
+
1057
+ class DiscriminatorS(torch.nn.Module):
1058
+ def __init__(self, use_spectral_norm=False):
1059
+ super(DiscriminatorS, self).__init__()
1060
+ norm_f = weight_norm if use_spectral_norm == False else spectral_norm
1061
+ self.convs = nn.ModuleList(
1062
+ [
1063
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
1064
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
1065
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
1066
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
1067
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
1068
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
1069
+ ]
1070
+ )
1071
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
1072
+
1073
+ def forward(self, x):
1074
+ fmap = []
1075
+
1076
+ for l in self.convs:
1077
+ x = l(x)
1078
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
1079
+ fmap.append(x)
1080
+ x = self.conv_post(x)
1081
+ fmap.append(x)
1082
+ x = torch.flatten(x, 1, -1)
1083
+
1084
+ return x, fmap
1085
+
1086
+
1087
+ class MultiPeriodDiscriminator(torch.nn.Module):
1088
+ def __init__(self, use_spectral_norm=False):
1089
+ super(MultiPeriodDiscriminator, self).__init__()
1090
+ periods = [2, 3, 5, 7, 11]
1091
+
1092
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
1093
+ discs = discs + [
1094
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
1095
+ ]
1096
+ self.discriminators = nn.ModuleList(discs)
1097
+
1098
+ def forward(self, y, y_hat):
1099
+ y_d_rs = []
1100
+ y_d_gs = []
1101
+ fmap_rs = []
1102
+ fmap_gs = []
1103
+ for i, d in enumerate(self.discriminators):
1104
+ y_d_r, fmap_r = d(y)
1105
+ y_d_g, fmap_g = d(y_hat)
1106
+ y_d_rs.append(y_d_r)
1107
+ y_d_gs.append(y_d_g)
1108
+ fmap_rs.append(fmap_r)
1109
+ fmap_gs.append(fmap_g)
1110
+
1111
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
1112
+
1113
+
1114
+ class SynthesizerTrn(nn.Module):
1115
+ """
1116
+ Synthesizer for Training
1117
+ """
1118
+
1119
+ def __init__(
1120
+ self,
1121
+ n_vocab,
1122
+ spec_channels,
1123
+ segment_size,
1124
+ inter_channels,
1125
+ hidden_channels,
1126
+ filter_channels,
1127
+ n_heads,
1128
+ n_layers,
1129
+ kernel_size,
1130
+ p_dropout,
1131
+ resblock,
1132
+ resblock_kernel_sizes,
1133
+ resblock_dilation_sizes,
1134
+ upsample_rates,
1135
+ upsample_initial_channel,
1136
+ upsample_kernel_sizes,
1137
+ n_speakers=0,
1138
+ gin_channels=0,
1139
+ use_sdp=True,
1140
+ **kwargs,
1141
+ ):
1142
+ super().__init__()
1143
+ self.n_vocab = n_vocab
1144
+ self.spec_channels = spec_channels
1145
+ self.inter_channels = inter_channels
1146
+ self.hidden_channels = hidden_channels
1147
+ self.filter_channels = filter_channels
1148
+ self.n_heads = n_heads
1149
+ self.n_layers = n_layers
1150
+ self.kernel_size = kernel_size
1151
+ self.p_dropout = p_dropout
1152
+ self.resblock = resblock
1153
+ self.resblock_kernel_sizes = resblock_kernel_sizes
1154
+ self.resblock_dilation_sizes = resblock_dilation_sizes
1155
+ self.upsample_rates = upsample_rates
1156
+ self.upsample_initial_channel = upsample_initial_channel
1157
+ self.upsample_kernel_sizes = upsample_kernel_sizes
1158
+ self.segment_size = segment_size
1159
+ self.n_speakers = n_speakers
1160
+ self.gin_channels = gin_channels
1161
+ self.use_spk_conditioned_encoder = kwargs.get(
1162
+ "use_spk_conditioned_encoder", False
1163
+ )
1164
+ self.use_transformer_flows = kwargs.get("use_transformer_flows", False)
1165
+ self.transformer_flow_type = kwargs.get(
1166
+ "transformer_flow_type", "mono_layer_post_residual"
1167
+ )
1168
+ if self.use_transformer_flows:
1169
+ assert (
1170
+ self.transformer_flow_type in AVAILABLE_FLOW_TYPES
1171
+ ), f"transformer_flow_type must be one of {AVAILABLE_FLOW_TYPES}"
1172
+ self.use_sdp = use_sdp
1173
+ # self.use_duration_discriminator = kwargs.get("use_duration_discriminator", False)
1174
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
1175
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
1176
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
1177
+
1178
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
1179
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
1180
+ self.enc_gin_channels = gin_channels
1181
+ else:
1182
+ self.enc_gin_channels = 0
1183
+ self.enc_p = TextEncoder(
1184
+ n_vocab,
1185
+ inter_channels,
1186
+ hidden_channels,
1187
+ filter_channels,
1188
+ n_heads,
1189
+ n_layers,
1190
+ kernel_size,
1191
+ p_dropout,
1192
+ gin_channels=self.enc_gin_channels,
1193
+ )
1194
+
1195
+ self.dec = Generator(
1196
+ inter_channels,
1197
+ resblock,
1198
+ resblock_kernel_sizes,
1199
+ resblock_dilation_sizes,
1200
+ upsample_rates,
1201
+ upsample_initial_channel,
1202
+ upsample_kernel_sizes,
1203
+ gin_channels=gin_channels,
1204
+ )
1205
+ self.enc_q = PosteriorEncoder(
1206
+ spec_channels,
1207
+ inter_channels,
1208
+ hidden_channels,
1209
+ 5,
1210
+ 1,
1211
+ 16,
1212
+ gin_channels=gin_channels,
1213
+ )
1214
+ # self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
1215
+ self.flow = ResidualCouplingTransformersBlock(
1216
+ inter_channels,
1217
+ hidden_channels,
1218
+ 5,
1219
+ 1,
1220
+ 4,
1221
+ gin_channels=gin_channels,
1222
+ use_transformer_flows=self.use_transformer_flows,
1223
+ transformer_flow_type=self.transformer_flow_type,
1224
+ )
1225
+
1226
+ if use_sdp:
1227
+ self.dp = StochasticDurationPredictor(
1228
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
1229
+ )
1230
+ else:
1231
+ self.dp = DurationPredictor(
1232
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
1233
+ )
1234
+
1235
+ if n_speakers > 1:
1236
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
1237
+
1238
+ def forward(self, x, x_lengths, y, y_lengths, sid=None):
1239
+ if self.n_speakers > 0:
1240
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1241
+ else:
1242
+ g = None
1243
+
1244
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, g=g)
1245
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
1246
+ z_p = self.flow(z, y_mask, g=g)
1247
+
1248
+ with torch.no_grad():
1249
+ # negative cross-entropy
1250
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
1251
+ neg_cent1 = torch.sum(
1252
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
1253
+ ) # [b, 1, t_s]
1254
+ neg_cent2 = torch.matmul(
1255
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
1256
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1257
+ neg_cent3 = torch.matmul(
1258
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
1259
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
1260
+ neg_cent4 = torch.sum(
1261
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
1262
+ ) # [b, 1, t_s]
1263
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
1264
+
1265
+ if self.use_noise_scaled_mas:
1266
+ epsilon = (
1267
+ torch.std(neg_cent)
1268
+ * torch.randn_like(neg_cent)
1269
+ * self.current_mas_noise_scale
1270
+ )
1271
+ neg_cent = neg_cent + epsilon
1272
+
1273
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1274
+ attn = (
1275
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
1276
+ .unsqueeze(1)
1277
+ .detach()
1278
+ )
1279
+
1280
+ w = attn.sum(2)
1281
+ if self.use_sdp:
1282
+ l_length = self.dp(x, x_mask, w, g=g)
1283
+ l_length = l_length / torch.sum(x_mask)
1284
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=1.0)
1285
+ logw_ = torch.log(w + 1e-6) * x_mask
1286
+ else:
1287
+ logw_ = torch.log(w + 1e-6) * x_mask
1288
+ logw = self.dp(x, x_mask, g=g)
1289
+ l_length = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
1290
+ x_mask
1291
+ ) # for averaging
1292
+
1293
+ # expand prior
1294
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
1295
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
1296
+
1297
+ z_slice, ids_slice = commons.rand_slice_segments(
1298
+ z, y_lengths, self.segment_size
1299
+ )
1300
+ o = self.dec(z_slice, g=g)
1301
+ return (
1302
+ o,
1303
+ l_length,
1304
+ attn,
1305
+ ids_slice,
1306
+ x_mask,
1307
+ y_mask,
1308
+ (z, z_p, m_p, logs_p, m_q, logs_q),
1309
+ (x, logw, logw_),
1310
+ )
1311
+
1312
+ def infer(
1313
+ self,
1314
+ x,
1315
+ x_lengths,
1316
+ sid=None,
1317
+ noise_scale=1,
1318
+ length_scale=1,
1319
+ noise_scale_w=1.0,
1320
+ max_len=None,
1321
+ ):
1322
+ if self.n_speakers > 0:
1323
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
1324
+ else:
1325
+ g = None
1326
+ x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, g=g)
1327
+ if self.use_sdp:
1328
+ logw = self.dp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w)
1329
+ else:
1330
+ logw = self.dp(x, x_mask, g=g)
1331
+ w = torch.exp(logw) * x_mask * length_scale
1332
+ w_ceil = torch.ceil(w)
1333
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1334
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1335
+ x_mask.dtype
1336
+ )
1337
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1338
+ attn = commons.generate_path(w_ceil, attn_mask)
1339
+
1340
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1341
+ 1, 2
1342
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1343
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1344
+ 1, 2
1345
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1346
+
1347
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1348
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1349
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1350
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
1351
+
1352
+ # currently vits-2 is not capable of voice conversion
1353
+ ## comment - choihkk
1354
+ ## Assuming the use of the ResidualCouplingTransformersLayer2 module, it seems that voice conversion is possible
1355
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt):
1356
+ assert self.n_speakers > 0, "n_speakers have to be larger than 0."
1357
+ g_src = self.emb_g(sid_src).unsqueeze(-1)
1358
+ g_tgt = self.emb_g(sid_tgt).unsqueeze(-1)
1359
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src)
1360
+ z_p = self.flow(z, y_mask, g=g_src)
1361
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1362
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1363
+ return o_hat, y_mask, (z, z_p, z_hat)
modules.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import math
3
+ import numpy as np
4
+ import scipy
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
10
+ from torch.nn.utils import weight_norm, remove_weight_norm
11
+
12
+ import commons
13
+ from commons import init_weights, get_padding
14
+ from transforms import piecewise_rational_quadratic_transform
15
+
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super().__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = p_dropout
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super().__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = p_dropout
98
+
99
+ self.drop = nn.Dropout(p_dropout)
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g=None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = p_dropout
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(p_dropout)
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(self, x, x_mask, g=None, **kwargs):
189
+ output = torch.zeros_like(x)
190
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
191
+
192
+ if g is not None:
193
+ g = self.cond_layer(g)
194
+
195
+ for i in range(self.n_layers):
196
+ x_in = self.in_layers[i](x)
197
+ if g is not None:
198
+ cond_offset = i * 2 * self.hidden_channels
199
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
200
+ else:
201
+ g_l = torch.zeros_like(x_in)
202
+
203
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
204
+ acts = self.drop(acts)
205
+
206
+ res_skip_acts = self.res_skip_layers[i](acts)
207
+ if i < self.n_layers - 1:
208
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
209
+ x = (x + res_acts) * x_mask
210
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
211
+ else:
212
+ output = output + res_skip_acts
213
+ return output * x_mask
214
+
215
+ def remove_weight_norm(self):
216
+ if self.gin_channels != 0:
217
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
218
+ for l in self.in_layers:
219
+ torch.nn.utils.remove_weight_norm(l)
220
+ for l in self.res_skip_layers:
221
+ torch.nn.utils.remove_weight_norm(l)
222
+
223
+
224
+ class ResBlock1(torch.nn.Module):
225
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
226
+ super(ResBlock1, self).__init__()
227
+ self.convs1 = nn.ModuleList(
228
+ [
229
+ weight_norm(
230
+ Conv1d(
231
+ channels,
232
+ channels,
233
+ kernel_size,
234
+ 1,
235
+ dilation=dilation[0],
236
+ padding=get_padding(kernel_size, dilation[0]),
237
+ )
238
+ ),
239
+ weight_norm(
240
+ Conv1d(
241
+ channels,
242
+ channels,
243
+ kernel_size,
244
+ 1,
245
+ dilation=dilation[1],
246
+ padding=get_padding(kernel_size, dilation[1]),
247
+ )
248
+ ),
249
+ weight_norm(
250
+ Conv1d(
251
+ channels,
252
+ channels,
253
+ kernel_size,
254
+ 1,
255
+ dilation=dilation[2],
256
+ padding=get_padding(kernel_size, dilation[2]),
257
+ )
258
+ ),
259
+ ]
260
+ )
261
+ self.convs1.apply(init_weights)
262
+
263
+ self.convs2 = nn.ModuleList(
264
+ [
265
+ weight_norm(
266
+ Conv1d(
267
+ channels,
268
+ channels,
269
+ kernel_size,
270
+ 1,
271
+ dilation=1,
272
+ padding=get_padding(kernel_size, 1),
273
+ )
274
+ ),
275
+ weight_norm(
276
+ Conv1d(
277
+ channels,
278
+ channels,
279
+ kernel_size,
280
+ 1,
281
+ dilation=1,
282
+ padding=get_padding(kernel_size, 1),
283
+ )
284
+ ),
285
+ weight_norm(
286
+ Conv1d(
287
+ channels,
288
+ channels,
289
+ kernel_size,
290
+ 1,
291
+ dilation=1,
292
+ padding=get_padding(kernel_size, 1),
293
+ )
294
+ ),
295
+ ]
296
+ )
297
+ self.convs2.apply(init_weights)
298
+
299
+ def forward(self, x, x_mask=None):
300
+ for c1, c2 in zip(self.convs1, self.convs2):
301
+ xt = F.leaky_relu(x, LRELU_SLOPE)
302
+ if x_mask is not None:
303
+ xt = xt * x_mask
304
+ xt = c1(xt)
305
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
306
+ if x_mask is not None:
307
+ xt = xt * x_mask
308
+ xt = c2(xt)
309
+ x = xt + x
310
+ if x_mask is not None:
311
+ x = x * x_mask
312
+ return x
313
+
314
+ def remove_weight_norm(self):
315
+ for l in self.convs1:
316
+ remove_weight_norm(l)
317
+ for l in self.convs2:
318
+ remove_weight_norm(l)
319
+
320
+
321
+ class ResBlock2(torch.nn.Module):
322
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
323
+ super(ResBlock2, self).__init__()
324
+ self.convs = nn.ModuleList(
325
+ [
326
+ weight_norm(
327
+ Conv1d(
328
+ channels,
329
+ channels,
330
+ kernel_size,
331
+ 1,
332
+ dilation=dilation[0],
333
+ padding=get_padding(kernel_size, dilation[0]),
334
+ )
335
+ ),
336
+ weight_norm(
337
+ Conv1d(
338
+ channels,
339
+ channels,
340
+ kernel_size,
341
+ 1,
342
+ dilation=dilation[1],
343
+ padding=get_padding(kernel_size, dilation[1]),
344
+ )
345
+ ),
346
+ ]
347
+ )
348
+ self.convs.apply(init_weights)
349
+
350
+ def forward(self, x, x_mask=None):
351
+ for c in self.convs:
352
+ xt = F.leaky_relu(x, LRELU_SLOPE)
353
+ if x_mask is not None:
354
+ xt = xt * x_mask
355
+ xt = c(xt)
356
+ x = xt + x
357
+ if x_mask is not None:
358
+ x = x * x_mask
359
+ return x
360
+
361
+ def remove_weight_norm(self):
362
+ for l in self.convs:
363
+ remove_weight_norm(l)
364
+
365
+
366
+ class Log(nn.Module):
367
+ def forward(self, x, x_mask, reverse=False, **kwargs):
368
+ if not reverse:
369
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
370
+ logdet = torch.sum(-y, [1, 2])
371
+ return y, logdet
372
+ else:
373
+ x = torch.exp(x) * x_mask
374
+ return x
375
+
376
+
377
+ class Flip(nn.Module):
378
+ def forward(self, x, *args, reverse=False, **kwargs):
379
+ x = torch.flip(x, [1])
380
+ if not reverse:
381
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
382
+ return x, logdet
383
+ else:
384
+ return x
385
+
386
+
387
+ class ElementwiseAffine(nn.Module):
388
+ def __init__(self, channels):
389
+ super().__init__()
390
+ self.channels = channels
391
+ self.m = nn.Parameter(torch.zeros(channels, 1))
392
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
393
+
394
+ def forward(self, x, x_mask, reverse=False, **kwargs):
395
+ if not reverse:
396
+ y = self.m + torch.exp(self.logs) * x
397
+ y = y * x_mask
398
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
399
+ return y, logdet
400
+ else:
401
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
402
+ return x
403
+
404
+
405
+ class ResidualCouplingLayer(nn.Module):
406
+ def __init__(
407
+ self,
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ p_dropout=0,
414
+ gin_channels=0,
415
+ mean_only=False,
416
+ ):
417
+ assert channels % 2 == 0, "channels should be divisible by 2"
418
+ super().__init__()
419
+ self.channels = channels
420
+ self.hidden_channels = hidden_channels
421
+ self.kernel_size = kernel_size
422
+ self.dilation_rate = dilation_rate
423
+ self.n_layers = n_layers
424
+ self.half_channels = channels // 2
425
+ self.mean_only = mean_only
426
+
427
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
428
+ self.enc = WN(
429
+ hidden_channels,
430
+ kernel_size,
431
+ dilation_rate,
432
+ n_layers,
433
+ p_dropout=p_dropout,
434
+ gin_channels=gin_channels,
435
+ )
436
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
437
+ self.post.weight.data.zero_()
438
+ self.post.bias.data.zero_()
439
+
440
+ def forward(self, x, x_mask, g=None, reverse=False):
441
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
442
+ h = self.pre(x0) * x_mask
443
+ h = self.enc(h, x_mask, g=g)
444
+ stats = self.post(h) * x_mask
445
+ if not self.mean_only:
446
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
447
+ else:
448
+ m = stats
449
+ logs = torch.zeros_like(m)
450
+
451
+ if not reverse:
452
+ x1 = m + x1 * torch.exp(logs) * x_mask
453
+ x = torch.cat([x0, x1], 1)
454
+ logdet = torch.sum(logs, [1, 2])
455
+ return x, logdet
456
+ else:
457
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
458
+ x = torch.cat([x0, x1], 1)
459
+ return x
460
+
461
+
462
+ class ConvFlow(nn.Module):
463
+ def __init__(
464
+ self,
465
+ in_channels,
466
+ filter_channels,
467
+ kernel_size,
468
+ n_layers,
469
+ num_bins=10,
470
+ tail_bound=5.0,
471
+ ):
472
+ super().__init__()
473
+ self.in_channels = in_channels
474
+ self.filter_channels = filter_channels
475
+ self.kernel_size = kernel_size
476
+ self.n_layers = n_layers
477
+ self.num_bins = num_bins
478
+ self.tail_bound = tail_bound
479
+ self.half_channels = in_channels // 2
480
+
481
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
482
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
483
+ self.proj = nn.Conv1d(
484
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
485
+ )
486
+ self.proj.weight.data.zero_()
487
+ self.proj.bias.data.zero_()
488
+
489
+ def forward(self, x, x_mask, g=None, reverse=False):
490
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
491
+ h = self.pre(x0)
492
+ h = self.convs(h, x_mask, g=g)
493
+ h = self.proj(h) * x_mask
494
+
495
+ b, c, t = x0.shape
496
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
497
+
498
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
499
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
500
+ self.filter_channels
501
+ )
502
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
503
+
504
+ x1, logabsdet = piecewise_rational_quadratic_transform(
505
+ x1,
506
+ unnormalized_widths,
507
+ unnormalized_heights,
508
+ unnormalized_derivatives,
509
+ inverse=reverse,
510
+ tails="linear",
511
+ tail_bound=self.tail_bound,
512
+ )
513
+
514
+ x = torch.cat([x0, x1], 1) * x_mask
515
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
516
+ if not reverse:
517
+ return x, logdet
518
+ else:
519
+ return x
monotonic_align/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from .monotonic_align.core import maximum_path_c
4
+
5
+
6
+ def maximum_path(neg_cent, mask):
7
+ """Cython optimized version.
8
+ neg_cent: [b, t_t, t_s]
9
+ mask: [b, t_t, t_s]
10
+ """
11
+ device = neg_cent.device
12
+ dtype = neg_cent.dtype
13
+ neg_cent = neg_cent.data.cpu().numpy().astype(np.float32)
14
+ path = np.zeros(neg_cent.shape, dtype=np.int32)
15
+
16
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(np.int32)
17
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(np.int32)
18
+ maximum_path_c(path, neg_cent, t_t_max, t_s_max)
19
+ return torch.from_numpy(path).to(device=device, dtype=dtype)
monotonic_align/build/temp.linux-x86_64-3.10/core.o ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d75191264a1468a8d701059facf526e45a2f9384688871df43bebd79ac06bf2
3
+ size 1288328
monotonic_align/core.c ADDED
The diff for this file is too large to render. See raw diff
 
monotonic_align/core.pyx ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cimport cython
2
+ from cython.parallel import prange
3
+
4
+
5
+ @cython.boundscheck(False)
6
+ @cython.wraparound(False)
7
+ cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_y, int t_x, float max_neg_val=-1e9) nogil:
8
+ cdef int x
9
+ cdef int y
10
+ cdef float v_prev
11
+ cdef float v_cur
12
+ cdef float tmp
13
+ cdef int index = t_x - 1
14
+
15
+ for y in range(t_y):
16
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
17
+ if x == y:
18
+ v_cur = max_neg_val
19
+ else:
20
+ v_cur = value[y-1, x]
21
+ if x == 0:
22
+ if y == 0:
23
+ v_prev = 0.
24
+ else:
25
+ v_prev = max_neg_val
26
+ else:
27
+ v_prev = value[y-1, x-1]
28
+ value[y, x] += max(v_prev, v_cur)
29
+
30
+ for y in range(t_y - 1, -1, -1):
31
+ path[y, index] = 1
32
+ if index != 0 and (index == y or value[y-1, index] < value[y-1, index-1]):
33
+ index = index - 1
34
+
35
+
36
+ @cython.boundscheck(False)
37
+ @cython.wraparound(False)
38
+ cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_ys, int[::1] t_xs) nogil:
39
+ cdef int b = paths.shape[0]
40
+ cdef int i
41
+ for i in prange(b, nogil=True):
42
+ maximum_path_each(paths[i], values[i], t_ys[i], t_xs[i])
monotonic_align/monotonic_align/.gitkeep ADDED
File without changes
monotonic_align/monotonic_align/core.cpython-310-x86_64-linux-gnu.so ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c48ad49b41e263174b00b3580f1bb024c5c6e4b4d378b8755b956b9f06fa90e5
3
+ size 707752
monotonic_align/setup.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from distutils.core import setup
2
+ from Cython.Build import cythonize
3
+ import numpy
4
+
5
+ setup(
6
+ name="monotonic_align",
7
+ ext_modules=cythonize("core.pyx"),
8
+ include_dirs=[numpy.get_include()],
9
+ )
text/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+ from text import cleaners
3
+ from text.symbols import symbols
4
+
5
+
6
+ # Mappings from symbol to numeric ID and vice versa:
7
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
8
+ _id_to_symbol = {i: s for i, s in enumerate(symbols)}
9
+
10
+
11
+ def text_to_sequence(text, cleaner_names):
12
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
13
+ Args:
14
+ text: string to convert to a sequence
15
+ cleaner_names: names of the cleaner functions to run the text through
16
+ Returns:
17
+ List of integers corresponding to the symbols in the text
18
+ """
19
+ sequence = []
20
+
21
+ clean_text = _clean_text(text, cleaner_names)
22
+ for symbol in clean_text:
23
+ if symbol in _symbol_to_id.keys():
24
+ symbol_id = _symbol_to_id[symbol]
25
+ sequence += [symbol_id]
26
+ else:
27
+ continue
28
+ return sequence
29
+
30
+
31
+ def cleaned_text_to_sequence(cleaned_text):
32
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
33
+ Args:
34
+ text: string to convert to a sequence
35
+ Returns:
36
+ List of integers corresponding to the symbols in the text
37
+ """
38
+ sequence = []
39
+
40
+ for symbol in cleaned_text:
41
+ if symbol in _symbol_to_id.keys():
42
+ symbol_id = _symbol_to_id[symbol]
43
+ sequence += [symbol_id]
44
+ else:
45
+ continue
46
+ return sequence
47
+
48
+
49
+ def sequence_to_text(sequence):
50
+ """Converts a sequence of IDs back to a string"""
51
+ result = ""
52
+ for symbol_id in sequence:
53
+ s = _id_to_symbol[symbol_id]
54
+ result += s
55
+ return result
56
+
57
+
58
+ def _clean_text(text, cleaner_names):
59
+ for name in cleaner_names:
60
+ cleaner = getattr(cleaners, name)
61
+ if not cleaner:
62
+ raise Exception("Unknown cleaner: %s" % name)
63
+ text = cleaner(text)
64
+ return text
text/cleaners.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ """
4
+ Cleaners are transformations that run over the input text at both training and eval time.
5
+
6
+ Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
7
+ hyperparameter. Some cleaners are English-specific. You'll typically want to use:
8
+ 1. "english_cleaners" for English text
9
+ 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
10
+ the Unidecode library (https://pypi.python.org/pypi/Unidecode)
11
+ 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
12
+ the symbols in symbols.py to match your data).
13
+ """
14
+
15
+ import re
16
+ from unidecode import unidecode
17
+ from phonemizer import phonemize
18
+ from phonemizer.backend import EspeakBackend
19
+ backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
20
+
21
+
22
+ # Regular expression matching whitespace:
23
+ _whitespace_re = re.compile(r"\s+")
24
+
25
+ # List of (regular expression, replacement) pairs for abbreviations:
26
+ _abbreviations = [
27
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
28
+ for x in [
29
+ ("mrs", "misess"),
30
+ ("mr", "mister"),
31
+ ("dr", "doctor"),
32
+ ("st", "saint"),
33
+ ("co", "company"),
34
+ ("jr", "junior"),
35
+ ("maj", "major"),
36
+ ("gen", "general"),
37
+ ("drs", "doctors"),
38
+ ("rev", "reverend"),
39
+ ("lt", "lieutenant"),
40
+ ("hon", "honorable"),
41
+ ("sgt", "sergeant"),
42
+ ("capt", "captain"),
43
+ ("esq", "esquire"),
44
+ ("ltd", "limited"),
45
+ ("col", "colonel"),
46
+ ("ft", "fort"),
47
+ ]
48
+ ]
49
+
50
+
51
+ def expand_abbreviations(text):
52
+ for regex, replacement in _abbreviations:
53
+ text = re.sub(regex, replacement, text)
54
+ return text
55
+
56
+
57
+ def expand_numbers(text):
58
+ return normalize_numbers(text)
59
+
60
+
61
+ def lowercase(text):
62
+ return text.lower()
63
+
64
+
65
+ def collapse_whitespace(text):
66
+ return re.sub(_whitespace_re, " ", text)
67
+
68
+
69
+ def convert_to_ascii(text):
70
+ return unidecode(text)
71
+
72
+
73
+ def basic_cleaners(text):
74
+ """Basic pipeline that lowercases and collapses whitespace without transliteration."""
75
+ text = lowercase(text)
76
+ text = collapse_whitespace(text)
77
+ return text
78
+
79
+
80
+ def transliteration_cleaners(text):
81
+ """Pipeline for non-English text that transliterates to ASCII."""
82
+ text = convert_to_ascii(text)
83
+ text = lowercase(text)
84
+ text = collapse_whitespace(text)
85
+ return text
86
+
87
+
88
+ def english_cleaners(text):
89
+ """Pipeline for English text, including abbreviation expansion."""
90
+ text = convert_to_ascii(text)
91
+ text = lowercase(text)
92
+ text = expand_abbreviations(text)
93
+ phonemes = phonemize(text, language="en-us", backend="espeak", strip=True)
94
+ phonemes = collapse_whitespace(phonemes)
95
+ return phonemes
96
+
97
+
98
+ def english_cleaners2(text):
99
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
100
+ text = convert_to_ascii(text)
101
+ text = lowercase(text)
102
+ text = expand_abbreviations(text)
103
+ phonemes = phonemize(
104
+ text,
105
+ language="en-us",
106
+ backend="espeak",
107
+ strip=True,
108
+ preserve_punctuation=True,
109
+ with_stress=True,
110
+ )
111
+ phonemes = collapse_whitespace(phonemes)
112
+ return phonemes
113
+
114
+
115
+ def english_cleaners3(text):
116
+ """Pipeline for English text, including abbreviation expansion. + punctuation + stress"""
117
+ text = convert_to_ascii(text)
118
+ text = lowercase(text)
119
+ text = expand_abbreviations(text)
120
+ phonemes = backend.phonemize([text], strip=True)[0]
121
+ phonemes = collapse_whitespace(phonemes)
122
+ return phonemes
123
+
124
+
125
+ def luxembourgish_cleaners(text):
126
+ """Pipeline for Luxembourgish text. Grapheme-based, no phonemization."""
127
+ text = lowercase(text)
128
+ text = collapse_whitespace(text)
129
+ return text
text/symbols.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Luxembourgish grapheme symbols for VITS2 TTS.
4
+ Based on the character set from the Kerstin dataset.
5
+ """
6
+
7
+ _pad = "_"
8
+ # Punctuation including special quotes used in Luxembourgish text
9
+ _punctuation = "!\"'()+,-./:;?[] " + "«»‒–'" + "'" + "„"
10
+ _letters = "abcdefghijklmnopqrstuvwxyzäëéèêïöüàâîôûçáãßćğœš"
11
+
12
+ # Export all symbols:
13
+ symbols = [_pad] + list(_punctuation) + list(_letters)
14
+
15
+ # Special symbol ids
16
+ SPACE_ID = symbols.index(" ")
transforms.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ # Clamp to prevent numerical instability causing negative discriminant
172
+ discriminant = torch.clamp(discriminant, min=1e-7)
173
+
174
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
175
+ outputs = root * input_bin_widths + input_cumwidths
176
+
177
+ theta_one_minus_theta = root * (1 - root)
178
+ denominator = input_delta + (
179
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
180
+ * theta_one_minus_theta
181
+ )
182
+ derivative_numerator = input_delta.pow(2) * (
183
+ input_derivatives_plus_one * root.pow(2)
184
+ + 2 * input_delta * theta_one_minus_theta
185
+ + input_derivatives * (1 - root).pow(2)
186
+ )
187
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
188
+
189
+ return outputs, -logabsdet
190
+ else:
191
+ theta = (inputs - input_cumwidths) / input_bin_widths
192
+ theta_one_minus_theta = theta * (1 - theta)
193
+
194
+ numerator = input_heights * (
195
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
196
+ )
197
+ denominator = input_delta + (
198
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
199
+ * theta_one_minus_theta
200
+ )
201
+ outputs = input_cumheights + numerator / denominator
202
+
203
+ derivative_numerator = input_delta.pow(2) * (
204
+ input_derivatives_plus_one * theta.pow(2)
205
+ + 2 * input_delta * theta_one_minus_theta
206
+ + input_derivatives * (1 - theta).pow(2)
207
+ )
208
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
209
+
210
+ return outputs, logabsdet
utils.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ import argparse
5
+ import logging
6
+ import json
7
+ import subprocess
8
+ import numpy as np
9
+ from scipy.io.wavfile import read
10
+ import torch
11
+
12
+ MATPLOTLIB_FLAG = False
13
+
14
+ logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
15
+ logger = logging
16
+
17
+
18
+ def load_checkpoint(checkpoint_path, model, optimizer=None):
19
+ assert os.path.isfile(checkpoint_path)
20
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
21
+ iteration = checkpoint_dict["iteration"]
22
+ learning_rate = checkpoint_dict["learning_rate"]
23
+ if optimizer is not None:
24
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
25
+ saved_state_dict = checkpoint_dict["model"]
26
+ if hasattr(model, "module"):
27
+ state_dict = model.module.state_dict()
28
+ else:
29
+ state_dict = model.state_dict()
30
+ new_state_dict = {}
31
+ for k, v in state_dict.items():
32
+ try:
33
+ new_state_dict[k] = saved_state_dict[k]
34
+ except:
35
+ logger.info("%s is not in the checkpoint" % k)
36
+ new_state_dict[k] = v
37
+ if hasattr(model, "module"):
38
+ model.module.load_state_dict(new_state_dict)
39
+ else:
40
+ model.load_state_dict(new_state_dict)
41
+ logger.info(
42
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
43
+ )
44
+ return model, optimizer, learning_rate, iteration
45
+
46
+
47
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
48
+ logger.info(
49
+ "Saving model and optimizer state at iteration {} to {}".format(
50
+ iteration, checkpoint_path
51
+ )
52
+ )
53
+ if hasattr(model, "module"):
54
+ state_dict = model.module.state_dict()
55
+ else:
56
+ state_dict = model.state_dict()
57
+ torch.save(
58
+ {
59
+ "model": state_dict,
60
+ "iteration": iteration,
61
+ "optimizer": optimizer.state_dict(),
62
+ "learning_rate": learning_rate,
63
+ },
64
+ checkpoint_path,
65
+ )
66
+
67
+
68
+ def summarize(
69
+ writer,
70
+ global_step,
71
+ scalars={},
72
+ histograms={},
73
+ images={},
74
+ audios={},
75
+ audio_sampling_rate=22050,
76
+ ):
77
+ for k, v in scalars.items():
78
+ writer.add_scalar(k, v, global_step)
79
+ for k, v in histograms.items():
80
+ writer.add_histogram(k, v, global_step)
81
+ for k, v in images.items():
82
+ writer.add_image(k, v, global_step, dataformats="HWC")
83
+ for k, v in audios.items():
84
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
85
+
86
+
87
+ def scan_checkpoint(dir_path, regex):
88
+ f_list = glob.glob(os.path.join(dir_path, regex))
89
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
90
+ if len(f_list) == 0:
91
+ return None
92
+ return f_list
93
+
94
+
95
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
96
+ f_list = scan_checkpoint(dir_path, regex)
97
+ if not f_list:
98
+ return None
99
+ x = f_list[-1]
100
+ print(x)
101
+ return x
102
+
103
+
104
+ def remove_old_checkpoints(cp_dir, prefixes=['G_*.pth', 'D_*.pth', 'DUR_*.pth']):
105
+ for prefix in prefixes:
106
+ sorted_ckpts = scan_checkpoint(cp_dir, prefix)
107
+ if sorted_ckpts and len(sorted_ckpts) > 3:
108
+ for ckpt_path in sorted_ckpts[:-3]:
109
+ os.remove(ckpt_path)
110
+ print("removed {}".format(ckpt_path))
111
+
112
+
113
+ def plot_spectrogram_to_numpy(spectrogram):
114
+ global MATPLOTLIB_FLAG
115
+ if not MATPLOTLIB_FLAG:
116
+ import matplotlib
117
+
118
+ matplotlib.use("Agg")
119
+ MATPLOTLIB_FLAG = True
120
+ mpl_logger = logging.getLogger("matplotlib")
121
+ mpl_logger.setLevel(logging.WARNING)
122
+ import matplotlib.pylab as plt
123
+ import numpy as np
124
+
125
+ fig, ax = plt.subplots(figsize=(10, 2))
126
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
127
+ plt.colorbar(im, ax=ax)
128
+ plt.xlabel("Frames")
129
+ plt.ylabel("Channels")
130
+ plt.tight_layout()
131
+
132
+ fig.canvas.draw()
133
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
134
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
135
+ plt.close()
136
+ return data
137
+
138
+
139
+ def plot_alignment_to_numpy(alignment, info=None):
140
+ global MATPLOTLIB_FLAG
141
+ if not MATPLOTLIB_FLAG:
142
+ import matplotlib
143
+
144
+ matplotlib.use("Agg")
145
+ MATPLOTLIB_FLAG = True
146
+ mpl_logger = logging.getLogger("matplotlib")
147
+ mpl_logger.setLevel(logging.WARNING)
148
+ import matplotlib.pylab as plt
149
+ import numpy as np
150
+
151
+ fig, ax = plt.subplots(figsize=(6, 4))
152
+ im = ax.imshow(
153
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
154
+ )
155
+ fig.colorbar(im, ax=ax)
156
+ xlabel = "Decoder timestep"
157
+ if info is not None:
158
+ xlabel += "\n\n" + info
159
+ plt.xlabel(xlabel)
160
+ plt.ylabel("Encoder timestep")
161
+ plt.tight_layout()
162
+
163
+ fig.canvas.draw()
164
+ data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep="")
165
+ data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
166
+ plt.close()
167
+ return data
168
+
169
+
170
+ def load_wav_to_torch(full_path):
171
+ sampling_rate, data = read(full_path)
172
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
173
+
174
+
175
+ def load_filepaths_and_text(filename, split="|"):
176
+ with open(filename, encoding="utf-8") as f:
177
+ filepaths_and_text = [line.strip().split(split) for line in f]
178
+ return filepaths_and_text
179
+
180
+
181
+ def get_hparams(init=True):
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument(
184
+ "-c",
185
+ "--config",
186
+ type=str,
187
+ default="./configs/base.json",
188
+ help="JSON file for configuration",
189
+ )
190
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
191
+
192
+ args = parser.parse_args()
193
+ model_dir = os.path.join("./logs", args.model)
194
+
195
+ if not os.path.exists(model_dir):
196
+ os.makedirs(model_dir)
197
+
198
+ config_path = args.config
199
+ config_save_path = os.path.join(model_dir, "config.json")
200
+ if init:
201
+ with open(config_path, "r") as f:
202
+ data = f.read()
203
+ with open(config_save_path, "w") as f:
204
+ f.write(data)
205
+ else:
206
+ with open(config_save_path, "r") as f:
207
+ data = f.read()
208
+ config = json.loads(data)
209
+
210
+ hparams = HParams(**config)
211
+ hparams.model_dir = model_dir
212
+ return hparams
213
+
214
+
215
+ def get_hparams_from_dir(model_dir):
216
+ config_save_path = os.path.join(model_dir, "config.json")
217
+ with open(config_save_path, "r") as f:
218
+ data = f.read()
219
+ config = json.loads(data)
220
+
221
+ hparams = HParams(**config)
222
+ hparams.model_dir = model_dir
223
+ return hparams
224
+
225
+
226
+ def get_hparams_from_file(config_path):
227
+ with open(config_path, "r") as f:
228
+ data = f.read()
229
+ config = json.loads(data)
230
+
231
+ hparams = HParams(**config)
232
+ return hparams
233
+
234
+
235
+ def check_git_hash(model_dir):
236
+ source_dir = os.path.dirname(os.path.realpath(__file__))
237
+ if not os.path.exists(os.path.join(source_dir, ".git")):
238
+ logger.warn(
239
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
240
+ source_dir
241
+ )
242
+ )
243
+ return
244
+
245
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
246
+
247
+ path = os.path.join(model_dir, "githash")
248
+ if os.path.exists(path):
249
+ saved_hash = open(path).read()
250
+ if saved_hash != cur_hash:
251
+ logger.warn(
252
+ "git hash values are different. {}(saved) != {}(current)".format(
253
+ saved_hash[:8], cur_hash[:8]
254
+ )
255
+ )
256
+ else:
257
+ open(path, "w").write(cur_hash)
258
+
259
+
260
+ def get_logger(model_dir, filename="train.log"):
261
+ global logger
262
+ logger = logging.getLogger(os.path.basename(model_dir))
263
+ logger.setLevel(logging.DEBUG)
264
+
265
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
266
+ if not os.path.exists(model_dir):
267
+ os.makedirs(model_dir)
268
+ h = logging.FileHandler(os.path.join(model_dir, filename))
269
+ h.setLevel(logging.DEBUG)
270
+ h.setFormatter(formatter)
271
+ logger.addHandler(h)
272
+ return logger
273
+
274
+
275
+ class HParams:
276
+ def __init__(self, **kwargs):
277
+ for k, v in kwargs.items():
278
+ if type(v) == dict:
279
+ v = HParams(**v)
280
+ self[k] = v
281
+
282
+ def keys(self):
283
+ return self.__dict__.keys()
284
+
285
+ def items(self):
286
+ return self.__dict__.items()
287
+
288
+ def values(self):
289
+ return self.__dict__.values()
290
+
291
+ def __len__(self):
292
+ return len(self.__dict__)
293
+
294
+ def __getitem__(self, key):
295
+ return getattr(self, key)
296
+
297
+ def __setitem__(self, key, value):
298
+ return setattr(self, key, value)
299
+
300
+ def __contains__(self, key):
301
+ return key in self.__dict__
302
+
303
+ def __repr__(self):
304
+ return self.__dict__.__repr__()
vits2-claude.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d63e6074bc3ddf6e3b41be1f8185b841c9fd35b6e6262de834726e3492e3332e
3
+ size 444644814