huazai676 commited on
Commit
ba80248
·
verified ·
1 Parent(s): 0adc7af

Upload 6 files

Browse files
StutterNet/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .models import *
2
+ from .io import *
3
+ from .losses import *
4
+ from .metrics import *
5
+ from .train import *
StutterNet/io.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pandas as pd
4
+ import librosa
5
+ import torchaudio as audio
6
+
7
+ class SEP28KDataset(torch.utils.data.Dataset):
8
+ """SEP-28k Dataset."""
9
+
10
+ def __init__(self, x, y, unsqueeze=False, transform=None):
11
+ """
12
+ Args:
13
+ x (hdf5): hdf5 data one of 'Xtrain', 'Xtest', or 'Xvalid'
14
+ y (hdf5): hdf5 file one of 'Ytrain', 'Ytest', or 'Yvalid'
15
+ unsqueeze (bool, Optional): Whether or not to unsqueeze the feature.
16
+ May be required for models that require image-like inputs.
17
+ transform (callable, Optional): Optional transform to be applied
18
+ on a sample.
19
+ """
20
+ self.data = x
21
+ self.labels = y
22
+ # self.spec = audio.transforms.MelSpectrogram(n_mels=80, sample_rate=16000,
23
+ # n_fft=512, f_max=8000, f_min=0,
24
+ # power=0.5, hop_length=152, win_length=480)
25
+ # self.db = audio.transforms.AmplitudeToDB()
26
+
27
+ # self.freq_mask = audio.transforms.FrequencyMasking(freq_mask_param=1)
28
+ # self.time_mask = audio.transforms.TimeMasking(time_mask_param=20)
29
+
30
+ # self.rng = np.random.default_rng(42)
31
+ # self.rng_2 = np.random.default_rng(68)
32
+ self.unsqueeze = unsqueeze
33
+
34
+ self.transform = transform
35
+
36
+ def __len__(self):
37
+ return len(self.data)
38
+
39
+ def __getitem__(self, idx):
40
+ if torch.is_tensor(idx):
41
+ idx = idx.tolist()
42
+
43
+ # load sliced clip
44
+ # _, wav = wavfile.read(clip_path)
45
+ wav = self.data[idx]
46
+ wav = self.pad_trunc(wav, 3000, 16000).astype('float32')
47
+
48
+ wav = torch.tensor(wav)
49
+ #wav = self.spec(wav)
50
+ #wav = self.db(wav)
51
+
52
+ #if (self.rng.choice(2,p=[0.2,0.8])):
53
+ # wav = self.freq_mask(wav)
54
+
55
+ # if (self.rng_2.choice(2,p=[0.2,0.8])):
56
+ # wav = self.time_mask(wav)
57
+
58
+ # get labels
59
+ labels = self.labels[idx].astype('float32')
60
+
61
+ if self.transform is not None:
62
+ wav = self.transform(wav)
63
+
64
+ if (self.unsqueeze):
65
+ wav = torch.unsqueeze(wav, 0)
66
+
67
+ return torch.tensor(wav).clone().detach(), torch.tensor(labels).clone().detach()
68
+
69
+ @staticmethod
70
+ def pad_trunc(sig, max_ms, sr):
71
+ sig_len = sig.shape[0]
72
+ max_len = sr//1000 * max_ms
73
+
74
+ if (sig_len > max_len):
75
+ # Truncate the signal to the given length
76
+ sig = sig[:,:max_len]
77
+
78
+ elif (sig_len < max_len):
79
+ # Length of padding to add at the beginning and end of the signal
80
+ pad_begin_len = np.random.randint(0, max_len - sig_len)
81
+ pad_end_len = max_len - sig_len - pad_begin_len
82
+
83
+ # Pad with 0s
84
+ pad_begin = np.zeros((pad_begin_len))
85
+ pad_end = np.zeros((pad_end_len))
86
+
87
+ sig = np.concatenate((pad_begin, sig, pad_end), 0)
88
+
89
+ return sig
StutterNet/losses.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torchvision.ops import sigmoid_focal_loss
4
+
5
+ class CCCLoss(nn.Module):
6
+ '''concordance correlation coefficient loss'''
7
+ def __init__(self, eps=1e-7):
8
+ '''
9
+ Args:
10
+ eps (float, optional): stabilizing term
11
+ '''
12
+ super(CCCLoss, self).__init__()
13
+ self.eps = eps
14
+ def forward(self, y_hat, y):
15
+ gold_mean = torch.mean(y.T, dim=-1, keepdim=True)
16
+ pred_mean = torch.mean(y_hat.T, dim=-1, keepdim=True)
17
+ covariance = (y.T-gold_mean)*(y_hat.T-pred_mean)
18
+ gold_var = torch.mean(torch.square(y.T-gold_mean), dim=-1, keepdim=True)
19
+ pred_var = torch.mean(torch.square(y_hat.T-pred_mean), dim=-1, keepdim=True)
20
+ ccc = 2 * covariance / (gold_var + pred_var + torch.square(gold_mean - pred_mean) + self.eps)
21
+ return torch.mean(1-ccc, dim=-1)
22
+ # return torch.mean(torch.mean(1-ccc, dim=-1))
23
+
24
+ class SigmoidFocalLoss(nn.Module):
25
+ def __init__(self, reduction=None):
26
+ super(SigmoidFocalLoss, self).__init__()
27
+ self.reduction = reduction
28
+
29
+ def forward(self, y_hat , y):
30
+ loss = sigmoid_focal_loss(y_hat, y, reduction=self.reduction)
31
+ return loss
32
+
33
+ class StutterLoss(nn.Module):
34
+ '''SEP-28k Loss '''
35
+ def __init__(self, alpha=1, beta=1, stutter_weights=None, reduction='mean'):
36
+ super(StutterLoss, self).__init__()
37
+ self.stutter_loss = CCCLoss()
38
+ self.disfluency_loss = SigmoidFocalLoss(reduction=reduction)
39
+ self.alpha = alpha
40
+ self.beta = beta
41
+ self.stutter_weights = stutter_weights
42
+ if (isinstance(self.stutter_weights, torch.Tensor)):
43
+ self.stutter_weights = self.stutter_weights.reshape((1,-1))
44
+
45
+ def forward(self, y_hat , y):
46
+ '''expects list of inputs and outputs'''
47
+ y_class, y_bin = torch.split(y, [6,6], dim=-1)
48
+ y_hat_class, y_hat_bin = torch.split(y_hat, [6,6], dim=-1)
49
+ disfluency_loss = self.disfluency_loss(y_hat_class, y_class)
50
+ stutter_loss = torch.mean(self.stutter_loss(y_hat_bin, y_bin))
51
+ if (not isinstance(self.stutter_weights, torch.Tensor)):
52
+ return self.alpha * stutter_loss + self.beta * torch.mean(disfluency_loss, dim=0)
53
+ return self.alpha * stutter_loss + self.beta * self.stutter_weights@disfluency_loss(y_hat_class, y_class)
StutterNet/metrics.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics import f1_score
2
+ import numpy as np
3
+
4
+ #TODO: implement as nn.Module subclass
5
+
6
+ def f1(y_hat, y):
7
+ per_class_score = f1_score(y.cpu().detach().numpy().astype('int'),
8
+ (sigmoid(y_hat.cpu().detach().numpy()) > 0.5).astype('int'),
9
+ average='samples', zero_division=1)
10
+ return np.mean(per_class_score)
11
+
12
+ def accuracy(outputs, labels):
13
+ # y_hat = (sigmoid(outputs.cpu().detach().numpy()).flatten() > 0.5).astype('int')
14
+ # y = labels.cpu().detach().numpy().flatten().astype('int')
15
+ y_hat = (sigmoid(outputs.cpu().detach().numpy()) > 0.5).astype('int')
16
+ y = labels.cpu().detach().numpy().astype('int')
17
+ batch_size = y.shape[0]
18
+ per_class_acc = np.sum(y == y_hat, axis=0) / batch_size
19
+ # total = float(len(y))
20
+ # correct = float(np.sum(y == y_hat))
21
+ # return correct / total
22
+ return np.mean(per_class_acc)
StutterNet/models.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torchaudio as audio
4
+ from torch import Tensor
5
+
6
+ class StutterNet(nn.Module):
7
+ def __init__(self, n_mels=40,
8
+ dropout=0.0, use_batchnorm=False, scale=1):
9
+ '''Implementation of StutterNet
10
+ from Sheikh et al. StutterNet:
11
+ "Stuttering Detection Using
12
+ Time Delay Neural Network" 2021
13
+
14
+ Args:
15
+ n_mels (int, optional): number of mel filter banks
16
+ n_classes (int, optional): number of classes in output layer
17
+ use_dropout (bool, optional): whether or not to use dropout in the
18
+ last two linear layers
19
+ use_batchnorm (bool, optional): whether ot not to batchnorm in the
20
+ TDNN layers
21
+ scale (float ,optional): width scale factor
22
+ '''
23
+ super(StutterNet, self).__init__()
24
+
25
+ self.n_mels = n_mels
26
+
27
+ # self.spec = audio.transforms.MelSpectrogram(n_mels=n_mels, sample_rate=16000,
28
+ # n_fft=512, pad=1, f_max=8000, win_length=400,
29
+ # f_min=0, power=2.0, hop_length=160, norm='slaney')
30
+ # self.db = audio.transforms.AmplitudeToDB()
31
+ # self.mfcc = audio.transforms.MFCC(16000, 40)
32
+ self.tdnn_1 = nn.Conv1d(n_mels, int(512*scale), 5, dilation=1)
33
+ self.tdnn_2 = nn.Conv1d(int(512*scale), int(1536*scale), 5, dilation=2)
34
+ self.tdnn_3 = nn.Conv1d(int(1536*scale), int(512*scale), 7, dilation=3)
35
+ self.tdnn_4 = nn.Conv1d(int(512*scale), int(512*scale), 1)
36
+ self.tdnn_5 = nn.Conv1d(int(512*scale), int(1500*scale), 1)
37
+ self.fc_1 = nn.Linear(int(3000*scale), 512)
38
+ self.relu = nn.ReLU()
39
+ self.bn_1 = nn.BatchNorm1d(int(512*scale))
40
+ self.bn_2 = nn.BatchNorm1d(int(1536*scale))
41
+ self.bn_3 = nn.BatchNorm1d(int(512*scale))
42
+ self.bn_4 = nn.BatchNorm1d(int(512*scale))
43
+ self.bn_5 = nn.BatchNorm1d(int(1500*scale))
44
+
45
+ nn.init.xavier_uniform_(self.fc_1.weight)
46
+ self.dropout_1 = nn.Dropout(dropout)
47
+ self.fc_2 = nn.Linear(512, 512)
48
+ nn.init.xavier_uniform_(self.fc_1.weight)
49
+ self.dropout_2 = nn.Dropout(dropout)
50
+
51
+ self.binary_head = nn.Linear(512, 6)
52
+ self.class_head = nn.Linear(512, 6)
53
+
54
+ self.sig = nn.Sigmoid()
55
+
56
+ def forward(self, x):
57
+ '''forward method'''
58
+ batch_size = x.shape[0]
59
+
60
+ # x = self.spec(x)
61
+ # x = self.db(x)
62
+ # x = self.mfcc(x)
63
+ x = self.tdnn_1(x)
64
+ x = self.relu(x)
65
+ x = self.bn_1(x)
66
+ x = self.tdnn_2(x)
67
+ x = self.relu(x)
68
+ x = self.bn_2(x)
69
+ x = self.tdnn_3(x)
70
+ x = self.relu(x)
71
+ x = self.bn_3(x)
72
+ x = self.tdnn_4(x)
73
+ x = self.relu(x)
74
+ x = self.bn_4(x)
75
+ x = self.tdnn_5(x)
76
+ x = self.relu(x)
77
+ x = self.bn_5(x)
78
+
79
+ mean = torch.mean(x,-1)
80
+ std = torch.std(x,-1)
81
+ x = torch.cat((mean,std),1)
82
+ x = self.fc_1(x)
83
+ x = self.dropout_1(x)
84
+ x = self.fc_2(x)
85
+ x = self.dropout_2(x)
86
+
87
+ binary = self.binary_head(x)
88
+ # binary = self.sig(binary)
89
+
90
+ classes = self.class_head(x)
91
+ # classes = self.sig(classes)
92
+
93
+ # return torch.cat((classes, binary), dim=-1)
94
+ return torch.cat((binary, classes), dim=-1)
95
+
96
+ class ResBlock1d(nn.Module):
97
+ def __init__(self, input_dims, output_dims, depth=2, kernel_size=3,
98
+ use_batchnorm=False, downsample=False, dropout=0.0):
99
+ super(ResBlock1d, self).__init__()
100
+
101
+ self.depth = depth
102
+ self.use_batchnorm = use_batchnorm
103
+
104
+ scale = 1
105
+ self.up = None
106
+ if (downsample):
107
+ self.down = nn.Conv1d(int(input_dims), int(output_dims), 3, 2, padding=1)
108
+ # self.down = nn.MaxPool1d(1, stride=2)
109
+ scale=2
110
+
111
+ self.downsample = downsample
112
+
113
+ self.conv_1 = nn.Conv1d(int(input_dims),
114
+ output_dims, 3, stride=scale, padding=1)
115
+
116
+ self.convs = nn.ModuleList([nn.Conv1d(output_dims,
117
+ output_dims, kernel_size, padding='same') for _ in range(depth-1)])
118
+
119
+ self.bn_1 = nn.BatchNorm1d(output_dims)
120
+ self.bn = None
121
+
122
+ if (use_batchnorm):
123
+ self.bn = nn.ModuleList([nn.BatchNorm1d(
124
+ output_dims) for _ in range(depth-1)])
125
+
126
+ self.relu = nn.ReLU()
127
+ self.dropout = nn.Dropout(dropout)
128
+
129
+ def forward(self, x):
130
+
131
+ temp = x
132
+ if (self.downsample):
133
+ temp = self.down(x)
134
+
135
+ x = self.conv_1(x)
136
+ x = self.bn_1(x)
137
+
138
+ if (not self.use_batchnorm):
139
+ for i in range(self.depth-1):
140
+ x = self.convs[i](x)
141
+ x = self.dropout(x)
142
+ if (i != self.depth-2):
143
+ x = self.relu(x)
144
+ else:
145
+ for i in range(self.depth-1):
146
+ x = self.convs[i](x)
147
+ x = self.dropout(x)
148
+ x = self.bn[i](x)
149
+ if (i != self.depth-2):
150
+ x = self.relu(x)
151
+ x = temp + x
152
+
153
+ return x
154
+
155
+ class ResNet1D(nn.Module):
156
+ def __init__(self, n_mels=100,n_classes=12, kernel_size=3,
157
+ dropout=0.0, use_batchnorm=False, scale=1):
158
+ '''Implementation of StutterNet
159
+ from Sheikh et al. StutterNet:
160
+ "Stuttering Detection Using
161
+ Time Delay Neural Network" 2021
162
+
163
+ Args:
164
+ n_mels (int, optional): number of mel filter banks
165
+ n_classes (int, optional): number of classes in output layer
166
+ use_dropout (bool, optional): whether or not to use dropout in the
167
+ last two linear layers
168
+ use_batchnorm (bool, optional): whether ot not to batchnorm in the
169
+ TDNN layers
170
+ scale (float ,optional): width scale factor
171
+ '''
172
+ super(ResNet1D, self).__init__()
173
+
174
+ self.n_mels = n_mels
175
+
176
+ # self.spec = audio.transforms.MelSpectrogram(n_mels=n_mels, sample_rate=16000,
177
+ # n_fft=512, pad=1, f_max=8000, f_min=0,
178
+ # power=2.0, hop_length=160)
179
+ # self.mfcc = audio.transforms.MFCC(16000, 40)
180
+ # self.db = audio.transforms.AmplitudeToDB()
181
+ self.tdnn_1 = nn.Conv1d(n_mels, int(64*scale), 3, padding=1, bias=False)
182
+
183
+ self.res_1_1 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
184
+ self.res_1_2 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
185
+ self.res_1_3 = ResBlock1d(int(64*scale), int(64*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
186
+
187
+ self.res_2_1 = ResBlock1d(int(64*scale), int(128*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
188
+ self.res_2_2 = ResBlock1d(int(128*scale), int(128*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
189
+ self.res_2_3 = ResBlock1d(int(128*scale), int(128*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
190
+
191
+ self.res_3_1 = ResBlock1d(int(128*scale), int(256*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
192
+ self.res_3_2 = ResBlock1d(int(256*scale), int(256*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
193
+ self.res_3_3 = ResBlock1d(int(256*scale), int(256*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
194
+
195
+ self.res_4_1 = ResBlock1d(int(256*scale), int(512*scale), kernel_size=kernel_size, downsample=True, use_batchnorm=use_batchnorm, dropout=dropout)
196
+ self.res_4_2 = ResBlock1d(int(512*scale), int(512*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
197
+ self.res_4_3 = ResBlock1d(int(512*scale), int(512*scale), kernel_size=kernel_size, downsample=False, use_batchnorm=use_batchnorm, dropout=dropout)
198
+
199
+ # self.bn = nn.BatchNorm1d(int(512*scale))
200
+
201
+ self.relu = nn.ReLU()
202
+ self.fc = nn.Linear(int(1024*scale), n_classes)
203
+
204
+ def forward(self, x):
205
+ '''forward method'''
206
+ batch_size = x.shape[0]
207
+
208
+ # x = self.spec(x)
209
+ # x = self.mfcc(x)
210
+ # x = self.db(x)
211
+ x = self.tdnn_1(x)
212
+
213
+ x = self.res_1_1(x)
214
+ x = self.relu(x)
215
+ x = self.res_1_2(x)
216
+ x = self.relu(x)
217
+ x = self.res_1_3(x)
218
+ x = self.relu(x)
219
+
220
+ x = self.res_2_1(x)
221
+ x = self.relu(x)
222
+ x = self.res_2_2(x)
223
+ x = self.relu(x)
224
+ x = self.res_2_3(x)
225
+ x = self.relu(x)
226
+
227
+ x = self.res_3_1(x)
228
+ x = self.relu(x)
229
+ x = self.res_3_2(x)
230
+ x = self.relu(x)
231
+ x = self.res_3_3(x)
232
+ x = self.relu(x)
233
+
234
+ x = self.res_4_1(x)
235
+ x = self.relu(x)
236
+ x = self.res_4_2(x)
237
+ x = self.relu(x)
238
+ x = self.res_4_3(x)
239
+ x = self.relu(x)
240
+
241
+ # x = self.bn(x)
242
+ mean = torch.mean(x,-1)
243
+ std = torch.std(x,-1)
244
+ x = torch.cat((mean,std),1)
245
+ x = self.fc(x)
246
+
247
+ return x
248
+
249
+ from torch import Tensor
250
+
251
+ '''credit: https://github.com/roman-vygon/BCResNet'''
252
+
253
+ class SubSpectralNorm(nn.Module):
254
+ def __init__(self, C, S, eps=1e-5):
255
+ super(SubSpectralNorm, self).__init__()
256
+ self.S = S
257
+ self.eps = eps
258
+ self.bn = nn.BatchNorm2d(C*S)
259
+
260
+ def forward(self, x):
261
+ # x: input features with shape {N, C, F, T}
262
+ # S: number of sub-bands
263
+ N, C, F, T = x.size()
264
+ x = x.view(N, C * self.S, F // self.S, T)
265
+
266
+ x = self.bn(x)
267
+
268
+ return x.view(N, C, F, T)
269
+
270
+
271
+ class BroadcastedBlock(nn.Module):
272
+ def __init__(
273
+ self,
274
+ planes: int,
275
+ dilation=1,
276
+ stride=1,
277
+ temp_pad=(0, 1),
278
+ ) -> None:
279
+ super(BroadcastedBlock, self).__init__()
280
+
281
+ self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
282
+ dilation=dilation,
283
+ stride=stride, bias=False)
284
+ self.ssn1 = SubSpectralNorm(planes, 5)
285
+ self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
286
+ dilation=dilation, stride=stride, bias=False)
287
+ self.bn = nn.BatchNorm2d(planes)
288
+ self.relu = nn.ReLU(inplace=True)
289
+ self.channel_drop = nn.Dropout2d(p=0.5)
290
+ self.swish = nn.SiLU()
291
+ self.conv1x1 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
292
+
293
+ def forward(self, x: Tensor) -> Tensor:
294
+ identity = x
295
+
296
+ # f2
297
+ ##########################
298
+ out = self.freq_dw_conv(x)
299
+ out = self.ssn1(out)
300
+ ##########################
301
+
302
+ auxilary = out
303
+ out = out.mean(2, keepdim=True) # frequency average pooling
304
+
305
+ # f1
306
+ ############################
307
+ out = self.temp_dw_conv(out)
308
+ out = self.bn(out)
309
+ out = self.swish(out)
310
+ out = self.conv1x1(out)
311
+ out = self.channel_drop(out)
312
+ ############################
313
+
314
+ out = out + identity + auxilary
315
+ out = self.relu(out)
316
+
317
+ return out
318
+
319
+
320
+ class TransitionBlock(nn.Module):
321
+
322
+ def __init__(
323
+ self,
324
+ inplanes: int,
325
+ planes: int,
326
+ dilation=1,
327
+ stride=1,
328
+ temp_pad=(0, 1),
329
+ ) -> None:
330
+ super(TransitionBlock, self).__init__()
331
+
332
+ self.freq_dw_conv = nn.Conv2d(planes, planes, kernel_size=(3, 1), padding=(1, 0), groups=planes,
333
+ stride=stride,
334
+ dilation=dilation, bias=False)
335
+ self.ssn = SubSpectralNorm(planes, 5)
336
+ self.temp_dw_conv = nn.Conv2d(planes, planes, kernel_size=(1, 3), padding=temp_pad, groups=planes,
337
+ dilation=dilation, stride=stride, bias=False)
338
+ self.bn1 = nn.BatchNorm2d(planes)
339
+ self.bn2 = nn.BatchNorm2d(planes)
340
+ self.relu = nn.ReLU(inplace=True)
341
+ self.channel_drop = nn.Dropout2d(p=0.5)
342
+ self.swish = nn.SiLU()
343
+ self.conv1x1_1 = nn.Conv2d(inplanes, planes, kernel_size=(1, 1), bias=False)
344
+ self.conv1x1_2 = nn.Conv2d(planes, planes, kernel_size=(1, 1), bias=False)
345
+
346
+ def forward(self, x: Tensor) -> Tensor:
347
+ # f2
348
+ #############################
349
+ out = self.conv1x1_1(x)
350
+ out = self.bn1(out)
351
+ out = self.relu(out)
352
+ out = self.freq_dw_conv(out)
353
+ out = self.ssn(out)
354
+ #############################
355
+ auxilary = out
356
+ out = out.mean(2, keepdim=True) # frequency average pooling
357
+
358
+ # f1
359
+ #############################
360
+ out = self.temp_dw_conv(out)
361
+ out = self.bn2(out)
362
+ out = self.swish(out)
363
+ out = self.conv1x1_2(out)
364
+ out = self.channel_drop(out)
365
+ #############################
366
+
367
+ out = auxilary + out
368
+ out = self.relu(out)
369
+
370
+ return out
371
+
372
+ class BCResNet(torch.nn.Module):
373
+ def __init__(self):
374
+ super(BCResNet, self).__init__()
375
+ self.conv1 = nn.Conv2d(1, 16, 5, stride=(2, 1), padding=(2, 2))
376
+ self.block1_1 = TransitionBlock(16, 8)
377
+ self.block1_2 = BroadcastedBlock(8)
378
+
379
+ self.block2_1 = TransitionBlock(8, 12, stride=(2, 1), dilation=(1, 2), temp_pad=(0, 2))
380
+ self.block2_2 = BroadcastedBlock(12, dilation=(1, 2), temp_pad=(0, 2))
381
+
382
+ self.block3_1 = TransitionBlock(12, 16, stride=(2, 1), dilation=(1, 4), temp_pad=(0, 4))
383
+ self.block3_2 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
384
+ self.block3_3 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
385
+ self.block3_4 = BroadcastedBlock(16, dilation=(1, 4), temp_pad=(0, 4))
386
+
387
+ self.block4_1 = TransitionBlock(16, 20, dilation=(1, 8), temp_pad=(0, 8))
388
+ self.block4_2 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
389
+ self.block4_3 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
390
+ self.block4_4 = BroadcastedBlock(20, dilation=(1, 8), temp_pad=(0, 8))
391
+
392
+ self.conv2 = nn.Conv2d(20, 20, 5, groups=20, padding=(0, 2))
393
+ self.conv3 = nn.Conv2d(20, 32, 1, bias=False)
394
+ self.conv4 = nn.Conv2d(32, 12, 1, bias=False)
395
+
396
+ def forward(self, x):
397
+
398
+ out = self.conv1(x)
399
+
400
+ out = self.block1_1(out)
401
+ out = self.block1_2(out)
402
+
403
+ out = self.block2_1(out)
404
+ out = self.block2_2(out)
405
+
406
+ out = self.block3_1(out)
407
+ out = self.block3_2(out)
408
+ out = self.block3_3(out)
409
+ out = self.block3_4(out)
410
+
411
+ out = self.block4_1(out)
412
+ out = self.block4_2(out)
413
+ out = self.block4_3(out)
414
+ out = self.block4_4(out)
415
+
416
+ out = self.conv2(out)
417
+
418
+ out = self.conv3(out)
419
+ out = out.mean(-1, keepdim=True)
420
+
421
+ out = self.conv4(out)
422
+
423
+ return out.reshape((-1, 12))
StutterNet/train.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import numpy as np
4
+ import argparse
5
+
6
+ def sigmoid(x):
7
+ return 1 / (1 + np.exp(-x))
8
+
9
+ def parser():
10
+ #TODO: create parser
11
+ ap = argparse.ArgumentParser()
12
+ return ap.parse_args()
13
+
14
+ def train(net, trainloader, criterion, batch_size, target_names,
15
+ validationloader=None, optimizer=None,
16
+ scheduler=None, epochs=50, logdir=None, metrics=None,
17
+ verbose=True, tuner=False, checkpoint_dir=None):
18
+ ''' training loop function for simple
19
+ supervised learning task.
20
+
21
+ Args:
22
+ net (torch.nn.Module): network to train
23
+ trainloader (torch.utils.data.DataLoader):
24
+ train data loader
25
+ criterion (torch.nn.object): criterion with which
26
+ to optimize the provided network
27
+ batch_size (int): batch of trainloader and validationloader
28
+ validationloader (torch.utils.data.DataLoader, optional):
29
+ validation data loader
30
+ optimizer (torch.optim.Optimizer, optional):
31
+ optimizer function, defaults to torch.nn.optim.Adam w/ amsgrad
32
+ scheduler (torch.optim.lr_scheduler, optional):
33
+ learning rate scheduler object
34
+ epochs (int, optional): number of epochs to train network,
35
+ defaults to 50
36
+ logdir (string, optional): path to tensorboard log directory,
37
+ if None logging default to ./runs/ directory
38
+ metrics (list of tuples, optional): metrics to be logged with
39
+ name and metric being the first and second element of the
40
+ each tuple respectively
41
+ verbose (bool, optional): whether or not to print information
42
+ to console
43
+ tuner (bool, optional): whether to employ ray tune
44
+ '''
45
+ from torch.utils.tensorboard import SummaryWriter
46
+ from sklearn.metrics import classification_report
47
+ writer = SummaryWriter(log_dir=logdir)
48
+
49
+ if (verbose):
50
+ from tensorflow.keras.utils import Progbar
51
+
52
+ if (optimizer is None):
53
+ optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, amsgrad=True)
54
+
55
+ start_epoch = 0
56
+
57
+ if (checkpoint_dir is not None):
58
+ # state, optim_state = torch.load(os.path.join(
59
+ # checkpoint_dir, "checkpoint"))
60
+ state = torch.load(checkpoint_dir)
61
+ start_epoch = state['epoch']
62
+ net.load_state_dict(state['state_dict'])
63
+ optimizer.load_state_dict(state['optimizer'])
64
+
65
+ assert epochs > 0, "Assertion failed. epochs must be greater than 0!"
66
+
67
+ steps = 0
68
+
69
+ device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') # get device
70
+
71
+ net.train(True)
72
+
73
+ # net.to(device)
74
+
75
+ if (tuner):
76
+ from ray import tune
77
+ import os
78
+
79
+ for i in range(start_epoch, start_epoch + epochs):
80
+ num_batches = len(trainloader)
81
+ num_samples = num_batches * batch_size
82
+
83
+ if (verbose):
84
+ print("\nepoch {}/{}".format(i+1, start_epoch+epochs))
85
+ pbar = Progbar(target=num_batches)
86
+
87
+ # if (metrics is not None):
88
+ # train_metrics = [0 for metric in metrics]
89
+
90
+ y_true = np.zeros((num_samples, 12))
91
+ y_pred = np.zeros((num_samples, 12))
92
+ idx = 0
93
+
94
+ for j, data in enumerate(iter(trainloader)):
95
+ # get the inputs; data is a list of [inputs, labels]
96
+ inputs, labels = data[0].to(device), data[1].to(device)
97
+ # inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)]
98
+
99
+ # zero the parameter gradients
100
+ optimizer.zero_grad()
101
+
102
+ # forward + backward + optimize
103
+ outputs = net(inputs)
104
+ train_loss = criterion(outputs, labels)
105
+ train_loss.backward()
106
+ optimizer.step()
107
+
108
+ y_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy()
109
+ y_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy()
110
+
111
+ idx += outputs.shape[0]
112
+
113
+ if (scheduler is not None):
114
+ scheduler.step()
115
+
116
+ if (verbose):
117
+ pbar.update(j, values=[("loss",
118
+ train_loss.detach().cpu().numpy().item())])
119
+
120
+ steps += 1
121
+
122
+ writer.add_scalar('Loss/train',
123
+ train_loss.detach().cpu().numpy().item(), steps)
124
+
125
+ # if (metrics is not None):
126
+ # for (j, metric) in enumerate(metrics):
127
+ # # train_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy()
128
+ # train_metrics[j] += metric[1](outputs, labels)
129
+
130
+ rep = classification_report(y_true.astype('int'),
131
+ (sigmoid(y_pred) > 0.5).astype('int'), target_names=target_names,
132
+ output_dict=True)
133
+
134
+ for k in rep.keys():
135
+ for j in rep[k].keys():
136
+ writer.add_scalar(j + '/' + k + '/train',
137
+ rep[k][j], steps)
138
+
139
+ # if (metrics is not None):
140
+ # for (j, metric) in enumerate(metrics):
141
+ # # writer.add_scalar(metric[0] + '/train',
142
+ # # train_metrics[j] / num_samples, steps)
143
+ # writer.add_scalar(metric[0] + '/train',
144
+ # train_metrics[j] / num_batches, steps)
145
+
146
+ if (validationloader is not None):
147
+ net.train(False)
148
+ val_loss = 0
149
+ # if (metrics is not None):
150
+ # val_metrics = [0 for metric in metrics]
151
+ num_val_batches = len(validationloader)
152
+ num_val_samples = num_val_batches * batch_size
153
+
154
+ y_val_true = np.zeros((num_val_samples, 12))
155
+ y_val_pred = np.zeros((num_val_samples, 12))
156
+
157
+ idx = 0
158
+
159
+ for data in iter(validationloader):
160
+ # get the inputs; data is a list of [inputs, labels]
161
+ inputs, labels = data[0].to(device), data[1].to(device)
162
+ # inputs, labels = data[0].to(device), [data[1][0].to(device), data[1][1].to(device)]
163
+
164
+ outputs = net(inputs)
165
+ val_loss += criterion(outputs, labels).detach().cpu().numpy()
166
+
167
+ y_val_true[idx:idx+outputs.shape[0], :] = labels.detach().cpu().numpy()
168
+ y_val_pred[idx:idx+outputs.shape[0], :] = outputs.detach().cpu().numpy()
169
+
170
+ idx += outputs.shape[0]
171
+
172
+ # if (metrics is not None):
173
+ # for (j, metric) in enumerate(metrics):
174
+ # # val_metrics[j] += metric[1](outputs, labels).detach().cpu().numpy()
175
+ # val_metrics[j] += metric[1](outputs, labels)
176
+
177
+ val_loss /= (num_val_batches) # assume all validation set used
178
+ # scheduler.step(val_loss)
179
+
180
+ rep = classification_report(y_val_true.astype('int'),
181
+ (sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names,
182
+ output_dict=True)
183
+ print(classification_report(y_val_true.astype('int'),
184
+ (sigmoid(y_val_pred) > 0.5).astype('int'), target_names=target_names))
185
+ # output_dict=False)
186
+ #print(rep2)
187
+
188
+ for k in rep.keys():
189
+ for j in rep[k].keys():
190
+ writer.add_scalar(j + '/' + k + '/valid',
191
+ rep[k][j], steps)
192
+
193
+ writer.add_scalar('Loss/valid', val_loss, steps)
194
+
195
+ # if (metrics is not None):
196
+ # for (j, metric) in enumerate(metrics):
197
+ # # writer.add_scalar(metric[0] + '/valid',
198
+ # # val_metrics[j] / num_val_samples, steps)
199
+ # writer.add_scalar(metric[0] + '/valid',
200
+ # val_metrics[j] / num_val_batches, steps)
201
+
202
+ # if (tuner):
203
+ # with tune.checkpoint_dir(i+1) as checkpoint_dir:
204
+ # path = os.path.join(checkpoint_dir, "checkpoint")
205
+ # torch.save((net.state_dict(), optimizer.state_dict()), path)
206
+
207
+ # tune.report(loss=val_loss, accuracy=val_metrics[0] / num_val_samples, iters=i+1)
208
+
209
+ if (verbose):
210
+ pbar.update(num_batches, values=[("val_loss",val_loss.item())])
211
+ net.train(True)
212
+ else:
213
+ if (verbose):
214
+ pbar.update(num_batches, values=None)
215
+
216
+ if __name__ == "__main__":
217
+ args = parser() # get arguments
218
+
219
+ # TODO: implement args such that we can train from the command line
220
+ #train(args.net, args.trainloader, args.criterion, args.batch_size,
221
+ # args.validationloader, args.optimizer,
222
+ # args.scheduler, args.epochs, args.logdir, args.metrics,
223
+ # args.verbose, args.tuner, args.checkpoint_dir):