ak36 commited on
Commit
58920bc
·
verified ·
1 Parent(s): 5619629

Upload folder using huggingface_hub

Browse files
Utils/ASR/.ipynb_checkpoints/__init__-checkpoint.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/ASR/.ipynb_checkpoints/layers-checkpoint.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
Utils/ASR/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
Utils/ASR/config.yml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ log_dir: "logs/20201006"
2
+ save_freq: 5
3
+ device: "cuda"
4
+ epochs: 180
5
+ batch_size: 64
6
+ pretrained_model: ""
7
+ train_data: "ASRDataset/train_list.txt"
8
+ val_data: "ASRDataset/val_list.txt"
9
+
10
+ dataset_params:
11
+ data_augmentation: false
12
+
13
+ preprocess_parasm:
14
+ sr: 24000
15
+ spect_params:
16
+ n_fft: 2048
17
+ win_length: 1200
18
+ hop_length: 300
19
+ mel_params:
20
+ n_mels: 80
21
+
22
+ model_params:
23
+ input_dim: 80
24
+ hidden_dim: 256
25
+ n_token: 178
26
+ token_embedding_dim: 512
27
+
28
+ optimizer_params:
29
+ lr: 0.0005
Utils/ASR/epoch_00080.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fedd55a1234b0c56e1e8b509c74edf3a5e2f27106a66038a4a946047a775bd6c
3
+ size 94552811
Utils/ASR/layers.py ADDED
@@ -0,0 +1,354 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from typing import Optional, Any
5
+ from torch import Tensor
6
+ import torch.nn.functional as F
7
+ import torchaudio
8
+ import torchaudio.functional as audio_F
9
+
10
+ import random
11
+ random.seed(0)
12
+
13
+
14
+ def _get_activation_fn(activ):
15
+ if activ == 'relu':
16
+ return nn.ReLU()
17
+ elif activ == 'lrelu':
18
+ return nn.LeakyReLU(0.2)
19
+ elif activ == 'swish':
20
+ return lambda x: x*torch.sigmoid(x)
21
+ else:
22
+ raise RuntimeError('Unexpected activ type %s, expected [relu, lrelu, swish]' % activ)
23
+
24
+ class LinearNorm(torch.nn.Module):
25
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
26
+ super(LinearNorm, self).__init__()
27
+ self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
28
+
29
+ torch.nn.init.xavier_uniform_(
30
+ self.linear_layer.weight,
31
+ gain=torch.nn.init.calculate_gain(w_init_gain))
32
+
33
+ def forward(self, x):
34
+ return self.linear_layer(x)
35
+
36
+
37
+ class ConvNorm(torch.nn.Module):
38
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
39
+ padding=None, dilation=1, bias=True, w_init_gain='linear', param=None):
40
+ super(ConvNorm, self).__init__()
41
+ if padding is None:
42
+ assert(kernel_size % 2 == 1)
43
+ padding = int(dilation * (kernel_size - 1) / 2)
44
+
45
+ self.conv = torch.nn.Conv1d(in_channels, out_channels,
46
+ kernel_size=kernel_size, stride=stride,
47
+ padding=padding, dilation=dilation,
48
+ bias=bias)
49
+
50
+ torch.nn.init.xavier_uniform_(
51
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
52
+
53
+ def forward(self, signal):
54
+ conv_signal = self.conv(signal)
55
+ return conv_signal
56
+
57
+ class CausualConv(nn.Module):
58
+ def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=1, dilation=1, bias=True, w_init_gain='linear', param=None):
59
+ super(CausualConv, self).__init__()
60
+ if padding is None:
61
+ assert(kernel_size % 2 == 1)
62
+ padding = int(dilation * (kernel_size - 1) / 2) * 2
63
+ else:
64
+ self.padding = padding * 2
65
+ self.conv = nn.Conv1d(in_channels, out_channels,
66
+ kernel_size=kernel_size, stride=stride,
67
+ padding=self.padding,
68
+ dilation=dilation,
69
+ bias=bias)
70
+
71
+ torch.nn.init.xavier_uniform_(
72
+ self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain, param=param))
73
+
74
+ def forward(self, x):
75
+ x = self.conv(x)
76
+ x = x[:, :, :-self.padding]
77
+ return x
78
+
79
+ class CausualBlock(nn.Module):
80
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='lrelu'):
81
+ super(CausualBlock, self).__init__()
82
+ self.blocks = nn.ModuleList([
83
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
84
+ for i in range(n_conv)])
85
+
86
+ def forward(self, x):
87
+ for block in self.blocks:
88
+ res = x
89
+ x = block(x)
90
+ x += res
91
+ return x
92
+
93
+ def _get_conv(self, hidden_dim, dilation, activ='lrelu', dropout_p=0.2):
94
+ layers = [
95
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
96
+ _get_activation_fn(activ),
97
+ nn.BatchNorm1d(hidden_dim),
98
+ nn.Dropout(p=dropout_p),
99
+ CausualConv(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
100
+ _get_activation_fn(activ),
101
+ nn.Dropout(p=dropout_p)
102
+ ]
103
+ return nn.Sequential(*layers)
104
+
105
+ class ConvBlock(nn.Module):
106
+ def __init__(self, hidden_dim, n_conv=3, dropout_p=0.2, activ='relu'):
107
+ super().__init__()
108
+ self._n_groups = 8
109
+ self.blocks = nn.ModuleList([
110
+ self._get_conv(hidden_dim, dilation=3**i, activ=activ, dropout_p=dropout_p)
111
+ for i in range(n_conv)])
112
+
113
+
114
+ def forward(self, x):
115
+ for block in self.blocks:
116
+ res = x
117
+ x = block(x)
118
+ x += res
119
+ return x
120
+
121
+ def _get_conv(self, hidden_dim, dilation, activ='relu', dropout_p=0.2):
122
+ layers = [
123
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=dilation, dilation=dilation),
124
+ _get_activation_fn(activ),
125
+ nn.GroupNorm(num_groups=self._n_groups, num_channels=hidden_dim),
126
+ nn.Dropout(p=dropout_p),
127
+ ConvNorm(hidden_dim, hidden_dim, kernel_size=3, padding=1, dilation=1),
128
+ _get_activation_fn(activ),
129
+ nn.Dropout(p=dropout_p)
130
+ ]
131
+ return nn.Sequential(*layers)
132
+
133
+ class LocationLayer(nn.Module):
134
+ def __init__(self, attention_n_filters, attention_kernel_size,
135
+ attention_dim):
136
+ super(LocationLayer, self).__init__()
137
+ padding = int((attention_kernel_size - 1) / 2)
138
+ self.location_conv = ConvNorm(2, attention_n_filters,
139
+ kernel_size=attention_kernel_size,
140
+ padding=padding, bias=False, stride=1,
141
+ dilation=1)
142
+ self.location_dense = LinearNorm(attention_n_filters, attention_dim,
143
+ bias=False, w_init_gain='tanh')
144
+
145
+ def forward(self, attention_weights_cat):
146
+ processed_attention = self.location_conv(attention_weights_cat)
147
+ processed_attention = processed_attention.transpose(1, 2)
148
+ processed_attention = self.location_dense(processed_attention)
149
+ return processed_attention
150
+
151
+
152
+ class Attention(nn.Module):
153
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
154
+ attention_location_n_filters, attention_location_kernel_size):
155
+ super(Attention, self).__init__()
156
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
157
+ bias=False, w_init_gain='tanh')
158
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
159
+ w_init_gain='tanh')
160
+ self.v = LinearNorm(attention_dim, 1, bias=False)
161
+ self.location_layer = LocationLayer(attention_location_n_filters,
162
+ attention_location_kernel_size,
163
+ attention_dim)
164
+ self.score_mask_value = -float("inf")
165
+
166
+ def get_alignment_energies(self, query, processed_memory,
167
+ attention_weights_cat):
168
+ """
169
+ PARAMS
170
+ ------
171
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
172
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
173
+ attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
174
+ RETURNS
175
+ -------
176
+ alignment (batch, max_time)
177
+ """
178
+
179
+ processed_query = self.query_layer(query.unsqueeze(1))
180
+ processed_attention_weights = self.location_layer(attention_weights_cat)
181
+ energies = self.v(torch.tanh(
182
+ processed_query + processed_attention_weights + processed_memory))
183
+
184
+ energies = energies.squeeze(-1)
185
+ return energies
186
+
187
+ def forward(self, attention_hidden_state, memory, processed_memory,
188
+ attention_weights_cat, mask):
189
+ """
190
+ PARAMS
191
+ ------
192
+ attention_hidden_state: attention rnn last output
193
+ memory: encoder outputs
194
+ processed_memory: processed encoder outputs
195
+ attention_weights_cat: previous and cummulative attention weights
196
+ mask: binary mask for padded data
197
+ """
198
+ alignment = self.get_alignment_energies(
199
+ attention_hidden_state, processed_memory, attention_weights_cat)
200
+
201
+ if mask is not None:
202
+ alignment.data.masked_fill_(mask, self.score_mask_value)
203
+
204
+ attention_weights = F.softmax(alignment, dim=1)
205
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
206
+ attention_context = attention_context.squeeze(1)
207
+
208
+ return attention_context, attention_weights
209
+
210
+
211
+ class ForwardAttentionV2(nn.Module):
212
+ def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
213
+ attention_location_n_filters, attention_location_kernel_size):
214
+ super(ForwardAttentionV2, self).__init__()
215
+ self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
216
+ bias=False, w_init_gain='tanh')
217
+ self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
218
+ w_init_gain='tanh')
219
+ self.v = LinearNorm(attention_dim, 1, bias=False)
220
+ self.location_layer = LocationLayer(attention_location_n_filters,
221
+ attention_location_kernel_size,
222
+ attention_dim)
223
+ self.score_mask_value = -float(1e20)
224
+
225
+ def get_alignment_energies(self, query, processed_memory,
226
+ attention_weights_cat):
227
+ """
228
+ PARAMS
229
+ ------
230
+ query: decoder output (batch, n_mel_channels * n_frames_per_step)
231
+ processed_memory: processed encoder outputs (B, T_in, attention_dim)
232
+ attention_weights_cat: prev. and cumulative att weights (B, 2, max_time)
233
+ RETURNS
234
+ -------
235
+ alignment (batch, max_time)
236
+ """
237
+
238
+ processed_query = self.query_layer(query.unsqueeze(1))
239
+ processed_attention_weights = self.location_layer(attention_weights_cat)
240
+ energies = self.v(torch.tanh(
241
+ processed_query + processed_attention_weights + processed_memory))
242
+
243
+ energies = energies.squeeze(-1)
244
+ return energies
245
+
246
+ def forward(self, attention_hidden_state, memory, processed_memory,
247
+ attention_weights_cat, mask, log_alpha):
248
+ """
249
+ PARAMS
250
+ ------
251
+ attention_hidden_state: attention rnn last output
252
+ memory: encoder outputs
253
+ processed_memory: processed encoder outputs
254
+ attention_weights_cat: previous and cummulative attention weights
255
+ mask: binary mask for padded data
256
+ """
257
+ log_energy = self.get_alignment_energies(
258
+ attention_hidden_state, processed_memory, attention_weights_cat)
259
+
260
+ #log_energy =
261
+
262
+ if mask is not None:
263
+ log_energy.data.masked_fill_(mask, self.score_mask_value)
264
+
265
+ #attention_weights = F.softmax(alignment, dim=1)
266
+
267
+ #content_score = log_energy.unsqueeze(1) #[B, MAX_TIME] -> [B, 1, MAX_TIME]
268
+ #log_alpha = log_alpha.unsqueeze(2) #[B, MAX_TIME] -> [B, MAX_TIME, 1]
269
+
270
+ #log_total_score = log_alpha + content_score
271
+
272
+ #previous_attention_weights = attention_weights_cat[:,0,:]
273
+
274
+ log_alpha_shift_padded = []
275
+ max_time = log_energy.size(1)
276
+ for sft in range(2):
277
+ shifted = log_alpha[:,:max_time-sft]
278
+ shift_padded = F.pad(shifted, (sft,0), 'constant', self.score_mask_value)
279
+ log_alpha_shift_padded.append(shift_padded.unsqueeze(2))
280
+
281
+ biased = torch.logsumexp(torch.cat(log_alpha_shift_padded,2), 2)
282
+
283
+ log_alpha_new = biased + log_energy
284
+
285
+ attention_weights = F.softmax(log_alpha_new, dim=1)
286
+
287
+ attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
288
+ attention_context = attention_context.squeeze(1)
289
+
290
+ return attention_context, attention_weights, log_alpha_new
291
+
292
+
293
+ class PhaseShuffle2d(nn.Module):
294
+ def __init__(self, n=2):
295
+ super(PhaseShuffle2d, self).__init__()
296
+ self.n = n
297
+ self.random = random.Random(1)
298
+
299
+ def forward(self, x, move=None):
300
+ # x.size = (B, C, M, L)
301
+ if move is None:
302
+ move = self.random.randint(-self.n, self.n)
303
+
304
+ if move == 0:
305
+ return x
306
+ else:
307
+ left = x[:, :, :, :move]
308
+ right = x[:, :, :, move:]
309
+ shuffled = torch.cat([right, left], dim=3)
310
+ return shuffled
311
+
312
+ class PhaseShuffle1d(nn.Module):
313
+ def __init__(self, n=2):
314
+ super(PhaseShuffle1d, self).__init__()
315
+ self.n = n
316
+ self.random = random.Random(1)
317
+
318
+ def forward(self, x, move=None):
319
+ # x.size = (B, C, M, L)
320
+ if move is None:
321
+ move = self.random.randint(-self.n, self.n)
322
+
323
+ if move == 0:
324
+ return x
325
+ else:
326
+ left = x[:, :, :move]
327
+ right = x[:, :, move:]
328
+ shuffled = torch.cat([right, left], dim=2)
329
+
330
+ return shuffled
331
+
332
+ class MFCC(nn.Module):
333
+ def __init__(self, n_mfcc=40, n_mels=80):
334
+ super(MFCC, self).__init__()
335
+ self.n_mfcc = n_mfcc
336
+ self.n_mels = n_mels
337
+ self.norm = 'ortho'
338
+ dct_mat = audio_F.create_dct(self.n_mfcc, self.n_mels, self.norm)
339
+ self.register_buffer('dct_mat', dct_mat)
340
+
341
+ def forward(self, mel_specgram):
342
+ if len(mel_specgram.shape) == 2:
343
+ mel_specgram = mel_specgram.unsqueeze(0)
344
+ unsqueezed = True
345
+ else:
346
+ unsqueezed = False
347
+ # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
348
+ # -> (channel, time, n_mfcc).tranpose(...)
349
+ mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
350
+
351
+ # unpack batch
352
+ if unsqueezed:
353
+ mfcc = mfcc.squeeze(0)
354
+ return mfcc
Utils/ASR/models.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import TransformerEncoder
5
+ import torch.nn.functional as F
6
+ from .layers import MFCC, Attention, LinearNorm, ConvNorm, ConvBlock
7
+
8
+ class ASRCNN(nn.Module):
9
+ def __init__(self,
10
+ input_dim=80,
11
+ hidden_dim=256,
12
+ n_token=35,
13
+ n_layers=6,
14
+ token_embedding_dim=256,
15
+
16
+ ):
17
+ super().__init__()
18
+ self.n_token = n_token
19
+ self.n_down = 1
20
+ self.to_mfcc = MFCC()
21
+ self.init_cnn = ConvNorm(input_dim//2, hidden_dim, kernel_size=7, padding=3, stride=2)
22
+ self.cnns = nn.Sequential(
23
+ *[nn.Sequential(
24
+ ConvBlock(hidden_dim),
25
+ nn.GroupNorm(num_groups=1, num_channels=hidden_dim)
26
+ ) for n in range(n_layers)])
27
+ self.projection = ConvNorm(hidden_dim, hidden_dim // 2)
28
+ self.ctc_linear = nn.Sequential(
29
+ LinearNorm(hidden_dim//2, hidden_dim),
30
+ nn.ReLU(),
31
+ LinearNorm(hidden_dim, n_token))
32
+ self.asr_s2s = ASRS2S(
33
+ embedding_dim=token_embedding_dim,
34
+ hidden_dim=hidden_dim//2,
35
+ n_token=n_token)
36
+
37
+ def forward(self, x, src_key_padding_mask=None, text_input=None):
38
+ x = self.to_mfcc(x)
39
+ x = self.init_cnn(x)
40
+ x = self.cnns(x)
41
+ x = self.projection(x)
42
+ x = x.transpose(1, 2)
43
+ ctc_logit = self.ctc_linear(x)
44
+ if text_input is not None:
45
+ _, s2s_logit, s2s_attn = self.asr_s2s(x, src_key_padding_mask, text_input)
46
+ return ctc_logit, s2s_logit, s2s_attn
47
+ else:
48
+ return ctc_logit
49
+
50
+ def get_feature(self, x):
51
+ x = self.to_mfcc(x.squeeze(1))
52
+ x = self.init_cnn(x)
53
+ x = self.cnns(x)
54
+ x = self.projection(x)
55
+ return x
56
+
57
+ def length_to_mask(self, lengths):
58
+ mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
59
+ mask = torch.gt(mask+1, lengths.unsqueeze(1)).to(lengths.device)
60
+ return mask
61
+
62
+ def get_future_mask(self, out_length, unmask_future_steps=0):
63
+ """
64
+ Args:
65
+ out_length (int): returned mask shape is (out_length, out_length).
66
+ unmask_futre_steps (int): unmasking future step size.
67
+ Return:
68
+ mask (torch.BoolTensor): mask future timesteps mask[i, j] = True if i > j + unmask_future_steps else False
69
+ """
70
+ index_tensor = torch.arange(out_length).unsqueeze(0).expand(out_length, -1)
71
+ mask = torch.gt(index_tensor, index_tensor.T + unmask_future_steps)
72
+ return mask
73
+
74
+ class ASRS2S(nn.Module):
75
+ def __init__(self,
76
+ embedding_dim=256,
77
+ hidden_dim=512,
78
+ n_location_filters=32,
79
+ location_kernel_size=63,
80
+ n_token=40):
81
+ super(ASRS2S, self).__init__()
82
+ self.embedding = nn.Embedding(n_token, embedding_dim)
83
+ val_range = math.sqrt(6 / hidden_dim)
84
+ self.embedding.weight.data.uniform_(-val_range, val_range)
85
+
86
+ self.decoder_rnn_dim = hidden_dim
87
+ self.project_to_n_symbols = nn.Linear(self.decoder_rnn_dim, n_token)
88
+ self.attention_layer = Attention(
89
+ self.decoder_rnn_dim,
90
+ hidden_dim,
91
+ hidden_dim,
92
+ n_location_filters,
93
+ location_kernel_size
94
+ )
95
+ self.decoder_rnn = nn.LSTMCell(self.decoder_rnn_dim + embedding_dim, self.decoder_rnn_dim)
96
+ self.project_to_hidden = nn.Sequential(
97
+ LinearNorm(self.decoder_rnn_dim * 2, hidden_dim),
98
+ nn.Tanh())
99
+ self.sos = 1
100
+ self.eos = 2
101
+
102
+ def initialize_decoder_states(self, memory, mask):
103
+ """
104
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
105
+ """
106
+ B, L, H = memory.shape
107
+ self.decoder_hidden = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
108
+ self.decoder_cell = torch.zeros((B, self.decoder_rnn_dim)).type_as(memory)
109
+ self.attention_weights = torch.zeros((B, L)).type_as(memory)
110
+ self.attention_weights_cum = torch.zeros((B, L)).type_as(memory)
111
+ self.attention_context = torch.zeros((B, H)).type_as(memory)
112
+ self.memory = memory
113
+ self.processed_memory = self.attention_layer.memory_layer(memory)
114
+ self.mask = mask
115
+ self.unk_index = 3
116
+ self.random_mask = 0.1
117
+
118
+ def forward(self, memory, memory_mask, text_input):
119
+ """
120
+ moemory.shape = (B, L, H) = (Batchsize, Maxtimestep, Hiddendim)
121
+ moemory_mask.shape = (B, L, )
122
+ texts_input.shape = (B, T)
123
+ """
124
+ self.initialize_decoder_states(memory, memory_mask)
125
+ # text random mask
126
+ random_mask = (torch.rand(text_input.shape) < self.random_mask).to(text_input.device)
127
+ _text_input = text_input.clone()
128
+ _text_input.masked_fill_(random_mask, self.unk_index)
129
+ decoder_inputs = self.embedding(_text_input).transpose(0, 1) # -> [T, B, channel]
130
+ start_embedding = self.embedding(
131
+ torch.LongTensor([self.sos]*decoder_inputs.size(1)).to(decoder_inputs.device))
132
+ decoder_inputs = torch.cat((start_embedding.unsqueeze(0), decoder_inputs), dim=0)
133
+
134
+ hidden_outputs, logit_outputs, alignments = [], [], []
135
+ while len(hidden_outputs) < decoder_inputs.size(0):
136
+
137
+ decoder_input = decoder_inputs[len(hidden_outputs)]
138
+ hidden, logit, attention_weights = self.decode(decoder_input)
139
+ hidden_outputs += [hidden]
140
+ logit_outputs += [logit]
141
+ alignments += [attention_weights]
142
+
143
+ hidden_outputs, logit_outputs, alignments = \
144
+ self.parse_decoder_outputs(
145
+ hidden_outputs, logit_outputs, alignments)
146
+
147
+ return hidden_outputs, logit_outputs, alignments
148
+
149
+
150
+ def decode(self, decoder_input):
151
+
152
+ cell_input = torch.cat((decoder_input, self.attention_context), -1)
153
+ self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
154
+ cell_input,
155
+ (self.decoder_hidden, self.decoder_cell))
156
+
157
+ attention_weights_cat = torch.cat(
158
+ (self.attention_weights.unsqueeze(1),
159
+ self.attention_weights_cum.unsqueeze(1)),dim=1)
160
+
161
+ self.attention_context, self.attention_weights = self.attention_layer(
162
+ self.decoder_hidden,
163
+ self.memory,
164
+ self.processed_memory,
165
+ attention_weights_cat,
166
+ self.mask)
167
+
168
+ self.attention_weights_cum += self.attention_weights
169
+
170
+ hidden_and_context = torch.cat((self.decoder_hidden, self.attention_context), -1)
171
+ hidden = self.project_to_hidden(hidden_and_context)
172
+
173
+ # dropout to increasing g
174
+ logit = self.project_to_n_symbols(F.dropout(hidden, 0.5, self.training))
175
+
176
+ return hidden, logit, self.attention_weights
177
+
178
+ def parse_decoder_outputs(self, hidden, logit, alignments):
179
+
180
+ # -> [B, T_out + 1, max_time]
181
+ alignments = torch.stack(alignments).transpose(0,1)
182
+ # [T_out + 1, B, n_symbols] -> [B, T_out + 1, n_symbols]
183
+ logit = torch.stack(logit).transpose(0, 1).contiguous()
184
+ hidden = torch.stack(hidden).transpose(0, 1).contiguous()
185
+
186
+ return hidden, logit, alignments