valtecAI-team commited on
Commit
f8b4a6c
·
verified ·
1 Parent(s): 162c88f

Upload folder using huggingface_hub

Browse files
Files changed (36) hide show
  1. src/__init__.py +5 -0
  2. src/__pycache__/__init__.cpython-310.pyc +0 -0
  3. src/alignment/__init__.py +64 -0
  4. src/alignment/__pycache__/__init__.cpython-310.pyc +0 -0
  5. src/alignment/monotonic_align.py +46 -0
  6. src/models/__init__.py +5 -0
  7. src/models/__pycache__/__init__.cpython-310.pyc +0 -0
  8. src/models/__pycache__/synthesizer.cpython-310.pyc +0 -0
  9. src/models/synthesizer.py +1030 -0
  10. src/nn/__init__.py +8 -0
  11. src/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  12. src/nn/__pycache__/attentions.cpython-310.pyc +0 -0
  13. src/nn/__pycache__/commons.cpython-310.pyc +0 -0
  14. src/nn/__pycache__/modules.cpython-310.pyc +0 -0
  15. src/nn/__pycache__/transforms.cpython-310.pyc +0 -0
  16. src/nn/attentions.py +459 -0
  17. src/nn/commons.py +160 -0
  18. src/nn/modules.py +598 -0
  19. src/nn/transforms.py +209 -0
  20. src/text/__init__.py +24 -0
  21. src/text/__pycache__/__init__.cpython-310.pyc +0 -0
  22. src/text/__pycache__/cleaner.cpython-310.pyc +0 -0
  23. src/text/__pycache__/symbols.cpython-310.pyc +0 -0
  24. src/text/cleaner.py +44 -0
  25. src/text/symbols.py +373 -0
  26. src/text/vietnamese.py +429 -0
  27. src/utils/__init__.py +5 -0
  28. src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  29. src/utils/__pycache__/helpers.cpython-310.pyc +0 -0
  30. src/utils/helpers.py +452 -0
  31. src/vietnamese/__init__.py +6 -0
  32. src/vietnamese/__pycache__/__init__.cpython-310.pyc +0 -0
  33. src/vietnamese/__pycache__/phonemizer.cpython-310.pyc +0 -0
  34. src/vietnamese/__pycache__/text_processor.cpython-310.pyc +0 -0
  35. src/vietnamese/phonemizer.py +484 -0
  36. src/vietnamese/text_processor.py +428 -0
