Aloukik21 commited on
Commit
b710a14
·
verified ·
1 Parent(s): b896626

Upload audio/DF_Arena_1B_V_1/conformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. audio/DF_Arena_1B_V_1/conformer.py +284 -0
audio/DF_Arena_1B_V_1/conformer.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn, einsum
4
+ import torch.nn.functional as F
5
+ import torch
6
+ import torch.nn as nn
7
+ from torch.nn.modules.transformer import _get_clones
8
+ from torch import Tensor
9
+ from einops import rearrange
10
+ from einops.layers.torch import Rearrange
11
+
12
+ # helper functions
13
+
14
+ def exists(val):
15
+ return val is not None
16
+
17
+ def default(val, d):
18
+ return val if exists(val) else d
19
+
20
+ def calc_same_padding(kernel_size):
21
+ pad = kernel_size // 2
22
+ return (pad, pad - (kernel_size + 1) % 2)
23
+
24
+ # helper classes
25
+
26
+ class Swish(nn.Module):
27
+ def forward(self, x):
28
+ return x * x.sigmoid()
29
+
30
+ class GLU(nn.Module):
31
+ def __init__(self, dim):
32
+ super().__init__()
33
+ self.dim = dim
34
+
35
+ def forward(self, x):
36
+ out, gate = x.chunk(2, dim=self.dim)
37
+ return out * gate.sigmoid()
38
+
39
+ class DepthWiseConv1d(nn.Module):
40
+ def __init__(self, chan_in, chan_out, kernel_size, padding):
41
+ super().__init__()
42
+ self.padding = padding
43
+ self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
44
+
45
+ def forward(self, x):
46
+ x = F.pad(x, self.padding)
47
+ return self.conv(x)
48
+
49
+ # attention, feedforward, and conv module
50
+
51
+ class Scale(nn.Module):
52
+ def __init__(self, scale, fn):
53
+ super().__init__()
54
+ self.fn = fn
55
+ self.scale = scale
56
+
57
+ def forward(self, x, **kwargs):
58
+ return self.fn(x, **kwargs) * self.scale
59
+
60
+ class PreNorm(nn.Module):
61
+ def __init__(self, dim, fn):
62
+ super().__init__()
63
+ self.fn = fn
64
+ self.norm = nn.LayerNorm(dim)
65
+
66
+ def forward(self, x, **kwargs):
67
+ x = self.norm(x)
68
+ return self.fn(x, **kwargs)
69
+
70
+ class Attention(nn.Module):
71
+ # Head Token attention: https://arxiv.org/pdf/2210.05958.pdf
72
+ def __init__(self, dim, heads=8, dim_head=64, qkv_bias=False, dropout=0., proj_drop=0.):
73
+ super().__init__()
74
+ self.num_heads = heads
75
+ inner_dim = dim_head * heads
76
+ self.scale = dim_head ** -0.5
77
+
78
+ self.qkv = nn.Linear(dim, inner_dim * 3, bias=qkv_bias)
79
+
80
+ self.attn_drop = nn.Dropout(dropout)
81
+ self.proj = nn.Linear(inner_dim, dim)
82
+ self.proj_drop = nn.Dropout(proj_drop)
83
+
84
+ self.act = nn.GELU()
85
+ self.ht_proj = nn.Linear(dim_head, dim,bias=True)
86
+ self.ht_norm = nn.LayerNorm(dim_head)
87
+ self.pos_embed = nn.Parameter(torch.zeros(1, self.num_heads, dim))
88
+
89
+ def forward(self, x, mask=None):
90
+ B, N, C = x.shape
91
+
92
+ # head token
93
+ head_pos = self.pos_embed.expand(x.shape[0], -1, -1)
94
+ x_ = x.reshape(B, -1, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
95
+ x_ = x_.mean(dim=2) # now the shape is [B, h, 1, d//h]
96
+ x_ = self.ht_proj(x_).reshape(B, -1, self.num_heads, C // self.num_heads)
97
+ x_ = self.act(self.ht_norm(x_)).flatten(2)
98
+ x_ = x_ + head_pos
99
+ x = torch.cat([x, x_], dim=1)
100
+
101
+ # normal mhsa
102
+ qkv = self.qkv(x).reshape(B, N+self.num_heads, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
103
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
104
+
105
+ attn = (q @ k.transpose(-2, -1)) * self.scale
106
+ attn = attn.softmax(dim=-1)
107
+ # attn = self.attn_drop(attn)
108
+
109
+ x = (attn @ v).transpose(1, 2).reshape(B, N+self.num_heads, C)
110
+ x = self.proj(x)
111
+
112
+ # merge head tokens into cls token
113
+ cls, patch, ht = torch.split(x, [1, N-1, self.num_heads], dim=1)
114
+ cls = cls + torch.mean(ht, dim=1, keepdim=True) + torch.mean(patch, dim=1, keepdim=True)
115
+ x = torch.cat([cls, patch], dim=1)
116
+
117
+ x = self.proj_drop(x)
118
+
119
+ return x, attn
120
+
121
+
122
+ class FeedForward(nn.Module):
123
+ def __init__(
124
+ self,
125
+ dim,
126
+ mult = 4,
127
+ dropout = 0.
128
+ ):
129
+ super().__init__()
130
+ self.net = nn.Sequential(
131
+ nn.Linear(dim, dim * mult),
132
+ Swish(),
133
+ nn.Dropout(dropout),
134
+ nn.Linear(dim * mult, dim),
135
+ nn.Dropout(dropout)
136
+ )
137
+
138
+ def forward(self, x):
139
+ return self.net(x)
140
+
141
+ class ConformerConvModule(nn.Module):
142
+ def __init__(
143
+ self,
144
+ dim,
145
+ causal = False,
146
+ expansion_factor = 2,
147
+ kernel_size = 31,
148
+ dropout = 0.
149
+ ):
150
+ super().__init__()
151
+
152
+ inner_dim = dim * expansion_factor
153
+ padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
154
+
155
+ self.net = nn.Sequential(
156
+ nn.LayerNorm(dim),
157
+ Rearrange('b n c -> b c n'),
158
+ nn.Conv1d(dim, inner_dim * 2, 1),
159
+ GLU(dim=1),
160
+ DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding),
161
+ nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
162
+ Swish(),
163
+ nn.Conv1d(inner_dim, dim, 1),
164
+ Rearrange('b c n -> b n c'),
165
+ nn.Dropout(dropout)
166
+ )
167
+
168
+ def forward(self, x):
169
+ return self.net(x)
170
+
171
+ # Conformer Block
172
+
173
+ class ConformerBlock(nn.Module):
174
+ def __init__(
175
+ self,
176
+ *,
177
+ dim,
178
+ dim_head = 64,
179
+ heads = 8,
180
+ ff_mult = 4,
181
+ conv_expansion_factor = 2,
182
+ conv_kernel_size = 31,
183
+ attn_dropout = 0.,
184
+ ff_dropout = 0.,
185
+ conv_dropout = 0.,
186
+ conv_causal = False
187
+ ):
188
+ super().__init__()
189
+ self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
190
+ self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
191
+ self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
192
+ self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
193
+
194
+ self.attn = PreNorm(dim, self.attn)
195
+ self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
196
+ self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
197
+
198
+ self.post_norm = nn.LayerNorm(dim)
199
+
200
+ def forward(self, x, mask = None):
201
+ x = self.ff1(x) + x
202
+ attn_x, attn_weight = self.attn(x, mask = mask)
203
+ x = attn_x + x
204
+ x = self.conv(x) + x
205
+ x = self.ff2(x) + x
206
+ x = self.post_norm(x)
207
+ return x, attn_weight
208
+
209
+ # Conformer
210
+
211
+ class Conformer(nn.Module):
212
+ def __init__(
213
+ self,
214
+ dim,
215
+ *,
216
+ depth,
217
+ dim_head = 64,
218
+ heads = 8,
219
+ ff_mult = 4,
220
+ conv_expansion_factor = 2,
221
+ conv_kernel_size = 31,
222
+ attn_dropout = 0.,
223
+ ff_dropout = 0.,
224
+ conv_dropout = 0.,
225
+ conv_causal = False
226
+ ):
227
+ super().__init__()
228
+ self.dim = dim
229
+ self.layers = nn.ModuleList([])
230
+
231
+ for _ in range(depth):
232
+ self.layers.append(ConformerBlock(
233
+ dim = dim,
234
+ dim_head = dim_head,
235
+ heads = heads,
236
+ ff_mult = ff_mult,
237
+ conv_expansion_factor = conv_expansion_factor,
238
+ conv_kernel_size = conv_kernel_size,
239
+ conv_causal = conv_causal
240
+
241
+ ))
242
+
243
+ def forward(self, x):
244
+
245
+ for block in self.layers:
246
+ x = block(x)
247
+
248
+ return x
249
+
250
+
251
+
252
+ def sinusoidal_embedding(n_channels, dim):
253
+ pe = torch.FloatTensor([[p / (10000 ** (2 * (i // 2) / dim)) for i in range(dim)]
254
+ for p in range(n_channels)])
255
+ pe[:, 0::2] = torch.sin(pe[:, 0::2])
256
+ pe[:, 1::2] = torch.cos(pe[:, 1::2])
257
+ return pe.unsqueeze(0)
258
+
259
+ class FinalConformer(nn.Module):
260
+ def __init__(self, emb_size=128, heads=4, ffmult=4, exp_fac=2, kernel_size=16, n_encoders=1):
261
+ super(FinalConformer, self).__init__()
262
+ self.dim_head=int(emb_size/heads)
263
+ self.dim=emb_size
264
+ self.heads=heads
265
+ self.kernel_size=kernel_size
266
+ self.n_encoders=n_encoders
267
+ self.positional_emb = nn.Parameter(sinusoidal_embedding(10000, emb_size), requires_grad=False)
268
+ self.encoder_blocks=_get_clones(ConformerBlock( dim = emb_size, dim_head=self.dim_head, heads= heads,
269
+ ff_mult = ffmult, conv_expansion_factor = exp_fac, conv_kernel_size = kernel_size),
270
+ n_encoders)
271
+ self.class_token = nn.Parameter(torch.rand(1, emb_size))
272
+ self.fc5 = nn.Linear(emb_size, 2)
273
+
274
+ def forward(self, x): # x shape [bs, tiempo, frecuencia]
275
+ x = x + self.positional_emb[:, :x.size(1), :]
276
+ x = torch.stack([torch.vstack((self.class_token, x[i])) for i in range(len(x))])#[bs,1+tiempo,emb_size]
277
+ list_attn_weight = []
278
+ for layer in self.encoder_blocks:
279
+ x, attn_weight = layer(x) #[bs,1+tiempo,emb_size]
280
+ list_attn_weight.append(attn_weight)
281
+ embedding=x[:,0,:] #[bs, emb_size]
282
+ out=self.fc5(embedding) #[bs,2]
283
+ return out, list_attn_weight
284
+