swc2 commited on
Commit
35dd64c
·
1 Parent(s): 0be9a04

change to v3.0

Browse files
config/config.yaml CHANGED
@@ -17,7 +17,7 @@ model:
17
 
18
 
19
  test:
20
- checkpoint: "./ckpt/v3.0.pt.tar"
21
  gpu: -1
22
  sample_rate: 16000
23
 
 
17
 
18
 
19
  test:
20
+ checkpoint: "./ckpt/v2.0.pt.tar"
21
  gpu: -1
22
  sample_rate: 16000
23
 
config/config_ira.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ model:
3
+ _target_: model.spex_plus.SpEx_Plus # str, model class name
4
+ L1: 40
5
+ L2: 160
6
+ L3: 320
7
+ N: 256
8
+ B: 8
9
+ O: 256
10
+ P: 512
11
+ Q: 3
12
+ num_spks: 2350 # with speed perturbation 470 -> 1410
13
+ spk_embed_dim: 256
14
+ causal: false
15
+ is_innorm: true
16
+ fusion_type: 'cat' #cat mul film att
17
+
18
+
19
+ test:
20
+ checkpoint: "./ckpt/v3.0.pt.tar"
21
+ gpu: -1
22
+ sample_rate: 16000
23
+
decode.py CHANGED
@@ -31,7 +31,7 @@ class NnetComputer(object):
31
  aux = aux.unsqueeze(0)
32
  print("raw",raw.shape)
33
  print("aux",aux.shape)
34
- sps, sps2, sps3, spk_pred = self.nnet(raw, aux, aux_len)
35
  sp_samps = np.squeeze(sps.detach().cpu().numpy())
36
  return sp_samps
37
 
 
31
  aux = aux.unsqueeze(0)
32
  print("raw",raw.shape)
33
  print("aux",aux.shape)
34
+ sps,spk_pred,emb = self.nnet(raw, aux, aux_len)
35
  sp_samps = np.squeeze(sps.detach().cpu().numpy())
36
  return sp_samps
37
 