src/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ valtec-tts source package
3
+ """
4
+
5
+ __version__ = "1.0.0"
src/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (211 Bytes). View file
 
src/alignment/__init__.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Monotonic alignment package
3
+ """
4
+
5
+ import numba
6
+ from numpy import zeros, int32, float32
7
+ from torch import from_numpy
8
+
9
+
10
+ @numba.jit(
11
+ numba.void(
12
+ numba.int32[:, :, ::1],
13
+ numba.float32[:, :, ::1],
14
+ numba.int32[::1],
15
+ numba.int32[::1],
16
+ ),
17
+ nopython=True,
18
+ nogil=True,
19
+ )
20
+ def maximum_path_jit(paths, values, t_ys, t_xs):
21
+ b = paths.shape[0]
22
+ max_neg_val = -1e9
23
+ for i in range(int(b)):
24
+ path = paths[i]
25
+ value = values[i]
26
+ t_y = t_ys[i]
27
+ t_x = t_xs[i]
28
+
29
+ v_prev = v_cur = 0.0
30
+ index = t_x - 1
31
+
32
+ for y in range(t_y):
33
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
34
+ if x == y:
35
+ v_cur = max_neg_val
36
+ else:
37
+ v_cur = value[y - 1, x]
38
+ if x == 0:
39
+ if y == 0:
40
+ v_prev = 0.0
41
+ else:
42
+ v_prev = max_neg_val
43
+ else:
44
+ v_prev = value[y - 1, x - 1]
45
+ value[y, x] += max(v_prev, v_cur)
46
+
47
+ for y in range(t_y - 1, -1, -1):
48
+ path[y, index] = 1
49
+ if index != 0 and (
50
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
51
+ ):
52
+ index = index - 1
53
+
54
+
55
+ def maximum_path(neg_cent, mask):
56
+ device = neg_cent.device
57
+ dtype = neg_cent.dtype
58
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
59
+ path = zeros(neg_cent.shape, dtype=int32)
60
+
61
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
62
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
63
+ maximum_path_jit(path, neg_cent, t_t_max, t_s_max)
64
+ return from_numpy(path).to(device=device, dtype=dtype)
src/alignment/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
src/alignment/monotonic_align.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def maximum_path_jit(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
src/models/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ TTS Models package
3
+ """
4
+
5
+ from .synthesizer import SynthesizerTrn, Generator, MultiPeriodDiscriminator
src/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (295 Bytes). View file
 
src/models/__pycache__/synthesizer.cpython-310.pyc ADDED
Binary file (21.6 kB). View file
 
src/models/synthesizer.py ADDED
@@ -0,0 +1,1030 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from src.nn import commons
7
+ from src.nn import modules
8
+ from src.nn import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d, Conv2d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
12
+
13
+ from src.nn.commons import init_weights, get_padding
14
+ from src import alignment as monotonic_align
15
+
16
+
17
+ class DurationDiscriminator(nn.Module): # vits2
18
+ def __init__(
19
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
20
+ ):
21
+ super().__init__()
22
+ self.in_channels = in_channels
23
+ self.filter_channels = filter_channels
24
+ self.kernel_size = kernel_size
25
+ self.p_dropout = p_dropout
26
+ self.gin_channels = gin_channels
27
+
28
+ self.drop = nn.Dropout(p_dropout)
29
+ self.conv_1 = nn.Conv1d(
30
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
31
+ )
32
+ self.norm_1 = modules.LayerNorm(filter_channels)
33
+ self.conv_2 = nn.Conv1d(
34
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
35
+ )
36
+ self.norm_2 = modules.LayerNorm(filter_channels)
37
+ self.dur_proj = nn.Conv1d(1, filter_channels, 1)
38
+
39
+ self.pre_out_conv_1 = nn.Conv1d(
40
+ 2 * filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
41
+ )
42
+ self.pre_out_norm_1 = modules.LayerNorm(filter_channels)
43
+ self.pre_out_conv_2 = nn.Conv1d(
44
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
45
+ )
46
+ self.pre_out_norm_2 = modules.LayerNorm(filter_channels)
47
+
48
+ if gin_channels != 0:
49
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
50
+
51
+ self.output_layer = nn.Sequential(nn.Linear(filter_channels, 1), nn.Sigmoid())
52
+
53
+ def forward_probability(self, x, x_mask, dur, g=None):
54
+ dur = self.dur_proj(dur)
55
+ x = torch.cat([x, dur], dim=1)
56
+ x = self.pre_out_conv_1(x * x_mask)
57
+ x = torch.relu(x)
58
+ x = self.pre_out_norm_1(x)
59
+ x = self.drop(x)
60
+ x = self.pre_out_conv_2(x * x_mask)
61
+ x = torch.relu(x)
62
+ x = self.pre_out_norm_2(x)
63
+ x = self.drop(x)
64
+ x = x * x_mask
65
+ x = x.transpose(1, 2)
66
+ output_prob = self.output_layer(x)
67
+ return output_prob
68
+
69
+ def forward(self, x, x_mask, dur_r, dur_hat, g=None):
70
+ x = torch.detach(x)
71
+ if g is not None:
72
+ g = torch.detach(g)
73
+ x = x + self.cond(g)
74
+ x = self.conv_1(x * x_mask)
75
+ x = torch.relu(x)
76
+ x = self.norm_1(x)
77
+ x = self.drop(x)
78
+ x = self.conv_2(x * x_mask)
79
+ x = torch.relu(x)
80
+ x = self.norm_2(x)
81
+ x = self.drop(x)
82
+
83
+ output_probs = []
84
+ for dur in [dur_r, dur_hat]:
85
+ output_prob = self.forward_probability(x, x_mask, dur, g)
86
+ output_probs.append(output_prob)
87
+
88
+ return output_probs
89
+
90
+
91
+ class TransformerCouplingBlock(nn.Module):
92
+ def __init__(
93
+ self,
94
+ channels,
95
+ hidden_channels,
96
+ filter_channels,
97
+ n_heads,
98
+ n_layers,
99
+ kernel_size,
100
+ p_dropout,
101
+ n_flows=4,
102
+ gin_channels=0,
103
+ share_parameter=False,
104
+ ):
105
+ super().__init__()
106
+ self.channels = channels
107
+ self.hidden_channels = hidden_channels
108
+ self.kernel_size = kernel_size
109
+ self.n_layers = n_layers
110
+ self.n_flows = n_flows
111
+ self.gin_channels = gin_channels
112
+
113
+ self.flows = nn.ModuleList()
114
+
115
+ self.wn = (
116
+ attentions.FFT(
117
+ hidden_channels,
118
+ filter_channels,
119
+ n_heads,
120
+ n_layers,
121
+ kernel_size,
122
+ p_dropout,
123
+ isflow=True,
124
+ gin_channels=self.gin_channels,
125
+ )
126
+ if share_parameter
127
+ else None
128
+ )
129
+
130
+ for i in range(n_flows):
131
+ self.flows.append(
132
+ modules.TransformerCouplingLayer(
133
+ channels,
134
+ hidden_channels,
135
+ kernel_size,
136
+ n_layers,
137
+ n_heads,
138
+ p_dropout,
139
+ filter_channels,
140
+ mean_only=True,
141
+ wn_sharing_parameter=self.wn,
142
+ gin_channels=self.gin_channels,
143
+ )
144
+ )
145
+ self.flows.append(modules.Flip())
146
+
147
+ def forward(self, x, x_mask, g=None, reverse=False):
148
+ if not reverse:
149
+ for flow in self.flows:
150
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
151
+ else:
152
+ for flow in reversed(self.flows):
153
+ x = flow(x, x_mask, g=g, reverse=reverse)
154
+ return x
155
+
156
+
157
+ class StochasticDurationPredictor(nn.Module):
158
+ def __init__(
159
+ self,
160
+ in_channels,
161
+ filter_channels,
162
+ kernel_size,
163
+ p_dropout,
164
+ n_flows=4,
165
+ gin_channels=0,
166
+ ):
167
+ super().__init__()
168
+ filter_channels = in_channels # it needs to be removed from future version.
169
+ self.in_channels = in_channels
170
+ self.filter_channels = filter_channels
171
+ self.kernel_size = kernel_size
172
+ self.p_dropout = p_dropout
173
+ self.n_flows = n_flows
174
+ self.gin_channels = gin_channels
175
+
176
+ self.log_flow = modules.Log()
177
+ self.flows = nn.ModuleList()
178
+ self.flows.append(modules.ElementwiseAffine(2))
179
+ for i in range(n_flows):
180
+ self.flows.append(
181
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
182
+ )
183
+ self.flows.append(modules.Flip())
184
+
185
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
186
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
187
+ self.post_convs = modules.DDSConv(
188
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
189
+ )
190
+ self.post_flows = nn.ModuleList()
191
+ self.post_flows.append(modules.ElementwiseAffine(2))
192
+ for i in range(4):
193
+ self.post_flows.append(
194
+ modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
195
+ )
196
+ self.post_flows.append(modules.Flip())
197
+
198
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
199
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
200
+ self.convs = modules.DDSConv(
201
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
202
+ )
203
+ if gin_channels != 0:
204
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
205
+
206
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
207
+ x = torch.detach(x)
208
+ x = self.pre(x)
209
+ if g is not None:
210
+ g = torch.detach(g)
211
+ x = x + self.cond(g)
212
+ x = self.convs(x, x_mask)
213
+ x = self.proj(x) * x_mask
214
+
215
+ if not reverse:
216
+ flows = self.flows
217
+ assert w is not None
218
+
219
+ logdet_tot_q = 0
220
+ h_w = self.post_pre(w)
221
+ h_w = self.post_convs(h_w, x_mask)
222
+ h_w = self.post_proj(h_w) * x_mask
223
+ e_q = (
224
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
225
+ * x_mask
226
+ )
227
+ z_q = e_q
228
+ for flow in self.post_flows:
229
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
230
+ logdet_tot_q += logdet_q
231
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
232
+ u = torch.sigmoid(z_u) * x_mask
233
+ z0 = (w - u) * x_mask
234
+ logdet_tot_q += torch.sum(
235
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
236
+ )
237
+ logq = (
238
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
239
+ - logdet_tot_q
240
+ )
241
+
242
+ logdet_tot = 0
243
+ z0, logdet = self.log_flow(z0, x_mask)
244
+ logdet_tot += logdet
245
+ z = torch.cat([z0, z1], 1)
246
+ for flow in flows:
247
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
248
+ logdet_tot = logdet_tot + logdet
249
+ nll = (
250
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
251
+ - logdet_tot
252
+ )
253
+ return nll + logq # [b]
254
+ else:
255
+ flows = list(reversed(self.flows))
256
+ flows = flows[:-2] + [flows[-1]] # remove a useless vflow
257
+ z = (
258
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
259
+ * noise_scale
260
+ )
261
+ for flow in flows:
262
+ z = flow(z, x_mask, g=x, reverse=reverse)
263
+ z0, z1 = torch.split(z, [1, 1], 1)
264
+ logw = z0
265
+ return logw
266
+
267
+
268
+ class DurationPredictor(nn.Module):
269
+ def __init__(
270
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
271
+ ):
272
+ super().__init__()
273
+
274
+ self.in_channels = in_channels
275
+ self.filter_channels = filter_channels
276
+ self.kernel_size = kernel_size
277
+ self.p_dropout = p_dropout
278
+ self.gin_channels = gin_channels
279
+
280
+ self.drop = nn.Dropout(p_dropout)
281
+ self.conv_1 = nn.Conv1d(
282
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
283
+ )
284
+ self.norm_1 = modules.LayerNorm(filter_channels)
285
+ self.conv_2 = nn.Conv1d(
286
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
287
+ )
288
+ self.norm_2 = modules.LayerNorm(filter_channels)
289
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
290
+
291
+ if gin_channels != 0:
292
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
293
+
294
+ def forward(self, x, x_mask, g=None):
295
+ x = torch.detach(x)
296
+ if g is not None:
297
+ g = torch.detach(g)
298
+ x = x + self.cond(g)
299
+ x = self.conv_1(x * x_mask)
300
+ x = torch.relu(x)
301
+ x = self.norm_1(x)
302
+ x = self.drop(x)
303
+ x = self.conv_2(x * x_mask)
304
+ x = torch.relu(x)
305
+ x = self.norm_2(x)
306
+ x = self.drop(x)
307
+ x = self.proj(x * x_mask)
308
+ return x * x_mask
309
+
310
+
311
+ class TextEncoder(nn.Module):
312
+ def __init__(
313
+ self,
314
+ n_vocab,
315
+ out_channels,
316
+ hidden_channels,
317
+ filter_channels,
318
+ n_heads,
319
+ n_layers,
320
+ kernel_size,
321
+ p_dropout,
322
+ gin_channels=0,
323
+ num_languages=None,
324
+ num_tones=None,
325
+ ):
326
+ super().__init__()
327
+ if num_languages is None:
328
+ from src.text import num_languages
329
+ if num_tones is None:
330
+ from src.text import num_tones
331
+ self.n_vocab = n_vocab
332
+ self.out_channels = out_channels
333
+ self.hidden_channels = hidden_channels
334
+ self.filter_channels = filter_channels
335
+ self.n_heads = n_heads
336
+ self.n_layers = n_layers
337
+ self.kernel_size = kernel_size
338
+ self.p_dropout = p_dropout
339
+ self.gin_channels = gin_channels
340
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
341
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
342
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
343
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
344
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
345
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
346
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
347
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
348
+
349
+ self.encoder = attentions.Encoder(
350
+ hidden_channels,
351
+ filter_channels,
352
+ n_heads,
353
+ n_layers,
354
+ kernel_size,
355
+ p_dropout,
356
+ gin_channels=self.gin_channels,
357
+ )
358
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
359
+
360
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
361
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
362
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
363
+ x = (
364
+ self.emb(x)
365
+ + self.tone_emb(tone)
366
+ + self.language_emb(language)
367
+ + bert_emb
368
+ + ja_bert_emb
369
+ ) * math.sqrt(
370
+ self.hidden_channels
371
+ ) # [b, t, h]
372
+ x = torch.transpose(x, 1, -1) # [b, h, t]
373
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
374
+ x.dtype
375
+ )
376
+
377
+ x = self.encoder(x * x_mask, x_mask, g=g)
378
+ stats = self.proj(x) * x_mask
379
+
380
+ m, logs = torch.split(stats, self.out_channels, dim=1)
381
+ return x, m, logs, x_mask
382
+
383
+
384
+ class ResidualCouplingBlock(nn.Module):
385
+ def __init__(
386
+ self,
387
+ channels,
388
+ hidden_channels,
389
+ kernel_size,
390
+ dilation_rate,
391
+ n_layers,
392
+ n_flows=4,
393
+ gin_channels=0,
394
+ ):
395
+ super().__init__()
396
+ self.channels = channels
397
+ self.hidden_channels = hidden_channels
398
+ self.kernel_size = kernel_size
399
+ self.dilation_rate = dilation_rate
400
+ self.n_layers = n_layers
401
+ self.n_flows = n_flows
402
+ self.gin_channels = gin_channels
403
+
404
+ self.flows = nn.ModuleList()
405
+ for i in range(n_flows):
406
+ self.flows.append(
407
+ modules.ResidualCouplingLayer(
408
+ channels,
409
+ hidden_channels,
410
+ kernel_size,
411
+ dilation_rate,
412
+ n_layers,
413
+ gin_channels=gin_channels,
414
+ mean_only=True,
415
+ )
416
+ )
417
+ self.flows.append(modules.Flip())
418
+
419
+ def forward(self, x, x_mask, g=None, reverse=False):
420
+ if not reverse:
421
+ for flow in self.flows:
422
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
423
+ else:
424
+ for flow in reversed(self.flows):
425
+ x = flow(x, x_mask, g=g, reverse=reverse)
426
+ return x
427
+
428
+
429
+ class PosteriorEncoder(nn.Module):
430
+ def __init__(
431
+ self,
432
+ in_channels,
433
+ out_channels,
434
+ hidden_channels,
435
+ kernel_size,
436
+ dilation_rate,
437
+ n_layers,
438
+ gin_channels=0,
439
+ ):
440
+ super().__init__()
441
+ self.in_channels = in_channels
442
+ self.out_channels = out_channels
443
+ self.hidden_channels = hidden_channels
444
+ self.kernel_size = kernel_size
445
+ self.dilation_rate = dilation_rate
446
+ self.n_layers = n_layers
447
+ self.gin_channels = gin_channels
448
+
449
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
450
+ self.enc = modules.WN(
451
+ hidden_channels,
452
+ kernel_size,
453
+ dilation_rate,
454
+ n_layers,
455
+ gin_channels=gin_channels,
456
+ )
457
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
458
+
459
+ def forward(self, x, x_lengths, g=None, tau=1.0):
460
+ x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
461
+ x.dtype
462
+ )
463
+ x = self.pre(x) * x_mask
464
+ x = self.enc(x, x_mask, g=g)
465
+ stats = self.proj(x) * x_mask
466
+ m, logs = torch.split(stats, self.out_channels, dim=1)
467
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
468
+ return z, m, logs, x_mask
469
+
470
+
471
+ class Generator(torch.nn.Module):
472
+ def __init__(
473
+ self,
474
+ initial_channel,
475
+ resblock,
476
+ resblock_kernel_sizes,
477
+ resblock_dilation_sizes,
478
+ upsample_rates,
479
+ upsample_initial_channel,
480
+ upsample_kernel_sizes,
481
+ gin_channels=0,
482
+ ):
483
+ super(Generator, self).__init__()
484
+ self.num_kernels = len(resblock_kernel_sizes)
485
+ self.num_upsamples = len(upsample_rates)
486
+ self.conv_pre = Conv1d(
487
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
488
+ )
489
+ resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
490
+
491
+ self.ups = nn.ModuleList()
492
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
493
+ self.ups.append(
494
+ weight_norm(
495
+ ConvTranspose1d(
496
+ upsample_initial_channel // (2**i),
497
+ upsample_initial_channel // (2 ** (i + 1)),
498
+ k,
499
+ u,
500
+ padding=(k - u) // 2,
501
+ )
502
+ )
503
+ )
504
+
505
+ self.resblocks = nn.ModuleList()
506
+ for i in range(len(self.ups)):
507
+ ch = upsample_initial_channel // (2 ** (i + 1))
508
+ for j, (k, d) in enumerate(
509
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
510
+ ):
511
+ self.resblocks.append(resblock(ch, k, d))
512
+
513
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
514
+ self.ups.apply(init_weights)
515
+
516
+ if gin_channels != 0:
517
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
518
+
519
+ def forward(self, x, g=None):
520
+ x = self.conv_pre(x)
521
+ if g is not None:
522
+ x = x + self.cond(g)
523
+
524
+ for i in range(self.num_upsamples):
525
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
526
+ x = self.ups[i](x)
527
+ xs = None
528
+ for j in range(self.num_kernels):
529
+ if xs is None:
530
+ xs = self.resblocks[i * self.num_kernels + j](x)
531
+ else:
532
+ xs += self.resblocks[i * self.num_kernels + j](x)
533
+ x = xs / self.num_kernels
534
+ x = F.leaky_relu(x)
535
+ x = self.conv_post(x)
536
+ x = torch.tanh(x)
537
+
538
+ return x
539
+
540
+ def remove_weight_norm(self):
541
+ print("Removing weight norm...")
542
+ for layer in self.ups:
543
+ remove_weight_norm(layer)
544
+ for layer in self.resblocks:
545
+ layer.remove_weight_norm()
546
+
547
+
548
+ class DiscriminatorP(torch.nn.Module):
549
+ def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
550
+ super(DiscriminatorP, self).__init__()
551
+ self.period = period
552
+ self.use_spectral_norm = use_spectral_norm
553
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
554
+ self.convs = nn.ModuleList(
555
+ [
556
+ norm_f(
557
+ Conv2d(
558
+ 1,
559
+ 32,
560
+ (kernel_size, 1),
561
+ (stride, 1),
562
+ padding=(get_padding(kernel_size, 1), 0),
563
+ )
564
+ ),
565
+ norm_f(
566
+ Conv2d(
567
+ 32,
568
+ 128,
569
+ (kernel_size, 1),
570
+ (stride, 1),
571
+ padding=(get_padding(kernel_size, 1), 0),
572
+ )
573
+ ),
574
+ norm_f(
575
+ Conv2d(
576
+ 128,
577
+ 512,
578
+ (kernel_size, 1),
579
+ (stride, 1),
580
+ padding=(get_padding(kernel_size, 1), 0),
581
+ )
582
+ ),
583
+ norm_f(
584
+ Conv2d(
585
+ 512,
586
+ 1024,
587
+ (kernel_size, 1),
588
+ (stride, 1),
589
+ padding=(get_padding(kernel_size, 1), 0),
590
+ )
591
+ ),
592
+ norm_f(
593
+ Conv2d(
594
+ 1024,
595
+ 1024,
596
+ (kernel_size, 1),
597
+ 1,
598
+ padding=(get_padding(kernel_size, 1), 0),
599
+ )
600
+ ),
601
+ ]
602
+ )
603
+ self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
604
+
605
+ def forward(self, x):
606
+ fmap = []
607
+
608
+ # 1d to 2d
609
+ b, c, t = x.shape
610
+ if t % self.period != 0: # pad first
611
+ n_pad = self.period - (t % self.period)
612
+ x = F.pad(x, (0, n_pad), "reflect")
613
+ t = t + n_pad
614
+ x = x.view(b, c, t // self.period, self.period)
615
+
616
+ for layer in self.convs:
617
+ x = layer(x)
618
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
619
+ fmap.append(x)
620
+ x = self.conv_post(x)
621
+ fmap.append(x)
622
+ x = torch.flatten(x, 1, -1)
623
+
624
+ return x, fmap
625
+
626
+
627
+ class DiscriminatorS(torch.nn.Module):
628
+ def __init__(self, use_spectral_norm=False):
629
+ super(DiscriminatorS, self).__init__()
630
+ norm_f = weight_norm if use_spectral_norm is False else spectral_norm
631
+ self.convs = nn.ModuleList(
632
+ [
633
+ norm_f(Conv1d(1, 16, 15, 1, padding=7)),
634
+ norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
635
+ norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
636
+ norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
637
+ norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
638
+ norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
639
+ ]
640
+ )
641
+ self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
642
+
643
+ def forward(self, x):
644
+ fmap = []
645
+
646
+ for layer in self.convs:
647
+ x = layer(x)
648
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
649
+ fmap.append(x)
650
+ x = self.conv_post(x)
651
+ fmap.append(x)
652
+ x = torch.flatten(x, 1, -1)
653
+
654
+ return x, fmap
655
+
656
+
657
+ class MultiPeriodDiscriminator(torch.nn.Module):
658
+ def __init__(self, use_spectral_norm=False):
659
+ super(MultiPeriodDiscriminator, self).__init__()
660
+ periods = [2, 3, 5, 7, 11]
661
+
662
+ discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
663
+ discs = discs + [
664
+ DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
665
+ ]
666
+ self.discriminators = nn.ModuleList(discs)
667
+
668
+ def forward(self, y, y_hat):
669
+ y_d_rs = []
670
+ y_d_gs = []
671
+ fmap_rs = []
672
+ fmap_gs = []
673
+ for i, d in enumerate(self.discriminators):
674
+ y_d_r, fmap_r = d(y)
675
+ y_d_g, fmap_g = d(y_hat)
676
+ y_d_rs.append(y_d_r)
677
+ y_d_gs.append(y_d_g)
678
+ fmap_rs.append(fmap_r)
679
+ fmap_gs.append(fmap_g)
680
+
681
+ return y_d_rs, y_d_gs, fmap_rs, fmap_gs
682
+
683
+
684
+ class ReferenceEncoder(nn.Module):
685
+ """
686
+ inputs --- [N, Ty/r, n_mels*r] mels
687
+ outputs --- [N, ref_enc_gru_size]
688
+ """
689
+
690
+ def __init__(self, spec_channels, gin_channels=0, layernorm=False):
691
+ super().__init__()
692
+ self.spec_channels = spec_channels
693
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
694
+ K = len(ref_enc_filters)
695
+ filters = [1] + ref_enc_filters
696
+ convs = [
697
+ weight_norm(
698
+ nn.Conv2d(
699
+ in_channels=filters[i],
700
+ out_channels=filters[i + 1],
701
+ kernel_size=(3, 3),
702
+ stride=(2, 2),
703
+ padding=(1, 1),
704
+ )
705
+ )
706
+ for i in range(K)
707
+ ]
708
+ self.convs = nn.ModuleList(convs)
709
+ # self.wns = nn.ModuleList([weight_norm(num_features=ref_enc_filters[i]) for i in range(K)]) # noqa: E501
710
+
711
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
712
+ self.gru = nn.GRU(
713
+ input_size=ref_enc_filters[-1] * out_channels,
714
+ hidden_size=256 // 2,
715
+ batch_first=True,
716
+ )
717
+ self.proj = nn.Linear(128, gin_channels)
718
+ if layernorm:
719
+ self.layernorm = nn.LayerNorm(self.spec_channels)
720
+ print('[Ref Enc]: using layer norm')
721
+ else:
722
+ self.layernorm = None
723
+
724
+ def forward(self, inputs, mask=None):
725
+ N = inputs.size(0)
726
+
727
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
728
+ if self.layernorm is not None:
729
+ out = self.layernorm(out)
730
+
731
+ for conv in self.convs:
732
+ out = conv(out)
733
+ # out = wn(out)
734
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
735
+
736
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
737
+ T = out.size(1)
738
+ N = out.size(0)
739
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
740
+
741
+ self.gru.flatten_parameters()
742
+ memory, out = self.gru(out) # out --- [1, N, 128]
743
+
744
+ return self.proj(out.squeeze(0))
745
+
746
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
747
+ for i in range(n_convs):
748
+ L = (L - kernel_size + 2 * pad) // stride + 1
749
+ return L
750
+
751
+
752
+ class SynthesizerTrn(nn.Module):
753
+ """
754
+ Synthesizer for Training
755
+ """
756
+
757
+ def __init__(
758
+ self,
759
+ n_vocab,
760
+ spec_channels,
761
+ segment_size,
762
+ inter_channels,
763
+ hidden_channels,
764
+ filter_channels,
765
+ n_heads,
766
+ n_layers,
767
+ kernel_size,
768
+ p_dropout,
769
+ resblock,
770
+ resblock_kernel_sizes,
771
+ resblock_dilation_sizes,
772
+ upsample_rates,
773
+ upsample_initial_channel,
774
+ upsample_kernel_sizes,
775
+ n_speakers=256,
776
+ gin_channels=256,
777
+ use_sdp=True,
778
+ n_flow_layer=4,
779
+ n_layers_trans_flow=6,
780
+ flow_share_parameter=False,
781
+ use_transformer_flow=True,
782
+ use_vc=False,
783
+ num_languages=None,
784
+ num_tones=None,
785
+ norm_refenc=False,
786
+ **kwargs
787
+ ):
788
+ super().__init__()
789
+ self.n_vocab = n_vocab
790
+ self.spec_channels = spec_channels
791
+ self.inter_channels = inter_channels
792
+ self.hidden_channels = hidden_channels
793
+ self.filter_channels = filter_channels
794
+ self.n_heads = n_heads
795
+ self.n_layers = n_layers
796
+ self.kernel_size = kernel_size
797
+ self.p_dropout = p_dropout
798
+ self.resblock = resblock
799
+ self.resblock_kernel_sizes = resblock_kernel_sizes
800
+ self.resblock_dilation_sizes = resblock_dilation_sizes
801
+ self.upsample_rates = upsample_rates
802
+ self.upsample_initial_channel = upsample_initial_channel
803
+ self.upsample_kernel_sizes = upsample_kernel_sizes
804
+ self.segment_size = segment_size
805
+ self.n_speakers = n_speakers
806
+ self.gin_channels = gin_channels
807
+ self.n_layers_trans_flow = n_layers_trans_flow
808
+ self.use_spk_conditioned_encoder = kwargs.get(
809
+ "use_spk_conditioned_encoder", True
810
+ )
811
+ self.use_sdp = use_sdp
812
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
813
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
814
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
815
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
816
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
817
+ self.enc_gin_channels = gin_channels
818
+ else:
819
+ self.enc_gin_channels = 0
820
+ self.enc_p = TextEncoder(
821
+ n_vocab,
822
+ inter_channels,
823
+ hidden_channels,
824
+ filter_channels,
825
+ n_heads,
826
+ n_layers,
827
+ kernel_size,
828
+ p_dropout,
829
+ gin_channels=self.enc_gin_channels,
830
+ num_languages=num_languages,
831
+ num_tones=num_tones,
832
+ )
833
+ self.dec = Generator(
834
+ inter_channels,
835
+ resblock,
836
+ resblock_kernel_sizes,
837
+ resblock_dilation_sizes,
838
+ upsample_rates,
839
+ upsample_initial_channel,
840
+ upsample_kernel_sizes,
841
+ gin_channels=gin_channels,
842
+ )
843
+ self.enc_q = PosteriorEncoder(
844
+ spec_channels,
845
+ inter_channels,
846
+ hidden_channels,
847
+ 5,
848
+ 1,
849
+ 16,
850
+ gin_channels=gin_channels,
851
+ )
852
+ if use_transformer_flow:
853
+ self.flow = TransformerCouplingBlock(
854
+ inter_channels,
855
+ hidden_channels,
856
+ filter_channels,
857
+ n_heads,
858
+ n_layers_trans_flow,
859
+ 5,
860
+ p_dropout,
861
+ n_flow_layer,
862
+ gin_channels=gin_channels,
863
+ share_parameter=flow_share_parameter,
864
+ )
865
+ else:
866
+ self.flow = ResidualCouplingBlock(
867
+ inter_channels,
868
+ hidden_channels,
869
+ 5,
870
+ 1,
871
+ n_flow_layer,
872
+ gin_channels=gin_channels,
873
+ )
874
+ self.sdp = StochasticDurationPredictor(
875
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
876
+ )
877
+ self.dp = DurationPredictor(
878
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
879
+ )
880
+
881
+ if n_speakers > 0:
882
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
883
+ else:
884
+ self.ref_enc = ReferenceEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
885
+ self.use_vc = use_vc
886
+
887
+
888
+ def forward(self, x, x_lengths, y, y_lengths, sid, tone, language, bert, ja_bert):
889
+ if self.n_speakers > 0:
890
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
891
+ else:
892
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
893
+ if self.use_vc:
894
+ g_p = None
895
+ else:
896
+ g_p = g
897
+ x, m_p, logs_p, x_mask = self.enc_p(
898
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
899
+ )
900
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
901
+ z_p = self.flow(z, y_mask, g=g)
902
+
903
+ with torch.no_grad():
904
+ # negative cross-entropy
905
+ s_p_sq_r = torch.exp(-2 * logs_p) # [b, d, t]
906
+ neg_cent1 = torch.sum(
907
+ -0.5 * math.log(2 * math.pi) - logs_p, [1], keepdim=True
908
+ ) # [b, 1, t_s]
909
+ neg_cent2 = torch.matmul(
910
+ -0.5 * (z_p**2).transpose(1, 2), s_p_sq_r
911
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
912
+ neg_cent3 = torch.matmul(
913
+ z_p.transpose(1, 2), (m_p * s_p_sq_r)
914
+ ) # [b, t_t, d] x [b, d, t_s] = [b, t_t, t_s]
915
+ neg_cent4 = torch.sum(
916
+ -0.5 * (m_p**2) * s_p_sq_r, [1], keepdim=True
917
+ ) # [b, 1, t_s]
918
+ neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
919
+ if self.use_noise_scaled_mas:
920
+ epsilon = (
921
+ torch.std(neg_cent)
922
+ * torch.randn_like(neg_cent)
923
+ * self.current_mas_noise_scale
924
+ )
925
+ neg_cent = neg_cent + epsilon
926
+
927
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
928
+ attn = (
929
+ monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1))
930
+ .unsqueeze(1)
931
+ .detach()
932
+ )
933
+
934
+ w = attn.sum(2)
935
+
936
+ l_length_sdp = self.sdp(x, x_mask, w, g=g)
937
+ l_length_sdp = l_length_sdp / torch.sum(x_mask)
938
+
939
+ logw_ = torch.log(w + 1e-6) * x_mask
940
+ logw = self.dp(x, x_mask, g=g)
941
+ l_length_dp = torch.sum((logw - logw_) ** 2, [1, 2]) / torch.sum(
942
+ x_mask
943
+ ) # for averaging
944
+
945
+ l_length = l_length_dp + l_length_sdp
946
+
947
+ # expand prior
948
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(1, 2)
949
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(1, 2)
950
+
951
+ z_slice, ids_slice = commons.rand_slice_segments(
952
+ z, y_lengths, self.segment_size
953
+ )
954
+ o = self.dec(z_slice, g=g)
955
+ return (
956
+ o,
957
+ l_length,
958
+ attn,
959
+ ids_slice,
960
+ x_mask,
961
+ y_mask,
962
+ (z, z_p, m_p, logs_p, m_q, logs_q),
963
+ (x, logw, logw_),
964
+ )
965
+
966
+ def infer(
967
+ self,
968
+ x,
969
+ x_lengths,
970
+ sid,
971
+ tone,
972
+ language,
973
+ bert,
974
+ ja_bert,
975
+ noise_scale=0.667,
976
+ length_scale=1,
977
+ noise_scale_w=0.8,
978
+ max_len=None,
979
+ sdp_ratio=0,
980
+ y=None,
981
+ g=None,
982
+ ):
983
+ # x, m_p, logs_p, x_mask = self.enc_p(x, x_lengths, tone, language, bert)
984
+ # g = self.gst(y)
985
+ if g is None:
986
+ if self.n_speakers > 0:
987
+ g = self.emb_g(sid).unsqueeze(-1) # [b, h, 1]
988
+ else:
989
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
990
+ if self.use_vc:
991
+ g_p = None
992
+ else:
993
+ g_p = g
994
+ x, m_p, logs_p, x_mask = self.enc_p(
995
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
996
+ )
997
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
998
+ sdp_ratio
999
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
1000
+ w = torch.exp(logw) * x_mask * length_scale
1001
+
1002
+ w_ceil = torch.ceil(w)
1003
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
1004
+ y_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, None), 1).to(
1005
+ x_mask.dtype
1006
+ )
1007
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
1008
+ attn = commons.generate_path(w_ceil, attn_mask)
1009
+
1010
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
1011
+ 1, 2
1012
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1013
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
1014
+ 1, 2
1015
+ ) # [b, t', t], [b, t, d] -> [b, d, t']
1016
+
1017
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
1018
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
1019
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
1020
+ # print('max/min of o:', o.max(), o.min())
1021
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
1022
+
1023
+ def voice_conversion(self, y, y_lengths, sid_src, sid_tgt, tau=1.0):
1024
+ g_src = sid_src
1025
+ g_tgt = sid_tgt
1026
+ z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g_src, tau=tau)
1027
+ z_p = self.flow(z, y_mask, g=g_src)
1028
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
1029
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt)
1030
+ return o_hat, y_mask, (z, z_p, z_hat)
src/nn/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Neural network components package
3
+ """
4
+
5
+ from .commons import *
6
+ from .attentions import *
7
+ from .modules import *
8
+ from .transforms import *
src/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (288 Bytes). View file
 
src/nn/__pycache__/attentions.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
src/nn/__pycache__/commons.cpython-310.pyc ADDED
Binary file (5.71 kB). View file
 
src/nn/__pycache__/modules.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
src/nn/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.91 kB). View file
 
