Aloukik21 commited on
Commit
b896626
·
verified ·
1 Parent(s): 9aaf10c

Upload audio/DF_Arena_1B_V_1/backbone.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. audio/DF_Arena_1B_V_1/backbone.py +62 -0
audio/DF_Arena_1B_V_1/backbone.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch import Tensor
6
+ from transformers import Wav2Vec2Model, Wav2Vec2Config
7
+ from .conformer import FinalConformer
8
+
9
+ class DF_Arena_1B(nn.Module):
10
+ def __init__(self):
11
+ super().__init__()
12
+ self.ssl_model = Wav2Vec2Model(Wav2Vec2Config.from_pretrained("facebook/wav2vec2-xls-r-1b"))
13
+ self.ssl_model.config.output_hidden_states = True
14
+ self.first_bn = nn.BatchNorm2d(num_features=1)
15
+ self.selu = nn.SELU(inplace=True)
16
+ self.fc0 = nn.Linear(1280, 1) #1280 for 1b, 1920 for 2b
17
+ self.sig = nn.Sigmoid()
18
+
19
+
20
+ self.conformer = FinalConformer(emb_size=1280, heads=4, ffmult=4, exp_fac=2, kernel_size=31, n_encoders=4)
21
+
22
+ # Learnable attention weights
23
+ self.attn_scores = nn.Linear(1280, 1, bias=False)
24
+
25
+ def get_attenF1Dpooling(self, x):
26
+ #print(x.shape, 'x shape in attnF1Dpooling')
27
+ logits = self.attn_scores(x)
28
+ weights = torch.softmax(logits, dim=1) # (B, T, 1)
29
+ pooled = torch.sum(weights * x, dim=1, keepdim=True) # (B, 1, D)
30
+ return pooled
31
+
32
+ def get_attenF1D(self, layerResult):
33
+ poollayerResult = []
34
+ fullf = []
35
+ for layer in layerResult:
36
+ # layer shape: (B, D, T)
37
+ #layery = layer.permute(0, 2, 1) # (B, T, D)
38
+ layery = self.get_attenF1Dpooling(layer) # (B, 1, D)
39
+ poollayerResult.append(layery)
40
+ fullf.append(layer.unsqueeze(1)) # (B, 1, D, T)
41
+
42
+ layery = torch.cat(poollayerResult, dim=1) # (B, L, D)
43
+ fullfeature = torch.cat(fullf, dim=1) # (B, L, D, T)
44
+ return layery, fullfeature
45
+
46
+ def forward(self, x):
47
+ out_ssl = self.ssl_model(x.unsqueeze(0)) #layerresult = [(x,z),24个] x(201,1,1024) z(1,201,201)
48
+ y0, fullfeature = self.get_attenF1D(out_ssl.hidden_states)
49
+ y0 = self.fc0(y0)
50
+ y0 = self.sig(y0)
51
+ y0 = y0.view(y0.shape[0], y0.shape[1], y0.shape[2], -1)
52
+ fullfeature = fullfeature * y0
53
+ fullfeature = torch.sum(fullfeature, 1)
54
+ fullfeature = fullfeature.unsqueeze(dim=1)
55
+ fullfeature = self.first_bn(fullfeature)
56
+ fullfeature = self.selu(fullfeature)
57
+
58
+
59
+ output, _ = self.conformer(fullfeature.squeeze(1))
60
+
61
+
62
+ return output