pavankumarvk commited on
Commit
67492f2
·
verified ·
1 Parent(s): 9c5fca6

Upload rawnet.py

Browse files
Files changed (1) hide show
  1. rawnet.py +365 -0
rawnet.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch import Tensor
5
+ import numpy as np
6
+ from torch.utils import data
7
+ from collections import OrderedDict
8
+ from torch.nn.parameter import Parameter
9
+
10
+
11
+
12
+
13
+ ___author__ = "Hemlata Tak"
14
+ __email__ = "tak@eurecom.fr"
15
+
16
+
17
+ class SincConv(nn.Module):
18
+ @staticmethod
19
+ def to_mel(hz):
20
+ return 2595 * np.log10(1 + hz / 700)
21
+
22
+ @staticmethod
23
+ def to_hz(mel):
24
+ return 700 * (10 ** (mel / 2595) - 1)
25
+
26
+
27
+ def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000,
28
+ stride=1, padding=0, dilation=1, bias=False, groups=1):
29
+
30
+ super(SincConv,self).__init__()
31
+
32
+ if in_channels != 1:
33
+
34
+ msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
35
+ raise ValueError(msg)
36
+
37
+ self.out_channels = out_channels
38
+ self.kernel_size = kernel_size
39
+ self.sample_rate=sample_rate
40
+
41
+ # Forcing the filters to be odd (i.e, perfectly symmetrics)
42
+ if kernel_size%2==0:
43
+ self.kernel_size=self.kernel_size+1
44
+
45
+ self.device=device
46
+ self.stride = stride
47
+ self.padding = padding
48
+ self.dilation = dilation
49
+
50
+ if bias:
51
+ raise ValueError('SincConv does not support bias.')
52
+ if groups > 1:
53
+ raise ValueError('SincConv does not support groups.')
54
+
55
+
56
+ # initialize filterbanks using Mel scale
57
+ NFFT = 512
58
+ f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1)
59
+ fmel=self.to_mel(f) # Hz to mel conversion
60
+ fmelmax=np.max(fmel)
61
+ fmelmin=np.min(fmel)
62
+ filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1)
63
+ filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion
64
+ self.mel=filbandwidthsf
65
+ self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1)
66
+ self.band_pass=torch.zeros(self.out_channels,self.kernel_size)
67
+
68
+
69
+
70
+ def forward(self,x):
71
+ for i in range(len(self.mel)-1):
72
+ fmin=self.mel[i]
73
+ fmax=self.mel[i+1]
74
+ hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate)
75
+ hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate)
76
+ hideal=hHigh-hLow
77
+
78
+ self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal)
79
+
80
+ band_pass_filter=self.band_pass.to(self.device)
81
+
82
+ self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
83
+
84
+ return F.conv1d(x, self.filters, stride=self.stride,
85
+ padding=self.padding, dilation=self.dilation,
86
+ bias=None, groups=1)
87
+
88
+
89
+
90
+ class Residual_block(nn.Module):
91
+ def __init__(self, nb_filts, first = False):
92
+ super(Residual_block, self).__init__()
93
+ self.first = first
94
+
95
+ if not self.first:
96
+ self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
97
+
98
+ self.lrelu = nn.LeakyReLU(negative_slope=0.3)
99
+
100
+ self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
101
+ out_channels = nb_filts[1],
102
+ kernel_size = 3,
103
+ padding = 1,
104
+ stride = 1)
105
+
106
+ self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
107
+ self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
108
+ out_channels = nb_filts[1],
109
+ padding = 1,
110
+ kernel_size = 3,
111
+ stride = 1)
112
+
113
+ if nb_filts[0] != nb_filts[1]:
114
+ self.downsample = True
115
+ self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
116
+ out_channels = nb_filts[1],
117
+ padding = 0,
118
+ kernel_size = 1,
119
+ stride = 1)
120
+
121
+ else:
122
+ self.downsample = False
123
+ self.mp = nn.MaxPool1d(3)
124
+
125
+ def forward(self, x):
126
+ identity = x
127
+ if not self.first:
128
+ out = self.bn1(x)
129
+ out = self.lrelu(out)
130
+ else:
131
+ out = x
132
+
133
+ out = self.conv1(x)
134
+ out = self.bn2(out)
135
+ out = self.lrelu(out)
136
+ out = self.conv2(out)
137
+
138
+ if self.downsample:
139
+ identity = self.conv_downsample(identity)
140
+
141
+ out += identity
142
+ out = self.mp(out)
143
+ return out
144
+
145
+
146
+
147
+
148
+
149
+ class RawNet(nn.Module):
150
+ def __init__(self, d_args, device):
151
+ super(RawNet, self).__init__()
152
+
153
+
154
+ self.device=device
155
+
156
+ self.Sinc_conv=SincConv(device=self.device,
157
+ out_channels = d_args['filts'][0],
158
+ kernel_size = d_args['first_conv'],
159
+ in_channels = d_args['in_channels']
160
+ )
161
+
162
+ self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
163
+ self.selu = nn.SELU(inplace=True)
164
+ self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True))
165
+ self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1]))
166
+ self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
167
+ d_args['filts'][2][0] = d_args['filts'][2][1]
168
+ self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
169
+ self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
170
+ self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
171
+ self.avgpool = nn.AdaptiveAvgPool1d(1)
172
+
173
+ self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
174
+ l_out_features = d_args['filts'][1][-1])
175
+ self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
176
+ l_out_features = d_args['filts'][1][-1])
177
+ self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
178
+ l_out_features = d_args['filts'][2][-1])
179
+ self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
180
+ l_out_features = d_args['filts'][2][-1])
181
+ self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
182
+ l_out_features = d_args['filts'][2][-1])
183
+ self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
184
+ l_out_features = d_args['filts'][2][-1])
185
+
186
+ self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
187
+ self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
188
+ hidden_size = d_args['gru_node'],
189
+ num_layers = d_args['nb_gru_layer'],
190
+ batch_first = True)
191
+
192
+
193
+ self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
194
+ out_features = d_args['nb_fc_node'])
195
+
196
+ self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
197
+ out_features = d_args['nb_classes'],bias=True)
198
+
199
+
200
+ self.sig = nn.Sigmoid()
201
+ self.logsoftmax = nn.LogSoftmax(dim=1)
202
+
203
+ def forward(self, x, y = None):
204
+
205
+
206
+ nb_samp = x.shape[0]
207
+ len_seq = x.shape[1]
208
+ x=x.view(nb_samp,1,len_seq)
209
+
210
+ x = self.Sinc_conv(x)
211
+ x = F.max_pool1d(torch.abs(x), 3)
212
+ x = self.first_bn(x)
213
+ x = self.selu(x)
214
+
215
+ x0 = self.block0(x)
216
+ y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
217
+ y0 = self.fc_attention0(y0)
218
+ y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
219
+ x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
220
+
221
+
222
+ x1 = self.block1(x)
223
+ y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
224
+ y1 = self.fc_attention1(y1)
225
+ y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
226
+ x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
227
+
228
+ x2 = self.block2(x)
229
+ y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
230
+ y2 = self.fc_attention2(y2)
231
+ y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
232
+ x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
233
+
234
+ x3 = self.block3(x)
235
+ y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
236
+ y3 = self.fc_attention3(y3)
237
+ y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
238
+ x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
239
+
240
+ x4 = self.block4(x)
241
+ y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
242
+ y4 = self.fc_attention4(y4)
243
+ y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
244
+ x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
245
+
246
+ x5 = self.block5(x)
247
+ y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
248
+ y5 = self.fc_attention5(y5)
249
+ y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
250
+ x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
251
+
252
+ x = self.bn_before_gru(x)
253
+ x = self.selu(x)
254
+ x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
255
+ self.gru.flatten_parameters()
256
+ x, _ = self.gru(x)
257
+ x = x[:,-1,:]
258
+ x = self.fc1_gru(x)
259
+ x = self.fc2_gru(x)
260
+ output=self.logsoftmax(x)
261
+
262
+ return output
263
+
264
+
265
+
266
+ def _make_attention_fc(self, in_features, l_out_features):
267
+
268
+ l_fc = []
269
+
270
+ l_fc.append(nn.Linear(in_features = in_features,
271
+ out_features = l_out_features))
272
+
273
+
274
+
275
+ return nn.Sequential(*l_fc)
276
+
277
+
278
+ def _make_layer(self, nb_blocks, nb_filts, first = False):
279
+ layers = []
280
+ #def __init__(self, nb_filts, first = False):
281
+ for i in range(nb_blocks):
282
+ first = first if i == 0 else False
283
+ layers.append(Residual_block(nb_filts = nb_filts,
284
+ first = first))
285
+ if i == 0: nb_filts[0] = nb_filts[1]
286
+
287
+ return nn.Sequential(*layers)
288
+
289
+ def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None):
290
+ if print_fn == None: printfn = print
291
+ model = self
292
+
293
+ def register_hook(module):
294
+ def hook(module, input, output):
295
+ class_name = str(module.__class__).split(".")[-1].split("'")[0]
296
+ module_idx = len(summary)
297
+
298
+ m_key = "%s-%i" % (class_name, module_idx + 1)
299
+ summary[m_key] = OrderedDict()
300
+ summary[m_key]["input_shape"] = list(input[0].size())
301
+ summary[m_key]["input_shape"][0] = batch_size
302
+ if isinstance(output, (list, tuple)):
303
+ summary[m_key]["output_shape"] = [
304
+ [-1] + list(o.size())[1:] for o in output
305
+ ]
306
+ else:
307
+ summary[m_key]["output_shape"] = list(output.size())
308
+ if len(summary[m_key]["output_shape"]) != 0:
309
+ summary[m_key]["output_shape"][0] = batch_size
310
+
311
+ params = 0
312
+ if hasattr(module, "weight") and hasattr(module.weight, "size"):
313
+ params += torch.prod(torch.LongTensor(list(module.weight.size())))
314
+ summary[m_key]["trainable"] = module.weight.requires_grad
315
+ if hasattr(module, "bias") and hasattr(module.bias, "size"):
316
+ params += torch.prod(torch.LongTensor(list(module.bias.size())))
317
+ summary[m_key]["nb_params"] = params
318
+
319
+ if (
320
+ not isinstance(module, nn.Sequential)
321
+ and not isinstance(module, nn.ModuleList)
322
+ and not (module == model)
323
+ ):
324
+ hooks.append(module.register_forward_hook(hook))
325
+
326
+ device = device.lower()
327
+ assert device in [
328
+ "cuda",
329
+ "cpu",
330
+ ], "Input device is not valid, please specify 'cuda' or 'cpu'"
331
+
332
+ if device == "cuda" and torch.cuda.is_available():
333
+ dtype = torch.cuda.FloatTensor
334
+ else:
335
+ dtype = torch.FloatTensor
336
+ if isinstance(input_size, tuple):
337
+ input_size = [input_size]
338
+ x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
339
+ summary = OrderedDict()
340
+ hooks = []
341
+ model.apply(register_hook)
342
+ model(*x)
343
+ for h in hooks:
344
+ h.remove()
345
+
346
+ print_fn("----------------------------------------------------------------")
347
+ line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
348
+ print_fn(line_new)
349
+ print_fn("================================================================")
350
+ total_params = 0
351
+ total_output = 0
352
+ trainable_params = 0
353
+ for layer in summary:
354
+ # input_shape, output_shape, trainable, nb_params
355
+ line_new = "{:>20} {:>25} {:>15}".format(
356
+ layer,
357
+ str(summary[layer]["output_shape"]),
358
+ "{0:,}".format(summary[layer]["nb_params"]),
359
+ )
360
+ total_params += summary[layer]["nb_params"]
361
+ total_output += np.prod(summary[layer]["output_shape"])
362
+ if "trainable" in summary[layer]:
363
+ if summary[layer]["trainable"] == True:
364
+ trainable_params += summary[layer]["nb_params"]
365
+ print_fn(line_new)