src/nn/attentions.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class LayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class Encoder(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+
59
+ self.cond_layer_idx = self.n_layers
60
+ if "gin_channels" in kwargs:
61
+ self.gin_channels = kwargs["gin_channels"]
62
+ if self.gin_channels != 0:
63
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
+ self.cond_layer_idx = (
65
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
+ )
67
+ assert (
68
+ self.cond_layer_idx < self.n_layers
69
+ ), "cond_layer_idx should be less than n_layers"
70
+ self.drop = nn.Dropout(p_dropout)
71
+ self.attn_layers = nn.ModuleList()
72
+ self.norm_layers_1 = nn.ModuleList()
73
+ self.ffn_layers = nn.ModuleList()
74
+ self.norm_layers_2 = nn.ModuleList()
75
+
76
+ for i in range(self.n_layers):
77
+ self.attn_layers.append(
78
+ MultiHeadAttention(
79
+ hidden_channels,
80
+ hidden_channels,
81
+ n_heads,
82
+ p_dropout=p_dropout,
83
+ window_size=window_size,
84
+ )
85
+ )
86
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
87
+ self.ffn_layers.append(
88
+ FFN(
89
+ hidden_channels,
90
+ hidden_channels,
91
+ filter_channels,
92
+ kernel_size,
93
+ p_dropout=p_dropout,
94
+ )
95
+ )
96
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
97
+
98
+ def forward(self, x, x_mask, g=None):
99
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
+ x = x * x_mask
101
+ for i in range(self.n_layers):
102
+ if i == self.cond_layer_idx and g is not None:
103
+ g = self.spk_emb_linear(g.transpose(1, 2))
104
+ g = g.transpose(1, 2)
105
+ x = x + g
106
+ x = x * x_mask
107
+ y = self.attn_layers[i](x, x, attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class Decoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_channels,
122
+ filter_channels,
123
+ n_heads,
124
+ n_layers,
125
+ kernel_size=1,
126
+ p_dropout=0.0,
127
+ proximal_bias=False,
128
+ proximal_init=True,
129
+ **kwargs
130
+ ):
131
+ super().__init__()
132
+ self.hidden_channels = hidden_channels
133
+ self.filter_channels = filter_channels
134
+ self.n_heads = n_heads
135
+ self.n_layers = n_layers
136
+ self.kernel_size = kernel_size
137
+ self.p_dropout = p_dropout
138
+ self.proximal_bias = proximal_bias
139
+ self.proximal_init = proximal_init
140
+
141
+ self.drop = nn.Dropout(p_dropout)
142
+ self.self_attn_layers = nn.ModuleList()
143
+ self.norm_layers_0 = nn.ModuleList()
144
+ self.encdec_attn_layers = nn.ModuleList()
145
+ self.norm_layers_1 = nn.ModuleList()
146
+ self.ffn_layers = nn.ModuleList()
147
+ self.norm_layers_2 = nn.ModuleList()
148
+ for i in range(self.n_layers):
149
+ self.self_attn_layers.append(
150
+ MultiHeadAttention(
151
+ hidden_channels,
152
+ hidden_channels,
153
+ n_heads,
154
+ p_dropout=p_dropout,
155
+ proximal_bias=proximal_bias,
156
+ proximal_init=proximal_init,
157
+ )
158
+ )
159
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
160
+ self.encdec_attn_layers.append(
161
+ MultiHeadAttention(
162
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
+ )
164
+ )
165
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
166
+ self.ffn_layers.append(
167
+ FFN(
168
+ hidden_channels,
169
+ hidden_channels,
170
+ filter_channels,
171
+ kernel_size,
172
+ p_dropout=p_dropout,
173
+ causal=True,
174
+ )
175
+ )
176
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
177
+
178
+ def forward(self, x, x_mask, h, h_mask):
179
+ """
180
+ x: decoder input
181
+ h: encoder output
182
+ """
183
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
184
+ device=x.device, dtype=x.dtype
185
+ )
186
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
187
+ x = x * x_mask
188
+ for i in range(self.n_layers):
189
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
190
+ y = self.drop(y)
191
+ x = self.norm_layers_0[i](x + y)
192
+
193
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
194
+ y = self.drop(y)
195
+ x = self.norm_layers_1[i](x + y)
196
+
197
+ y = self.ffn_layers[i](x, x_mask)
198
+ y = self.drop(y)
199
+ x = self.norm_layers_2[i](x + y)
200
+ x = x * x_mask
201
+ return x
202
+
203
+
204
+ class MultiHeadAttention(nn.Module):
205
+ def __init__(
206
+ self,
207
+ channels,
208
+ out_channels,
209
+ n_heads,
210
+ p_dropout=0.0,
211
+ window_size=None,
212
+ heads_share=True,
213
+ block_length=None,
214
+ proximal_bias=False,
215
+ proximal_init=False,
216
+ ):
217
+ super().__init__()
218
+ assert channels % n_heads == 0
219
+
220
+ self.channels = channels
221
+ self.out_channels = out_channels
222
+ self.n_heads = n_heads
223
+ self.p_dropout = p_dropout
224
+ self.window_size = window_size
225
+ self.heads_share = heads_share
226
+ self.block_length = block_length
227
+ self.proximal_bias = proximal_bias
228
+ self.proximal_init = proximal_init
229
+ self.attn = None
230
+
231
+ self.k_channels = channels // n_heads
232
+ self.conv_q = nn.Conv1d(channels, channels, 1)
233
+ self.conv_k = nn.Conv1d(channels, channels, 1)
234
+ self.conv_v = nn.Conv1d(channels, channels, 1)
235
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
236
+ self.drop = nn.Dropout(p_dropout)
237
+
238
+ if window_size is not None:
239
+ n_heads_rel = 1 if heads_share else n_heads
240
+ rel_stddev = self.k_channels**-0.5
241
+ self.emb_rel_k = nn.Parameter(
242
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
+ * rel_stddev
244
+ )
245
+ self.emb_rel_v = nn.Parameter(
246
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
247
+ * rel_stddev
248
+ )
249
+
250
+ nn.init.xavier_uniform_(self.conv_q.weight)
251
+ nn.init.xavier_uniform_(self.conv_k.weight)
252
+ nn.init.xavier_uniform_(self.conv_v.weight)
253
+ if proximal_init:
254
+ with torch.no_grad():
255
+ self.conv_k.weight.copy_(self.conv_q.weight)
256
+ self.conv_k.bias.copy_(self.conv_q.bias)
257
+
258
+ def forward(self, x, c, attn_mask=None):
259
+ q = self.conv_q(x)
260
+ k = self.conv_k(c)
261
+ v = self.conv_v(c)
262
+
263
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
264
+
265
+ x = self.conv_o(x)
266
+ return x
267
+
268
+ def attention(self, query, key, value, mask=None):
269
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
270
+ b, d, t_s, t_t = (*key.size(), query.size(2))
271
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
272
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
273
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
274
+
275
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
276
+ if self.window_size is not None:
277
+ assert (
278
+ t_s == t_t
279
+ ), "Relative attention is only available for self-attention."
280
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
281
+ rel_logits = self._matmul_with_relative_keys(
282
+ query / math.sqrt(self.k_channels), key_relative_embeddings
283
+ )
284
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
285
+ scores = scores + scores_local
286
+ if self.proximal_bias:
287
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
288
+ scores = scores + self._attention_bias_proximal(t_s).to(
289
+ device=scores.device, dtype=scores.dtype
290
+ )
291
+ if mask is not None:
292
+ scores = scores.masked_fill(mask == 0, -1e4)
293
+ if self.block_length is not None:
294
+ assert (
295
+ t_s == t_t
296
+ ), "Local attention is only available for self-attention."
297
+ block_mask = (
298
+ torch.ones_like(scores)
299
+ .triu(-self.block_length)
300
+ .tril(self.block_length)
301
+ )
302
+ scores = scores.masked_fill(block_mask == 0, -1e4)
303
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
304
+ p_attn = self.drop(p_attn)
305
+ output = torch.matmul(p_attn, value)
306
+ if self.window_size is not None:
307
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
308
+ value_relative_embeddings = self._get_relative_embeddings(
309
+ self.emb_rel_v, t_s
310
+ )
311
+ output = output + self._matmul_with_relative_values(
312
+ relative_weights, value_relative_embeddings
313
+ )
314
+ output = (
315
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
316
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
317
+ return output, p_attn
318
+
319
+ def _matmul_with_relative_values(self, x, y):
320
+ """
321
+ x: [b, h, l, m]
322
+ y: [h or 1, m, d]
323
+ ret: [b, h, l, d]
324
+ """
325
+ ret = torch.matmul(x, y.unsqueeze(0))
326
+ return ret
327
+
328
+ def _matmul_with_relative_keys(self, x, y):
329
+ """
330
+ x: [b, h, l, d]
331
+ y: [h or 1, m, d]
332
+ ret: [b, h, l, m]
333
+ """
334
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
335
+ return ret
336
+
337
+ def _get_relative_embeddings(self, relative_embeddings, length):
338
+ 2 * self.window_size + 1
339
+ # Pad first before slice to avoid using cond ops.
340
+ pad_length = max(length - (self.window_size + 1), 0)
341
+ slice_start_position = max((self.window_size + 1) - length, 0)
342
+ slice_end_position = slice_start_position + 2 * length - 1
343
+ if pad_length > 0:
344
+ padded_relative_embeddings = F.pad(
345
+ relative_embeddings,
346
+ commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
347
+ )
348
+ else:
349
+ padded_relative_embeddings = relative_embeddings
350
+ used_relative_embeddings = padded_relative_embeddings[
351
+ :, slice_start_position:slice_end_position
352
+ ]
353
+ return used_relative_embeddings
354
+
355
+ def _relative_position_to_absolute_position(self, x):
356
+ """
357
+ x: [b, h, l, 2*l-1]
358
+ ret: [b, h, l, l]
359
+ """
360
+ batch, heads, length, _ = x.size()
361
+ # Concat columns of pad to shift from relative to absolute indexing.
362
+ x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
363
+
364
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
365
+ x_flat = x.view([batch, heads, length * 2 * length])
366
+ x_flat = F.pad(
367
+ x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
368
+ )
369
+
370
+ # Reshape and slice out the padded elements.
371
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
372
+ :, :, :length, length - 1 :
373
+ ]
374
+ return x_final
375
+
376
+ def _absolute_position_to_relative_position(self, x):
377
+ """
378
+ x: [b, h, l, l]
379
+ ret: [b, h, l, 2*l-1]
380
+ """
381
+ batch, heads, length, _ = x.size()
382
+ # pad along column
383
+ x = F.pad(
384
+ x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
385
+ )
386
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
387
+ # add 0's in the beginning that will skew the elements after reshape
388
+ x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
389
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
390
+ return x_final
391
+
392
+ def _attention_bias_proximal(self, length):
393
+ """Bias for self-attention to encourage attention to close positions.
394
+ Args:
395
+ length: an integer scalar.
396
+ Returns:
397
+ a Tensor with shape [1, 1, length, length]
398
+ """
399
+ r = torch.arange(length, dtype=torch.float32)
400
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
401
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
402
+
403
+
404
+ class FFN(nn.Module):
405
+ def __init__(
406
+ self,
407
+ in_channels,
408
+ out_channels,
409
+ filter_channels,
410
+ kernel_size,
411
+ p_dropout=0.0,
412
+ activation=None,
413
+ causal=False,
414
+ ):
415
+ super().__init__()
416
+ self.in_channels = in_channels
417
+ self.out_channels = out_channels
418
+ self.filter_channels = filter_channels
419
+ self.kernel_size = kernel_size
420
+ self.p_dropout = p_dropout
421
+ self.activation = activation
422
+ self.causal = causal
423
+
424
+ if causal:
425
+ self.padding = self._causal_padding
426
+ else:
427
+ self.padding = self._same_padding
428
+
429
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
430
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
431
+ self.drop = nn.Dropout(p_dropout)
432
+
433
+ def forward(self, x, x_mask):
434
+ x = self.conv_1(self.padding(x * x_mask))
435
+ if self.activation == "gelu":
436
+ x = x * torch.sigmoid(1.702 * x)
437
+ else:
438
+ x = torch.relu(x)
439
+ x = self.drop(x)
440
+ x = self.conv_2(self.padding(x * x_mask))
441
+ return x * x_mask
442
+
443
+ def _causal_padding(self, x):
444
+ if self.kernel_size == 1:
445
+ return x
446
+ pad_l = self.kernel_size - 1
447
+ pad_r = 0
448
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
449
+ x = F.pad(x, commons.convert_pad_shape(padding))
450
+ return x
451
+
452
+ def _same_padding(self, x):
453
+ if self.kernel_size == 1:
454
+ return x
455
+ pad_l = (self.kernel_size - 1) // 2
456
+ pad_r = self.kernel_size // 2
457
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
458
+ x = F.pad(x, commons.convert_pad_shape(padding))
459
+ return x
src/nn/commons.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def init_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def get_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def convert_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def intersperse(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ """KL(P||Q)"""
30
+ kl = (logs_q - logs_p) - 0.5
31
+ kl += (
32
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
33
+ )
34
+ return kl
35
+
36
+
37
+ def rand_gumbel(shape):
38
+ """Sample from the Gumbel distribution, protect from overflows."""
39
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
40
+ return -torch.log(-torch.log(uniform_samples))
41
+
42
+
43
+ def rand_gumbel_like(x):
44
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
45
+ return g
46
+
47
+
48
+ def slice_segments(x, ids_str, segment_size=4):
49
+ ret = torch.zeros_like(x[:, :, :segment_size])
50
+ for i in range(x.size(0)):
51
+ idx_str = ids_str[i]
52
+ idx_end = idx_str + segment_size
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ return ret
55
+
56
+
57
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
58
+ b, d, t = x.size()
59
+ if x_lengths is None:
60
+ x_lengths = t
61
+ ids_str_max = x_lengths - segment_size + 1
62
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
63
+ ret = slice_segments(x, ids_str, segment_size)
64
+ return ret, ids_str
65
+
66
+
67
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
68
+ position = torch.arange(length, dtype=torch.float)
69
+ num_timescales = channels // 2
70
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
71
+ num_timescales - 1
72
+ )
73
+ inv_timescales = min_timescale * torch.exp(
74
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
75
+ )
76
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
77
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
78
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
79
+ signal = signal.view(1, channels, length)
80
+ return signal
81
+
82
+
83
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
84
+ b, channels, length = x.size()
85
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
86
+ return x + signal.to(dtype=x.dtype, device=x.device)
87
+
88
+
89
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
90
+ b, channels, length = x.size()
91
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
92
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
93
+
94
+
95
+ def subsequent_mask(length):
96
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
97
+ return mask
98
+
99
+
100
+ @torch.jit.script
101
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
102
+ n_channels_int = n_channels[0]
103
+ in_act = input_a + input_b
104
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
105
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
106
+ acts = t_act * s_act
107
+ return acts
108
+
109
+
110
+ def convert_pad_shape(pad_shape):
111
+ layer = pad_shape[::-1]
112
+ pad_shape = [item for sublist in layer for item in sublist]
113
+ return pad_shape
114
+
115
+
116
+ def shift_1d(x):
117
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
118
+ return x
119
+
120
+
121
+ def sequence_mask(length, max_length=None):
122
+ if max_length is None:
123
+ max_length = length.max()
124
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
125
+ return x.unsqueeze(0) < length.unsqueeze(1)
126
+
127
+
128
+ def generate_path(duration, mask):
129
+ """
130
+ duration: [b, 1, t_x]
131
+ mask: [b, 1, t_y, t_x]
132
+ """
133
+
134
+ b, _, t_y, t_x = mask.shape
135
+ cum_duration = torch.cumsum(duration, -1)
136
+
137
+ cum_duration_flat = cum_duration.view(b * t_x)
138
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
139
+ path = path.view(b, t_x, t_y)
140
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
141
+ path = path.unsqueeze(1).transpose(2, 3) * mask
142
+ return path
143
+
144
+
145
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
146
+ if isinstance(parameters, torch.Tensor):
147
+ parameters = [parameters]
148
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
149
+ norm_type = float(norm_type)
150
+ if clip_value is not None:
151
+ clip_value = float(clip_value)
152
+
153
+ total_norm = 0
154
+ for p in parameters:
155
+ param_norm = p.grad.data.norm(norm_type)
156
+ total_norm += param_norm.item() ** norm_type
157
+ if clip_value is not None:
158
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
159
+ total_norm = total_norm ** (1.0 / norm_type)
160
+ return total_norm
src/nn/modules.py ADDED
@@ -0,0 +1,598 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from . import commons
10
+ from .commons import init_weights, get_padding
11
+ from .transforms import piecewise_rational_quadratic_transform
12
+ from .attentions import Encoder
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class LayerNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(LayerNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(LayerNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DDSConv(nn.Module):
85
+ """
86
+ Dialted and Depth-Separable Convolution
87
+ """
88
+
89
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
90
+ super().__init__()
91
+ self.channels = channels
92
+ self.kernel_size = kernel_size
93
+ self.n_layers = n_layers
94
+ self.p_dropout = p_dropout
95
+
96
+ self.drop = nn.Dropout(p_dropout)
97
+ self.convs_sep = nn.ModuleList()
98
+ self.convs_1x1 = nn.ModuleList()
99
+ self.norms_1 = nn.ModuleList()
100
+ self.norms_2 = nn.ModuleList()
101
+ for i in range(n_layers):
102
+ dilation = kernel_size**i
103
+ padding = (kernel_size * dilation - dilation) // 2
104
+ self.convs_sep.append(
105
+ nn.Conv1d(
106
+ channels,
107
+ channels,
108
+ kernel_size,
109
+ groups=channels,
110
+ dilation=dilation,
111
+ padding=padding,
112
+ )
113
+ )
114
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
115
+ self.norms_1.append(LayerNorm(channels))
116
+ self.norms_2.append(LayerNorm(channels))
117
+
118
+ def forward(self, x, x_mask, g=None):
119
+ if g is not None:
120
+ x = x + g
121
+ for i in range(self.n_layers):
122
+ y = self.convs_sep[i](x * x_mask)
123
+ y = self.norms_1[i](y)
124
+ y = F.gelu(y)
125
+ y = self.convs_1x1[i](y)
126
+ y = self.norms_2[i](y)
127
+ y = F.gelu(y)
128
+ y = self.drop(y)
129
+ x = x + y
130
+ return x * x_mask
131
+
132
+
133
+ class WN(torch.nn.Module):
134
+ def __init__(
135
+ self,
136
+ hidden_channels,
137
+ kernel_size,
138
+ dilation_rate,
139
+ n_layers,
140
+ gin_channels=0,
141
+ p_dropout=0,
142
+ ):
143
+ super(WN, self).__init__()
144
+ assert kernel_size % 2 == 1
145
+ self.hidden_channels = hidden_channels
146
+ self.kernel_size = (kernel_size,)
147
+ self.dilation_rate = dilation_rate
148
+ self.n_layers = n_layers
149
+ self.gin_channels = gin_channels
150
+ self.p_dropout = p_dropout
151
+
152
+ self.in_layers = torch.nn.ModuleList()
153
+ self.res_skip_layers = torch.nn.ModuleList()
154
+ self.drop = nn.Dropout(p_dropout)
155
+
156
+ if gin_channels != 0:
157
+ cond_layer = torch.nn.Conv1d(
158
+ gin_channels, 2 * hidden_channels * n_layers, 1
159
+ )
160
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
161
+
162
+ for i in range(n_layers):
163
+ dilation = dilation_rate**i
164
+ padding = int((kernel_size * dilation - dilation) / 2)
165
+ in_layer = torch.nn.Conv1d(
166
+ hidden_channels,
167
+ 2 * hidden_channels,
168
+ kernel_size,
169
+ dilation=dilation,
170
+ padding=padding,
171
+ )
172
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
173
+ self.in_layers.append(in_layer)
174
+
175
+ # last one is not necessary
176
+ if i < n_layers - 1:
177
+ res_skip_channels = 2 * hidden_channels
178
+ else:
179
+ res_skip_channels = hidden_channels
180
+
181
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
182
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
183
+ self.res_skip_layers.append(res_skip_layer)
184
+
185
+ def forward(self, x, x_mask, g=None, **kwargs):
186
+ output = torch.zeros_like(x)
187
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
188
+
189
+ if g is not None:
190
+ g = self.cond_layer(g)
191
+
192
+ for i in range(self.n_layers):
193
+ x_in = self.in_layers[i](x)
194
+ if g is not None:
195
+ cond_offset = i * 2 * self.hidden_channels
196
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
197
+ else:
198
+ g_l = torch.zeros_like(x_in)
199
+
200
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
201
+ acts = self.drop(acts)
202
+
203
+ res_skip_acts = self.res_skip_layers[i](acts)
204
+ if i < self.n_layers - 1:
205
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
206
+ x = (x + res_acts) * x_mask
207
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
208
+ else:
209
+ output = output + res_skip_acts
210
+ return output * x_mask
211
+
212
+ def remove_weight_norm(self):
213
+ if self.gin_channels != 0:
214
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
215
+ for l in self.in_layers:
216
+ torch.nn.utils.remove_weight_norm(l)
217
+ for l in self.res_skip_layers:
218
+ torch.nn.utils.remove_weight_norm(l)
219
+
220
+
221
+ class ResBlock1(torch.nn.Module):
222
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
223
+ super(ResBlock1, self).__init__()
224
+ self.convs1 = nn.ModuleList(
225
+ [
226
+ weight_norm(
227
+ Conv1d(
228
+ channels,
229
+ channels,
230
+ kernel_size,
231
+ 1,
232
+ dilation=dilation[0],
233
+ padding=get_padding(kernel_size, dilation[0]),
234
+ )
235
+ ),
236
+ weight_norm(
237
+ Conv1d(
238
+ channels,
239
+ channels,
240
+ kernel_size,
241
+ 1,
242
+ dilation=dilation[1],
243
+ padding=get_padding(kernel_size, dilation[1]),
244
+ )
245
+ ),
246
+ weight_norm(
247
+ Conv1d(
248
+ channels,
249
+ channels,
250
+ kernel_size,
251
+ 1,
252
+ dilation=dilation[2],
253
+ padding=get_padding(kernel_size, dilation[2]),
254
+ )
255
+ ),
256
+ ]
257
+ )
258
+ self.convs1.apply(init_weights)
259
+
260
+ self.convs2 = nn.ModuleList(
261
+ [
262
+ weight_norm(
263
+ Conv1d(
264
+ channels,
265
+ channels,
266
+ kernel_size,
267
+ 1,
268
+ dilation=1,
269
+ padding=get_padding(kernel_size, 1),
270
+ )
271
+ ),
272
+ weight_norm(
273
+ Conv1d(
274
+ channels,
275
+ channels,
276
+ kernel_size,
277
+ 1,
278
+ dilation=1,
279
+ padding=get_padding(kernel_size, 1),
280
+ )
281
+ ),
282
+ weight_norm(
283
+ Conv1d(
284
+ channels,
285
+ channels,
286
+ kernel_size,
287
+ 1,
288
+ dilation=1,
289
+ padding=get_padding(kernel_size, 1),
290
+ )
291
+ ),
292
+ ]
293
+ )
294
+ self.convs2.apply(init_weights)
295
+
296
+ def forward(self, x, x_mask=None):
297
+ for c1, c2 in zip(self.convs1, self.convs2):
298
+ xt = F.leaky_relu(x, LRELU_SLOPE)
299
+ if x_mask is not None:
300
+ xt = xt * x_mask
301
+ xt = c1(xt)
302
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
303
+ if x_mask is not None:
304
+ xt = xt * x_mask
305
+ xt = c2(xt)
306
+ x = xt + x
307
+ if x_mask is not None:
308
+ x = x * x_mask
309
+ return x
310
+
311
+ def remove_weight_norm(self):
312
+ for l in self.convs1:
313
+ remove_weight_norm(l)
314
+ for l in self.convs2:
315
+ remove_weight_norm(l)
316
+
317
+
318
+ class ResBlock2(torch.nn.Module):
319
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
320
+ super(ResBlock2, self).__init__()
321
+ self.convs = nn.ModuleList(
322
+ [
323
+ weight_norm(
324
+ Conv1d(
325
+ channels,
326
+ channels,
327
+ kernel_size,
328
+ 1,
329
+ dilation=dilation[0],
330
+ padding=get_padding(kernel_size, dilation[0]),
331
+ )
332
+ ),
333
+ weight_norm(
334
+ Conv1d(
335
+ channels,
336
+ channels,
337
+ kernel_size,
338
+ 1,
339
+ dilation=dilation[1],
340
+ padding=get_padding(kernel_size, dilation[1]),
341
+ )
342
+ ),
343
+ ]
344
+ )
345
+ self.convs.apply(init_weights)
346
+
347
+ def forward(self, x, x_mask=None):
348
+ for c in self.convs:
349
+ xt = F.leaky_relu(x, LRELU_SLOPE)
350
+ if x_mask is not None:
351
+ xt = xt * x_mask
352
+ xt = c(xt)
353
+ x = xt + x
354
+ if x_mask is not None:
355
+ x = x * x_mask
356
+ return x
357
+
358
+ def remove_weight_norm(self):
359
+ for l in self.convs:
360
+ remove_weight_norm(l)
361
+
362
+
363
+ class Log(nn.Module):
364
+ def forward(self, x, x_mask, reverse=False, **kwargs):
365
+ if not reverse:
366
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
367
+ logdet = torch.sum(-y, [1, 2])
368
+ return y, logdet
369
+ else:
370
+ x = torch.exp(x) * x_mask
371
+ return x
372
+
373
+
374
+ class Flip(nn.Module):
375
+ def forward(self, x, *args, reverse=False, **kwargs):
376
+ x = torch.flip(x, [1])
377
+ if not reverse:
378
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
379
+ return x, logdet
380
+ else:
381
+ return x
382
+
383
+
384
+ class ElementwiseAffine(nn.Module):
385
+ def __init__(self, channels):
386
+ super().__init__()
387
+ self.channels = channels
388
+ self.m = nn.Parameter(torch.zeros(channels, 1))
389
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
390
+
391
+ def forward(self, x, x_mask, reverse=False, **kwargs):
392
+ if not reverse:
393
+ y = self.m + torch.exp(self.logs) * x
394
+ y = y * x_mask
395
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
396
+ return y, logdet
397
+ else:
398
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
399
+ return x
400
+
401
+
402
+ class ResidualCouplingLayer(nn.Module):
403
+ def __init__(
404
+ self,
405
+ channels,
406
+ hidden_channels,
407
+ kernel_size,
408
+ dilation_rate,
409
+ n_layers,
410
+ p_dropout=0,
411
+ gin_channels=0,
412
+ mean_only=False,
413
+ ):
414
+ assert channels % 2 == 0, "channels should be divisible by 2"
415
+ super().__init__()
416
+ self.channels = channels
417
+ self.hidden_channels = hidden_channels
418
+ self.kernel_size = kernel_size
419
+ self.dilation_rate = dilation_rate
420
+ self.n_layers = n_layers
421
+ self.half_channels = channels // 2
422
+ self.mean_only = mean_only
423
+
424
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
425
+ self.enc = WN(
426
+ hidden_channels,
427
+ kernel_size,
428
+ dilation_rate,
429
+ n_layers,
430
+ p_dropout=p_dropout,
431
+ gin_channels=gin_channels,
432
+ )
433
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
434
+ self.post.weight.data.zero_()
435
+ self.post.bias.data.zero_()
436
+
437
+ def forward(self, x, x_mask, g=None, reverse=False):
438
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
439
+ h = self.pre(x0) * x_mask
440
+ h = self.enc(h, x_mask, g=g)
441
+ stats = self.post(h) * x_mask
442
+ if not self.mean_only:
443
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
444
+ else:
445
+ m = stats
446
+ logs = torch.zeros_like(m)
447
+
448
+ if not reverse:
449
+ x1 = m + x1 * torch.exp(logs) * x_mask
450
+ x = torch.cat([x0, x1], 1)
451
+ logdet = torch.sum(logs, [1, 2])
452
+ return x, logdet
453
+ else:
454
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
455
+ x = torch.cat([x0, x1], 1)
456
+ return x
457
+
458
+
459
+ class ConvFlow(nn.Module):
460
+ def __init__(
461
+ self,
462
+ in_channels,
463
+ filter_channels,
464
+ kernel_size,
465
+ n_layers,
466
+ num_bins=10,
467
+ tail_bound=5.0,
468
+ ):
469
+ super().__init__()
470
+ self.in_channels = in_channels
471
+ self.filter_channels = filter_channels
472
+ self.kernel_size = kernel_size
473
+ self.n_layers = n_layers
474
+ self.num_bins = num_bins
475
+ self.tail_bound = tail_bound
476
+ self.half_channels = in_channels // 2
477
+
478
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
479
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
480
+ self.proj = nn.Conv1d(
481
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
482
+ )
483
+ self.proj.weight.data.zero_()
484
+ self.proj.bias.data.zero_()
485
+
486
+ def forward(self, x, x_mask, g=None, reverse=False):
487
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
488
+ h = self.pre(x0)
489
+ h = self.convs(h, x_mask, g=g)
490
+ h = self.proj(h) * x_mask
491
+
492
+ b, c, t = x0.shape
493
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
494
+
495
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
496
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
497
+ self.filter_channels
498
+ )
499
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
500
+
501
+ x1, logabsdet = piecewise_rational_quadratic_transform(
502
+ x1,
503
+ unnormalized_widths,
504
+ unnormalized_heights,
505
+ unnormalized_derivatives,
506
+ inverse=reverse,
507
+ tails="linear",
508
+ tail_bound=self.tail_bound,
509
+ )
510
+
511
+ x = torch.cat([x0, x1], 1) * x_mask
512
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
513
+ if not reverse:
514
+ return x, logdet
515
+ else:
516
+ return x
517
+
518
+
519
+ class TransformerCouplingLayer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ channels,
523
+ hidden_channels,
524
+ kernel_size,
525
+ n_layers,
526
+ n_heads,
527
+ p_dropout=0,
528
+ filter_channels=0,
529
+ mean_only=False,
530
+ wn_sharing_parameter=None,
531
+ gin_channels=0,
532
+ ):
533
+ assert n_layers == 3, n_layers
534
+ assert channels % 2 == 0, "channels should be divisible by 2"
535
+ super().__init__()
536
+ self.channels = channels
537
+ self.hidden_channels = hidden_channels
538
+ self.kernel_size = kernel_size
539
+ self.n_layers = n_layers
540
+ self.half_channels = channels // 2
541
+ self.mean_only = mean_only
542
+
543
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
544
+ self.enc = (
545
+ Encoder(
546
+ hidden_channels,
547
+ filter_channels,
548
+ n_heads,
549
+ n_layers,
550
+ kernel_size,
551
+ p_dropout,
552
+ isflow=True,
553
+ gin_channels=gin_channels,
554
+ )
555
+ if wn_sharing_parameter is None
556
+ else wn_sharing_parameter
557
+ )
558
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
559
+ self.post.weight.data.zero_()
560
+ self.post.bias.data.zero_()
561
+
562
+ def forward(self, x, x_mask, g=None, reverse=False):
563
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
564
+ h = self.pre(x0) * x_mask
565
+ h = self.enc(h, x_mask, g=g)
566
+ stats = self.post(h) * x_mask
567
+ if not self.mean_only:
568
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
569
+ else:
570
+ m = stats
571
+ logs = torch.zeros_like(m)
572
+
573
+ if not reverse:
574
+ x1 = m + x1 * torch.exp(logs) * x_mask
575
+ x = torch.cat([x0, x1], 1)
576
+ logdet = torch.sum(logs, [1, 2])
577
+ return x, logdet
578
+ else:
579
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
580
+ x = torch.cat([x0, x1], 1)
581
+ return x
582
+
583
+ x1, logabsdet = piecewise_rational_quadratic_transform(
584
+ x1,
585
+ unnormalized_widths,
586
+ unnormalized_heights,
587
+ unnormalized_derivatives,
588
+ inverse=reverse,
589
+ tails="linear",
590
+ tail_bound=self.tail_bound,
591
+ )
592
+
593
+ x = torch.cat([x0, x1], 1) * x_mask
594
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
595
+ if not reverse:
596
+ return x, logdet
597
+ else:
598
+ return x
src/nn/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def piecewise_rational_quadratic_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = rational_quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unconstrained_rational_quadratic_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unconstrained_rational_quadratic_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = rational_quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def rational_quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
src/text/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def cleaned_text_to_sequence(cleaned_text, tones, language, symbol_to_id=None):
8
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
9
+ Args:
10
+ text: string to convert to a sequence
11
+ Returns:
12
+ List of integers corresponding to the symbols in the text
13
+ """
14
+ symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
15
+ unk_id = symbol_to_id_map.get("UNK")
16
+ if unk_id is None:
17
+ phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
18
+ else:
19
+ phones = [symbol_to_id_map.get(symbol, unk_id) for symbol in cleaned_text]
20
+ tone_start = language_tone_start_map[language]
21
+ tones = [i + tone_start for i in tones]
22
+ lang_id = language_id_map[language]
23
+ lang_ids = [lang_id for _ in phones]
24
+ return phones, tones, lang_ids
src/text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.46 kB). View file
 
src/text/__pycache__/cleaner.cpython-310.pyc ADDED
Binary file (1.41 kB). View file
 
src/text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (3.55 kB). View file
 
src/text/cleaner.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from . import cleaned_text_to_sequence
2
+ import copy
3
+
4
+ _language_modules = {}
5
+
6
+ def _get_language_module(language):
7
+ """Lazy import language modules to avoid unnecessary dependencies."""
8
+ if language == 'VI':
9
+ from . import vietnamese
10
+ _language_modules['VI'] = vietnamese
11
+ else:
12
+ raise ValueError(f"Unsupported language: {language}")
13
+
14
+ return _language_modules[language]
15
+
16
+
17
+ def clean_text(text, language):
18
+ language_module = _get_language_module(language)
19
+ norm_text = language_module.text_normalize(text)
20
+ phones, tones, word2ph = language_module.g2p(norm_text)
21
+ return norm_text, phones, tones, word2ph
22
+
23
+
24
+ def clean_text_bert(text, language, device=None):
25
+ language_module = _get_language_module(language)
26
+ norm_text = language_module.text_normalize(text)
27
+ phones, tones, word2ph = language_module.g2p(norm_text)
28
+
29
+ word2ph_bak = copy.deepcopy(word2ph)
30
+ for i in range(len(word2ph)):
31
+ word2ph[i] = word2ph[i] * 2
32
+ word2ph[0] += 1
33
+ bert = language_module.get_bert_feature(norm_text, word2ph, device=device)
34
+
35
+ return norm_text, phones, tones, word2ph_bak, bert
36
+
37
+
38
+ def text_to_sequence(text, language):
39
+ norm_text, phones, tones, word2ph = clean_text(text, language)
40
+ return cleaned_text_to_sequence(phones, tones, language)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ pass
src/text/symbols.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ punctuation = ["!", "?", "…", ",", ".", "'", "-", "¿", "¡"]
3
+ pu_symbols = punctuation + ["SP", "UNK"]
4
+ pad = "_"
5
+
6
+ # chinese
7
+ zh_symbols = [
8
+ "E",
9
+ "En",
10
+ "a",
11
+ "ai",
12
+ "an",
13
+ "ang",
14
+ "ao",
15
+ "b",
16
+ "c",
17
+ "ch",
18
+ "d",
19
+ "e",
20
+ "ei",
21
+ "en",
22
+ "eng",
23
+ "er",
24
+ "f",
25
+ "g",
26
+ "h",
27
+ "i",
28
+ "i0",
29
+ "ia",
30
+ "ian",
31
+ "iang",
32
+ "iao",
33
+ "ie",
34
+ "in",
35
+ "ing",
36
+ "iong",
37
+ "ir",
38
+ "iu",
39
+ "j",
40
+ "k",
41
+ "l",
42
+ "m",
43
+ "n",
44
+ "o",
45
+ "ong",
46
+ "ou",
47
+ "p",
48
+ "q",
49
+ "r",
50
+ "s",
51
+ "sh",
52
+ "t",
53
+ "u",
54
+ "ua",
55
+ "uai",
56
+ "uan",
57
+ "uang",
58
+ "ui",
59
+ "un",
60
+ "uo",
61
+ "v",
62
+ "van",
63
+ "ve",
64
+ "vn",
65
+ "w",
66
+ "x",
67
+ "y",
68
+ "z",
69
+ "zh",
70
+ "AA",
71
+ "EE",
72
+ "OO",
73
+ ]
74
+ num_zh_tones = 6
75
+
76
+ # japanese
77
+ ja_symbols = [
78
+ "N",
79
+ "a",
80
+ "a:",
81
+ "b",
82
+ "by",
83
+ "ch",
84
+ "d",
85
+ "dy",
86
+ "e",
87
+ "e:",
88
+ "f",
89
+ "g",
90
+ "gy",
91
+ "h",
92
+ "hy",
93
+ "i",
94
+ "i:",
95
+ "j",
96
+ "k",
97
+ "ky",
98
+ "m",
99
+ "my",
100
+ "n",
101
+ "ny",
102
+ "o",
103
+ "o:",
104
+ "p",
105
+ "py",
106
+ "q",
107
+ "r",
108
+ "ry",
109
+ "s",
110
+ "sh",
111
+ "t",
112
+ "ts",
113
+ "ty",
114
+ "u",
115
+ "u:",
116
+ "w",
117
+ "y",
118
+ "z",
119
+ "zy",
120
+ ]
121
+ num_ja_tones = 1
122
+
123
+ # English
124
+ en_symbols = [
125
+ "aa",
126
+ "ae",
127
+ "ah",
128
+ "ao",
129
+ "aw",
130
+ "ay",
131
+ "b",
132
+ "ch",
133
+ "d",
134
+ "dh",
135
+ "eh",
136
+ "er",
137
+ "ey",
138
+ "f",
139
+ "g",
140
+ "hh",
141
+ "ih",
142
+ "iy",
143
+ "jh",
144
+ "k",
145
+ "l",
146
+ "m",
147
+ "n",
148
+ "ng",
149
+ "ow",
150
+ "oy",
151
+ "p",
152
+ "r",
153
+ "s",
154
+ "sh",
155
+ "t",
156
+ "th",
157
+ "uh",
158
+ "uw",
159
+ "V",
160
+ "w",
161
+ "y",
162
+ "z",
163
+ "zh",
164
+ ]
165
+ num_en_tones = 4
166
+
167
+ # Korean
168
+ kr_symbols = ['ᄌ', 'ᅥ', 'ᆫ', 'ᅦ', 'ᄋ', 'ᅵ', 'ᄅ', 'ᅴ', 'ᄀ', 'ᅡ', 'ᄎ', 'ᅪ', 'ᄑ', 'ᅩ', 'ᄐ', 'ᄃ', 'ᅢ', 'ᅮ', 'ᆼ', 'ᅳ', 'ᄒ', 'ᄆ', 'ᆯ', 'ᆷ', 'ᄂ', 'ᄇ', 'ᄉ', 'ᆮ', 'ᄁ', 'ᅬ', 'ᅣ', 'ᄄ', 'ᆨ', 'ᄍ', 'ᅧ', 'ᄏ', 'ᆸ', 'ᅭ', '(', 'ᄊ', ')', 'ᅲ', 'ᅨ', 'ᄈ', 'ᅱ', 'ᅯ', 'ᅫ', 'ᅰ', 'ᅤ', '~', '\\', '[', ']', '/', '^', ':', 'ㄸ', '*']
169
+ num_kr_tones = 1
170
+
171
+ # Spanish
172
+ es_symbols = [
173
+ "N",
174
+ "Q",
175
+ "a",
176
+ "b",
177
+ "d",
178
+ "e",
179
+ "f",
180
+ "g",
181
+ "h",
182
+ "i",
183
+ "j",
184
+ "k",
185
+ "l",
186
+ "m",
187
+ "n",
188
+ "o",
189
+ "p",
190
+ "s",
191
+ "t",
192
+ "u",
193
+ "v",
194
+ "w",
195
+ "x",
196
+ "y",
197
+ "z",
198
+ "ɑ",
199
+ "æ",
200
+ "ʃ",
201
+ "ʑ",
202
+ "ç",
203
+ "ɯ",
204
+ "ɪ",
205
+ "ɔ",
206
+ "ɛ",
207
+ "ɹ",
208
+ "ð",
209
+ "ə",
210
+ "ɫ",
211
+ "ɥ",
212
+ "ɸ",
213
+ "ʊ",
214
+ "ɾ",
215
+ "ʒ",
216
+ "θ",
217
+ "β",
218
+ "ŋ",
219
+ "ɦ",
220
+ "ɡ",
221
+ "r",
222
+ "ɲ",
223
+ "ʝ",
224
+ "ɣ",
225
+ "ʎ",
226
+ "ˈ",
227
+ "ˌ",
228
+ "ː"
229
+ ]
230
+ num_es_tones = 1
231
+
232
+ # French
233
+ fr_symbols = [
234
+ "\u0303",
235
+ "œ",
236
+ "ø",
237
+ "ʁ",
238
+ "ɒ",
239
+ "ʌ",
240
+ "ɜ",
241
+ "ɐ"
242
+ ]
243
+ num_fr_tones = 1
244
+
245
+ # German
246
+ de_symbols = [
247
+ "ʏ",
248
+ "̩"
249
+ ]
250
+ num_de_tones = 1
251
+
252
+ # Russian
253
+ ru_symbols = [
254
+ "ɭ",
255
+ "ʲ",
256
+ "ɕ",
257
+ "\"",
258
+ "ɵ",
259
+ "^",
260
+ "ɬ"
261
+ ]
262
+ num_ru_tones = 1
263
+
264
+ # Vietnamese (IPA-based, compatible with VieNeu-TTS-140h dataset)
265
+ vi_symbols = [
266
+ # Consonants (simple)
267
+ "ʈ", # tr
268
+ "ɖ", # đ
269
+ "ɗ", # implosive d (đ variant)
270
+ "ɓ", # implosive b
271
+ "ʰ", # aspiration marker
272
+ "ă", # short a (Vietnamese)
273
+ "ʷ", # labialization marker
274
+ "̆", # breve diacritic
275
+ "͡", # tie bar (for affricates)
276
+ "ʤ", # voiced postalveolar affricate
277
+ "ʧ", # voiceless postalveolar affricate
278
+ # Foreign/special characters found in dataset
279
+ "т", # Cyrillic т
280
+ "輪", # Chinese character
281
+ "и", # Cyrillic и
282
+ "л", # Cyrillic л
283
+ "р", # Cyrillic р
284
+ "µ", # micro sign
285
+ "ʂ", # s (retroflex)
286
+ "ʐ", # r (retroflex)
287
+ "ʔ", # glottal stop
288
+ "ɣ", # g (southern)
289
+ # Multi-char consonants (from vietnamese.py g2p)
290
+ "tʰ", # th
291
+ "kʰ", # kh
292
+ "kw", # qu -> kw
293
+ "tʃ", # ch
294
+ "ɹ", # r IPA
295
+ # Vowels specific to Vietnamese
296
+ "ɤ", # ơ
297
+ "ɐ", # a short
298
+ "ɑ", # a back
299
+ "ɨ", # ư variant
300
+ "ʉ", # u variant
301
+ "ɜ", # open-mid central
302
+ # Long vowels (from VieNeu-TTS dataset)
303
+ "əː", # schwa long
304
+ "aː", # a long
305
+ "ɜː", # open-mid central long
306
+ "ɑː", # open back long
307
+ "ɔː", # open-mid back long
308
+ "iː", # close front long
309
+ "uː", # close back long
310
+ "eː", # close-mid front long
311
+ "oː", # close-mid back long
312
+ # Diphthongs and special combinations
313
+ "iə", # ia/iê
314
+ "ɨə", # ưa/ươ
315
+ "uə", # ua/uô
316
+ # Additional IPA markers
317
+ "ˑ", # half-long
318
+ "̪", # dental diacritic
319
+ # Tone-related (though tones are handled separately)
320
+ "˥", # tone 1 marker
321
+ "˩", # tone marker
322
+ "˧", # tone marker
323
+ "˨", # tone marker
324
+ "˦", # tone marker
325
+ # Numbers (found in phonemized dataset)
326
+ "0", "1", "2", "3", "4", "5", "6", "7", "8", "9",
327
+ # Special characters from dataset
328
+ "$", "%", "&", "«", "»", "–", "ı",
329
+ # viphoneme specific symbols
330
+ "wʷ", # labialized w
331
+ "#", # unknown/fallback marker
332
+ "ô", # Vietnamese ô (fallback)
333
+ "ʃ", # voiceless postalveolar fricative
334
+ "ʒ", # voiced postalveolar fricative
335
+ "θ", # voiceless dental fricative
336
+ "ð", # voiced dental fricative
337
+ "æ", # near-open front unrounded
338
+ "ɪ", # near-close front unrounded
339
+ "ʊ", # near-close back rounded
340
+ # Vietnamese fallback characters (when viphoneme fails to parse)
341
+ "ẩ", "ò", "à", "á", "ủ", "ờ", "ộ", "ả", "ó", "é", "ê",
342
+ "ồ", "ấ", "ú", "ế", "ớ", "ì", "ọ", "ố", "ư", "ữ",
343
+ ]
344
+ num_vi_tones = 8 # 6 tones + 1 neutral + 1 extra for data compatibility
345
+
346
+ # combine all symbols
347
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols + kr_symbols + es_symbols + fr_symbols + de_symbols + ru_symbols + vi_symbols))
348
+ symbols = [pad] + normal_symbols + pu_symbols
349
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
350
+
351
+ # combine all tones
352
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones + num_vi_tones
353
+
354
+ # language maps
355
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4, 'ES': 5, 'SP': 5, 'FR': 6, 'VI': 7}
356
+ num_languages = len(language_id_map.keys())
357
+
358
+ language_tone_start_map = {
359
+ "ZH": 0,
360
+ "ZH_MIX_EN": 0,
361
+ "JP": num_zh_tones,
362
+ "EN": num_zh_tones + num_ja_tones,
363
+ 'KR': num_zh_tones + num_ja_tones + num_en_tones,
364
+ "ES": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
365
+ "SP": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
366
+ "FR": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones,
367
+ "VI": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones,
368
+ }
369
+
370
+ if __name__ == "__main__":
371
+ a = set(zh_symbols)
372
+ b = set(en_symbols)
373
+ print(sorted(a & b))
src/text/vietnamese.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import unicodedata
3
+ from transformers import AutoTokenizer
4
+ from . import punctuation, symbols
5
+
6
+ # Vietnamese BERT model
7
+ model_id = 'vinai/phobert-base-v2'
8
+ tokenizer = None
9
+
10
+ def get_tokenizer():
11
+ global tokenizer
12
+ if tokenizer is None:
13
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
14
+ return tokenizer
15
+
16
+ # Vietnamese IPA phoneme set based on VieNeu-TTS-140h dataset
17
+ # These are extracted from the phonemized_text field in the dataset
18
+ VI_IPA_CONSONANTS = [
19
+ 'b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'r', 's', 't', 'v', 'w', 'x', 'z',
20
+ 'ŋ', # ng
21
+ 'ɲ', # nh
22
+ 'ʈ', # tr
23
+ 'ɖ', # đ
24
+ 'tʰ', # th
25
+ 'kʰ', # kh
26
+ 'ʂ', # s (southern)
27
+ 'ɣ', # g (southern)
28
+ 'χ', # x (some dialects)
29
+ ]
30
+
31
+ VI_IPA_VOWELS = [
32
+ 'a', 'ă', 'â', 'e', 'ê', 'i', 'o', 'ô', 'ơ', 'u', 'ư', 'y',
33
+ 'ə', # ơ
34
+ 'ɛ', # e
35
+ 'ɔ', # o
36
+ 'ɯ', # ư
37
+ 'ɤ', # ơ variant
38
+ 'ɐ', # a short
39
+ 'ʊ', # u short
40
+ 'ɪ', # i short
41
+ 'ʌ', # â
42
+ 'æ', # a variant
43
+ ]
44
+
45
+ # Vietnamese tone markers (numbers 1-6 or ˈ ˌ for stress)
46
+ VI_TONE_MARKERS = ['1', '2', '3', '4', '5', '6', 'ˈ', 'ˌ', 'ː']
47
+
48
+ # Combined IPA symbols used in VieNeu-TTS dataset
49
+ VI_IPA_SYMBOLS = [
50
+ # Consonants
51
+ 'b', 'c', 'd', 'f', 'g', 'h', 'j', 'k', 'l', 'm', 'n', 'p', 'r', 's', 't', 'v', 'w', 'x', 'z',
52
+ 'ŋ', 'ɲ', 'ʈ', 'ɖ', 'ʂ', 'ɣ', 'χ', 'ʔ',
53
+ # Vowels
54
+ 'a', 'ă', 'e', 'i', 'o', 'u', 'y',
55
+ 'ə', 'ɛ', 'ɔ', 'ɯ', 'ɤ', 'ɐ', 'ʊ', 'ɪ', 'ʌ', 'æ', 'ɑ',
56
+ # Special markers
57
+ 'ˈ', 'ˌ', 'ː',
58
+ # Tone numbers
59
+ '1', '2', '3', '4', '5', '6',
60
+ ]
61
+
62
+ def normalize_vietnamese_text(text):
63
+ """Normalize Vietnamese text."""
64
+ # Normalize unicode
65
+ text = unicodedata.normalize('NFC', text)
66
+
67
+ # Remove extra whitespace
68
+ text = re.sub(r'\s+', ' ', text)
69
+ text = text.strip()
70
+
71
+ # Convert numbers to words (basic)
72
+ text = convert_numbers_to_vietnamese(text)
73
+
74
+ return text
75
+
76
+ def convert_numbers_to_vietnamese(text):
77
+ """Convert numbers to Vietnamese words (basic implementation)."""
78
+ num_map = {
79
+ '0': 'không', '1': 'một', '2': 'hai', '3': 'ba', '4': 'bốn',
80
+ '5': 'năm', '6': 'sáu', '7': 'bảy', '8': 'tám', '9': 'chín',
81
+ '10': 'mười', '100': 'trăm', '1000': 'nghìn'
82
+ }
83
+
84
+ # Simple replacement for single digits in context
85
+ def replace_num(match):
86
+ num = match.group(0)
87
+ if num in num_map:
88
+ return num_map[num]
89
+ return num
90
+
91
+ # Only replace standalone numbers
92
+ text = re.sub(r'\b\d\b', replace_num, text)
93
+ return text
94
+
95
+ def text_normalize(text):
96
+ """Normalize text for Vietnamese TTS."""
97
+ text = normalize_vietnamese_text(text)
98
+ return text
99
+
100
+ def parse_ipa_phonemes(phonemized_text):
101
+ """
102
+ Parse IPA phonemized text from VieNeu-TTS dataset.
103
+ Example: "ŋˈyə2j ŋˈyə2j bˈan xwˈan vˈe2"
104
+ Returns: phones, tones, word2ph
105
+ """
106
+ phones = []
107
+ tones = []
108
+ word2ph = []
109
+
110
+ # Split by space to get words
111
+ words = phonemized_text.strip().split()
112
+
113
+ for word in words:
114
+ word_phones = []
115
+ word_tones = []
116
+
117
+ # Parse each character/symbol in the word
118
+ i = 0
119
+ current_tone = 0 # Default tone (neutral/tone 1)
120
+
121
+ while i < len(word):
122
+ char = word[i]
123
+
124
+ # Check for tone numbers (1-6)
125
+ if char.isdigit():
126
+ current_tone = int(char)
127
+ i += 1
128
+ continue
129
+
130
+ # Check for stress markers
131
+ if char in ['ˈ', 'ˌ']:
132
+ # Primary or secondary stress - could be used as tone variant
133
+ i += 1
134
+ continue
135
+
136
+ # Check for length marker
137
+ if char == 'ː':
138
+ # Long vowel marker - append to previous phone if exists
139
+ if word_phones:
140
+ word_phones[-1] = word_phones[-1] + 'ː'
141
+ i += 1
142
+ continue
143
+
144
+ # Check for punctuation
145
+ if char in punctuation:
146
+ if word_phones:
147
+ phones.extend(word_phones)
148
+ tones.extend([current_tone] * len(word_phones))
149
+ word2ph.append(len(word_phones))
150
+ word_phones = []
151
+ word_tones = []
152
+ phones.append(char)
153
+ tones.append(0)
154
+ word2ph.append(1)
155
+ i += 1
156
+ continue
157
+
158
+ # Regular phoneme
159
+ word_phones.append(char)
160
+ i += 1
161
+
162
+ # Apply collected tone to all phones in this word
163
+ if word_phones:
164
+ phones.extend(word_phones)
165
+ tones.extend([current_tone] * len(word_phones))
166
+ word2ph.append(len(word_phones))
167
+
168
+ return phones, tones, word2ph
169
+
170
+ def g2p_ipa(text):
171
+ """
172
+ Convert text to phonemes using external IPA converter.
173
+ This is a fallback for when phonemized_text is not available.
174
+ For training, we use the pre-phonemized text from the dataset.
175
+ """
176
+ try:
177
+ from viphoneme import vi2ipa
178
+ phonemized = vi2ipa(text)
179
+ phones, tones, word2ph = parse_ipa_phonemes(phonemized)
180
+ except ImportError:
181
+ # Fallback: use character-based representation
182
+ phones, tones, word2ph = g2p_char_based(text)
183
+
184
+ # Add start and end tokens
185
+ phones = ["_"] + phones + ["_"]
186
+ tones = [0] + tones + [0]
187
+ word2ph = [1] + word2ph + [1]
188
+
189
+ return phones, tones, word2ph
190
+
191
+ def g2p_char_based(text):
192
+ """
193
+ Character-based G2P with Vietnamese to IPA mapping.
194
+ """
195
+ phones = []
196
+ tones = []
197
+ word2ph = []
198
+
199
+ # Vietnamese tone marks to tone number mapping
200
+ tone_marks = {
201
+ '\u0300': 2, # à - huyền
202
+ '\u0301': 1, # á - sắc
203
+ '\u0303': 3, # ã - ngã
204
+ '\u0309': 4, # ả - hỏi
205
+ '\u0323': 5, # ạ - nặng
206
+ }
207
+
208
+ # Vietnamese character to IPA mapping (COMPREHENSIVE - matching training data)
209
+ # Multi-char outputs are split into lists to avoid KeyError for missing multi-char symbols
210
+ vi_to_ipa = {
211
+ # Multi-char consonants (check these first - ORDER MATTERS)
212
+ 'ngh': 'ŋ',
213
+ 'ng': 'ŋ',
214
+ 'nh': 'ɲ',
215
+ 'ch': ['t', 'ʃ'], # Vietnamese ch = IPA t + ʃ (separated in training data)
216
+ 'tr': 'ʈ', # retroflex
217
+ 'th': ['t', 'h'], # aspirated th
218
+ 'ph': 'f',
219
+ 'kh': 'x', # Vietnamese 'kh' = IPA 'x' (matches training data)
220
+ 'gh': 'ɣ',
221
+ 'gi': 'z',
222
+ 'qu': 'kw', # qu -> kw (single symbol in training data)
223
+ # Special Vietnamese consonants
224
+ 'đ': 'ɗ', # implosive d
225
+ # Basic consonants that need IPA mapping
226
+ 'x': 's', # Vietnamese 'x' = IPA 's'
227
+ 'c': 'k', # Vietnamese 'c' = IPA 'k'
228
+ 'd': 'z', # Vietnamese 'd' (northern) = 'z'
229
+ 'r': 'ɹ', # Vietnamese 'r' = IPA 'ɹ' (matches training data)
230
+ 's': 's',
231
+ 'b': 'b',
232
+ 'g': 'ɣ',
233
+ 'h': 'h',
234
+ 'k': 'k',
235
+ 'l': 'l',
236
+ 'm': 'm',
237
+ 'n': 'n',
238
+ 'p': 'p',
239
+ 't': 't',
240
+ 'v': 'v',
241
+ 'f': 'f',
242
+ 'j': 'j',
243
+ 'w': 'w',
244
+ 'y': 'j', # Vietnamese 'y' = IPA 'j' (matches training data)
245
+ # Vowels - MUST match training data phonemes exactly!
246
+ 'a': 'aː', # Long 'a' (matches training: aː)
247
+ 'ă': 'a', # Short 'a'
248
+ 'â': 'ə', # schwa
249
+ 'e': 'ɛ', # open-mid (matches training: ɛ)
250
+ 'ê': 'e', # close-mid
251
+ 'i': 'i',
252
+ 'o': 'ɔ', # open-mid back (matches training: ɔ)
253
+ 'ô': 'o', # close-mid back
254
+ 'ơ': 'əː', # long schwa
255
+ 'u': 'u',
256
+ 'ư': 'ɯ', # close back unrounded
257
+ }
258
+
259
+ words = text.split()
260
+ for word in words:
261
+ # Decompose to separate base char and tone mark
262
+ decomposed = unicodedata.normalize('NFD', word)
263
+ word_phones = []
264
+ current_tone = 0
265
+
266
+ i = 0
267
+ chars = list(decomposed)
268
+ while i < len(chars):
269
+ char = chars[i]
270
+
271
+ if char in tone_marks:
272
+ current_tone = tone_marks[char]
273
+ i += 1
274
+ continue
275
+
276
+ if char in punctuation:
277
+ if word_phones:
278
+ phones.extend(word_phones)
279
+ tones.extend([current_tone] * len(word_phones))
280
+ word2ph.append(len(word_phones))
281
+ word_phones = []
282
+ phones.append(char)
283
+ tones.append(0)
284
+ word2ph.append(1)
285
+ current_tone = 0
286
+ i += 1
287
+ continue
288
+
289
+ if unicodedata.combining(char):
290
+ i += 1
291
+ continue
292
+
293
+ # Check for multi-char sequences (digraphs/trigraphs)
294
+ lower_char = char.lower()
295
+ matched = False
296
+
297
+ # Try trigraphs first
298
+ if i + 2 < len(chars):
299
+ trigraph = (lower_char + chars[i+1].lower() + chars[i+2].lower())
300
+ if trigraph in vi_to_ipa:
301
+ result = vi_to_ipa[trigraph]
302
+ if isinstance(result, list):
303
+ word_phones.extend(result)
304
+ else:
305
+ word_phones.append(result)
306
+ i += 3
307
+ matched = True
308
+
309
+ # Try digraphs
310
+ if not matched and i + 1 < len(chars):
311
+ digraph = lower_char + chars[i+1].lower()
312
+ if digraph in vi_to_ipa:
313
+ result = vi_to_ipa[digraph]
314
+ if isinstance(result, list):
315
+ word_phones.extend(result)
316
+ else:
317
+ word_phones.append(result)
318
+ i += 2
319
+ matched = True
320
+
321
+ # Single char
322
+ if not matched:
323
+ if lower_char in vi_to_ipa:
324
+ result = vi_to_ipa[lower_char]
325
+ if isinstance(result, list):
326
+ word_phones.extend(result)
327
+ else:
328
+ word_phones.append(result)
329
+ else:
330
+ word_phones.append(lower_char)
331
+ i += 1
332
+
333
+ if word_phones:
334
+ phones.extend(word_phones)
335
+ tones.extend([current_tone] * len(word_phones))
336
+ word2ph.append(len(word_phones))
337
+
338
+ # Add boundary tokens
339
+ phones = ["_"] + phones + ["_"]
340
+ tones = [0] + tones + [0]
341
+ word2ph = [1] + word2ph + [1]
342
+
343
+ return phones, tones, word2ph
344
+
345
+ def g2p(text):
346
+ """
347
+ Main G2P function for Vietnamese.
348
+ Uses character-to-IPA mapping with BERT alignment.
349
+ """
350
+ tok = get_tokenizer()
351
+ norm_text = text_normalize(text)
352
+
353
+ # Tokenize for BERT alignment
354
+ tokenized = tok.tokenize(norm_text)
355
+
356
+ # Use character-based G2P with IPA mapping
357
+ phones, tones, word2ph = g2p_char_based(norm_text)
358
+
359
+ # Ensure word2ph aligns with tokenized output
360
+ # PhoBERT uses subword tokenization, so we need to distribute phones
361
+ if len(word2ph) != len(tokenized) + 2: # +2 for start/end tokens
362
+ # Redistribute word2ph to match tokenized length
363
+ total_phones = sum(word2ph)
364
+ new_word2ph = distribute_phones(total_phones, len(tokenized))
365
+ word2ph = [1] + new_word2ph + [1]
366
+
367
+ return phones, tones, word2ph
368
+
369
+ def g2p_with_phonemes(text, phonemized_text):
370
+ """
371
+ G2P using pre-phonemized text from dataset.
372
+ This is the recommended method for training.
373
+ """
374
+ tok = get_tokenizer()
375
+
376
+ # Parse IPA phonemes
377
+ phones, tones, word2ph = parse_ipa_phonemes(phonemized_text)
378
+
379
+ # Add boundary tokens
380
+ phones = ["_"] + phones + ["_"]
381
+ tones = [0] + tones + [0]
382
+
383
+ # Get tokenized text for BERT alignment
384
+ tokenized = tok.tokenize(text)
385
+
386
+ # Distribute word2ph to match tokenized output + boundaries
387
+ if word2ph:
388
+ total_phones = sum(word2ph)
389
+ new_word2ph = distribute_phones(total_phones, len(tokenized))
390
+ word2ph = [1] + new_word2ph + [1]
391
+ else:
392
+ word2ph = [1] + [1] * len(tokenized) + [1]
393
+
394
+ return phones, tones, word2ph
395
+
396
+ def distribute_phones(n_phone, n_word):
397
+ """Distribute phones across words as evenly as possible."""
398
+ if n_word == 0:
399
+ return []
400
+ phones_per_word = [n_phone // n_word] * n_word
401
+ remainder = n_phone % n_word
402
+ for i in range(remainder):
403
+ phones_per_word[i] += 1
404
+ return phones_per_word
405
+
406
+ def get_bert_feature(text, word2ph, device='cuda'):
407
+ """Get BERT features for Vietnamese text."""
408
+ from . import vietnamese_bert
409
+ return vietnamese_bert.get_bert_feature(text, word2ph, device=device, model_id=model_id)
410
+
411
+
412
+ if __name__ == "__main__":
413
+ # Test
414
+ test_text = "Xin chào, tôi là một trợ lý AI."
415
+ test_phonemes = "sˈin tʂˈaːw, tˈoj lˈaː2 mˈo6t tʂˈɤ4 lˈi4 ˌaːˈi."
416
+
417
+ print("Test text:", test_text)
418
+ print("Normalized:", text_normalize(test_text))
419
+
420
+ # Test with phonemes
421
+ phones, tones, word2ph = g2p_with_phonemes(test_text, test_phonemes)
422
+ print("Phones:", phones)
423
+ print("Tones:", tones)
424
+ print("Word2Ph:", word2ph)
425
+
426
+ # Test without phonemes
427
+ phones2, tones2, word2ph2 = g2p(test_text)
428
+ print("\nChar-based phones:", phones2)
429
+ print("Char-based tones:", tones2)
src/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """
2
+ Utility functions package
3
+ """
4
+
5
+ from .helpers import *
src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (220 Bytes). View file
 
src/utils/__pycache__/helpers.cpython-310.pyc ADDED
Binary file (13.9 kB). View file
 
src/utils/helpers.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import argparse
4
+ import logging
5
+ import json
6
+ import subprocess
7
+ import numpy as np
8
+ from scipy.io.wavfile import read
9
+ import torch
10
+ import torchaudio
11
+ import librosa
12
+ from src.text import cleaned_text_to_sequence
13
+ from src.text.cleaner import clean_text
14
+ from src.nn import commons
15
+
16
+ MATPLOTLIB_FLAG = False
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+
22
+ def get_text_for_tts_infer(text, language_str, hps, device, symbol_to_id=None):
23
+ norm_text, phone, tone, word2ph = clean_text(text, language_str)
24
+ phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str, symbol_to_id)
25
+
26
+ if hps.data.add_blank:
27
+ phone = commons.intersperse(phone, 0)
28
+ tone = commons.intersperse(tone, 0)
29
+ language = commons.intersperse(language, 0)
30
+ for i in range(len(word2ph)):
31
+ word2ph[i] = word2ph[i] * 2
32
+ word2ph[0] += 1
33
+
34
+ if getattr(hps.data, "disable_bert", False):
35
+ bert = torch.zeros(1024, len(phone))
36
+ ja_bert = torch.zeros(768, len(phone))
37
+ else:
38
+ bert = get_bert(norm_text, word2ph, language_str, device)
39
+ del word2ph
40
+ assert bert.shape[-1] == len(phone), phone
41
+
42
+ if language_str == "ZH":
43
+ bert = bert
44
+ ja_bert = torch.zeros(768, len(phone))
45
+ elif language_str in ["JP", "EN", "ZH_MIX_EN", 'KR', 'SP', 'ES', 'FR', 'DE', 'RU', 'VI']:
46
+ ja_bert = bert
47
+ bert = torch.zeros(1024, len(phone))
48
+ else:
49
+ raise NotImplementedError()
50
+
51
+ assert bert.shape[-1] == len(
52
+ phone
53
+ ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
54
+
55
+ phone = torch.LongTensor(phone)
56
+ tone = torch.LongTensor(tone)
57
+ language = torch.LongTensor(language)
58
+ return bert, ja_bert, phone, tone, language
59
+
60
+ def load_checkpoint(checkpoint_path, model, optimizer=None, skip_optimizer=False):
61
+ assert os.path.isfile(checkpoint_path)
62
+ checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
63
+ iteration = checkpoint_dict.get("iteration", 0)
64
+ learning_rate = checkpoint_dict.get("learning_rate", 0.)
65
+ if (
66
+ optimizer is not None
67
+ and not skip_optimizer
68
+ and checkpoint_dict["optimizer"] is not None
69
+ ):
70
+ optimizer.load_state_dict(checkpoint_dict["optimizer"])
71
+ elif optimizer is None and not skip_optimizer:
72
+ # else: Disable this line if Infer and resume checkpoint,then enable the line upper
73
+ new_opt_dict = optimizer.state_dict()
74
+ new_opt_dict_params = new_opt_dict["param_groups"][0]["params"]
75
+ new_opt_dict["param_groups"] = checkpoint_dict["optimizer"]["param_groups"]
76
+ new_opt_dict["param_groups"][0]["params"] = new_opt_dict_params
77
+ optimizer.load_state_dict(new_opt_dict)
78
+
79
+ saved_state_dict = checkpoint_dict["model"]
80
+ if hasattr(model, "module"):
81
+ state_dict = model.module.state_dict()
82
+ else:
83
+ state_dict = model.state_dict()
84
+
85
+ new_state_dict = {}
86
+ for k, v in state_dict.items():
87
+ try:
88
+ # assert "emb_g" not in k
89
+ new_state_dict[k] = saved_state_dict[k]
90
+ assert saved_state_dict[k].shape == v.shape, (
91
+ saved_state_dict[k].shape,
92
+ v.shape,
93
+ )
94
+ except Exception as e:
95
+ print(e)
96
+ # For upgrading from the old version
97
+ if "ja_bert_proj" in k:
98
+ v = torch.zeros_like(v)
99
+ logger.warn(
100
+ f"Seems you are using the old version of the model, the {k} is automatically set to zero for backward compatibility"
101
+ )
102
+ else:
103
+ logger.error(f"{k} is not in the checkpoint")
104
+
105
+ new_state_dict[k] = v
106
+
107
+ if hasattr(model, "module"):
108
+ model.module.load_state_dict(new_state_dict, strict=False)
109
+ else:
110
+ model.load_state_dict(new_state_dict, strict=False)
111
+
112
+ logger.info(
113
+ "Loaded checkpoint '{}' (iteration {})".format(checkpoint_path, iteration)
114
+ )
115
+
116
+ return model, optimizer, learning_rate, iteration
117
+
118
+
119
+ def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
120
+ logger.info(
121
+ "Saving model and optimizer state at iteration {} to {}".format(
122
+ iteration, checkpoint_path
123
+ )
124
+ )
125
+ if hasattr(model, "module"):
126
+ state_dict = model.module.state_dict()
127
+ else:
128
+ state_dict = model.state_dict()
129
+ torch.save(
130
+ {
131
+ "model": state_dict,
132
+ "iteration": iteration,
133
+ "optimizer": optimizer.state_dict(),
134
+ "learning_rate": learning_rate,
135
+ },
136
+ checkpoint_path,
137
+ )
138
+
139
+
140
+ def summarize(
141
+ writer,
142
+ global_step,
143
+ scalars={},
144
+ histograms={},
145
+ images={},
146
+ audios={},
147
+ audio_sampling_rate=22050,
148
+ ):
149
+ for k, v in scalars.items():
150
+ writer.add_scalar(k, v, global_step)
151
+ for k, v in histograms.items():
152
+ writer.add_histogram(k, v, global_step)
153
+ for k, v in images.items():
154
+ writer.add_image(k, v, global_step, dataformats="HWC")
155
+ for k, v in audios.items():
156
+ writer.add_audio(k, v, global_step, audio_sampling_rate)
157
+
158
+
159
+ def latest_checkpoint_path(dir_path, regex="G_*.pth"):
160
+ f_list = glob.glob(os.path.join(dir_path, regex))
161
+ f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
162
+ x = f_list[-1]
163
+ return x
164
+
165
+
166
+ def plot_spectrogram_to_numpy(spectrogram):
167
+ global MATPLOTLIB_FLAG
168
+ if not MATPLOTLIB_FLAG:
169
+ try:
170
+ import matplotlib
171
+
172
+ matplotlib.use("Agg")
173
+ MATPLOTLIB_FLAG = True
174
+ mpl_logger = logging.getLogger("matplotlib")
175
+ mpl_logger.setLevel(logging.WARNING)
176
+ except Exception:
177
+ spec = np.asarray(spectrogram, dtype=np.float32)
178
+ if spec.ndim > 2:
179
+ spec = np.squeeze(spec)
180
+ if spec.ndim != 2:
181
+ return np.zeros((1, 1, 3), dtype=np.uint8)
182
+ vmin = np.nanmin(spec)
183
+ vmax = np.nanmax(spec)
184
+ if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin:
185
+ return np.zeros((spec.shape[0], spec.shape[1], 3), dtype=np.uint8)
186
+ img = ((spec - vmin) / (vmax - vmin) * 255.0).clip(0, 255).astype(np.uint8)
187
+ return np.stack([img, img, img], axis=-1)
188
+ import matplotlib.pylab as plt
189
+ import numpy as np
190
+
191
+ fig, ax = plt.subplots(figsize=(10, 2))
192
+ im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
193
+ plt.colorbar(im, ax=ax)
194
+ plt.xlabel("Frames")
195
+ plt.ylabel("Channels")
196
+ plt.tight_layout()
197
+
198
+ fig.canvas.draw()
199
+ # Use buffer_rgba() instead of deprecated tostring_rgb()
200
+ buf = fig.canvas.buffer_rgba()
201
+ data = np.asarray(buf)[:, :, :3] # Remove alpha channel
202
+ plt.close()
203
+ return data
204
+
205
+
206
+ def plot_alignment_to_numpy(alignment, info=None):
207
+ global MATPLOTLIB_FLAG
208
+ if not MATPLOTLIB_FLAG:
209
+ try:
210
+ import matplotlib
211
+
212
+ matplotlib.use("Agg")
213
+ MATPLOTLIB_FLAG = True
214
+ mpl_logger = logging.getLogger("matplotlib")
215
+ mpl_logger.setLevel(logging.WARNING)
216
+ except Exception:
217
+ ali = np.asarray(alignment, dtype=np.float32)
218
+ if ali.ndim > 2:
219
+ ali = np.squeeze(ali)
220
+ if ali.ndim != 2:
221
+ return np.zeros((1, 1, 3), dtype=np.uint8)
222
+ vmin = np.nanmin(ali)
223
+ vmax = np.nanmax(ali)
224
+ if not np.isfinite(vmin) or not np.isfinite(vmax) or vmax <= vmin:
225
+ return np.zeros((ali.shape[0], ali.shape[1], 3), dtype=np.uint8)
226
+ img = ((ali - vmin) / (vmax - vmin) * 255.0).clip(0, 255).astype(np.uint8)
227
+ return np.stack([img, img, img], axis=-1)
228
+ import matplotlib.pylab as plt
229
+ import numpy as np
230
+
231
+ fig, ax = plt.subplots(figsize=(6, 4))
232
+ im = ax.imshow(
233
+ alignment.transpose(), aspect="auto", origin="lower", interpolation="none"
234
+ )
235
+ fig.colorbar(im, ax=ax)
236
+ xlabel = "Decoder timestep"
237
+ if info is not None:
238
+ xlabel += "\n\n" + info
239
+ plt.xlabel(xlabel)
240
+ plt.ylabel("Encoder timestep")
241
+ plt.tight_layout()
242
+
243
+ fig.canvas.draw()
244
+ # Use buffer_rgba() instead of deprecated tostring_rgb()
245
+ buf = fig.canvas.buffer_rgba()
246
+ data = np.asarray(buf)[:, :, :3] # Remove alpha channel
247
+ plt.close()
248
+ return data
249
+
250
+
251
+ def load_wav_to_torch(full_path):
252
+ sampling_rate, data = read(full_path)
253
+ return torch.FloatTensor(data.astype(np.float32)), sampling_rate
254
+
255
+
256
+ def load_wav_to_torch_new(full_path):
257
+ audio_norm, sampling_rate = torchaudio.load(full_path, frame_offset=0, num_frames=-1, normalize=True, channels_first=True)
258
+ audio_norm = audio_norm.mean(dim=0)
259
+ return audio_norm, sampling_rate
260
+
261
+ def load_wav_to_torch_librosa(full_path, sr):
262
+ audio_norm, sampling_rate = librosa.load(full_path, sr=sr, mono=True)
263
+ return torch.FloatTensor(audio_norm.astype(np.float32)), sampling_rate
264
+
265
+
266
+ def load_filepaths_and_text(filename, split="|"):
267
+ with open(filename, encoding="utf-8") as f:
268
+ filepaths_and_text = [line.strip().split(split) for line in f]
269
+ return filepaths_and_text
270
+
271
+
272
+ def get_hparams(init=True):
273
+ parser = argparse.ArgumentParser()
274
+ parser.add_argument(
275
+ "-c",
276
+ "--config",
277
+ type=str,
278
+ default="./configs/base.json",
279
+ help="JSON file for configuration",
280
+ )
281
+ parser.add_argument('--local_rank', type=int, default=0)
282
+ parser.add_argument('--world-size', type=int, default=1)
283
+ parser.add_argument('--port', type=int, default=10000)
284
+ parser.add_argument("-m", "--model", type=str, required=True, help="Model name")
285
+ parser.add_argument('--pretrain_G', type=str, default=None,
286
+ help='pretrain model')
287
+ parser.add_argument('--pretrain_D', type=str, default=None,
288
+ help='pretrain model D')
289
+ parser.add_argument('--pretrain_dur', type=str, default=None,
290
+ help='pretrain model duration')
291
+
292
+ args = parser.parse_args()
293
+ model_dir = os.path.join("./logs", args.model)
294
+
295
+ os.makedirs(model_dir, exist_ok=True)
296
+
297
+ config_path = args.config
298
+ config_save_path = os.path.join(model_dir, "config.json")
299
+ if init:
300
+ with open(config_path, "r") as f:
301
+ data = f.read()
302
+ with open(config_save_path, "w") as f:
303
+ f.write(data)
304
+ else:
305
+ with open(config_save_path, "r") as f:
306
+ data = f.read()
307
+ config = json.loads(data)
308
+
309
+ hparams = HParams(**config)
310
+ hparams.model_dir = model_dir
311
+ hparams.pretrain_G = args.pretrain_G
312
+ hparams.pretrain_D = args.pretrain_D
313
+ hparams.pretrain_dur = args.pretrain_dur
314
+ hparams.port = args.port
315
+ return hparams
316
+
317
+
318
+ def clean_checkpoints(path_to_models="logs/44k/", n_ckpts_to_keep=2, sort_by_time=True):
319
+ """Freeing up space by deleting saved ckpts
320
+
321
+ Arguments:
322
+ path_to_models -- Path to the model directory
323
+ n_ckpts_to_keep -- Number of ckpts to keep, excluding G_0.pth and D_0.pth
324
+ sort_by_time -- True -> chronologically delete ckpts
325
+ False -> lexicographically delete ckpts
326
+ """
327
+ import re
328
+
329
+ ckpts_files = [
330
+ f
331
+ for f in os.listdir(path_to_models)
332
+ if os.path.isfile(os.path.join(path_to_models, f))
333
+ ]
334
+
335
+ def name_key(_f):
336
+ return int(re.compile("._(\\d+)\\.pth").match(_f).group(1))
337
+
338
+ def time_key(_f):
339
+ return os.path.getmtime(os.path.join(path_to_models, _f))
340
+
341
+ sort_key = time_key if sort_by_time else name_key
342
+
343
+ def x_sorted(_x):
344
+ return sorted(
345
+ [f for f in ckpts_files if f.startswith(_x) and not f.endswith("_0.pth")],
346
+ key=sort_key,
347
+ )
348
+
349
+ to_del = [
350
+ os.path.join(path_to_models, fn)
351
+ for fn in (x_sorted("G")[:-n_ckpts_to_keep] + x_sorted("D")[:-n_ckpts_to_keep])
352
+ ]
353
+
354
+ def del_info(fn):
355
+ return logger.info(f".. Free up space by deleting ckpt {fn}")
356
+
357
+ def del_routine(x):
358
+ return [os.remove(x), del_info(x)]
359
+
360
+ [del_routine(fn) for fn in to_del]
361
+
362
+
363
+ def get_hparams_from_dir(model_dir):
364
+ config_save_path = os.path.join(model_dir, "config.json")
365
+ with open(config_save_path, "r", encoding="utf-8") as f:
366
+ data = f.read()
367
+ config = json.loads(data)
368
+
369
+ hparams = HParams(**config)
370
+ hparams.model_dir = model_dir
371
+ return hparams
372
+
373
+
374
+ def get_hparams_from_file(config_path):
375
+ with open(config_path, "r", encoding="utf-8") as f:
376
+ data = f.read()
377
+ config = json.loads(data)
378
+
379
+ hparams = HParams(**config)
380
+ return hparams
381
+
382
+
383
+ def check_git_hash(model_dir):
384
+ source_dir = os.path.dirname(os.path.realpath(__file__))
385
+ if not os.path.exists(os.path.join(source_dir, ".git")):
386
+ logger.warn(
387
+ "{} is not a git repository, therefore hash value comparison will be ignored.".format(
388
+ source_dir
389
+ )
390
+ )
391
+ return
392
+
393
+ cur_hash = subprocess.getoutput("git rev-parse HEAD")
394
+
395
+ path = os.path.join(model_dir, "githash")
396
+ if os.path.exists(path):
397
+ saved_hash = open(path).read()
398
+ if saved_hash != cur_hash:
399
+ logger.warn(
400
+ "git hash values are different. {}(saved) != {}(current)".format(
401
+ saved_hash[:8], cur_hash[:8]
402
+ )
403
+ )
404
+ else:
405
+ open(path, "w").write(cur_hash)
406
+
407
+
408
+ def get_logger(model_dir, filename="train.log"):
409
+ global logger
410
+ logger = logging.getLogger(os.path.basename(model_dir))
411
+ logger.setLevel(logging.DEBUG)
412
+
413
+ formatter = logging.Formatter("%(asctime)s\t%(name)s\t%(levelname)s\t%(message)s")
414
+ if not os.path.exists(model_dir):
415
+ os.makedirs(model_dir, exist_ok=True)
416
+ h = logging.FileHandler(os.path.join(model_dir, filename))
417
+ h.setLevel(logging.DEBUG)
418
+ h.setFormatter(formatter)
419
+ logger.addHandler(h)
420
+ return logger
421
+
422
+
423
+ class HParams:
424
+ def __init__(self, **kwargs):
425
+ for k, v in kwargs.items():
426
+ if type(v) == dict:
427
+ v = HParams(**v)
428
+ self[k] = v
429
+
430
+ def keys(self):
431
+ return self.__dict__.keys()
432
+
433
+ def items(self):
434
+ return self.__dict__.items()
435
+
436
+ def values(self):
437
+ return self.__dict__.values()
438
+
439
+ def __len__(self):
440
+ return len(self.__dict__)
441
+
442
+ def __getitem__(self, key):
443
+ return getattr(self, key)
444
+
445
+ def __setitem__(self, key, value):
446
+ return setattr(self, key, value)
447
+
448
+ def __contains__(self, key):
449
+ return key in self.__dict__
450
+
451
+ def __repr__(self):
452
+ return self.__dict__.__repr__()
src/vietnamese/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Vietnamese language support package
3
+ """
4
+
5
+ from .phonemizer import text_to_phonemes, VIPHONEME_AVAILABLE, get_all_phonemes
6
+ from .text_processor import process_vietnamese_text
src/vietnamese/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (381 Bytes). View file
 