model/spex_plus_plus.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
8
+ from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
9
+ import warnings
10
+
11
+ # inference aux_len
12
+
13
+
14
+ class SpEx_Plus_Double(nn.Module):
15
+ def __init__(self,
16
+ L1=20,
17
+ L2=80,
18
+ L3=160,
19
+ N=256,
20
+ B=8,
21
+ O=256,
22
+ P=512,
23
+ Q=3,
24
+ num_spks=101,
25
+ spk_embed_dim=256,
26
+ causal=False,
27
+ norm_type='gLN',
28
+ fusion_type='cat',
29
+ is_innorm=False,
30
+ ):
31
+ super(SpEx_Plus_Double, self).__init__()
32
+
33
+ # n x S => n x N x T, S = 4s*8000 = 32000
34
+
35
+ self.L1 = L1
36
+ self.L2 = L2
37
+ self.L3 = L3
38
+ self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0)
39
+ self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
40
+ self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
41
+ # before repeat blocks, always cLN
42
+
43
+ self.instancenorm = nn.InstanceNorm1d(N)
44
+
45
+ self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
46
+ self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
47
+ self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
48
+ self.num_spks = num_spks
49
+ self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
50
+ self.is_innorm = is_innorm
51
+
52
+ if causal and norm_type not in ["cgLN", "cLN"]:
53
+ norm_type = "cLN"
54
+ warnings.warn(
55
+ "In causal configuration cumulative layer normalization (cgLN)"
56
+ "or channel-wise layer normalization (chanLN) "
57
+ f"must be used. Changing {norm_type} to cLN"
58
+ )
59
+
60
+ self.speaker_encoder = Speaker_Model(
61
+ L1=L1,
62
+ L2=L2,
63
+ L3=L3,
64
+ N=N,
65
+ O=O,
66
+ P=P,
67
+ spk_embed_dim=spk_embed_dim,
68
+ )
69
+
70
+ self.extractor = Extractor(
71
+ L1=L1,
72
+ L2=L2,
73
+ L3=L3,
74
+ N=N,
75
+ B=B,
76
+ O=O,
77
+ P=P,
78
+ Q=Q,
79
+ num_spks=num_spks,
80
+ spk_embed_dim=spk_embed_dim,
81
+ causal=causal,
82
+ fusion_type=fusion_type,
83
+ norm_type=norm_type,
84
+ )
85
+
86
+ self.frameconv1 = Conv1D(2*N, N, 1)
87
+ self.frameconv2 = Conv1D(2*N, N, 1)
88
+ self.frameconv3 = Conv1D(2*N, N, 1)
89
+
90
+ self.fusion1 = nn.Parameter(th.tensor(0.8))
91
+ self.fusion2 = nn.Parameter(th.tensor(0.1))
92
+ self.fusion3 = nn.Parameter(th.tensor(0.1))
93
+
94
+ def align_to_w(self,frame, w):
95
+ diff = frame.shape[-1] - w.shape[-1]
96
+ if diff > 0:
97
+ frame = frame[..., :w.shape[-1]] # 裁剪
98
+ elif diff < 0:
99
+ frame = th.nn.functional.pad(frame, (0, -diff)) # 补零
100
+ return frame, w # w 保持不动
101
+
102
+ def ira(self, est1, aux, aux_len, xlen1, xlen2, xlen3, w1 ,w2, w3):
103
+ ### 2
104
+ concat_aux = th.cat((est1, aux), dim=1)
105
+ concat_aux_len = aux_len + xlen1
106
+
107
+
108
+ concat_aux_w1 = F.relu(self.encoder_1d_short(concat_aux))
109
+ concat_aux_T_shape = concat_aux_w1.shape[-1]
110
+ concat_aux_len1 = concat_aux.shape[-1]
111
+ concat_aux_len2 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L2
112
+ concat_aux_len3 = (concat_aux_T_shape - 1) * (self.L1 // 2) + self.L3
113
+ concat_aux_w2 = F.relu(self.encoder_1d_middle(F.pad(concat_aux, (0, concat_aux_len2 - concat_aux_len1), "constant", 0)))
114
+ concat_aux_w3 = F.relu(self.encoder_1d_long(F.pad(concat_aux, (0, concat_aux_len3 - concat_aux_len1), "constant", 0)))
115
+ concat_aux = self.speaker_encoder(th.cat([concat_aux_w1, concat_aux_w2, concat_aux_w3], 1), concat_aux_len)
116
+
117
+ frame1 = F.relu(self.encoder_1d_short(est1))
118
+ frame2 = F.relu(self.encoder_1d_middle(F.pad(est1, (0, xlen2 - xlen1), "constant", 0)))
119
+ frame3 = F.relu(self.encoder_1d_long(F.pad(est1, (0, xlen3 - xlen1), "constant", 0)))
120
+
121
+ if self.is_innorm:
122
+ frame1 = self.instancenorm(frame1)
123
+ frame2 = self.instancenorm(frame2)
124
+ frame3 = self.instancenorm(frame3)
125
+
126
+ frame1, w1 = self.align_to_w(frame1, w1)
127
+ frame2, w2 = self.align_to_w(frame2, w2)
128
+ frame3, w3 = self.align_to_w(frame3, w3)
129
+
130
+ # frame2, w2 长度不匹配 4098 != 4099
131
+
132
+ # print("frame2 shape: ", frame2.shape)
133
+ # print("w2 shape: ", w2.shape)
134
+ concat1 = self.frameconv1(th.cat([frame1, w1], 1))
135
+ concat2 = self.frameconv2(th.cat([frame2, w2], 1))
136
+ concat3 = self.frameconv3(th.cat([frame3, w3], 1))
137
+
138
+ mask1, mask2, mask3 = self.extractor(concat1, concat2, concat3, concat_aux)
139
+
140
+ F1 = concat1 * mask1
141
+ F2 = concat2 * mask2
142
+ F3 = concat3 * mask3
143
+
144
+ f1 = self.decoder_1d_short(F1)
145
+ xlen1 = f1.shape[-1]
146
+ f2 = self.decoder_1d_middle(F2)[:, :xlen1]
147
+ f3 = self.decoder_1d_long(F3)[:, :xlen1]
148
+
149
+ est2 = self.fusion1 * f1 + self.fusion2 * f2 + self.fusion3 * f3
150
+
151
+ return est2
152
+
153
+
154
+
155
+ def forward(self, x, aux, aux_len):
156
+ if x.dim() >= 3:
157
+ raise RuntimeError(
158
+ "{} accept 1/2D tensor as input, but got {:d}".format(
159
+ self.__name__, x.dim()))
160
+ # when inference, only one utt
161
+ if x.dim() == 1:
162
+ x = th.unsqueeze(x, 0)
163
+ # n x 1 x S => n x N x T
164
+
165
+
166
+ w1 = F.relu(self.encoder_1d_short(x))
167
+ T = w1.shape[-1]
168
+ xlen1 = x.shape[-1]
169
+ xlen2 = (T - 1) * (self.L1 // 2) + self.L2
170
+ xlen3 = (T - 1) * (self.L1 // 2) + self.L3
171
+ w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
172
+ w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
173
+ # n x 3N x T
174
+ # speaker encoder (share params from speech encoder)
175
+
176
+ if self.is_innorm:
177
+ w1 = self.instancenorm(w1)
178
+ w2 = self.instancenorm(w2)
179
+ w3 = self.instancenorm(w3)
180
+
181
+ aux_w1 = F.relu(self.encoder_1d_short(aux))
182
+ aux_T_shape = aux_w1.shape[-1]
183
+ aux_len1 = aux.shape[-1]
184
+ aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
185
+ aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
186
+ aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
187
+ aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
188
+
189
+ aux = self.speaker_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1), aux_len)
190
+
191
+
192
+ m1, m2, m3 = self.extractor(w1, w2, w3, aux)
193
+
194
+ S1 = w1 * m1
195
+ S2 = w2 * m2
196
+ S3 = w3 * m3
197
+
198
+ s1 = F.pad(self.decoder_1d_short(S1), (0, max(0, xlen1 - self.decoder_1d_short(S1).shape[1])))[:, :xlen1]
199
+ s2 = self.decoder_1d_middle(S2)[:, :xlen1]
200
+ s3 = self.decoder_1d_long(S3)[:, :xlen1]
201
+
202
+ est1 = self.fusion1 * s1 + self.fusion2 * s2 + self.fusion3 * s3
203
+
204
+
205
+ est2 = self.ira(est1, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
206
+
207
+ est3 = self.ira(est2, aux, aux_len,xlen1, xlen2, xlen3, w1, w2, w3)
208
+
209
+ return est3,self.pred_linear(aux), aux
210
+
211
+ class Extractor(nn.Module):
212
+ def __init__(self,
213
+ L1=20,
214
+ L2=80,
215
+ L3=160,
216
+ N=256,
217
+ B=8,
218
+ O=256,
219
+ P=512,
220
+ Q=3,
221
+ num_spks=101,
222
+ spk_embed_dim=256,
223
+ causal=False,
224
+ fusion_type='cat',
225
+ norm_type='gLN',
226
+ ):
227
+ super(Extractor, self).__init__()
228
+ # n x N x T => n x O x T
229
+ self.ln = ChannelwiseLayerNorm(3*N)
230
+ self.proj = Conv1D(3*N, O, 1)
231
+ self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
232
+ self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
233
+ self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
234
+ self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
235
+ self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
236
+ self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
237
+ self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1,fusion_type=fusion_type,norm_type=norm_type)
238
+ self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal,norm_type=norm_type)
239
+ # n x O x T => n x N x T
240
+ self.mask1 = Conv1D(O, N, 1)
241
+ self.mask2 = Conv1D(O, N, 1)
242
+ self.mask3 = Conv1D(O, N, 1)
243
+
244
+ def _build_stacks(self, num_blocks, **block_kwargs):
245
+ """
246
+ Stack B numbers of TCN block, the first TCN block takes the speaker embedding
247
+ """
248
+ blocks = [
249
+ TCNBlock(**block_kwargs, dilation=(2**b))
250
+ for b in range(1,num_blocks)
251
+ ]
252
+ return nn.Sequential(*blocks)
253
+
254
+ def forward(self, w1, w2, w3, aux):
255
+
256
+ y = self.ln(th.cat([w1, w2, w3], 1))
257
+ # n x O x T
258
+ y = self.proj(y)
259
+ y = self.conv_block_1(y, aux)
260
+ y = self.conv_block_1_other(y)
261
+ y = self.conv_block_2(y, aux)
262
+ y = self.conv_block_2_other(y)
263
+ y = self.conv_block_3(y, aux)
264
+ y = self.conv_block_3_other(y)
265
+ y = self.conv_block_4(y, aux)
266
+ y = self.conv_block_4_other(y)
267
+
268
+ # n x N x T
269
+ m1 = F.relu(self.mask1(y))
270
+ m2 = F.relu(self.mask2(y))
271
+ m3 = F.relu(self.mask3(y))
272
+
273
+
274
+
275
+ return m1, m2, m3
276
+
277
+
278
+
279
+ class Speaker_Model(nn.Module):
280
+ def __init__(self,
281
+ L1=20,
282
+ L2=80,
283
+ L3=160,
284
+ N=256,
285
+ O=256,
286
+ P=512,
287
+ spk_embed_dim=256,
288
+ ):
289
+ super(Speaker_Model, self).__init__()
290
+ self.L1 = L1
291
+ self.L2 = L2
292
+ self.L3 = L3
293
+ self.spk_encoder = nn.Sequential(
294
+ ChannelwiseLayerNorm(3*N),
295
+ Conv1D(3*N, O, 1),
296
+ ResBlock(O, O),
297
+ ResBlock(O, P),
298
+ ResBlock(P, P),
299
+ Conv1D(P, spk_embed_dim, 1),
300
+ )
301
+ def forward(self, aux, aux_len):
302
+ aux = self.spk_encoder(aux)
303
+ aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
304
+ aux_T = ((aux_T // 3) // 3) // 3
305
+ aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
306
+ return aux