pavankumarvk commited on
Commit
fd1c975
·
verified ·
1 Parent(s): 486f884

Update rawnet.py

Browse files
Files changed (1) hide show
  1. rawnet.py +240 -365
rawnet.py CHANGED
@@ -1,365 +1,240 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from torch import Tensor
5
- import numpy as np
6
- from torch.utils import data
7
- from collections import OrderedDict
8
- from torch.nn.parameter import Parameter
9
-
10
-
11
-
12
-
13
- ___author__ = "Hemlata Tak"
14
- __email__ = "tak@eurecom.fr"
15
-
16
-
17
- class SincConv(nn.Module):
18
- @staticmethod
19
- def to_mel(hz):
20
- return 2595 * np.log10(1 + hz / 700)
21
-
22
- @staticmethod
23
- def to_hz(mel):
24
- return 700 * (10 ** (mel / 2595) - 1)
25
-
26
-
27
- def __init__(self, device,out_channels, kernel_size,in_channels=1,sample_rate=16000,
28
- stride=1, padding=0, dilation=1, bias=False, groups=1):
29
-
30
- super(SincConv,self).__init__()
31
-
32
- if in_channels != 1:
33
-
34
- msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
35
- raise ValueError(msg)
36
-
37
- self.out_channels = out_channels
38
- self.kernel_size = kernel_size
39
- self.sample_rate=sample_rate
40
-
41
- # Forcing the filters to be odd (i.e, perfectly symmetrics)
42
- if kernel_size%2==0:
43
- self.kernel_size=self.kernel_size+1
44
-
45
- self.device=device
46
- self.stride = stride
47
- self.padding = padding
48
- self.dilation = dilation
49
-
50
- if bias:
51
- raise ValueError('SincConv does not support bias.')
52
- if groups > 1:
53
- raise ValueError('SincConv does not support groups.')
54
-
55
-
56
- # initialize filterbanks using Mel scale
57
- NFFT = 512
58
- f=int(self.sample_rate/2)*np.linspace(0,1,int(NFFT/2)+1)
59
- fmel=self.to_mel(f) # Hz to mel conversion
60
- fmelmax=np.max(fmel)
61
- fmelmin=np.min(fmel)
62
- filbandwidthsmel=np.linspace(fmelmin,fmelmax,self.out_channels+1)
63
- filbandwidthsf=self.to_hz(filbandwidthsmel) # Mel to Hz conversion
64
- self.mel=filbandwidthsf
65
- self.hsupp=torch.arange(-(self.kernel_size-1)/2, (self.kernel_size-1)/2+1)
66
- self.band_pass=torch.zeros(self.out_channels,self.kernel_size)
67
-
68
-
69
-
70
- def forward(self,x):
71
- for i in range(len(self.mel)-1):
72
- fmin=self.mel[i]
73
- fmax=self.mel[i+1]
74
- hHigh=(2*fmax/self.sample_rate)*np.sinc(2*fmax*self.hsupp/self.sample_rate)
75
- hLow=(2*fmin/self.sample_rate)*np.sinc(2*fmin*self.hsupp/self.sample_rate)
76
- hideal=hHigh-hLow
77
-
78
- self.band_pass[i,:]=Tensor(np.hamming(self.kernel_size))*Tensor(hideal)
79
-
80
- band_pass_filter=self.band_pass.to(self.device)
81
-
82
- self.filters = (band_pass_filter).view(self.out_channels, 1, self.kernel_size)
83
-
84
- return F.conv1d(x, self.filters, stride=self.stride,
85
- padding=self.padding, dilation=self.dilation,
86
- bias=None, groups=1)
87
-
88
-
89
-
90
- class Residual_block(nn.Module):
91
- def __init__(self, nb_filts, first = False):
92
- super(Residual_block, self).__init__()
93
- self.first = first
94
-
95
- if not self.first:
96
- self.bn1 = nn.BatchNorm1d(num_features = nb_filts[0])
97
-
98
- self.lrelu = nn.LeakyReLU(negative_slope=0.3)
99
-
100
- self.conv1 = nn.Conv1d(in_channels = nb_filts[0],
101
- out_channels = nb_filts[1],
102
- kernel_size = 3,
103
- padding = 1,
104
- stride = 1)
105
-
106
- self.bn2 = nn.BatchNorm1d(num_features = nb_filts[1])
107
- self.conv2 = nn.Conv1d(in_channels = nb_filts[1],
108
- out_channels = nb_filts[1],
109
- padding = 1,
110
- kernel_size = 3,
111
- stride = 1)
112
-
113
- if nb_filts[0] != nb_filts[1]:
114
- self.downsample = True
115
- self.conv_downsample = nn.Conv1d(in_channels = nb_filts[0],
116
- out_channels = nb_filts[1],
117
- padding = 0,
118
- kernel_size = 1,
119
- stride = 1)
120
-
121
- else:
122
- self.downsample = False
123
- self.mp = nn.MaxPool1d(3)
124
-
125
- def forward(self, x):
126
- identity = x
127
- if not self.first:
128
- out = self.bn1(x)
129
- out = self.lrelu(out)
130
- else:
131
- out = x
132
-
133
- out = self.conv1(x)
134
- out = self.bn2(out)
135
- out = self.lrelu(out)
136
- out = self.conv2(out)
137
-
138
- if self.downsample:
139
- identity = self.conv_downsample(identity)
140
-
141
- out += identity
142
- out = self.mp(out)
143
- return out
144
-
145
-
146
-
147
-
148
-
149
- class RawNet(nn.Module):
150
- def __init__(self, d_args, device):
151
- super(RawNet, self).__init__()
152
-
153
-
154
- self.device=device
155
-
156
- self.Sinc_conv=SincConv(device=self.device,
157
- out_channels = d_args['filts'][0],
158
- kernel_size = d_args['first_conv'],
159
- in_channels = d_args['in_channels']
160
- )
161
-
162
- self.first_bn = nn.BatchNorm1d(num_features = d_args['filts'][0])
163
- self.selu = nn.SELU(inplace=True)
164
- self.block0 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1], first = True))
165
- self.block1 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][1]))
166
- self.block2 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
167
- d_args['filts'][2][0] = d_args['filts'][2][1]
168
- self.block3 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
169
- self.block4 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
170
- self.block5 = nn.Sequential(Residual_block(nb_filts = d_args['filts'][2]))
171
- self.avgpool = nn.AdaptiveAvgPool1d(1)
172
-
173
- self.fc_attention0 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
174
- l_out_features = d_args['filts'][1][-1])
175
- self.fc_attention1 = self._make_attention_fc(in_features = d_args['filts'][1][-1],
176
- l_out_features = d_args['filts'][1][-1])
177
- self.fc_attention2 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
178
- l_out_features = d_args['filts'][2][-1])
179
- self.fc_attention3 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
180
- l_out_features = d_args['filts'][2][-1])
181
- self.fc_attention4 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
182
- l_out_features = d_args['filts'][2][-1])
183
- self.fc_attention5 = self._make_attention_fc(in_features = d_args['filts'][2][-1],
184
- l_out_features = d_args['filts'][2][-1])
185
-
186
- self.bn_before_gru = nn.BatchNorm1d(num_features = d_args['filts'][2][-1])
187
- self.gru = nn.GRU(input_size = d_args['filts'][2][-1],
188
- hidden_size = d_args['gru_node'],
189
- num_layers = d_args['nb_gru_layer'],
190
- batch_first = True)
191
-
192
-
193
- self.fc1_gru = nn.Linear(in_features = d_args['gru_node'],
194
- out_features = d_args['nb_fc_node'])
195
-
196
- self.fc2_gru = nn.Linear(in_features = d_args['nb_fc_node'],
197
- out_features = d_args['nb_classes'],bias=True)
198
-
199
-
200
- self.sig = nn.Sigmoid()
201
- self.logsoftmax = nn.LogSoftmax(dim=1)
202
-
203
- def forward(self, x, y = None):
204
-
205
-
206
- nb_samp = x.shape[0]
207
- len_seq = x.shape[1]
208
- x=x.view(nb_samp,1,len_seq)
209
-
210
- x = self.Sinc_conv(x)
211
- x = F.max_pool1d(torch.abs(x), 3)
212
- x = self.first_bn(x)
213
- x = self.selu(x)
214
-
215
- x0 = self.block0(x)
216
- y0 = self.avgpool(x0).view(x0.size(0), -1) # torch.Size([batch, filter])
217
- y0 = self.fc_attention0(y0)
218
- y0 = self.sig(y0).view(y0.size(0), y0.size(1), -1) # torch.Size([batch, filter, 1])
219
- x = x0 * y0 + y0 # (batch, filter, time) x (batch, filter, 1)
220
-
221
-
222
- x1 = self.block1(x)
223
- y1 = self.avgpool(x1).view(x1.size(0), -1) # torch.Size([batch, filter])
224
- y1 = self.fc_attention1(y1)
225
- y1 = self.sig(y1).view(y1.size(0), y1.size(1), -1) # torch.Size([batch, filter, 1])
226
- x = x1 * y1 + y1 # (batch, filter, time) x (batch, filter, 1)
227
-
228
- x2 = self.block2(x)
229
- y2 = self.avgpool(x2).view(x2.size(0), -1) # torch.Size([batch, filter])
230
- y2 = self.fc_attention2(y2)
231
- y2 = self.sig(y2).view(y2.size(0), y2.size(1), -1) # torch.Size([batch, filter, 1])
232
- x = x2 * y2 + y2 # (batch, filter, time) x (batch, filter, 1)
233
-
234
- x3 = self.block3(x)
235
- y3 = self.avgpool(x3).view(x3.size(0), -1) # torch.Size([batch, filter])
236
- y3 = self.fc_attention3(y3)
237
- y3 = self.sig(y3).view(y3.size(0), y3.size(1), -1) # torch.Size([batch, filter, 1])
238
- x = x3 * y3 + y3 # (batch, filter, time) x (batch, filter, 1)
239
-
240
- x4 = self.block4(x)
241
- y4 = self.avgpool(x4).view(x4.size(0), -1) # torch.Size([batch, filter])
242
- y4 = self.fc_attention4(y4)
243
- y4 = self.sig(y4).view(y4.size(0), y4.size(1), -1) # torch.Size([batch, filter, 1])
244
- x = x4 * y4 + y4 # (batch, filter, time) x (batch, filter, 1)
245
-
246
- x5 = self.block5(x)
247
- y5 = self.avgpool(x5).view(x5.size(0), -1) # torch.Size([batch, filter])
248
- y5 = self.fc_attention5(y5)
249
- y5 = self.sig(y5).view(y5.size(0), y5.size(1), -1) # torch.Size([batch, filter, 1])
250
- x = x5 * y5 + y5 # (batch, filter, time) x (batch, filter, 1)
251
-
252
- x = self.bn_before_gru(x)
253
- x = self.selu(x)
254
- x = x.permute(0, 2, 1) #(batch, filt, time) >> (batch, time, filt)
255
- self.gru.flatten_parameters()
256
- x, _ = self.gru(x)
257
- x = x[:,-1,:]
258
- x = self.fc1_gru(x)
259
- x = self.fc2_gru(x)
260
- output=self.logsoftmax(x)
261
-
262
- return output
263
-
264
-
265
-
266
- def _make_attention_fc(self, in_features, l_out_features):
267
-
268
- l_fc = []
269
-
270
- l_fc.append(nn.Linear(in_features = in_features,
271
- out_features = l_out_features))
272
-
273
-
274
-
275
- return nn.Sequential(*l_fc)
276
-
277
-
278
- def _make_layer(self, nb_blocks, nb_filts, first = False):
279
- layers = []
280
- #def __init__(self, nb_filts, first = False):
281
- for i in range(nb_blocks):
282
- first = first if i == 0 else False
283
- layers.append(Residual_block(nb_filts = nb_filts,
284
- first = first))
285
- if i == 0: nb_filts[0] = nb_filts[1]
286
-
287
- return nn.Sequential(*layers)
288
-
289
- def summary(self, input_size, batch_size=-1, device="cuda", print_fn = None):
290
- if print_fn == None: printfn = print
291
- model = self
292
-
293
- def register_hook(module):
294
- def hook(module, input, output):
295
- class_name = str(module.__class__).split(".")[-1].split("'")[0]
296
- module_idx = len(summary)
297
-
298
- m_key = "%s-%i" % (class_name, module_idx + 1)
299
- summary[m_key] = OrderedDict()
300
- summary[m_key]["input_shape"] = list(input[0].size())
301
- summary[m_key]["input_shape"][0] = batch_size
302
- if isinstance(output, (list, tuple)):
303
- summary[m_key]["output_shape"] = [
304
- [-1] + list(o.size())[1:] for o in output
305
- ]
306
- else:
307
- summary[m_key]["output_shape"] = list(output.size())
308
- if len(summary[m_key]["output_shape"]) != 0:
309
- summary[m_key]["output_shape"][0] = batch_size
310
-
311
- params = 0
312
- if hasattr(module, "weight") and hasattr(module.weight, "size"):
313
- params += torch.prod(torch.LongTensor(list(module.weight.size())))
314
- summary[m_key]["trainable"] = module.weight.requires_grad
315
- if hasattr(module, "bias") and hasattr(module.bias, "size"):
316
- params += torch.prod(torch.LongTensor(list(module.bias.size())))
317
- summary[m_key]["nb_params"] = params
318
-
319
- if (
320
- not isinstance(module, nn.Sequential)
321
- and not isinstance(module, nn.ModuleList)
322
- and not (module == model)
323
- ):
324
- hooks.append(module.register_forward_hook(hook))
325
-
326
- device = device.lower()
327
- assert device in [
328
- "cuda",
329
- "cpu",
330
- ], "Input device is not valid, please specify 'cuda' or 'cpu'"
331
-
332
- if device == "cuda" and torch.cuda.is_available():
333
- dtype = torch.cuda.FloatTensor
334
- else:
335
- dtype = torch.FloatTensor
336
- if isinstance(input_size, tuple):
337
- input_size = [input_size]
338
- x = [torch.rand(2, *in_size).type(dtype) for in_size in input_size]
339
- summary = OrderedDict()
340
- hooks = []
341
- model.apply(register_hook)
342
- model(*x)
343
- for h in hooks:
344
- h.remove()
345
-
346
- print_fn("----------------------------------------------------------------")
347
- line_new = "{:>20} {:>25} {:>15}".format("Layer (type)", "Output Shape", "Param #")
348
- print_fn(line_new)
349
- print_fn("================================================================")
350
- total_params = 0
351
- total_output = 0
352
- trainable_params = 0
353
- for layer in summary:
354
- # input_shape, output_shape, trainable, nb_params
355
- line_new = "{:>20} {:>25} {:>15}".format(
356
- layer,
357
- str(summary[layer]["output_shape"]),
358
- "{0:,}".format(summary[layer]["nb_params"]),
359
- )
360
- total_params += summary[layer]["nb_params"]
361
- total_output += np.prod(summary[layer]["output_shape"])
362
- if "trainable" in summary[layer]:
363
- if summary[layer]["trainable"] == True:
364
- trainable_params += summary[layer]["nb_params"]
365
- print_fn(line_new)
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+
7
+ class SincConv(nn.Module):
8
+
9
+ @staticmethod
10
+ def to_mel(hz):
11
+ return 2595 * np.log10(1 + hz / 700)
12
+
13
+ @staticmethod
14
+ def to_hz(mel):
15
+ return 700 * (10 ** (mel / 2595) - 1)
16
+
17
+ def __init__(
18
+ self,
19
+ device,
20
+ out_channels,
21
+ kernel_size,
22
+ in_channels=1,
23
+ sample_rate=16000,
24
+ stride=1,
25
+ padding=0,
26
+ dilation=1
27
+ ):
28
+
29
+ super().__init__()
30
+
31
+ if in_channels != 1:
32
+ raise ValueError("SincConv only supports one input channel")
33
+
34
+ if kernel_size % 2 == 0:
35
+ kernel_size += 1
36
+
37
+ self.out_channels = out_channels
38
+ self.kernel_size = kernel_size
39
+ self.sample_rate = sample_rate
40
+ self.device = device
41
+
42
+ self.stride = stride
43
+ self.padding = padding
44
+ self.dilation = dilation
45
+
46
+ NFFT = 512
47
+ f = int(sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
48
+
49
+ fmel = self.to_mel(f)
50
+ filbandwidthsmel = np.linspace(min(fmel), max(fmel), out_channels + 1)
51
+ filbandwidthsf = self.to_hz(filbandwidthsmel)
52
+
53
+ self.mel = filbandwidthsf
54
+
55
+ self.hsupp = torch.arange(
56
+ -(kernel_size - 1) / 2,
57
+ (kernel_size - 1) / 2 + 1
58
+ )
59
+
60
+ self.band_pass = torch.zeros(out_channels, kernel_size)
61
+
62
+ def forward(self, x):
63
+
64
+ for i in range(len(self.mel) - 1):
65
+
66
+ fmin = self.mel[i]
67
+ fmax = self.mel[i + 1]
68
+
69
+ h_high = (2 * fmax / self.sample_rate) * np.sinc(
70
+ 2 * fmax * self.hsupp / self.sample_rate
71
+ )
72
+
73
+ h_low = (2 * fmin / self.sample_rate) * np.sinc(
74
+ 2 * fmin * self.hsupp / self.sample_rate
75
+ )
76
+
77
+ hideal = h_high - h_low
78
+
79
+ window = torch.tensor(np.hamming(self.kernel_size))
80
+ self.band_pass[i, :] = window * torch.tensor(hideal)
81
+
82
+ filters = self.band_pass.to(self.device).view(
83
+ self.out_channels, 1, self.kernel_size
84
+ )
85
+
86
+ return F.conv1d(
87
+ x,
88
+ filters,
89
+ stride=self.stride,
90
+ padding=self.padding,
91
+ dilation=self.dilation
92
+ )
93
+
94
+
95
+ class Residual_block(nn.Module):
96
+
97
+ def __init__(self, nb_filts, first=False):
98
+
99
+ super().__init__()
100
+
101
+ self.first = first
102
+
103
+ if not self.first:
104
+ self.bn1 = nn.BatchNorm1d(nb_filts[0])
105
+
106
+ self.lrelu = nn.LeakyReLU(0.3)
107
+
108
+ self.conv1 = nn.Conv1d(
109
+ nb_filts[0],
110
+ nb_filts[1],
111
+ kernel_size=3,
112
+ padding=1
113
+ )
114
+
115
+ self.bn2 = nn.BatchNorm1d(nb_filts[1])
116
+
117
+ self.conv2 = nn.Conv1d(
118
+ nb_filts[1],
119
+ nb_filts[1],
120
+ kernel_size=3,
121
+ padding=1
122
+ )
123
+
124
+ if nb_filts[0] != nb_filts[1]:
125
+
126
+ self.downsample = True
127
+
128
+ self.conv_downsample = nn.Conv1d(
129
+ nb_filts[0],
130
+ nb_filts[1],
131
+ kernel_size=1
132
+ )
133
+
134
+ else:
135
+ self.downsample = False
136
+
137
+ self.pool = nn.MaxPool1d(3)
138
+
139
+ def forward(self, x):
140
+
141
+ identity = x
142
+
143
+ if not self.first:
144
+ out = self.bn1(x)
145
+ out = self.lrelu(out)
146
+ else:
147
+ out = x
148
+
149
+ out = self.conv1(out)
150
+ out = self.bn2(out)
151
+ out = self.lrelu(out)
152
+ out = self.conv2(out)
153
+
154
+ if self.downsample:
155
+ identity = self.conv_downsample(identity)
156
+
157
+ out = out + identity
158
+ out = self.pool(out)
159
+
160
+ return out
161
+
162
+
163
+ class RawNet(nn.Module):
164
+
165
+ def __init__(self, d_args, device):
166
+
167
+ super().__init__()
168
+
169
+ self.device = device
170
+
171
+ self.sinc = SincConv(
172
+ device=device,
173
+ out_channels=d_args["filts"][0],
174
+ kernel_size=d_args["first_conv"],
175
+ in_channels=d_args["in_channels"]
176
+ )
177
+
178
+ self.first_bn = nn.BatchNorm1d(d_args["filts"][0])
179
+ self.selu = nn.SELU()
180
+
181
+ self.block0 = Residual_block(d_args["filts"][1], first=True)
182
+ self.block1 = Residual_block(d_args["filts"][1])
183
+ self.block2 = Residual_block(d_args["filts"][2])
184
+ self.block3 = Residual_block(d_args["filts"][3])
185
+
186
+ self.bn_gru = nn.BatchNorm1d(d_args["filts"][3][-1])
187
+
188
+ self.gru = nn.GRU(
189
+ input_size=d_args["filts"][3][-1],
190
+ hidden_size=d_args["gru_node"],
191
+ num_layers=d_args["nb_gru_layer"],
192
+ batch_first=True
193
+ )
194
+
195
+ self.fc1 = nn.Linear(
196
+ d_args["gru_node"],
197
+ d_args["nb_fc_node"]
198
+ )
199
+
200
+ self.fc2 = nn.Linear(
201
+ d_args["nb_fc_node"],
202
+ d_args["nb_classes"]
203
+ )
204
+
205
+ self.logsoftmax = nn.LogSoftmax(dim=1)
206
+
207
+ def forward(self, x):
208
+
209
+ batch = x.shape[0]
210
+ length = x.shape[1]
211
+
212
+ x = x.view(batch, 1, length)
213
+
214
+ x = self.sinc(x)
215
+
216
+ x = F.max_pool1d(torch.abs(x), 3)
217
+
218
+ x = self.first_bn(x)
219
+ x = self.selu(x)
220
+
221
+ x = self.block0(x)
222
+ x = self.block1(x)
223
+ x = self.block2(x)
224
+ x = self.block3(x)
225
+
226
+ x = self.bn_gru(x)
227
+ x = self.selu(x)
228
+
229
+ x = x.permute(0, 2, 1)
230
+
231
+ self.gru.flatten_parameters()
232
+
233
+ x, _ = self.gru(x)
234
+
235
+ x = x[:, -1, :]
236
+
237
+ x = self.fc1(x)
238
+ x = self.fc2(x)
239
+
240
+ return self.logsoftmax(x)