src/vietnamese/__pycache__/phonemizer.cpython-310.pyc ADDED
Binary file (9.03 kB). View file
 
src/vietnamese/__pycache__/text_processor.cpython-310.pyc ADDED
Binary file (11.1 kB). View file
 
src/vietnamese/phonemizer.py ADDED
@@ -0,0 +1,484 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import atexit
2
+ import contextlib
3
+ import importlib.util
4
+ import io
5
+ import os
6
+ import re
7
+ import shutil
8
+ import sys
9
+ import tempfile
10
+ import unicodedata
11
+ from typing import List, Tuple
12
+ from viphoneme import vi2IPA
13
+
14
+ try:
15
+ import fcntl # type: ignore
16
+ except Exception:
17
+ fcntl = None
18
+
19
+ VIPHONEME_AVAILABLE = True
20
+ _VIPHONEME_WORKDIR = None
21
+ _VINORM_ISOLATED_PARENT = None
22
+
23
+
24
+ def _get_viphoneme_workdir() -> str:
25
+ global _VIPHONEME_WORKDIR
26
+ if _VIPHONEME_WORKDIR is None:
27
+ _VIPHONEME_WORKDIR = tempfile.mkdtemp(prefix="viphoneme_")
28
+ atexit.register(shutil.rmtree, _VIPHONEME_WORKDIR, ignore_errors=True)
29
+ return _VIPHONEME_WORKDIR
30
+
31
+
32
+ def _ensure_vinorm_isolated() -> None:
33
+ global _VINORM_ISOLATED_PARENT
34
+ if os.environ.get("VIPHONEME_ISOLATE_VINORM", "1") not in {"1", "true", "True", "YES", "yes"}:
35
+ return
36
+ if _VINORM_ISOLATED_PARENT is not None:
37
+ return
38
+
39
+ spec = importlib.util.find_spec("vinorm")
40
+ if spec is None or spec.origin is None:
41
+ return
42
+
43
+ src_dir = os.path.dirname(spec.origin)
44
+ if not os.path.isfile(os.path.join(src_dir, "__init__.py")):
45
+ return
46
+
47
+ parent = tempfile.mkdtemp(prefix="vinorm_")
48
+ dst_dir = os.path.join(parent, "vinorm")
49
+ os.makedirs(dst_dir, exist_ok=True)
50
+
51
+ shutil.copy2(os.path.join(src_dir, "__init__.py"), os.path.join(dst_dir, "__init__.py"))
52
+
53
+ for name in os.listdir(src_dir):
54
+ if name in {"__init__.py", "__pycache__", "input.txt", "output.txt"}:
55
+ continue
56
+ src = os.path.join(src_dir, name)
57
+ dst = os.path.join(dst_dir, name)
58
+ if os.path.exists(dst):
59
+ continue
60
+ try:
61
+ os.symlink(src, dst)
62
+ except Exception:
63
+ if os.path.isdir(src):
64
+ shutil.copytree(src, dst)
65
+ elif os.path.isfile(src):
66
+ shutil.copy2(src, dst)
67
+
68
+ if parent not in sys.path:
69
+ sys.path.insert(0, parent)
70
+ if "vinorm" in sys.modules:
71
+ del sys.modules["vinorm"]
72
+
73
+ _VINORM_ISOLATED_PARENT = parent
74
+ atexit.register(shutil.rmtree, parent, ignore_errors=True)
75
+
76
+
77
+ @contextlib.contextmanager
78
+ def _redirect_fds_to_devnull():
79
+ devnull_fd = os.open(os.devnull, os.O_WRONLY)
80
+ saved_stdout_fd = os.dup(1)
81
+ saved_stderr_fd = os.dup(2)
82
+ try:
83
+ os.dup2(devnull_fd, 1)
84
+ os.dup2(devnull_fd, 2)
85
+ yield
86
+ finally:
87
+ os.dup2(saved_stdout_fd, 1)
88
+ os.dup2(saved_stderr_fd, 2)
89
+ os.close(saved_stdout_fd)
90
+ os.close(saved_stderr_fd)
91
+ os.close(devnull_fd)
92
+
93
+
94
+ @contextlib.contextmanager
95
+ def _viphoneme_global_lock():
96
+ lock_path = os.environ.get("VIPHONEME_LOCK_PATH", "/tmp/viphoneme.lock")
97
+ use_lock = os.environ.get("VIPHONEME_USE_LOCK")
98
+ if use_lock is None:
99
+ use_lock = "0" if os.environ.get("VIPHONEME_ISOLATE_VINORM", "1") in {"1", "true", "True", "YES", "yes"} else "1"
100
+ if use_lock not in {"1", "true", "True", "YES", "yes"}:
101
+ yield
102
+ return
103
+ if fcntl is None:
104
+ yield
105
+ return
106
+ fd = os.open(lock_path, os.O_CREAT | os.O_RDWR, 0o666)
107
+ try:
108
+ fcntl.flock(fd, fcntl.LOCK_EX)
109
+ yield
110
+ finally:
111
+ try:
112
+ fcntl.flock(fd, fcntl.LOCK_UN)
113
+ finally:
114
+ os.close(fd)
115
+
116
+
117
+ # Vietnamese tone diacritics to tone number mapping
118
+ TONE_MARKS = {
119
+ '\u0300': 2, # ̀ huyền (falling)
120
+ '\u0301': 1, # ́ sắc (rising)
121
+ '\u0303': 3, # ̃ ngã (broken)
122
+ '\u0309': 4, # ̉ hỏi (dipping)
123
+ '\u0323': 5, # ̣ nặng (heavy/glottalized)
124
+ }
125
+
126
+ # Default tone (no diacritic) = 0 (ngang/level)
127
+
128
+ # Vietnamese orthography to IPA mapping
129
+ VI_TO_IPA = {
130
+ # Trigraphs (check first)
131
+ 'ngh': 'ŋ',
132
+
133
+ # Digraphs
134
+ 'ng': 'ŋ',
135
+ 'nh': 'ɲ',
136
+ 'ch': 'c', # Vietnamese ch = palatal stop
137
+ 'tr': 'ʈ', # Retroflex
138
+ 'th': 'tʰ', # Aspirated
139
+ 'ph': 'f',
140
+ 'kh': 'x', # Voiceless velar fricative
141
+ 'gh': 'ɣ',
142
+ 'gi': 'z',
143
+ 'qu': 'kw',
144
+
145
+ # Special consonants
146
+ 'đ': 'ɗ', # Implosive d
147
+
148
+ # Simple consonants
149
+ 'b': 'ɓ', # Implosive b (can also be plain b)
150
+ 'c': 'k',
151
+ 'd': 'z', # Northern: z, Southern: j
152
+ 'g': 'ɣ',
153
+ 'h': 'h',
154
+ 'k': 'k',
155
+ 'l': 'l',
156
+ 'm': 'm',
157
+ 'n': 'n',
158
+ 'p': 'p',
159
+ 'r': 'ʐ', # Retroflex (varies by dialect)
160
+ 's': 's',
161
+ 't': 't',
162
+ 'v': 'v',
163
+ 'x': 's', # Vietnamese x = s
164
+
165
+ # Vowels
166
+ 'a': 'aː',
167
+ 'ă': 'a', # Short a
168
+ 'â': 'ə', # Schwa
169
+ 'e': 'ɛ',
170
+ 'ê': 'e',
171
+ 'i': 'i',
172
+ 'y': 'i', # Same as i
173
+ 'o': 'ɔ',
174
+ 'ô': 'o',
175
+ 'ơ': 'əː', # Long schwa
176
+ 'u': 'u',
177
+ 'ư': 'ɯ', # Unrounded u
178
+
179
+ # Diphthongs (handled separately)
180
+ }
181
+
182
+ # Final consonants (codas)
183
+ FINAL_CONSONANTS = {
184
+ 'c': 'k',
185
+ 'ch': 'c',
186
+ 'm': 'm',
187
+ 'n': 'n',
188
+ 'ng': 'ŋ',
189
+ 'nh': 'ɲ',
190
+ 'p': 'p',
191
+ 't': 't',
192
+ }
193
+
194
+ # Punctuation to keep
195
+ PUNCTUATION = set(',.!?;:\'"--—…()[]{}')
196
+
197
+ # Punctuation that creates pauses (SP = short pause)
198
+ PAUSE_PUNCTUATION = {',', ';', ':'}
199
+ STOP_PUNCTUATION = {'.', '!', '?', '…'}
200
+
201
+ def extract_tone(char: str) -> Tuple[str, int]:
202
+ """
203
+ Extract tone from a Vietnamese character.
204
+ Returns (base_char, tone_number)
205
+ """
206
+ # Decompose to separate base and combining marks
207
+ decomposed = unicodedata.normalize('NFD', char)
208
+ base = ''
209
+ tone = 0
210
+
211
+ for c in decomposed:
212
+ if c in TONE_MARKS:
213
+ tone = TONE_MARKS[c]
214
+ elif not unicodedata.combining(c):
215
+ base += c
216
+
217
+ return base, tone
218
+
219
+
220
+ def syllable_to_ipa(syllable: str) -> Tuple[List[str], int]:
221
+ """
222
+ Convert a Vietnamese syllable to IPA phonemes with tone.
223
+ Returns (phonemes, tone)
224
+ """
225
+ syllable = syllable.lower()
226
+ phonemes = []
227
+ tone = 0
228
+
229
+ # Extract tone from vowels
230
+ processed = ''
231
+ for char in syllable:
232
+ base, char_tone = extract_tone(char)
233
+ if char_tone > 0:
234
+ tone = char_tone
235
+ processed += base
236
+
237
+ syllable = processed
238
+ i = 0
239
+
240
+ while i < len(syllable):
241
+ matched = False
242
+
243
+ # Try trigraphs
244
+ if i + 2 < len(syllable):
245
+ tri = syllable[i:i+3]
246
+ if tri in VI_TO_IPA:
247
+ phonemes.append(VI_TO_IPA[tri])
248
+ i += 3
249
+ matched = True
250
+
251
+ # Try digraphs
252
+ if not matched and i + 1 < len(syllable):
253
+ di = syllable[i:i+2]
254
+ if di in VI_TO_IPA:
255
+ phonemes.append(VI_TO_IPA[di])
256
+ i += 2
257
+ matched = True
258
+
259
+ # Single character
260
+ if not matched:
261
+ char = syllable[i]
262
+ if char in VI_TO_IPA:
263
+ phonemes.append(VI_TO_IPA[char])
264
+ elif char.isalpha():
265
+ phonemes.append(char) # Keep as-is if not mapped
266
+ i += 1
267
+
268
+ return phonemes, tone
269
+
270
+
271
+ def text_to_phonemes_viphoneme(text: str) -> Tuple[List[str], List[int], List[int]]:
272
+ """
273
+ Convert text to phonemes using viphoneme library.
274
+ Returns (phones, tones, word2ph)
275
+
276
+ viphoneme output format:
277
+ - Syllables separated by space
278
+ - Compound words joined by underscore: hom1_năj1
279
+ - Tone number (1-6) at end of each syllable
280
+ - Punctuation as separate tokens
281
+ """
282
+ import warnings
283
+
284
+ # Call viphoneme (ICU warnings will appear but won't affect results)
285
+ # Note: viphoneme may not work on Windows due to platform-specific binaries
286
+ try:
287
+ _ensure_vinorm_isolated()
288
+ workdir = _get_viphoneme_workdir()
289
+ with _viphoneme_global_lock():
290
+ cwd = os.getcwd()
291
+ os.chdir(workdir)
292
+ try:
293
+ with warnings.catch_warnings():
294
+ warnings.simplefilter("ignore")
295
+ with _redirect_fds_to_devnull():
296
+ ipa_text = vi2IPA(text)
297
+ finally:
298
+ os.chdir(cwd)
299
+ except Exception:
300
+ # Fallback to char-based on error (e.g., Windows compatibility issues)
301
+ return text_to_phonemes_charbased(text)
302
+
303
+ # Check if viphoneme returned empty or invalid result
304
+ if not ipa_text or ipa_text.strip() in ['', '.', '..', '...']:
305
+ return text_to_phonemes_charbased(text)
306
+
307
+ phones = []
308
+ tones = []
309
+ word2ph = []
310
+
311
+ # viphoneme tone mapping: 1=ngang, 2=huyền, 3=ngã, 4=hỏi, 5=sắc, 6=nặng
312
+ # Our internal: 0=ngang, 1=sắc, 2=huyền, 3=ngã, 4=hỏi, 5=nặng
313
+ VIPHONEME_TONE_MAP = {1: 0, 2: 2, 3: 3, 4: 4, 5: 1, 6: 5}
314
+
315
+ # Characters to skip (combining marks, ties)
316
+ SKIP_CHARS = {'\u0306', '\u0361', '\u032f', '\u0330', '\u0329'} # breve, tie, etc.
317
+
318
+ # Split by space
319
+ tokens = ipa_text.strip().split()
320
+
321
+ for token in tokens:
322
+ # Handle punctuation-only tokens
323
+ if all(c in PUNCTUATION or c == '.' for c in token):
324
+ for c in token:
325
+ if c in PUNCTUATION:
326
+ phones.append(c)
327
+ tones.append(0)
328
+ word2ph.append(1)
329
+ continue
330
+
331
+ # Split compound words by underscore
332
+ syllables = token.split('_')
333
+
334
+ for syllable in syllables:
335
+ if not syllable:
336
+ continue
337
+
338
+ syllable_phones = []
339
+ syllable_tone = 0
340
+ i = 0
341
+
342
+ while i < len(syllable):
343
+ char = syllable[i]
344
+
345
+ # Tone number at end
346
+ if char.isdigit():
347
+ syllable_tone = VIPHONEME_TONE_MAP.get(int(char), 0)
348
+ i += 1
349
+ continue
350
+
351
+ # Skip combining marks (they modify previous char, already handled)
352
+ if unicodedata.combining(char):
353
+ i += 1
354
+ continue
355
+
356
+ # Skip modifier letters like ʷ ʰ (append to previous if exists)
357
+ if char in {'ʷ', 'ʰ', 'ː'}:
358
+ if syllable_phones:
359
+ syllable_phones[-1] = syllable_phones[-1] + char
360
+ i += 1
361
+ continue
362
+
363
+ # Skip tie bars and other special marks
364
+ if char in {'\u0361', '\u035c', '\u0361'}: # tie bars
365
+ i += 1
366
+ continue
367
+
368
+ # Punctuation within syllable
369
+ if char in PUNCTUATION:
370
+ i += 1
371
+ continue
372
+
373
+ # Regular phoneme character
374
+ syllable_phones.append(char)
375
+ i += 1
376
+
377
+ if syllable_phones:
378
+ phones.extend(syllable_phones)
379
+ tones.extend([syllable_tone] * len(syllable_phones))
380
+ word2ph.append(len(syllable_phones))
381
+
382
+ return phones, tones, word2ph
383
+
384
+
385
+ def text_to_phonemes_charbased(text: str) -> Tuple[List[str], List[int], List[int]]:
386
+ """
387
+ Convert text to phonemes using character-based mapping.
388
+ Returns (phones, tones, word2ph)
389
+ """
390
+ phones = []
391
+ tones = []
392
+ word2ph = []
393
+
394
+ words = text.split()
395
+
396
+ for word in words:
397
+ # Check for punctuation at end
398
+ trailing_punct = []
399
+ while word and word[-1] in PUNCTUATION:
400
+ trailing_punct.insert(0, word[-1])
401
+ word = word[:-1]
402
+
403
+ # Check for punctuation at start
404
+ leading_punct = []
405
+ while word and word[0] in PUNCTUATION:
406
+ leading_punct.append(word[0])
407
+ word = word[1:]
408
+
409
+ # Add leading punctuation
410
+ for p in leading_punct:
411
+ phones.append(p)
412
+ tones.append(0)
413
+ word2ph.append(1)
414
+
415
+ # Process word syllables (Vietnamese words can be multi-syllable)
416
+ if word:
417
+ word_phones, tone = syllable_to_ipa(word)
418
+ if word_phones:
419
+ phones.extend(word_phones)
420
+ tones.extend([tone] * len(word_phones))
421
+ word2ph.append(len(word_phones))
422
+
423
+ # Add trailing punctuation
424
+ for p in trailing_punct:
425
+ phones.append(p)
426
+ tones.append(0)
427
+ word2ph.append(1)
428
+
429
+ return phones, tones, word2ph
430
+
431
+
432
+ def text_to_phonemes(text: str, use_viphoneme: bool = True) -> Tuple[List[str], List[int], List[int]]:
433
+ """
434
+ Main function to convert Vietnamese text to phonemes.
435
+
436
+ Args:
437
+ text: Vietnamese text
438
+ use_viphoneme: Whether to use viphoneme library (if available)
439
+
440
+ Returns:
441
+ phones: List of IPA phonemes
442
+ tones: List of tone numbers (0-5)
443
+ word2ph: List of phone counts per word
444
+ """
445
+ if use_viphoneme and VIPHONEME_AVAILABLE:
446
+ phones, tones, word2ph = text_to_phonemes_viphoneme(text)
447
+ else:
448
+ phones, tones, word2ph = text_to_phonemes_charbased(text)
449
+
450
+ # Add boundary tokens
451
+ phones = ["_"] + phones + ["_"]
452
+ tones = [0] + tones + [0]
453
+ word2ph = [1] + word2ph + [1]
454
+
455
+ return phones, tones, word2ph
456
+
457
+
458
+ def get_all_phonemes() -> List[str]:
459
+ """Get list of all possible phonemes for symbol table."""
460
+ phonemes = set()
461
+
462
+ # From IPA mapping
463
+ for ipa in VI_TO_IPA.values():
464
+ if isinstance(ipa, str):
465
+ phonemes.add(ipa)
466
+ # Also add with length marker
467
+ if len(ipa) == 1:
468
+ phonemes.add(ipa + 'ː')
469
+
470
+ # Common IPA symbols
471
+ phonemes.update([
472
+ # Consonants
473
+ 'b', 'ɓ', 'c', 'd', 'ɗ', 'f', 'g', 'ɣ', 'h', 'j', 'k', 'l', 'm', 'n',
474
+ 'ŋ', 'ɲ', 'p', 'r', 'ʐ', 's', 'ʂ', 't', 'tʰ', 'ʈ', 'v', 'w', 'x', 'z',
475
+ # Vowels
476
+ 'a', 'aː', 'ə', 'əː', 'ɛ', 'e', 'i', 'ɪ', 'o', 'ɔ', 'u', 'ʊ', 'ɯ', 'ɤ',
477
+ # Special
478
+ '_', ' ',
479
+ ])
480
+
481
+ # Punctuation
482
+ phonemes.update(PUNCTUATION)
483
+
484
+ return sorted(list(phonemes))
src/vietnamese/text_processor.py ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Vietnamese Text Processor for TTS
4
+ Handles normalization of numbers, dates, times, currencies, etc.
5
+ """
6
+
7
+ import re
8
+ import unicodedata
9
+
10
+
11
+ # Vietnamese number words
12
+ DIGITS = {
13
+ '0': 'không', '1': 'một', '2': 'hai', '3': 'ba', '4': 'bốn',
14
+ '5': 'năm', '6': 'sáu', '7': 'bảy', '8': 'tám', '9': 'chín'
15
+ }
16
+
17
+ TEENS = {
18
+ '10': 'mười', '11': 'mười một', '12': 'mười hai', '13': 'mười ba',
19
+ '14': 'mười bốn', '15': 'mười lăm', '16': 'mười sáu', '17': 'mười bảy',
20
+ '18': 'mười tám', '19': 'mười chín'
21
+ }
22
+
23
+ TENS = {
24
+ '2': 'hai mươi', '3': 'ba mươi', '4': 'bốn mươi', '5': 'năm mươi',
25
+ '6': 'sáu mươi', '7': 'bảy mươi', '8': 'tám mươi', '9': 'chín mươi'
26
+ }
27
+
28
+
29
+ def number_to_words(num_str):
30
+ """
31
+ Convert a number string to Vietnamese words.
32
+ Handles numbers from 0 to billions.
33
+ """
34
+ # Remove leading zeros but keep at least one digit
35
+ num_str = num_str.lstrip('0') or '0'
36
+
37
+ # Handle negative numbers
38
+ if num_str.startswith('-'):
39
+ return 'âm ' + number_to_words(num_str[1:])
40
+
41
+ # Convert to integer for processing
42
+ try:
43
+ num = int(num_str)
44
+ except ValueError:
45
+ return num_str
46
+
47
+ if num == 0:
48
+ return 'không'
49
+
50
+ if num < 10:
51
+ return DIGITS[str(num)]
52
+
53
+ if num < 20:
54
+ return TEENS[str(num)]
55
+
56
+ if num < 100:
57
+ tens = num // 10
58
+ units = num % 10
59
+ if units == 0:
60
+ return TENS[str(tens)]
61
+ elif units == 1:
62
+ return TENS[str(tens)] + ' mốt'
63
+ elif units == 4:
64
+ return TENS[str(tens)] + ' tư'
65
+ elif units == 5:
66
+ return TENS[str(tens)] + ' lăm'
67
+ else:
68
+ return TENS[str(tens)] + ' ' + DIGITS[str(units)]
69
+
70
+ if num < 1000:
71
+ hundreds = num // 100
72
+ remainder = num % 100
73
+ result = DIGITS[str(hundreds)] + ' trăm'
74
+ if remainder == 0:
75
+ return result
76
+ elif remainder < 10:
77
+ return result + ' lẻ ' + DIGITS[str(remainder)]
78
+ else:
79
+ return result + ' ' + number_to_words(str(remainder))
80
+
81
+ if num < 1000000:
82
+ thousands = num // 1000
83
+ remainder = num % 1000
84
+ result = number_to_words(str(thousands)) + ' nghìn'
85
+ if remainder == 0:
86
+ return result
87
+ elif remainder < 100:
88
+ return result + ' không trăm ' + number_to_words(str(remainder))
89
+ else:
90
+ return result + ' ' + number_to_words(str(remainder))
91
+
92
+ if num < 1000000000:
93
+ millions = num // 1000000
94
+ remainder = num % 1000000
95
+ result = number_to_words(str(millions)) + ' triệu'
96
+ if remainder == 0:
97
+ return result
98
+ else:
99
+ return result + ' ' + number_to_words(str(remainder))
100
+
101
+ if num < 1000000000000:
102
+ billions = num // 1000000000
103
+ remainder = num % 1000000000
104
+ result = number_to_words(str(billions)) + ' tỷ'
105
+ if remainder == 0:
106
+ return result
107
+ else:
108
+ return result + ' ' + number_to_words(str(remainder))
109
+
110
+ # For very large numbers, read digit by digit
111
+ return ' '.join(DIGITS.get(d, d) for d in num_str)
112
+
113
+
114
+ def convert_decimal(text):
115
+ """Convert decimal numbers: 3.14 -> ba phẩy mười bốn"""
116
+ def replace_decimal(match):
117
+ integer_part = match.group(1)
118
+ decimal_part = match.group(2)
119
+
120
+ integer_words = number_to_words(integer_part)
121
+
122
+ # Read decimal part as a number
123
+ decimal_words = number_to_words(decimal_part.lstrip('0') or '0')
124
+
125
+ return f"{integer_words} phẩy {decimal_words}"
126
+
127
+ # Match decimal numbers: X.Y where Y is 1-2 digits, followed by space or end
128
+ # Avoid matching large numbers like 100.000 (thousand separator)
129
+ text = re.sub(r'(\d+)\.(\d{1,2})(?=\s|$|[^\d])', replace_decimal, text)
130
+ return text
131
+
132
+
133
+ def convert_percentage(text):
134
+ """Convert percentages: 50% -> năm mươi phần trăm"""
135
+ def replace_percent(match):
136
+ num = match.group(1)
137
+ return number_to_words(num) + ' phần trăm'
138
+
139
+ text = re.sub(r'(\d+(?:[.,]\d+)?)\s*%', replace_percent, text)
140
+ return text
141
+
142
+
143
+ def convert_currency(text):
144
+ """Convert currency amounts"""
145
+ # Vietnamese Dong - be specific to avoid matching "đ" in other words like "độ"
146
+ def replace_vnd(match):
147
+ num = match.group(1).replace('.', '').replace(',', '')
148
+ return number_to_words(num) + ' đồng'
149
+
150
+ # Only match currency patterns: number followed by currency symbol at word boundary
151
+ text = re.sub(r'(\d+(?:[.,]\d+)*)\s*(?:đồng|VND|vnđ)\b', replace_vnd, text, flags=re.IGNORECASE)
152
+ text = re.sub(r'(\d+(?:[.,]\d+)*)đ(?![a-zà-ỹ])', replace_vnd, text, flags=re.IGNORECASE)
153
+
154
+ # USD
155
+ def replace_usd(match):
156
+ num = match.group(1).replace('.', '').replace(',', '')
157
+ return number_to_words(num) + ' đô la'
158
+
159
+ text = re.sub(r'\$\s*(\d+(?:[.,]\d+)*)', replace_usd, text)
160
+ text = re.sub(r'(\d+(?:[.,]\d+)*)\s*(?:USD|\$)', replace_usd, text, flags=re.IGNORECASE)
161
+
162
+ return text
163
+
164
+
165
+ def convert_time(text):
166
+ """Convert time expressions: 2 giờ 20 phút -> hai giờ hai mươi phút"""
167
+ def replace_time(match):
168
+ hour = match.group(1)
169
+ minute = match.group(2) if match.group(2) else None
170
+ second = match.group(3) if len(match.groups()) > 2 and match.group(3) else None
171
+
172
+ result = number_to_words(hour) + ' giờ'
173
+ if minute:
174
+ result += ' ' + number_to_words(minute) + ' phút'
175
+ if second:
176
+ result += ' ' + number_to_words(second) + ' giây'
177
+ return result
178
+
179
+ # HH:MM:SS or HH:MM
180
+ text = re.sub(r'(\d{1,2}):(\d{2})(?::(\d{2}))?', replace_time, text)
181
+
182
+ # X giờ Y phút
183
+ def replace_time_vn(match):
184
+ hour = match.group(1)
185
+ minute = match.group(2)
186
+ return number_to_words(hour) + ' giờ ' + number_to_words(minute) + ' phút'
187
+
188
+ text = re.sub(r'(\d+)\s*giờ\s*(\d+)\s*phút', replace_time_vn, text)
189
+
190
+ # X giờ (without minute)
191
+ def replace_hour(match):
192
+ hour = match.group(1)
193
+ return number_to_words(hour) + ' giờ'
194
+
195
+ text = re.sub(r'(\d+)\s*giờ(?!\s*\d)', replace_hour, text)
196
+
197
+ return text
198
+
199
+
200
+ def convert_date(text):
201
+ """Convert date expressions"""
202
+ # DD/MM/YYYY or DD-MM-YYYY
203
+ def replace_date_full(match):
204
+ day = match.group(1)
205
+ month = match.group(2)
206
+ year = match.group(3)
207
+ return f"ngày {number_to_words(day)} tháng {number_to_words(month)} năm {number_to_words(year)}"
208
+
209
+ # First, replace "Sinh ngày DD/MM/YYYY" pattern to avoid double "ngày"
210
+ text = re.sub(r'(Sinh|sinh)\s+ngày\s+(\d{1,2})[/-](\d{1,2})[/-](\d{4})',
211
+ lambda m: f"{m.group(1)} ngày {number_to_words(m.group(2))} tháng {number_to_words(m.group(3))} năm {number_to_words(m.group(4))}", text)
212
+
213
+ text = re.sub(r'(\d{1,2})[/-](\d{1,2})[/-](\d{4})', replace_date_full, text)
214
+
215
+ # X tháng Y
216
+ def replace_month_day(match):
217
+ day = match.group(1)
218
+ month = match.group(2)
219
+ return f"ngày {number_to_words(day)} tháng {number_to_words(month)}"
220
+
221
+ text = re.sub(r'(\d+)\s*tháng\s*(\d+)', replace_month_day, text)
222
+
223
+ # tháng X (month only)
224
+ def replace_month(match):
225
+ month = match.group(1)
226
+ return 'tháng ' + number_to_words(month)
227
+
228
+ text = re.sub(r'tháng\s*(\d+)', replace_month, text)
229
+
230
+ # ngày X
231
+ def replace_day(match):
232
+ day = match.group(1)
233
+ return 'ngày ' + number_to_words(day)
234
+
235
+ text = re.sub(r'ngày\s*(\d+)', replace_day, text)
236
+
237
+ return text
238
+
239
+
240
+ def convert_year_range(text):
241
+ """Convert year ranges: 1873-1907 -> một nghìn tám trăm bảy mươi ba đến một nghìn chín trăm lẻ bảy"""
242
+ def replace_year_range(match):
243
+ year1 = match.group(1)
244
+ year2 = match.group(2)
245
+ return number_to_words(year1) + ' đến ' + number_to_words(year2)
246
+
247
+ text = re.sub(r'(\d{4})\s*[-–—]\s*(\d{4})', replace_year_range, text)
248
+ return text
249
+
250
+
251
+ def convert_ordinal(text):
252
+ """Convert ordinals: thứ 2 -> thứ hai"""
253
+ ordinal_map = {
254
+ '1': 'nhất', '2': 'hai', '3': 'ba', '4': 'tư', '5': 'năm',
255
+ '6': 'sáu', '7': 'bảy', '8': 'tám', '9': 'chín', '10': 'mười'
256
+ }
257
+
258
+ def replace_ordinal(match):
259
+ prefix = match.group(1)
260
+ num = match.group(2)
261
+ if num in ordinal_map:
262
+ return prefix + ' ' + ordinal_map[num]
263
+ return prefix + ' ' + number_to_words(num)
264
+
265
+ # thứ X, lần X, bước X, phần X
266
+ text = re.sub(r'(thứ|lần|bước|phần|chương|tập|số)\s*(\d+)', replace_ordinal, text, flags=re.IGNORECASE)
267
+
268
+ return text
269
+
270
+
271
+ def convert_standalone_numbers(text):
272
+ """Convert remaining standalone numbers to words"""
273
+ def replace_num(match):
274
+ num = match.group(0)
275
+ # Skip if it's part of a word or already processed
276
+ return number_to_words(num)
277
+
278
+ # Match numbers not followed/preceded by letters
279
+ text = re.sub(r'\b\d+\b', replace_num, text)
280
+ return text
281
+
282
+
283
+ def convert_phone_number(text):
284
+ """Read phone numbers digit by digit"""
285
+ def replace_phone(match):
286
+ phone = match.group(0)
287
+ digits = re.findall(r'\d', phone)
288
+ return ' '.join(DIGITS.get(d, d) for d in digits)
289
+
290
+ # Vietnamese phone patterns
291
+ text = re.sub(r'0\d{9,10}', replace_phone, text)
292
+ text = re.sub(r'\+84\d{9,10}', replace_phone, text)
293
+
294
+ return text
295
+
296
+
297
+ def normalize_unicode(text):
298
+ """Normalize Unicode to NFC form"""
299
+ return unicodedata.normalize('NFC', text)
300
+
301
+
302
+ def clean_whitespace(text):
303
+ """Clean up extra whitespace"""
304
+ text = re.sub(r'\s+', ' ', text)
305
+ return text.strip()
306
+
307
+
308
+ def remove_special_chars(text):
309
+ """Remove or replace special characters that can't be spoken"""
310
+ # Keep Vietnamese diacritics and common punctuation
311
+ # Remove emojis and special symbols
312
+
313
+ # Replace common symbols with words
314
+ text = text.replace('&', ' và ')
315
+ text = text.replace('@', ' a còng ')
316
+ text = text.replace('#', ' thăng ')
317
+ text = text.replace('*', '')
318
+ text = text.replace('_', ' ')
319
+ text = text.replace('~', '')
320
+ text = text.replace('`', '')
321
+ text = text.replace('^', '')
322
+
323
+ # Remove URLs
324
+ text = re.sub(r'https?://\S+', '', text)
325
+ text = re.sub(r'www\.\S+', '', text)
326
+
327
+ # Remove email addresses
328
+ text = re.sub(r'\S+@\S+\.\S+', '', text)
329
+
330
+ return text
331
+
332
+
333
+ def normalize_punctuation(text):
334
+ """Normalize punctuation marks"""
335
+ # Normalize quotes
336
+ text = re.sub(r'[""„‟]', '"', text)
337
+ text = re.sub(r"[''‚‛]", "'", text)
338
+
339
+ # Normalize dashes
340
+ text = re.sub(r'[–—−]', '-', text)
341
+
342
+ # Normalize ellipsis
343
+ text = re.sub(r'\.{3,}', '...', text)
344
+ text = text.replace('…', '...')
345
+
346
+ # Remove multiple punctuation
347
+ text = re.sub(r'([!?.]){2,}', r'\1', text)
348
+
349
+ return text
350
+
351
+
352
+ def process_vietnamese_text(text):
353
+ """
354
+ Main function to process Vietnamese text for TTS.
355
+ Applies all normalization steps in the correct order.
356
+
357
+ Args:
358
+ text: Raw Vietnamese text
359
+
360
+ Returns:
361
+ Normalized text suitable for TTS
362
+ """
363
+ # Step 1: Normalize Unicode
364
+ text = normalize_unicode(text)
365
+
366
+ # Step 2: Remove special characters
367
+ text = remove_special_chars(text)
368
+
369
+ # Step 3: Normalize punctuation
370
+ text = normalize_punctuation(text)
371
+
372
+ # Step 4: Convert year ranges (before other number conversions)
373
+ text = convert_year_range(text)
374
+
375
+ # Step 5: Convert dates
376
+ text = convert_date(text)
377
+
378
+ # Step 6: Convert times
379
+ text = convert_time(text)
380
+
381
+ # Step 7: Convert ordinals
382
+ text = convert_ordinal(text)
383
+
384
+ # Step 8: Convert currency
385
+ text = convert_currency(text)
386
+
387
+ # Step 9: Convert percentages
388
+ text = convert_percentage(text)
389
+
390
+ # Step 10: Convert phone numbers
391
+ text = convert_phone_number(text)
392
+
393
+ # Step 11: Convert decimals (before standalone numbers, after currency)
394
+ text = convert_decimal(text)
395
+
396
+ # Step 12: Convert remaining standalone numbers
397
+ text = convert_standalone_numbers(text)
398
+
399
+ # Step 13: Clean whitespace
400
+ text = clean_whitespace(text)
401
+
402
+ return text
403
+
404
+
405
+ if __name__ == "__main__":
406
+ # Test cases
407
+ test_cases = [
408
+ "Lúc khoảng 2 giờ 20 phút sáng ngày thứ Bảy hay 8 tháng 11",
409
+ "Alfred Jarry 1873-1907 hợp những nhà văn",
410
+ "ông Derringer 44 ly, dí sát đầu tổng thống",
411
+ "Giá sản phẩm là 100.000đ",
412
+ "Tỷ lệ thành công đạt 85%",
413
+ "Họp lúc 14:30",
414
+ "Sinh ngày 15/08/1990",
415
+ "Chương 3: Hành trình mới",
416
+ "Số điện thoại: 0912345678",
417
+ "Nhiệt độ 25.5 độ C",
418
+ "Công ty XYZ có 1500 nhân viên",
419
+ ]
420
+
421
+ print("=" * 60)
422
+ print("Vietnamese Text Processor Test")
423
+ print("=" * 60)
424
+
425
+ for text in test_cases:
426
+ processed = process_vietnamese_text(text)
427
+ print(f"\nOriginal: {text}")
428
+ print(f"Processed: {processed}")