ashishkblink commited on
Commit
74d5ca1
·
verified ·
1 Parent(s): cbaf761

Upload f5_tts/eval/ecapa_tdnn.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/eval/ecapa_tdnn.py +330 -0
f5_tts/eval/ecapa_tdnn.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # just for speaker similarity evaluation, third-party code
2
+
3
+ # From https://github.com/microsoft/UniSpeech/blob/main/downstreams/speaker_verification/models/
4
+ # part of the code is borrowed from https://github.com/lawlict/ECAPA-TDNN
5
+
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """ Res2Conv1d + BatchNorm1d + ReLU
13
+ """
14
+
15
+
16
+ class Res2Conv1dReluBn(nn.Module):
17
+ """
18
+ in_channels == out_channels == channels
19
+ """
20
+
21
+ def __init__(self, channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True, scale=4):
22
+ super().__init__()
23
+ assert channels % scale == 0, "{} % {} != 0".format(channels, scale)
24
+ self.scale = scale
25
+ self.width = channels // scale
26
+ self.nums = scale if scale == 1 else scale - 1
27
+
28
+ self.convs = []
29
+ self.bns = []
30
+ for i in range(self.nums):
31
+ self.convs.append(nn.Conv1d(self.width, self.width, kernel_size, stride, padding, dilation, bias=bias))
32
+ self.bns.append(nn.BatchNorm1d(self.width))
33
+ self.convs = nn.ModuleList(self.convs)
34
+ self.bns = nn.ModuleList(self.bns)
35
+
36
+ def forward(self, x):
37
+ out = []
38
+ spx = torch.split(x, self.width, 1)
39
+ for i in range(self.nums):
40
+ if i == 0:
41
+ sp = spx[i]
42
+ else:
43
+ sp = sp + spx[i]
44
+ # Order: conv -> relu -> bn
45
+ sp = self.convs[i](sp)
46
+ sp = self.bns[i](F.relu(sp))
47
+ out.append(sp)
48
+ if self.scale != 1:
49
+ out.append(spx[self.nums])
50
+ out = torch.cat(out, dim=1)
51
+
52
+ return out
53
+
54
+
55
+ """ Conv1d + BatchNorm1d + ReLU
56
+ """
57
+
58
+
59
+ class Conv1dReluBn(nn.Module):
60
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=True):
61
+ super().__init__()
62
+ self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)
63
+ self.bn = nn.BatchNorm1d(out_channels)
64
+
65
+ def forward(self, x):
66
+ return self.bn(F.relu(self.conv(x)))
67
+
68
+
69
+ """ The SE connection of 1D case.
70
+ """
71
+
72
+
73
+ class SE_Connect(nn.Module):
74
+ def __init__(self, channels, se_bottleneck_dim=128):
75
+ super().__init__()
76
+ self.linear1 = nn.Linear(channels, se_bottleneck_dim)
77
+ self.linear2 = nn.Linear(se_bottleneck_dim, channels)
78
+
79
+ def forward(self, x):
80
+ out = x.mean(dim=2)
81
+ out = F.relu(self.linear1(out))
82
+ out = torch.sigmoid(self.linear2(out))
83
+ out = x * out.unsqueeze(2)
84
+
85
+ return out
86
+
87
+
88
+ """ SE-Res2Block of the ECAPA-TDNN architecture.
89
+ """
90
+
91
+ # def SE_Res2Block(channels, kernel_size, stride, padding, dilation, scale):
92
+ # return nn.Sequential(
93
+ # Conv1dReluBn(channels, 512, kernel_size=1, stride=1, padding=0),
94
+ # Res2Conv1dReluBn(512, kernel_size, stride, padding, dilation, scale=scale),
95
+ # Conv1dReluBn(512, channels, kernel_size=1, stride=1, padding=0),
96
+ # SE_Connect(channels)
97
+ # )
98
+
99
+
100
+ class SE_Res2Block(nn.Module):
101
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, scale, se_bottleneck_dim):
102
+ super().__init__()
103
+ self.Conv1dReluBn1 = Conv1dReluBn(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
104
+ self.Res2Conv1dReluBn = Res2Conv1dReluBn(out_channels, kernel_size, stride, padding, dilation, scale=scale)
105
+ self.Conv1dReluBn2 = Conv1dReluBn(out_channels, out_channels, kernel_size=1, stride=1, padding=0)
106
+ self.SE_Connect = SE_Connect(out_channels, se_bottleneck_dim)
107
+
108
+ self.shortcut = None
109
+ if in_channels != out_channels:
110
+ self.shortcut = nn.Conv1d(
111
+ in_channels=in_channels,
112
+ out_channels=out_channels,
113
+ kernel_size=1,
114
+ )
115
+
116
+ def forward(self, x):
117
+ residual = x
118
+ if self.shortcut:
119
+ residual = self.shortcut(x)
120
+
121
+ x = self.Conv1dReluBn1(x)
122
+ x = self.Res2Conv1dReluBn(x)
123
+ x = self.Conv1dReluBn2(x)
124
+ x = self.SE_Connect(x)
125
+
126
+ return x + residual
127
+
128
+
129
+ """ Attentive weighted mean and standard deviation pooling.
130
+ """
131
+
132
+
133
+ class AttentiveStatsPool(nn.Module):
134
+ def __init__(self, in_dim, attention_channels=128, global_context_att=False):
135
+ super().__init__()
136
+ self.global_context_att = global_context_att
137
+
138
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
139
+ if global_context_att:
140
+ self.linear1 = nn.Conv1d(in_dim * 3, attention_channels, kernel_size=1) # equals W and b in the paper
141
+ else:
142
+ self.linear1 = nn.Conv1d(in_dim, attention_channels, kernel_size=1) # equals W and b in the paper
143
+ self.linear2 = nn.Conv1d(attention_channels, in_dim, kernel_size=1) # equals V and k in the paper
144
+
145
+ def forward(self, x):
146
+ if self.global_context_att:
147
+ context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
148
+ context_std = torch.sqrt(torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
149
+ x_in = torch.cat((x, context_mean, context_std), dim=1)
150
+ else:
151
+ x_in = x
152
+
153
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
154
+ alpha = torch.tanh(self.linear1(x_in))
155
+ # alpha = F.relu(self.linear1(x_in))
156
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
157
+ mean = torch.sum(alpha * x, dim=2)
158
+ residuals = torch.sum(alpha * (x**2), dim=2) - mean**2
159
+ std = torch.sqrt(residuals.clamp(min=1e-9))
160
+ return torch.cat([mean, std], dim=1)
161
+
162
+
163
+ class ECAPA_TDNN(nn.Module):
164
+ def __init__(
165
+ self,
166
+ feat_dim=80,
167
+ channels=512,
168
+ emb_dim=192,
169
+ global_context_att=False,
170
+ feat_type="wavlm_large",
171
+ sr=16000,
172
+ feature_selection="hidden_states",
173
+ update_extract=False,
174
+ config_path=None,
175
+ ):
176
+ super().__init__()
177
+
178
+ self.feat_type = feat_type
179
+ self.feature_selection = feature_selection
180
+ self.update_extract = update_extract
181
+ self.sr = sr
182
+
183
+ torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
184
+ try:
185
+ local_s3prl_path = os.path.expanduser("~/.cache/torch/hub/s3prl_s3prl_main")
186
+ self.feature_extract = torch.hub.load(local_s3prl_path, feat_type, source="local", config_path=config_path)
187
+ except: # noqa: E722
188
+ self.feature_extract = torch.hub.load("s3prl/s3prl", feat_type)
189
+
190
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
191
+ self.feature_extract.model.encoder.layers[23].self_attn, "fp32_attention"
192
+ ):
193
+ self.feature_extract.model.encoder.layers[23].self_attn.fp32_attention = False
194
+ if len(self.feature_extract.model.encoder.layers) == 24 and hasattr(
195
+ self.feature_extract.model.encoder.layers[11].self_attn, "fp32_attention"
196
+ ):
197
+ self.feature_extract.model.encoder.layers[11].self_attn.fp32_attention = False
198
+
199
+ self.feat_num = self.get_feat_num()
200
+ self.feature_weight = nn.Parameter(torch.zeros(self.feat_num))
201
+
202
+ if feat_type != "fbank" and feat_type != "mfcc":
203
+ freeze_list = ["final_proj", "label_embs_concat", "mask_emb", "project_q", "quantizer"]
204
+ for name, param in self.feature_extract.named_parameters():
205
+ for freeze_val in freeze_list:
206
+ if freeze_val in name:
207
+ param.requires_grad = False
208
+ break
209
+
210
+ if not self.update_extract:
211
+ for param in self.feature_extract.parameters():
212
+ param.requires_grad = False
213
+
214
+ self.instance_norm = nn.InstanceNorm1d(feat_dim)
215
+ # self.channels = [channels] * 4 + [channels * 3]
216
+ self.channels = [channels] * 4 + [1536]
217
+
218
+ self.layer1 = Conv1dReluBn(feat_dim, self.channels[0], kernel_size=5, padding=2)
219
+ self.layer2 = SE_Res2Block(
220
+ self.channels[0],
221
+ self.channels[1],
222
+ kernel_size=3,
223
+ stride=1,
224
+ padding=2,
225
+ dilation=2,
226
+ scale=8,
227
+ se_bottleneck_dim=128,
228
+ )
229
+ self.layer3 = SE_Res2Block(
230
+ self.channels[1],
231
+ self.channels[2],
232
+ kernel_size=3,
233
+ stride=1,
234
+ padding=3,
235
+ dilation=3,
236
+ scale=8,
237
+ se_bottleneck_dim=128,
238
+ )
239
+ self.layer4 = SE_Res2Block(
240
+ self.channels[2],
241
+ self.channels[3],
242
+ kernel_size=3,
243
+ stride=1,
244
+ padding=4,
245
+ dilation=4,
246
+ scale=8,
247
+ se_bottleneck_dim=128,
248
+ )
249
+
250
+ # self.conv = nn.Conv1d(self.channels[-1], self.channels[-1], kernel_size=1)
251
+ cat_channels = channels * 3
252
+ self.conv = nn.Conv1d(cat_channels, self.channels[-1], kernel_size=1)
253
+ self.pooling = AttentiveStatsPool(
254
+ self.channels[-1], attention_channels=128, global_context_att=global_context_att
255
+ )
256
+ self.bn = nn.BatchNorm1d(self.channels[-1] * 2)
257
+ self.linear = nn.Linear(self.channels[-1] * 2, emb_dim)
258
+
259
+ def get_feat_num(self):
260
+ self.feature_extract.eval()
261
+ wav = [torch.randn(self.sr).to(next(self.feature_extract.parameters()).device)]
262
+ with torch.no_grad():
263
+ features = self.feature_extract(wav)
264
+ select_feature = features[self.feature_selection]
265
+ if isinstance(select_feature, (list, tuple)):
266
+ return len(select_feature)
267
+ else:
268
+ return 1
269
+
270
+ def get_feat(self, x):
271
+ if self.update_extract:
272
+ x = self.feature_extract([sample for sample in x])
273
+ else:
274
+ with torch.no_grad():
275
+ if self.feat_type == "fbank" or self.feat_type == "mfcc":
276
+ x = self.feature_extract(x) + 1e-6 # B x feat_dim x time_len
277
+ else:
278
+ x = self.feature_extract([sample for sample in x])
279
+
280
+ if self.feat_type == "fbank":
281
+ x = x.log()
282
+
283
+ if self.feat_type != "fbank" and self.feat_type != "mfcc":
284
+ x = x[self.feature_selection]
285
+ if isinstance(x, (list, tuple)):
286
+ x = torch.stack(x, dim=0)
287
+ else:
288
+ x = x.unsqueeze(0)
289
+ norm_weights = F.softmax(self.feature_weight, dim=-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
290
+ x = (norm_weights * x).sum(dim=0)
291
+ x = torch.transpose(x, 1, 2) + 1e-6
292
+
293
+ x = self.instance_norm(x)
294
+ return x
295
+
296
+ def forward(self, x):
297
+ x = self.get_feat(x)
298
+
299
+ out1 = self.layer1(x)
300
+ out2 = self.layer2(out1)
301
+ out3 = self.layer3(out2)
302
+ out4 = self.layer4(out3)
303
+
304
+ out = torch.cat([out2, out3, out4], dim=1)
305
+ out = F.relu(self.conv(out))
306
+ out = self.bn(self.pooling(out))
307
+ out = self.linear(out)
308
+
309
+ return out
310
+
311
+
312
+ def ECAPA_TDNN_SMALL(
313
+ feat_dim,
314
+ emb_dim=256,
315
+ feat_type="wavlm_large",
316
+ sr=16000,
317
+ feature_selection="hidden_states",
318
+ update_extract=False,
319
+ config_path=None,
320
+ ):
321
+ return ECAPA_TDNN(
322
+ feat_dim=feat_dim,
323
+ channels=512,
324
+ emb_dim=emb_dim,
325
+ feat_type=feat_type,
326
+ sr=sr,
327
+ feature_selection=feature_selection,
328
+ update_extract=update_extract,
329
+ config_path=config_path,
330
+ )