swc2 commited on
Commit
c7ba938
·
1 Parent(s): bb19896
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch as th
3
+ import numpy as np
4
+ from nnet.spex_plus import SpEx_Plus
5
+ from utils.logger import get_logger
6
+ from utils.audio import WaveReader, write_wav
7
+
8
+ logger = get_logger(__name__)
9
+
10
+ class NnetComputer(object):
11
+ def __init__(self, cpt_dir, gpuid, nnet_conf):
12
+ self.device = th.device("cuda:{}".format(gpuid)) if gpuid >= 0 else th.device("cpu")
13
+ nnet = self._load_nnet(cpt_dir, nnet_conf)
14
+ self.nnet = nnet.to(self.device) if gpuid >= 0 else nnet
15
+ # set eval model
16
+ self.nnet.eval()
17
+
18
+ def _load_nnet(self, cpt_dir, nnet_conf):
19
+ nnet = SpEx_Plus(**nnet_conf)
20
+ cpt_fname = os.path.join(cpt_dir, "59.pt.tar")
21
+ cpt = th.load(cpt_fname, map_location="cpu")
22
+ nnet.load_state_dict(cpt["model_state_dict"])
23
+ logger.info("Load checkpoint from {}, epoch {:d}".format(
24
+ cpt_fname, cpt["epoch"]))
25
+ return nnet
26
+
27
+ def compute(self, samps, aux_samps, aux_samps_len):
28
+ with th.no_grad():
29
+ raw = th.tensor(samps, dtype=th.float32, device=self.device)
30
+ aux = th.tensor(aux_samps, dtype=th.float32, device=self.device)
31
+ aux_len = th.tensor(aux_samps_len, dtype=th.float32, device=self.device)
32
+ aux = aux.unsqueeze(0)
33
+ sps, sps2, sps3, spk_pred = self.nnet(raw, aux, aux_len)
34
+ sp_samps = np.squeeze(sps.detach().cpu().numpy())
35
+ return sp_samps
36
+
37
+ def compute_output(input_audio, use_gpu, checkpoint, output_dir):
38
+ # Prepare mix_input and aux_input based on the input_audio
39
+ mix_input = {} # Modify this to include your mix_input
40
+ aux_input = {} # Modify this to include your aux_input
41
+
42
+ # Set GPU index based on the user's choice
43
+ gpu_index = -1 if not use_gpu else 0
44
+
45
+ # Run the computation
46
+ nnet_conf = {
47
+ "L1": int(0.0025 * 16000),
48
+ "L2": int(0.01 * 16000),
49
+ "L3": int(0.02 * 16000),
50
+ "N": 256,
51
+ "B": 8,
52
+ "O": 256,
53
+ "P": 512,
54
+ "Q": 3,
55
+ "num_spks": 395,
56
+ "spk_embed_dim": 256,
57
+ "causal": False
58
+ }
59
+ computer = NnetComputer(checkpoint, gpu_index, nnet_conf)
60
+ for key, mix_samps in mix_input:
61
+ aux_samps = aux_input[key]
62
+ logger.info("Compute on utterance {}...".format(key))
63
+ samps = computer.compute(mix_samps, aux_samps, len(aux_samps))
64
+ norm = np.linalg.norm(mix_samps, np.inf)
65
+ samps = samps[:mix_samps.size]
66
+ # Normalize the output
67
+ samps = samps * norm / np.max(np.abs(samps))
68
+ # Write output to the specified directory
69
+ write_wav(os.path.join(output_dir, "{}.wav".format(key)), samps, sample_rate=args.sample_rate)
70
+ logger.info("Compute over {:d} utterances".format(len(mix_input)))
71
+
72
+ # Define the Gradio interface
73
+ inputs = [
74
+ gr.Audio(name="input_audio", label="Input Audio"),
75
+ gr.Checkbox(name="use_gpu", label="Use GPU"),
76
+ gr.TextInput(name="checkpoint", label="Checkpoint Directory"),
77
+ gr.TextInput(name="output_dir", label="Output Directory")
78
+ ]
79
+ output = gr.Interface(
80
+ fn=compute_output,
81
+ inputs=inputs,
82
+ outputs=None,
83
+ title="Audio Processing with Neural Network",
84
+ description="Process audio input using a neural network model.",
85
+ theme="compact"
86
+ )
87
+ output.launch()
nnet/ResNet34.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+ '''
4
+ Fast ResNet
5
+ https://arxiv.org/pdf/2003.11982.pdf
6
+ '''
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from torch.nn import Parameter
12
+ try:
13
+ from .pooling import *
14
+ except:
15
+ from pooling import *
16
+
17
+ class SEBasicBlock(nn.Module):
18
+ expansion = 1
19
+
20
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
21
+ super(SEBasicBlock, self).__init__()
22
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
23
+ self.bn1 = nn.BatchNorm2d(planes)
24
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
25
+ self.bn2 = nn.BatchNorm2d(planes)
26
+ self.relu = nn.ReLU(inplace=True)
27
+ self.se = SELayer(planes, reduction)
28
+ self.downsample = downsample
29
+ self.stride = stride
30
+
31
+ def forward(self, x):
32
+ residual = x
33
+
34
+ out = self.conv1(x)
35
+ out = self.relu(out)
36
+ out = self.bn1(out)
37
+
38
+ out = self.conv2(out)
39
+ out = self.bn2(out)
40
+ out = self.se(out)
41
+
42
+ if self.downsample is not None:
43
+ residual = self.downsample(x)
44
+
45
+ out += residual
46
+ out = self.relu(out)
47
+ return out
48
+
49
+
50
+ class SEBottleneck(nn.Module):
51
+ expansion = 4
52
+
53
+ def __init__(self, inplanes, planes, stride=1, downsample=None, reduction=8):
54
+ super(SEBottleneck, self).__init__()
55
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
56
+ self.bn1 = nn.BatchNorm2d(planes)
57
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
58
+ padding=1, bias=False)
59
+ self.bn2 = nn.BatchNorm2d(planes)
60
+ self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
61
+ self.bn3 = nn.BatchNorm2d(planes * 4)
62
+ self.relu = nn.ReLU(inplace=True)
63
+ self.se = SELayer(planes * 4, reduction)
64
+ self.downsample = downsample
65
+ self.stride = stride
66
+
67
+ def forward(self, x):
68
+ residual = x
69
+
70
+ out = self.conv1(x)
71
+ out = self.bn1(out)
72
+ out = self.relu(out)
73
+
74
+ out = self.conv2(out)
75
+ out = self.bn2(out)
76
+ out = self.relu(out)
77
+
78
+ out = self.conv3(out)
79
+ out = self.bn3(out)
80
+ out = self.se(out)
81
+
82
+ if self.downsample is not None:
83
+ residual = self.downsample(x)
84
+
85
+ out += residual
86
+ out = self.relu(out)
87
+
88
+ return out
89
+
90
+
91
+ class SELayer(nn.Module):
92
+ def __init__(self, channel, reduction=8):
93
+ super(SELayer, self).__init__()
94
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
95
+ self.fc = nn.Sequential(
96
+ nn.Linear(channel, channel // reduction),
97
+ nn.ReLU(inplace=True),
98
+ nn.Linear(channel // reduction, channel),
99
+ nn.Sigmoid()
100
+ )
101
+
102
+ def forward(self, x):
103
+ b, c, _, _ = x.size()
104
+ y = self.avg_pool(x).view(b, c)
105
+ y = self.fc(y).view(b, c, 1, 1)
106
+ return x * y
107
+
108
+
109
+ class ResNetSE(nn.Module):
110
+ def __init__(self, block, layers, num_filters, embedding_dim, n_mels=80, pooling_type="TSP", **kwargs):
111
+ super(ResNetSE, self).__init__()
112
+
113
+ self.inplanes = num_filters[0]
114
+ self.conv1 = nn.Conv2d(1, num_filters[0] , kernel_size=3, stride=(1, 1), padding=1,
115
+ bias=False)
116
+ self.bn1 = nn.BatchNorm2d(num_filters[0])
117
+ self.relu = nn.ReLU(inplace=True)
118
+
119
+ self.layer1 = self._make_layer(block, num_filters[0], layers[0])
120
+ self.layer2 = self._make_layer(block, num_filters[1], layers[1], stride=(2, 2))
121
+ self.layer3 = self._make_layer(block, num_filters[2], layers[2], stride=(2, 2))
122
+ self.layer4 = self._make_layer(block, num_filters[3], layers[3], stride=(2, 2))
123
+
124
+ out_dim = num_filters[3] * block.expansion * (n_mels//8)
125
+
126
+ if pooling_type == "Temporal_Average_Pooling" or pooling_type == "TAP":
127
+ self.pooling = Temporal_Average_Pooling()
128
+ self.bn2 = nn.BatchNorm1d(out_dim)
129
+ self.fc = nn.Linear(out_dim, embedding_dim)
130
+ self.bn3 = nn.BatchNorm1d(embedding_dim)
131
+
132
+ elif pooling_type == "Temporal_Statistics_Pooling" or pooling_type == "TSP":
133
+ self.pooling = Temporal_Statistics_Pooling()
134
+ self.bn2 = nn.BatchNorm1d(out_dim * 2)
135
+ self.fc = nn.Linear(out_dim * 2, embedding_dim)
136
+ self.bn3 = nn.BatchNorm1d(embedding_dim)
137
+
138
+ elif pooling_type == "Self_Attentive_Pooling" or pooling_type == "SAP":
139
+ self.pooling = Self_Attentive_Pooling(out_dim)
140
+ self.bn2 = nn.BatchNorm1d(out_dim)
141
+ self.fc = nn.Linear(out_dim, embedding_dim)
142
+ self.bn3 = nn.BatchNorm1d(embedding_dim)
143
+
144
+ elif pooling_type == "Attentive_Statistics_Pooling" or pooling_type == "ASP":
145
+ self.pooling = Attentive_Statistics_Pooling(out_dim)
146
+ self.bn2 = nn.BatchNorm1d(out_dim * 2)
147
+ self.fc = nn.Linear(out_dim * 2, embedding_dim)
148
+ self.bn3 = nn.BatchNorm1d(embedding_dim)
149
+
150
+ else:
151
+ raise ValueError('{} pooling type is not defined'.format(pooling_type))
152
+
153
+
154
+ for m in self.modules():
155
+ if isinstance(m, nn.Conv2d):
156
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
157
+ elif isinstance(m, nn.BatchNorm2d):
158
+ nn.init.constant_(m.weight, 1)
159
+ nn.init.constant_(m.bias, 0)
160
+
161
+ def _make_layer(self, block, planes, blocks, stride=1):
162
+ downsample = None
163
+ if stride != 1 or self.inplanes != planes * block.expansion:
164
+ downsample = nn.Sequential(
165
+ nn.Conv2d(self.inplanes, planes * block.expansion,
166
+ kernel_size=1, stride=stride, bias=False),
167
+ nn.BatchNorm2d(planes * block.expansion),
168
+ )
169
+
170
+ layers = []
171
+ layers.append(block(self.inplanes, planes, stride, downsample))
172
+ self.inplanes = planes * block.expansion
173
+ for i in range(1, blocks):
174
+ layers.append(block(self.inplanes, planes))
175
+
176
+ return nn.Sequential(*layers)
177
+
178
+ def forward(self, x):
179
+ x = x.unsqueeze(1)
180
+ x = self.conv1(x)
181
+ x = self.bn1(x)
182
+ x = self.relu(x)
183
+
184
+ x = self.layer1(x)
185
+ x = self.layer2(x)
186
+ x = self.layer3(x)
187
+ x = self.layer4(x)
188
+
189
+ x = x.reshape(x.shape[0], -1, x.shape[-1])
190
+
191
+ x = self.pooling(x)
192
+ x = self.bn2(x)
193
+ x = torch.flatten(x, 1)
194
+ x = self.fc(x)
195
+ x = self.bn3(x)
196
+ return x
197
+
198
+
199
+ def Speaker_Encoder(embedding_dim=256, **kwargs):
200
+ # Number of filters
201
+ num_filters = [32, 64, 128, 256]
202
+ model = ResNetSE(SEBasicBlock, [3, 4, 6, 3], num_filters, embedding_dim, **kwargs)
203
+ return model
204
+
205
+ if __name__ == '__main__':
206
+ model = Speaker_Encoder()
207
+ total = sum([param.nelement() for param in model.parameters()])
208
+ print(total/1e6)
209
+ data = torch.randn(10, 80, 100)
210
+ out = model(data)
211
+ print(data.shape)
212
+ print(out.shape)
213
+
nnet/__init__.py ADDED
File without changes
nnet/cnns.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
7
+
8
+ class Conv1D(nn.Conv1d):
9
+ """
10
+ 1D Conv based on nn.Conv1d for 2D or 3D tensor
11
+ Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
12
+ Output: Default 3D tensor with [N, C_out, L_out]
13
+ If C_out=1 and squeeze is true, return 2D tensor
14
+ """
15
+
16
+ def __init__(self, *args, **kwargs):
17
+ super(Conv1D, self).__init__(*args, **kwargs)
18
+
19
+ def forward(self, x, squeeze=False):
20
+ if x.dim() not in [2, 3]:
21
+ raise RuntimeError("{} require a 2/3D tensor input".format(
22
+ self.__name__))
23
+ x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
24
+ if squeeze:
25
+ x = th.squeeze(x)
26
+ return x
27
+
28
+
29
+ class ConvTrans1D(nn.ConvTranspose1d):
30
+ """
31
+ 1D Transposed Conv based on nn.ConvTranspose1d for 2D or 3D tensor
32
+ Input: 2D or 3D tensor with [N, L_in] or [N, C_in, L_in]
33
+ Output: 2D tensor with [N, L_out]
34
+ """
35
+
36
+ def __init__(self, *args, **kwargs):
37
+ super(ConvTrans1D, self).__init__(*args, **kwargs)
38
+
39
+ def forward(self, x):
40
+ if x.dim() not in [2, 3]:
41
+ raise RuntimeError("{} require a 2/3D tensor input".format(
42
+ self.__name__))
43
+ x = super().forward(x if x.dim() == 3 else th.unsqueeze(x, 1))
44
+
45
+ # squeeze the channel dimension 1 after reconstructing the signal
46
+ return th.squeeze(x, 1)
47
+
48
+ class TCNBlock(nn.Module):
49
+ """
50
+ Temporal convolutional network block,
51
+ 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
52
+ Input: 3D tensor with [N, C_in, L_in]
53
+ Output: 3D tensor with [N, C_out, L_out]
54
+ """
55
+
56
+ def __init__(self,
57
+ in_channels=256,
58
+ conv_channels=512,
59
+ kernel_size=3,
60
+ dilation=1,
61
+ causal=False):
62
+ super(TCNBlock, self).__init__()
63
+ self.conv1x1 = Conv1D(in_channels, conv_channels, 1)
64
+ self.prelu1 = nn.PReLU()
65
+ self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
66
+ ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
67
+ dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
68
+ dilation * (kernel_size - 1))
69
+ self.dconv = nn.Conv1d(
70
+ conv_channels,
71
+ conv_channels,
72
+ kernel_size,
73
+ groups=conv_channels,
74
+ padding=dconv_pad,
75
+ dilation=dilation,
76
+ bias=True)
77
+ self.prelu2 = nn.PReLU()
78
+ self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
79
+ ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
80
+ self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
81
+ self.causal = causal
82
+ self.dconv_pad = dconv_pad
83
+
84
+ def forward(self, x):
85
+ y = self.conv1x1(x)
86
+ y = self.norm1(self.prelu1(y))
87
+ y = self.dconv(y)
88
+ if self.causal:
89
+ y = y[:, :, :-self.dconv_pad]
90
+ y = self.norm2(self.prelu2(y))
91
+ y = self.sconv(y)
92
+ y += x
93
+ return y
94
+
95
+ class TCNBlock_Spk(nn.Module):
96
+ """
97
+ Temporal convolutional network block,
98
+ 1x1Conv - PReLU - Norm - DConv - PReLU - Norm - SConv
99
+ The first tcn block takes additional speaker embedding as inputs
100
+ Input: 3D tensor with [N, C_in, L_in]
101
+ Input Speaker Embedding: 2D tensor with [N, D]
102
+ Output: 3D tensor with [N, C_out, L_out]
103
+ """
104
+
105
+ def __init__(self,
106
+ in_channels=256,
107
+ spk_embed_dim=100,
108
+ conv_channels=512,
109
+ kernel_size=3,
110
+ dilation=1,
111
+ causal=False):
112
+ super(TCNBlock_Spk, self).__init__()
113
+ self.conv1x1 = Conv1D(in_channels+spk_embed_dim, conv_channels, 1)
114
+ self.prelu1 = nn.PReLU()
115
+ self.norm1 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
116
+ ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
117
+ dconv_pad = (dilation * (kernel_size - 1)) // 2 if not causal else (
118
+ dilation * (kernel_size - 1))
119
+ self.dconv = nn.Conv1d(
120
+ conv_channels,
121
+ conv_channels,
122
+ kernel_size,
123
+ groups=conv_channels,
124
+ padding=dconv_pad,
125
+ dilation=dilation,
126
+ bias=True)
127
+ self.prelu2 = nn.PReLU()
128
+ self.norm2 = GlobalLayerNorm(conv_channels, elementwise_affine=True) if not causal else (
129
+ ChannelwiseLayerNorm(conv_channels, elementwise_affine=True))
130
+ self.sconv = nn.Conv1d(conv_channels, in_channels, 1, bias=True)
131
+ self.causal = causal
132
+ self.dconv_pad = dconv_pad
133
+ self.dilation = dilation
134
+
135
+ def forward(self, x, aux):
136
+ # Repeatedly concated speaker embedding aux to each frame of the representation x
137
+ T = x.shape[-1]
138
+ aux = th.unsqueeze(aux, -1)
139
+ aux = aux.repeat(1,1,T)
140
+ y = th.cat([x, aux], 1)
141
+ y = self.conv1x1(y)
142
+ y = self.norm1(self.prelu1(y))
143
+ y = self.dconv(y)
144
+ if self.causal:
145
+ y = y[:, :, :-self.dconv_pad]
146
+ y = self.norm2(self.prelu2(y))
147
+ y = self.sconv(y)
148
+ y += x
149
+ return y
150
+
151
+ class ResBlock(nn.Module):
152
+ """
153
+ Resnet block for speaker encoder to obtain speaker embedding
154
+ ref to
155
+ https://github.com/fatchord/WaveRNN/blob/master/models/fatchord_version.py
156
+ and
157
+ https://github.com/Jungjee/RawNet/blob/master/PyTorch/model_RawNet.py
158
+ """
159
+ def __init__(self, in_dims, out_dims):
160
+ super(ResBlock, self).__init__()
161
+ self.conv1 = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
162
+ self.conv2 = nn.Conv1d(out_dims, out_dims, kernel_size=1, bias=False)
163
+ self.batch_norm1 = nn.BatchNorm1d(out_dims)
164
+ self.batch_norm2 = nn.BatchNorm1d(out_dims)
165
+ self.prelu1 = nn.PReLU()
166
+ self.prelu2 = nn.PReLU()
167
+ self.maxpool = nn.MaxPool1d(3)
168
+ if in_dims != out_dims:
169
+ self.downsample = True
170
+ self.conv_downsample = nn.Conv1d(in_dims, out_dims, kernel_size=1, bias=False)
171
+ else:
172
+ self.downsample = False
173
+
174
+ def forward(self, x):
175
+ y = self.conv1(x)
176
+ y = self.batch_norm1(y)
177
+ y = self.prelu1(y)
178
+ y = self.conv2(y)
179
+ y = self.batch_norm2(y)
180
+ if self.downsample:
181
+ y += self.conv_downsample(x)
182
+ else:
183
+ y += x
184
+ y = self.prelu2(y)
185
+ return self.maxpool(y)
186
+
nnet/norm.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+
6
+ class ChannelwiseLayerNorm(nn.LayerNorm):
7
+ """
8
+ Channel-wise layer normalization based on nn.LayerNorm
9
+ Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
10
+ Output: 3D tensor with same shape
11
+ """
12
+
13
+ def __init__(self, *args, **kwargs):
14
+ super(ChannelwiseLayerNorm, self).__init__(*args, **kwargs)
15
+
16
+ def forward(self, x):
17
+ if x.dim() != 3:
18
+ raise RuntimeError("{} requires a 3D tensor input".format(
19
+ self.__name__))
20
+ x = th.transpose(x, 1, 2)
21
+ x = super().forward(x)
22
+ x = th.transpose(x, 1, 2)
23
+ return x
24
+
25
+ class GlobalLayerNorm(nn.Module):
26
+ """
27
+ Global layer normalization
28
+ Input: 3D tensor with [batch_size(N), channel_size(C), frame_num(T)]
29
+ Output: 3D tensor with same shape
30
+ """
31
+
32
+ def __init__(self, dim, eps=1e-05, elementwise_affine=True):
33
+ super(GlobalLayerNorm, self).__init__()
34
+ self.eps = eps
35
+ self.normalized_dim = dim
36
+ self.elementwise_affine = elementwise_affine
37
+ if elementwise_affine:
38
+ self.beta = nn.Parameter(th.zeros(dim, 1))
39
+ self.gamma = nn.Parameter(th.ones(dim, 1))
40
+ else:
41
+ self.register_parameter("weight", None)
42
+ self.register_parameter("bias", None)
43
+
44
+ def forward(self, x):
45
+ if x.dim() != 3:
46
+ raise RuntimeError("{} requires a 3D tensor input".format(
47
+ self.__name__))
48
+ # calculate the mean, variance over the channel and time dimensions
49
+ mean = th.mean(x, (1, 2), keepdim=True)
50
+ var = th.mean((x - mean)**2, (1, 2), keepdim=True)
51
+ if self.elementwise_affine:
52
+ x = self.gamma * (x - mean) / th.sqrt(var + self.eps) + self.beta
53
+ else:
54
+ x = (x - mean) / th.sqrt(var + self.eps)
55
+ return x
56
+
57
+ def extra_repr(self):
58
+ return "{normalized_dim}, eps={eps}, " \
59
+ "elementwise_affine={elementwise_affine}".format(**self.__dict__)
nnet/pooling.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from speechbrain.lobes.models.ECAPA_TDNN import AttentiveStatisticsPooling
5
+
6
+ class Temporal_Average_Pooling(nn.Module):
7
+ def __init__(self, **kwargs):
8
+ """TAP
9
+ Paper: Multi-Task Learning with High-Order Statistics for X-vector based Text-Independent Speaker Verification
10
+ Link: https://arxiv.org/pdf/1903.12058.pdf
11
+ """
12
+ super(Temporal_Average_Pooling, self).__init__()
13
+
14
+ def forward(self, x):
15
+ """Computes Temporal Average Pooling Module
16
+ Args:
17
+ x (torch.Tensor): Input tensor (#batch, channels, frames).
18
+ Returns:
19
+ torch.Tensor: Output tensor (#batch, channels)
20
+ """
21
+ x = torch.mean(x, axis=2)
22
+ return x
23
+
24
+
25
+ class Temporal_Statistics_Pooling(nn.Module):
26
+ def __init__(self, **kwargs):
27
+ """TSP
28
+ Paper: X-vectors: Robust DNN Embeddings for Speaker Recognition
29
+ Link: http://www.danielpovey.com/files/2018_icassp_xvectors.pdf
30
+ """
31
+ super(Temporal_Statistics_Pooling, self).__init__()
32
+
33
+ def forward(self, x):
34
+ """Computes Temporal Statistics Pooling Module
35
+ Args:
36
+ x (torch.Tensor): Input tensor (#batch, channels, frames).
37
+ Returns:
38
+ torch.Tensor: Output tensor (#batch, channels*2)
39
+ """
40
+ mean = torch.mean(x, axis=2)
41
+ var = torch.var(x, axis=2)
42
+ x = torch.cat((mean, var), axis=1)
43
+ return x
44
+
45
+
46
+ ''' Self attentive weighted mean pooling.
47
+ '''
48
+ class Self_Attentive_Pooling(nn.Module):
49
+ def __init__(self, dim, **kwargs):
50
+ # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
51
+ # attention dim = 128
52
+ super(Self_Attentive_Pooling, self).__init__()
53
+ self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
54
+ self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
55
+
56
+ def forward(self, x):
57
+ # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
58
+ alpha = torch.tanh(self.linear1(x))
59
+ alpha = torch.softmax(self.linear2(alpha), dim=2)
60
+ mean = torch.sum(alpha * x, dim=2)
61
+ return mean
62
+
63
+
64
+ ''' Attentive weighted mean and standard deviation pooling.
65
+ '''
66
+ class Attentive_Statistics_Pooling(nn.Module):
67
+ def __init__(self, dim, **kwargs):
68
+ # Use AttentiveStatisticsPooling and BatchNorm1d from speechbrain
69
+ super(Attentive_Statistics_Pooling, self).__init__()
70
+ self.pooling = AttentiveStatisticsPooling(dim)
71
+
72
+ def forward(self, x):
73
+ x = self.pooling(x)
74
+ return x
75
+
76
+ # class Attentive_Statistics_Pooling(nn.Module):
77
+ # def __init__(self, dim, **kwargs):
78
+ # # Use Conv1d with stride == 1 rather than Linear, then we don't need to transpose inputs.
79
+ # # attention dim = 128
80
+ # super(Attentive_Statistics_Pooling, self).__init__()
81
+ # self.linear1 = nn.Conv1d(dim, dim, kernel_size=1) # equals W and b in the paper
82
+ # self.linear2 = nn.Conv1d(dim, dim, kernel_size=1) # equals V and k in the paper
83
+ #
84
+ # def forward(self, x):
85
+ # # DON'T use ReLU here! In experiments, I find ReLU hard to converge.
86
+ # alpha = torch.tanh(self.linear1(x))
87
+ # alpha = torch.softmax(self.linear2(alpha), dim=2)
88
+ # mean = torch.sum(alpha * x, dim=2)
89
+ # residuals = torch.sum(alpha * x ** 2, dim=2) - mean ** 2
90
+ # std = torch.sqrt(residuals.clamp(min=1e-9))
91
+ # return torch.cat([mean, std], dim=1)
92
+
93
+
94
+
95
+ if __name__ == "__main__":
96
+ data = torch.randn(10, 128, 100)
97
+ pooling = Self_Attentive_Pooling(128)
98
+ out = pooling(data)
99
+ print(data.shape)
100
+ print(out.shape)
nnet/speaker_encoder.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import torch.nn as nn
4
+ from torch.nn import functional as F
5
+ from .ResNet34 import Speaker_Encoder
6
+
7
+
8
+ class Speaker_Model(torch.nn.Module):
9
+ #class Speaker_Model(LightningModule):
10
+ def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
11
+ super().__init__()
12
+ # self.save_hyperparameters()
13
+
14
+ self.pooling_type = pooling_type
15
+ self.spk_embed_dim = spk_embed_dim
16
+ self.sample_rate = sample_rate
17
+ self.n_mels = n_mels
18
+ sr = self.sample_rate
19
+
20
+ self.mel_trans = torch.nn.Sequential(
21
+ PreEmphasis(),
22
+ torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
23
+ win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
24
+ window_fn=torch.hamming_window, n_mels=self.n_mels)
25
+ )
26
+ self.instancenorm = nn.InstanceNorm1d(self.n_mels)
27
+
28
+ self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
29
+
30
+ self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
31
+
32
+ class PreEmphasis(torch.nn.Module):
33
+ def __init__(self, coef: float = 0.97):
34
+ super().__init__()
35
+ self.coef = coef
36
+ # make kernel
37
+ # In pytorch, the convolution operation uses cross-correlation. So, filter is flipped.
38
+ self.register_buffer(
39
+ 'flipped_filter', torch.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
40
+ )
41
+
42
+ def forward(self, inputs: torch.tensor) -> torch.tensor:
43
+ assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
44
+ # reflect padding to match lengths of in/out
45
+ inputs = inputs.unsqueeze(1)
46
+ inputs = F.pad(inputs, (1, 0), 'reflect')
47
+ return F.conv1d(inputs, self.flipped_filter).squeeze(1)
nnet/spex_plus.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch as th
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from .norm import ChannelwiseLayerNorm, GlobalLayerNorm
8
+ from .cnns import Conv1D, ConvTrans1D, TCNBlock, TCNBlock_Spk, ResBlock
9
+
10
+ import torchaudio
11
+ from .ResNet34 import Speaker_Encoder
12
+ # from .sunine.trainer.utils import PreEmphasis
13
+
14
+
15
+
16
+ # 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
17
+ # 注意下维度 是 B N T 还是 B T N
18
+
19
+ class SpEx_Plus(nn.Module):
20
+ def __init__(self,
21
+ L1=20,
22
+ L2=80,
23
+ L3=160,
24
+ N=256,
25
+ B=8,
26
+ O=256,
27
+ P=512,
28
+ Q=3,
29
+ num_spks=101,
30
+ spk_embed_dim=256,
31
+ sample_rate = 16000,
32
+ n_mels = 80,
33
+ causal=False,
34
+ ):
35
+ super(SpEx_Plus, self).__init__()
36
+ # n x S => n x N x T, S = 4s*8000 = 32000
37
+ self.sample_rate = sample_rate
38
+ self.n_mels = n_mels
39
+ self.L1 = L1
40
+ self.L2 = L2
41
+ self.L3 = L3
42
+ self.encoder_1d_short = Conv1D(1, N, L1, stride=L1 // 2, padding=0)
43
+ self.encoder_1d_middle = Conv1D(1, N, L2, stride=L1 // 2, padding=0)
44
+ self.encoder_1d_long = Conv1D(1, N, L3, stride=L1 // 2, padding=0)
45
+ # before repeat blocks, always cLN
46
+ self.ln = ChannelwiseLayerNorm(3*N)
47
+ # n x N x T => n x O x T
48
+ self.proj = Conv1D(3*N, O, 1)
49
+ self.conv_block_1 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
50
+ self.conv_block_1_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
51
+ self.conv_block_2 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
52
+ self.conv_block_2_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
53
+ self.conv_block_3 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
54
+ self.conv_block_3_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
55
+ self.conv_block_4 = TCNBlock_Spk(spk_embed_dim=spk_embed_dim, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal, dilation=1)
56
+ self.conv_block_4_other = self._build_stacks(num_blocks=B, in_channels=O, conv_channels=P, kernel_size=Q, causal=causal)
57
+ # n x O x T => n x N x T
58
+ self.mask1 = Conv1D(O, N, 1)
59
+ self.mask2 = Conv1D(O, N, 1)
60
+ self.mask3 = Conv1D(O, N, 1)
61
+ # using ConvTrans1D: n x N x T => n x 1 x To
62
+ # To = (T - 1) * L // 2 + L
63
+ #############################################################
64
+ self.decoder_1d_short = ConvTrans1D(N, 1, kernel_size=L1, stride=L1 // 2, bias=True)
65
+ self.decoder_1d_middle = ConvTrans1D(N, 1, kernel_size=L2, stride=L1 // 2, bias=True)
66
+ self.decoder_1d_long = ConvTrans1D(N, 1, kernel_size=L3, stride=L1 // 2, bias=True)
67
+ self.num_spks = num_spks
68
+ # self.spk_encoder = nn.Sequential(
69
+ # ChannelwiseLayerNorm(3*N),
70
+ # Conv1D(3*N, O, 1),
71
+ # ResBlock(O, O),
72
+ # ResBlock(O, P),
73
+ # ResBlock(P, P),
74
+ # Conv1D(P, spk_embed_dim, 1),
75
+ # )
76
+
77
+ # self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
78
+
79
+ # 改为pretrain
80
+ # 考虑两种可能,频域就不大可能有所谓的多时间尺度,所以肯定speaker是直接频谱,那speech呢?
81
+ # /work105/youzhenghai/model/resnet_asp_aam_adamw_welr
82
+ # import ..sunine/trainer/speaker encoder
83
+ # **kwargs 无需关心 找到 self.hparams就行 按照 main_infer改就行
84
+ #############################################################
85
+
86
+ # # 1. Acoustic Feature
87
+ # self.mel_trans = th.nn.Sequential(
88
+ # PreEmphasis(),
89
+ # torchaudio.transforms.MelSpectrogram(sample_rate=self.sample_rate, n_fft=512,
90
+ # win_length=400, hop_length=160, window_fn=th.hamming_window, n_mels=self.n_mels)
91
+ # )
92
+
93
+ # self.instancenorm = nn.InstanceNorm1d(self.n_mels)
94
+
95
+ # # 在调用的地方设置超参数 记得后面写为参数传入
96
+ # self.hparams = {'embedding_dim': spk_embed_dim, 'pooling_type': 'ASP' , 'n_mels': self.n_mels}
97
+ # # 使用 **self.hparams 调用函数
98
+ # self.speaker_encoder = Speaker_Encoder(**self.hparams)
99
+ self.speaker_embedding_extracter = Speaker_Model(pooling_type='ASP', spk_embed_dim=spk_embed_dim, sample_rate=self.sample_rate, n_mels=self.n_mels)
100
+ self.pred_linear = nn.Linear(spk_embed_dim, num_spks)
101
+
102
+ #############################################################
103
+
104
+ # # 3. Loss / Classifier
105
+ # if not self.hparams.evaluate:
106
+ # LossFunction = importlib.import_module('trainer.loss.'+self.hparams.loss_type).__getattribute__('LossFunction')
107
+ # self.loss = LossFunction(**dict(self.hparams))
108
+
109
+
110
+ def _build_stacks(self, num_blocks, **block_kwargs):
111
+ """
112
+ Stack B numbers of TCN block, the first TCN block takes the speaker embedding
113
+ """
114
+ blocks = [
115
+ TCNBlock(**block_kwargs, dilation=(2**b))
116
+ for b in range(1,num_blocks)
117
+ ]
118
+ return nn.Sequential(*blocks)
119
+ # 注意下维度 是 B N T 还是 B T N
120
+
121
+
122
+
123
+ def forward(self, x, aux, aux_len):
124
+ if x.dim() >= 3:
125
+ raise RuntimeError(
126
+ "{} accept 1/2D tensor as input, but got {:d}".format(
127
+ self.__name__, x.dim()))
128
+ # when inference, only one utt
129
+ if x.dim() == 1:
130
+ x = th.unsqueeze(x, 0)
131
+
132
+ # n x 1 x S => n x N x T
133
+ w1 = F.relu(self.encoder_1d_short(x))
134
+ T = w1.shape[-1]
135
+ xlen1 = x.shape[-1]
136
+ xlen2 = (T - 1) * (self.L1 // 2) + self.L2
137
+ xlen3 = (T - 1) * (self.L1 // 2) + self.L3
138
+ w2 = F.relu(self.encoder_1d_middle(F.pad(x, (0, xlen2 - xlen1), "constant", 0)))
139
+ w3 = F.relu(self.encoder_1d_long(F.pad(x, (0, xlen3 - xlen1), "constant", 0)))
140
+
141
+ # n x 3N x T
142
+ y = self.ln(th.cat([w1, w2, w3], 1))
143
+ # n x O x T
144
+ y = self.proj(y)
145
+
146
+ # speaker encoder (share params from speech encoder)
147
+ # aux_w1 = F.relu(self.encoder_1d_short(aux))
148
+ # aux_T_shape = aux_w1.shape[-1]
149
+ # aux_len1 = aux.shape[-1]
150
+ # aux_len2 = (aux_T_shape - 1) * (self.L1 // 2) + self.L2
151
+ # aux_len3 = (aux_T_shape - 1) * (self.L1 // 2) + self.L3
152
+ # aux_w2 = F.relu(self.encoder_1d_middle(F.pad(aux, (0, aux_len2 - aux_len1), "constant", 0)))
153
+ # aux_w3 = F.relu(self.encoder_1d_long(F.pad(aux, (0, aux_len3 - aux_len1), "constant", 0)))
154
+
155
+ # spk_encoder + mean pooling
156
+ # aux = self.spk_encoder(th.cat([aux_w1, aux_w2, aux_w3], 1))
157
+ # aux_T = (aux_len - self.L1) // (self.L1 // 2) + 1
158
+ # aux_T = ((aux_T // 3) // 3) // 3
159
+ # aux = th.sum(aux, -1)/aux_T.view(-1,1).float()
160
+
161
+ # spk_encoder + TAP pooling
162
+ aux = self.speaker_embedding_extracter(aux)
163
+
164
+
165
+
166
+ #aux = torch.mean(aux, axis=0)
167
+
168
+ # aux = aux.cpu().detach().numpy()
169
+
170
+ # 不需要 reshape N * D 是正确的维度
171
+ #aux = aux.reshape(-1, self.hparams.nPerSpeaker, self.spk_embed_dim)
172
+ # loss, acc = self.loss(x, label)
173
+ # return loss.mean(), acc
174
+ # 考虑 loss 是否也要
175
+
176
+ y = self.conv_block_1(y, aux)
177
+ y = self.conv_block_1_other(y)
178
+ y = self.conv_block_2(y, aux)
179
+ y = self.conv_block_2_other(y)
180
+ y = self.conv_block_3(y, aux)
181
+ y = self.conv_block_3_other(y)
182
+ y = self.conv_block_4(y, aux)
183
+ y = self.conv_block_4_other(y)
184
+
185
+ # n x N x T
186
+ m1 = F.relu(self.mask1(y))
187
+ m2 = F.relu(self.mask2(y))
188
+ m3 = F.relu(self.mask3(y))
189
+ S1 = w1 * m1
190
+ S2 = w2 * m2
191
+ S3 = w3 * m3
192
+
193
+ return self.decoder_1d_short(S1), self.decoder_1d_middle(S2)[:, :xlen1], self.decoder_1d_long(S3)[:, :xlen1], self.pred_linear(aux)
194
+
195
+ class PreEmphasis(th.nn.Module):
196
+ def __init__(self, coef: float = 0.97):
197
+ super().__init__()
198
+ self.coef = coef
199
+ # make kernel
200
+ # In pyth, the convolution operation uses cross-correlation. So, filter is flipped.
201
+ self.register_buffer(
202
+ 'flipped_filter', th.FloatTensor([-self.coef, 1.]).unsqueeze(0).unsqueeze(0)
203
+ )
204
+
205
+ def forward(self, inputs: th.tensor) -> th.tensor:
206
+ assert len(inputs.size()) == 2, 'The number of dimensions of inputs tensor must be 2!'
207
+ # reflect padding to match lengths of in/out
208
+ inputs = inputs.unsqueeze(1)
209
+ inputs = F.pad(inputs, (1, 0), 'reflect')
210
+ return F.conv1d(inputs, self.flipped_filter).squeeze(1)
211
+
212
+
213
+ class Speaker_Model(nn.Module):
214
+ #class Speaker_Model(LightningModule):
215
+ def __init__(self, pooling_type, spk_embed_dim, sample_rate, n_mels):
216
+ super().__init__()
217
+ # self.save_hyperparameters()
218
+
219
+ self.pooling_type = pooling_type
220
+ self.spk_embed_dim = spk_embed_dim
221
+ self.sample_rate = sample_rate
222
+ self.n_mels = n_mels
223
+ sr = self.sample_rate
224
+
225
+ self.mel_trans = th.nn.Sequential(
226
+ PreEmphasis(),
227
+ torchaudio.transforms.MelSpectrogram(sample_rate=sr, n_fft=512,
228
+ win_length=sr * 25 // 1000, hop_length=sr * 10 // 1000,
229
+ window_fn=th.hamming_window, n_mels=self.n_mels)
230
+ )
231
+ self.instancenorm = nn.InstanceNorm1d(self.n_mels)
232
+
233
+ self.hparams = {'embedding_dim': self.spk_embed_dim, 'pooling_type': self.pooling_type , 'n_mels': self.n_mels}
234
+
235
+ self.speaker_encoder = Speaker_Encoder(**dict(self.hparams))
236
+
237
+ def extract_speaker_embedding(self, data):
238
+ x = data.reshape(-1, data.size()[-1])
239
+ x = self.mel_trans(x) + 1e-6
240
+ x = x.log()
241
+ x = self.instancenorm(x)
242
+ x = self.speaker_encoder(x)
243
+ return x
244
+
245
+ def forward(self, x):
246
+ x = self.extract_speaker_embedding(x)
247
+ return x
requirement.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch=1.8.0
2
+ torchaudio=0.8.0
3
+ speechbrain=0.5.10
4
+ soundfile
5
+ gradio
utils/__init__.py ADDED
File without changes
utils/audio.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import os
4
+ import numpy as np
5
+ import soundfile as sf
6
+
7
+ def write_wav(fname, samps, sample_rate=16000, normalize=True):
8
+ """
9
+ Write wav files in float32, support single/multi-channel
10
+ """
11
+
12
+ # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
13
+ # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
14
+ fdir = os.path.dirname(fname)
15
+ if fdir and not os.path.exists(fdir):
16
+ os.makedirs(fdir)
17
+ sf.write(fname, samps, sample_rate, subtype='FLOAT')
18
+
19
+
20
+ def read_wav(fname, normalize=True, return_rate=False):
21
+ """
22
+ Read wave files (support multi-channel)
23
+ """
24
+
25
+ # wham and whamr mixture and clean data are float 32, can not use scipy.io.wavfile to read and write int16
26
+ # change to soundfile to read and write, although reference speech is int16, soundfile still can read and outputs as float
27
+ samps, samp_rate = sf.read(fname)
28
+ if return_rate:
29
+ return samp_rate, samps
30
+ return samps
31
+
32
+ def parse_scripts(scp_path, value_processor=lambda x: x, num_tokens=2):
33
+ """
34
+ Parse kaldi's script(.scp) file
35
+ If num_tokens >= 2, function will check token number
36
+ """
37
+ scp_dict = dict()
38
+ line = 0
39
+ with open(scp_path, "r") as f:
40
+ for raw_line in f:
41
+ scp_tokens = raw_line.strip().split()
42
+ line += 1
43
+ if num_tokens >= 2 and len(scp_tokens) != num_tokens or len(
44
+ scp_tokens) < 2:
45
+ raise RuntimeError(
46
+ "For {}, format error in line[{:d}]: {}".format(
47
+ scp_path, line, raw_line))
48
+ if num_tokens == 2:
49
+ key, value = scp_tokens
50
+ else:
51
+ key, value = scp_tokens[0], scp_tokens[1:]
52
+ if key in scp_dict:
53
+ raise ValueError("Duplicated key \'{0}\' exists in {1}".format(
54
+ key, scp_path))
55
+ scp_dict[key] = value_processor(value)
56
+ return scp_dict
57
+
58
+
59
+ class Reader(object):
60
+ """
61
+ Basic Reader Class
62
+ """
63
+
64
+ def __init__(self, scp_path, value_processor=lambda x: x):
65
+ self.index_dict = parse_scripts(
66
+ scp_path, value_processor=value_processor, num_tokens=2)
67
+ self.index_keys = list(self.index_dict.keys())
68
+
69
+ def _load(self, key):
70
+ # return path
71
+ return self.index_dict[key]
72
+
73
+ # number of utterance
74
+ def __len__(self):
75
+ return len(self.index_dict)
76
+
77
+ # avoid key error
78
+ def __contains__(self, key):
79
+ return key in self.index_dict
80
+
81
+ # sequential index
82
+ def __iter__(self):
83
+ for key in self.index_keys:
84
+ yield key, self._load(key)
85
+
86
+ # random index, support str/int as index
87
+ def __getitem__(self, index):
88
+ if type(index) not in [int, str]:
89
+ raise IndexError("Unsupported index type: {}".format(type(index)))
90
+ if type(index) == int:
91
+ # from int index to key
92
+ num_utts = len(self.index_keys)
93
+ if index >= num_utts or index < 0:
94
+ raise KeyError(
95
+ "Interger index out of range, {:d} vs {:d}".format(
96
+ index, num_utts))
97
+ index = self.index_keys[index]
98
+ if index not in self.index_dict:
99
+ raise KeyError("Missing utterance {}!".format(index))
100
+ return self._load(index)
101
+
102
+
103
+ class WaveReader(Reader):
104
+ """
105
+ Sequential/Random Reader for single channel wave
106
+ Format of wav.scp follows Kaldi's definition:
107
+ key1 /path/to/wav
108
+ ...
109
+ """
110
+
111
+ def __init__(self, wav_scp, sample_rate=None, normalize=True):
112
+ super(WaveReader, self).__init__(wav_scp)
113
+ self.samp_rate = sample_rate
114
+ self.normalize = normalize
115
+
116
+ def _load(self, key):
117
+ # return C x N or N
118
+ samp_rate, samps = read_wav(
119
+ self.index_dict[key], normalize=self.normalize, return_rate=True)
120
+ # if given samp_rate, check it
121
+ if self.samp_rate is not None and samp_rate != self.samp_rate:
122
+ raise RuntimeError("SampleRate mismatch: {:d} vs {:d}".format(
123
+ samp_rate, self.samp_rate))
124
+ return samps
utils/dataset copy.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import random
4
+ import torch as th
5
+ import numpy as np
6
+
7
+ from torch.utils.data.dataloader import default_collate
8
+ import torch.utils.data as dat
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ from .audio import WaveReader
12
+
13
+ import soundfile as sf
14
+
15
+ # random_seed = 1453
16
+ # random.seed(random_seed)
17
+
18
+ def make_dataloader(train=True,
19
+ utt_scp_file=None,
20
+ spk_list=None,
21
+ sample_rate=16000,
22
+ num_workers=4,
23
+ chunk_size=32000,
24
+ batch_size=16):
25
+ dataset = Dataset(utt_scp_file=utt_scp_file,
26
+ spk_list=spk_list,
27
+ chunk_size=chunk_size,
28
+ sample_rate=sample_rate)
29
+ return DataLoader(dataset,
30
+ train=train,
31
+ chunk_size=chunk_size,
32
+ batch_size=batch_size,
33
+ num_workers=num_workers)
34
+
35
+ class Dataset(object):
36
+ """
37
+ Per Utterance Loader
38
+ """
39
+ def __init__(self, utt_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
40
+ self.sample_rate = sample_rate
41
+ self.spk_list = self._load_spk(spk_list)
42
+
43
+ self.seg_least= int(chunk_size // 2 )
44
+
45
+ # self.mix = WaveReader(mix_scp, sample_rate=sample_rate)
46
+ # self.ref = WaveReader(ref_scp, sample_rate=sample_rate)
47
+ # self.aux = WaveReader(aux_scp, sample_rate=sample_rate)
48
+
49
+ with open(utt_scp_file, 'r') as f:
50
+ lines = f.readlines()
51
+ self.data = []
52
+ self.total_lines = len(self.data)
53
+ for line in lines:
54
+ parts = line.strip().split()
55
+ sentence_id = parts[0]
56
+ sentence_path = parts[1]
57
+ data_len = parts[2]
58
+ spk_id = (sentence_id.split('-')[0])[1:5]
59
+ self.data.append((sentence_id, spk_id, sentence_path, data_len))
60
+
61
+ if not self.data:
62
+ raise ValueError("No valid lines found in the input file.")
63
+ self.total_lines = len(self.data)
64
+
65
+ def _load_spk(self, spk_list_path):
66
+ if spk_list_path is None:
67
+ return []
68
+ lines = open(spk_list_path).readlines()
69
+ new_lines = []
70
+ for line in lines:
71
+ new_lines.append(line.strip())
72
+
73
+ return new_lines
74
+
75
+ def __len__(self):
76
+ return len(self.data)
77
+
78
+ def _get_segment_start_stop(self, seg_len, length):
79
+ if seg_len is not None:
80
+ start = random.randint(0, length - seg_len)
81
+ stop = start + seg_len
82
+ else:
83
+ start = 0
84
+ stop = None
85
+ return start, stop
86
+
87
+ def _mix(self, sources_list):
88
+
89
+ # if self.seg_len:
90
+ # mix_length = self.seg_len
91
+
92
+ # else:
93
+ # mix_length = self.common_length
94
+ mix_length = self.common_length
95
+ mixture = np.zeros(mix_length)
96
+ for i, _ in enumerate(sources_list):
97
+ mixture += sources_list[i]
98
+
99
+ return mixture
100
+
101
+ def __getitem__(self, idx):
102
+ source_id, source_spk, source_path, all_source_length= self.data[idx]
103
+ all_source_length = int(all_source_length)
104
+ spk_idx = self.spk_list.index(source_spk)
105
+
106
+ other_counter = 0
107
+ while True:
108
+ random_idx = np.random.randint(0, self.total_lines)
109
+ if self.data[random_idx][1] != source_spk:
110
+ other_id, other_spk, other_path, other_length = self.data[random_idx]
111
+ other_length = int(other_length)
112
+
113
+ if other_length > self.seg_least:
114
+ break
115
+
116
+ other_counter += 1
117
+
118
+ if other_counter >= self.total_lines:
119
+ raise ValueError("All Data too shorter to mix")
120
+
121
+ enroll_counter = 0
122
+
123
+ while True:
124
+ random_idx = np.random.randint(0, self.total_lines)
125
+ if self.data[random_idx][1] == source_spk:
126
+ enroll_id, enroll_spk, enroll_path, all_enroll_length= self.data[random_idx]
127
+ all_enroll_length = int(all_enroll_length)
128
+ if all_enroll_length > self.seg_least:
129
+ break
130
+
131
+ enroll_counter += 1
132
+ if enroll_counter >= self.total_lines:
133
+ raise ValueError("All Data too shorter to enroll")
134
+ # lengths = [all_source_length, other_length]
135
+
136
+ if all_source_length >= other_length:
137
+ self.common_length = other_length
138
+ start, stop = self._get_segment_start_stop(other_length, all_source_length)
139
+ source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
140
+ other_tmp,_ = sf.read(other_path, dtype="float32")
141
+ elif all_source_length <= other_length:
142
+ self.common_length = all_source_length
143
+ start, stop = self._get_segment_start_stop(all_source_length, other_length)
144
+ source_tmp,_ = sf.read(source_path, dtype="float32")
145
+ other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
146
+
147
+ source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
148
+
149
+ other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
150
+
151
+ mixture = self._mix([source, other])
152
+ mixture = mixture.astype(np.float32)
153
+
154
+ enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
155
+ enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
156
+
157
+ return {
158
+ "mix": mixture,
159
+ "ref": source,
160
+ "aux": enroll,
161
+ "aux_len": len(enroll),
162
+ "spk_idx": spk_idx
163
+ }
164
+
165
+ class ChunkSplitter(object):
166
+ """
167
+ Split utterance into small chunks
168
+ """
169
+ def __init__(self, chunk_size, train=True, least=16000):
170
+ self.chunk_size = chunk_size
171
+ self.least = least
172
+ self.train = train
173
+
174
+ def _make_chunk(self, eg, s):
175
+ """
176
+ Make a chunk instance, which contains:
177
+ "mix": ndarray,
178
+ "ref": [ndarray...]
179
+ """
180
+ chunk = dict()
181
+ chunk["mix"] = eg["mix"][s:s + self.chunk_size]
182
+ chunk["ref"] = eg["ref"][s:s + self.chunk_size]
183
+ chunk["aux"] = eg["aux"]
184
+ chunk["aux_len"] = eg["aux_len"]
185
+ chunk["valid_len"] = int(self.chunk_size)
186
+ chunk["spk_idx"] = eg["spk_idx"]
187
+ return chunk
188
+
189
+ def split(self, eg):
190
+ N = eg["mix"].size
191
+ # too short, throw away
192
+ if N < self.least:
193
+ return []
194
+ chunks = []
195
+ # padding zeros
196
+ if N < self.chunk_size:
197
+ P = self.chunk_size - N
198
+ chunk = dict()
199
+ chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
200
+ chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
201
+ chunk["aux"] = eg["aux"]
202
+ chunk["aux_len"] = eg["aux_len"]
203
+ chunk["valid_len"] = int(N)
204
+ chunk["spk_idx"] = eg["spk_idx"]
205
+ chunks.append(chunk)
206
+ else:
207
+ # random select start point for training
208
+ s = random.randint(0, N % self.least) if self.train else 0
209
+ while True:
210
+ if s + self.chunk_size > N:
211
+ break
212
+ chunk = self._make_chunk(eg, s)
213
+ chunks.append(chunk)
214
+ s += self.least
215
+ return chunks
216
+
217
+
218
+ class DataLoader(object):
219
+ """
220
+ Online dataloader for chunk-level
221
+ """
222
+ def __init__(self,
223
+ dataset,
224
+ num_workers=4,
225
+ chunk_size=32000,
226
+ batch_size=16,
227
+ train=True):
228
+ self.batch_size = batch_size
229
+ self.train = train
230
+ self.splitter = ChunkSplitter(chunk_size,
231
+ train=train,
232
+ least=chunk_size // 2)
233
+ # just return batch of egs, support multiple workers
234
+ self.eg_loader = dat.DataLoader(dataset,
235
+ batch_size=batch_size // 2,
236
+ num_workers=num_workers,
237
+ shuffle=train,
238
+ collate_fn=self._collate)
239
+
240
+ def _collate(self, batch):
241
+ """
242
+ Online split utterances
243
+ """
244
+ chunk = []
245
+ for eg in batch:
246
+ chunk += self.splitter.split(eg)
247
+ return chunk
248
+
249
+ def _pad_aux(self, chunk_list):
250
+ lens_list = []
251
+ for chunk_item in chunk_list:
252
+ lens_list.append(chunk_item['aux_len'])
253
+ max_len = np.max(lens_list)
254
+
255
+
256
+ for idx in range(len(chunk_list)):
257
+ P = max_len - len(chunk_list[idx]["aux"])
258
+ chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
259
+
260
+ return chunk_list
261
+
262
+ def _merge(self, chunk_list):
263
+ """
264
+ Merge chunk list into mini-batch
265
+ """
266
+ N = len(chunk_list)
267
+ if self.train:
268
+ random.shuffle(chunk_list)
269
+ blist = []
270
+ for s in range(0, N - self.batch_size + 1, self.batch_size):
271
+ # padding aux info
272
+ #self._pad_aux(chunk_list[s:s + self.batch_size])
273
+ batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
274
+ blist.append(batch)
275
+ rn = N % self.batch_size
276
+ return blist, chunk_list[-rn:] if rn else []
277
+
278
+ def __iter__(self):
279
+ chunk_list = []
280
+ for chunks in self.eg_loader:
281
+ chunk_list += chunks
282
+ batch, chunk_list = self._merge(chunk_list)
283
+ for obj in batch:
284
+ yield obj
utils/dataset.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import random
4
+ import torch as th
5
+ import numpy as np
6
+
7
+ from torch.utils.data.dataloader import default_collate
8
+ import torch.utils.data as dat
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ from .audio import WaveReader
12
+
13
+ import soundfile as sf
14
+
15
+ # random_seed = 1453
16
+ # random.seed(random_seed)
17
+
18
+ # "aux_len": all_enroll_length,
19
+
20
+ EPS = 1e-10
21
+ def make_dataloader(train=True,
22
+ mix_scp_file=None,
23
+ enroll_scp_file=None,
24
+ noise_scp_file=None,
25
+ spk_list=None,
26
+ sample_rate=16000,
27
+ num_workers=4,
28
+ chunk_size=32000,
29
+ batch_size=16):
30
+ dataset = Dataset(mix_scp_file=mix_scp_file,
31
+ enroll_scp_file=enroll_scp_file,
32
+ noise_scp_file=noise_scp_file,
33
+ spk_list=spk_list,
34
+ chunk_size=chunk_size,
35
+ sample_rate=sample_rate)
36
+ return DataLoader(dataset,
37
+ train=train,
38
+ chunk_size=chunk_size,
39
+ batch_size=batch_size,
40
+ num_workers=num_workers)
41
+
42
+ class Dataset(object):
43
+ """
44
+ Per Utterance Loader
45
+ """
46
+ def __init__(self, mix_scp_file="", enroll_scp_file="", noise_scp_file="", spk_list=None,chunk_size=32000, sample_rate=8000):
47
+ self.sample_rate = sample_rate
48
+ self.spk_list = self._load_spk(spk_list)
49
+
50
+ self.seg_least= int(chunk_size // 2 )
51
+
52
+ with open(mix_scp_file, 'r') as f:
53
+ lines = f.readlines()
54
+ self.data = []
55
+
56
+
57
+ for line in lines:
58
+ parts = line.strip().split()
59
+ sentence_id = parts[0]
60
+ sentence_path = parts[1]
61
+ data_len = parts[2]
62
+ spk_id = (sentence_id.split('-')[0])[1:5]
63
+ self.data.append((sentence_id, spk_id, sentence_path, data_len))
64
+
65
+ with open(enroll_scp_file, 'r') as f:
66
+ enroll_lines = f.readlines()
67
+ self.enroll_data = []
68
+
69
+ for line in enroll_lines:
70
+ parts = line.strip().split()
71
+ sentence_id = parts[0]
72
+ sentence_path = parts[1]
73
+ data_len = parts[2]
74
+ spk_id = (sentence_id.split('-')[0])[1:5]
75
+ self.enroll_data.append((sentence_id, spk_id, sentence_path, data_len))
76
+
77
+
78
+ with open(noise_scp_file, 'r') as f:
79
+ noise_lines = f.readlines()
80
+ self.noise_data = []
81
+
82
+ for line in noise_lines:
83
+ parts = line.strip().split()
84
+ sentence_id = parts[0]
85
+ sentence_path = parts[1]
86
+ data_len = parts[2]
87
+ # spk_id = (sentence_id.split('-')[0])[1:5]
88
+ self.noise_data.append((sentence_id, sentence_path, data_len))
89
+
90
+ self.total_lines = len(self.data)
91
+ self.total_enroll = self._enroll_data_len()
92
+ self.total_noise = self._noise_data_len()
93
+
94
+ if not self.data:
95
+ raise ValueError("No valid lines found in the input file.")
96
+
97
+
98
+ def _load_spk(self, spk_list_path):
99
+ if spk_list_path is None:
100
+ return []
101
+ lines = open(spk_list_path).readlines()
102
+ new_lines = []
103
+ for line in lines:
104
+ new_lines.append(line.strip())
105
+
106
+ return new_lines
107
+
108
+ def __len__(self):
109
+ return len(self.data)
110
+
111
+ def _enroll_data_len(self):
112
+ return len(self.enroll_data)
113
+
114
+ def _noise_data_len(self):
115
+ return len(self.noise_data)
116
+
117
+ def _get_segment_start_stop(self, seg_len, length):
118
+ if seg_len is not None:
119
+ start = random.randint(0, length - seg_len)
120
+ stop = start + seg_len
121
+ else:
122
+ start = 0
123
+ stop = None
124
+ return start, stop
125
+
126
+ def _mix(self, sources_list):
127
+
128
+ # if self.seg_len:
129
+ # mix_length = self.seg_len
130
+
131
+ # else:
132
+ # mix_length = self.common_length
133
+ mix_length = self.common_length
134
+ mixture = np.zeros(mix_length)
135
+ for i, _ in enumerate(sources_list):
136
+ mixture += sources_list[i]
137
+
138
+ return mixture
139
+
140
+ def __getitem__(self, idx):
141
+ source_id, source_spk, source_path, all_source_length= self.data[idx]
142
+ all_source_length = int(all_source_length)
143
+ spk_idx = self.spk_list.index(source_spk)
144
+
145
+ other_counter = 0
146
+ while True:
147
+ random_idx = np.random.randint(0, self.total_lines)
148
+ if self.data[random_idx][1] != source_spk:
149
+ other_id, other_spk, other_path, other_length = self.data[random_idx]
150
+ other_length = int(other_length)
151
+
152
+ if other_length > self.seg_least:
153
+ break
154
+
155
+ other_counter += 1
156
+
157
+ if other_counter >= self.total_lines:
158
+ raise ValueError("All Data too shorter to mix")
159
+
160
+
161
+ if all_source_length >= other_length:
162
+ self.common_length = other_length
163
+ start, stop = self._get_segment_start_stop(self.common_length, all_source_length)
164
+ source_tmp,_ = sf.read(source_path, dtype="float32", start=start, stop=stop)
165
+ other_tmp,_ = sf.read(other_path, dtype="float32")
166
+ elif all_source_length <= other_length:
167
+ self.common_length = all_source_length
168
+ start, stop = self._get_segment_start_stop(self.common_length, other_length)
169
+ source_tmp,_ = sf.read(source_path, dtype="float32")
170
+ other_tmp,_ = sf.read(other_path, dtype="float32", start=start, stop=stop)
171
+
172
+ noise_counter = 0
173
+ while True:
174
+ random_idx = np.random.randint(0, self.total_noise)
175
+
176
+ noise_id, noise_path, all_noise_length= self.noise_data[random_idx]
177
+ all_noise_length = int(all_noise_length)
178
+
179
+ if all_noise_length >= self.common_length:
180
+ break
181
+ noise_counter += 1
182
+ if noise_counter >= self.total_noise:
183
+ raise ValueError("All Data can't as noise")
184
+
185
+ enroll_counter = 0
186
+ while True:
187
+ random_idx = np.random.randint(0, self.total_enroll)
188
+ if self.enroll_data[random_idx][1] == source_spk:
189
+ enroll_id, enroll_spk, enroll_path, all_enroll_length= self.enroll_data[random_idx]
190
+ all_enroll_length = int(all_enroll_length)
191
+ break
192
+
193
+ enroll_counter += 1
194
+ if enroll_counter >= self.total_enroll:
195
+ raise ValueError("All Data can't as enroll")
196
+
197
+
198
+
199
+
200
+ source = source_tmp[:, np.random.randint(0, source_tmp.shape[1])]
201
+ other = other_tmp[:, np.random.randint(0, other_tmp.shape[1])]
202
+
203
+ noise_start, noise_stop = self._get_segment_start_stop(self.common_length, all_noise_length)
204
+ noise,_ = sf.read(noise_path, dtype="float32", start=noise_start, stop=noise_stop) # single channel?
205
+ # noise = noise_tmp[:, np.random.randint(0, noise_tmp.shape[1])]
206
+ # other_noise = self._mix([other,noise])
207
+ desired_snr = np.random.uniform(-4, 4) # 设置目标 SNR
208
+ current_snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(noise ** 2) + EPS) + EPS)
209
+ scale_factor = 10 ** ((current_snr - desired_snr ) / 20)
210
+ scaled_noise = noise * scale_factor
211
+
212
+ snr = 10 * np.log10(np.mean(source ** 2) / (np.mean(scaled_noise ** 2) + EPS) + EPS)
213
+ mixture = self._mix([source,other,scaled_noise])
214
+
215
+ mixture = mixture.astype(np.float32)
216
+
217
+ enroll_tmp, _ = sf.read(enroll_path, dtype="float32")
218
+ enroll = enroll_tmp[:, np.random.randint(0, enroll_tmp.shape[1])]
219
+
220
+ return {
221
+ "mix": mixture,
222
+ "ref": source,
223
+ "aux": enroll,
224
+ "aux_len": all_enroll_length,
225
+ "spk_idx": spk_idx
226
+ }
227
+
228
+ class ChunkSplitter(object):
229
+ """
230
+ Split utterance into small chunks
231
+ """
232
+ def __init__(self, chunk_size, train=True, least=16000):
233
+ self.chunk_size = chunk_size
234
+ self.least = least
235
+ self.train = train
236
+
237
+ def _make_chunk(self, eg, s):
238
+ """
239
+ Make a chunk instance, which contains:
240
+ "mix": ndarray,
241
+ "ref": [ndarray...]
242
+ """
243
+ chunk = dict()
244
+ chunk["mix"] = eg["mix"][s:s + self.chunk_size]
245
+ chunk["ref"] = eg["ref"][s:s + self.chunk_size]
246
+ chunk["aux"] = eg["aux"]
247
+ chunk["aux_len"] = eg["aux_len"]
248
+ chunk["valid_len"] = int(self.chunk_size)
249
+ chunk["spk_idx"] = eg["spk_idx"]
250
+ return chunk
251
+
252
+ def split(self, eg):
253
+ N = eg["mix"].size
254
+ # too short, throw away
255
+ if N < self.least:
256
+ return []
257
+ chunks = []
258
+ # padding zeros
259
+ if N < self.chunk_size:
260
+ P = self.chunk_size - N
261
+ chunk = dict()
262
+ chunk["mix"] = np.pad(eg["mix"], (0, P), "constant")
263
+ chunk["ref"] = np.pad(eg["ref"], (0, P), "constant")
264
+ chunk["aux"] = eg["aux"]
265
+ chunk["aux_len"] = eg["aux_len"]
266
+ chunk["valid_len"] = int(N)
267
+ chunk["spk_idx"] = eg["spk_idx"]
268
+ chunks.append(chunk)
269
+ # else:
270
+ # # random select start point for training
271
+ # s = random.randint(0, N % self.least) if self.train else 0
272
+ # while True:
273
+ # if s + self.chunk_size > N:
274
+ # break
275
+ # chunk = self._make_chunk(eg, s)
276
+ # chunks.append(chunk)
277
+ # s += self.least
278
+ # return chunks
279
+
280
+ else:
281
+ if self.train:
282
+ # random select A start point for training
283
+ s = random.randint(0, N - self.chunk_size)
284
+ chunk = self._make_chunk(eg, s)
285
+ chunks.append(chunk)
286
+ else:
287
+ s = 0
288
+ while True:
289
+ if s + self.chunk_size > N:
290
+ break
291
+ chunk = self._make_chunk(eg, s)
292
+ chunks.append(chunk)
293
+ s += self.least
294
+ return chunks
295
+
296
+ class DataLoader(object):
297
+ """
298
+ Online dataloader for chunk-level
299
+ """
300
+ def __init__(self,
301
+ dataset,
302
+ num_workers=4,
303
+ chunk_size=32000,
304
+ batch_size=16,
305
+ train=True):
306
+ self.batch_size = batch_size
307
+ self.train = train
308
+ self.splitter = ChunkSplitter(chunk_size,
309
+ train=train,
310
+ least=chunk_size // 2)
311
+ # just return batch of egs, support multiple workers
312
+ self.eg_loader = dat.DataLoader(dataset,
313
+ batch_size=batch_size // 2,
314
+ num_workers=num_workers,
315
+ shuffle=train,
316
+ collate_fn=self._collate)
317
+
318
+ def _collate(self, batch):
319
+ """
320
+ Online split utterances
321
+ """
322
+ chunk = []
323
+ for eg in batch:
324
+ chunk += self.splitter.split(eg)
325
+ return chunk
326
+
327
+ def _pad_aux(self, chunk_list):
328
+ lens_list = []
329
+ for chunk_item in chunk_list:
330
+ lens_list.append(chunk_item['aux_len'])
331
+ max_len = np.max(lens_list)
332
+ # pad 0
333
+ for idx in range(len(chunk_list)):
334
+ P = max_len - len(chunk_list[idx]["aux"])
335
+ chunk_list[idx]["aux"] = np.pad(chunk_list[idx]["aux"], (0, P), "constant")
336
+ # # pad circle
337
+ # for idx in range(len(chunk_list)):
338
+ # P = max_len - len(chunk_list[idx]["aux"])
339
+ # original_aux_len = len(chunk_list[idx]["aux"])
340
+ # # 使用循环来填充原句子的内容
341
+ # for i in range(P):
342
+ # chunk_list[idx]["aux"].append(chunk_list[idx]["aux"][i % original_aux_len])
343
+
344
+
345
+ return chunk_list
346
+
347
+ def _merge(self, chunk_list):
348
+ """
349
+ Merge chunk list into mini-batch
350
+ """
351
+ N = len(chunk_list)
352
+ if self.train:
353
+ random.shuffle(chunk_list)
354
+ blist = []
355
+ for s in range(0, N - self.batch_size + 1, self.batch_size):
356
+ # padding aux info
357
+ #self._pad_aux(chunk_list[s:s + self.batch_size])
358
+ batch = default_collate(self._pad_aux(chunk_list[s:s + self.batch_size]))
359
+ blist.append(batch)
360
+ rn = N % self.batch_size
361
+ return blist, chunk_list[-rn:] if rn else []
362
+
363
+ def __iter__(self):
364
+ chunk_list = []
365
+ for chunks in self.eg_loader:
366
+ chunk_list += chunks
367
+ batch, chunk_list = self._merge(chunk_list)
368
+ for obj in batch:
369
+ yield obj
370
+
371
+
372
+
373
+ # def snr_xy(x, y):
374
+ # return 10 * np.log10(np.mean(x ** 2) / (np.mean(y ** 2) + EPS) + EPS)
375
+
376
+ # def main(args):
377
+ # wham_noise_dir = args.wham_dir
378
+ # # Get train dir
379
+ # subdir = os.path.join(wham_noise_dir, 'tr')
380
+ # # List files in that dir
381
+ # sound_paths = glob.glob(os.path.join(subdir, '**/*.wav'),
382
+ # recursive=True)
383
+ # # Avoid running this script if it already have been run
384
+ # if len(sound_paths) == 60000:
385
+ # print("It appears that augmented files have already been generated.\n"
386
+ # "Skipping data augmentation.")
387
+ # return
388
+ # elif len(sound_paths) != 20000:
389
+ # print("It appears that augmented files have not been generated properly\n"
390
+ # "Resuming augmentation.")
391
+ # originals = [x for x in sound_paths if 'sp' not in x]
392
+ # to_be_removed_08 = [x.replace('sp08','') for x in sound_paths if 'sp08' in x]
393
+ # to_be_removed_12 = [x.replace('sp12','') for x in sound_paths if 'sp12' in x ]
394
+ # sound_paths_08 = list(set(originals) - set(to_be_removed_08))
395
+ # sound_paths_12 = list(set(originals) - set(to_be_removed_12))
396
+ # augment_noise(sound_paths_08, 0.8)
397
+ # augment_noise(sound_paths_12, 1.2)
398
+ # else:
399
+ # print(f'Augmenting {subdir} files')
400
+ # # Transform audio speed
401
+ # augment_noise(sound_paths, 0.8)
402
+ # augment_noise(sound_paths, 1.2)
utils/load_obj.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import torch as th
4
+
5
+ def load_obj(obj, device):
6
+ """
7
+ Offload tensor object in obj to cuda device
8
+ """
9
+
10
+ def cuda(obj):
11
+ return obj.to(device) if isinstance(obj, th.Tensor) else obj
12
+
13
+ if isinstance(obj, dict):
14
+ return {key: load_obj(obj[key], device) for key in obj}
15
+ elif isinstance(obj, list):
16
+ return [load_obj(val, device) for val in obj]
17
+ else:
18
+ return cuda(obj)
utils/logger.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import logging
4
+
5
+ def get_logger(
6
+ name,
7
+ format_str="%(asctime)s [%(pathname)s:%(lineno)s - %(levelname)s ] %(message)s",
8
+ date_format="%Y-%m-%d %H:%M:%S",
9
+ file=False):
10
+ """
11
+ Get python logger instance
12
+ """
13
+ logger = logging.getLogger(name)
14
+ logger.setLevel(logging.INFO)
15
+ # file or console
16
+ handler = logging.StreamHandler() if not file else logging.FileHandler(
17
+ name)
18
+ handler.setLevel(logging.INFO)
19
+ formatter = logging.Formatter(fmt=format_str, datefmt=date_format)
20
+ handler.setFormatter(formatter)
21
+ logger.addHandler(handler)
22
+ return logger
utils/sisdr.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import numpy as np
4
+
5
+ def sisdr(x, s, remove_dc=True):
6
+ """
7
+ Compute SI-SDR
8
+ x: extracted signal
9
+ s: reference signal(ground truth)
10
+ """
11
+
12
+ def vec_l2norm(x):
13
+ return np.linalg.norm(x, 2)
14
+
15
+ if remove_dc:
16
+ x_zm = x - np.mean(x)
17
+ s_zm = s - np.mean(s)
18
+ t = np.inner(x_zm, s_zm) * s_zm / vec_l2norm(s_zm)**2
19
+ n = x_zm - t
20
+ else:
21
+ t = np.inner(x, s) * s / vec_l2norm(s)**2
22
+ n = x - t
23
+ return 20 * np.log10(vec_l2norm(t) / vec_l2norm(n))
utils/timer.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ import time
4
+
5
+ class Timer(object):
6
+ """
7
+ A timer to record the elapsed time
8
+ """
9
+
10
+ def __init__(self):
11
+ self.reset()
12
+
13
+ def reset(self):
14
+ self.start = time.time()
15
+
16
+ def elapsed(self):
17
+ return (time.time() - self.start) / 60