Beijuka commited on
Commit
047a111
·
verified ·
1 Parent(s): 899d7fc

Upload model.py

Browse files
Files changed (1) hide show
  1. model.py +174 -0
model.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## This script is based on the https://github.com/TaoRuijie/ECAPA-TDNN/blob/main/model.py
2
+ ## I made some changes to the original code for training a binary classifier.
3
+
4
+ from typing import Optional
5
+ import math
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ import torchaudio
12
+ from torchaudio.transforms import Resample
13
+
14
+
15
+ from huggingface_hub import PyTorchModelHubMixin
16
+
17
+
18
+ class SEModule(nn.Module):
19
+ def __init__(self, channels : int , bottleneck : int = 128) -> None:
20
+ super(SEModule, self).__init__()
21
+ self.se = nn.Sequential(
22
+ nn.AdaptiveAvgPool1d(1),
23
+ nn.Conv1d(channels, bottleneck, kernel_size=1, padding=0),
24
+ nn.ReLU(),
25
+ # nn.BatchNorm1d(bottleneck), # I remove this layer
26
+ nn.Conv1d(bottleneck, channels, kernel_size=1, padding=0),
27
+ nn.Sigmoid(),
28
+ )
29
+
30
+ def forward(self, input : torch.Tensor) -> torch.Tensor:
31
+ x = self.se(input)
32
+ return input * x
33
+
34
+ class Bottle2neck(nn.Module):
35
+ def __init__(self, inplanes : int, planes : int, kernel_size : Optional[int] = None, dilation : Optional[int] = None, scale : int = 8) -> None:
36
+ super(Bottle2neck, self).__init__()
37
+ width = int(math.floor(planes / scale))
38
+ self.conv1 = nn.Conv1d(inplanes, width*scale, kernel_size=1)
39
+ self.bn1 = nn.BatchNorm1d(width*scale)
40
+ self.nums = scale -1
41
+ convs = []
42
+ bns = []
43
+ num_pad = math.floor(kernel_size/2)*dilation
44
+ for i in range(self.nums):
45
+ convs.append(nn.Conv1d(width, width, kernel_size=kernel_size, dilation=dilation, padding=num_pad))
46
+ bns.append(nn.BatchNorm1d(width))
47
+ self.convs = nn.ModuleList(convs)
48
+ self.bns = nn.ModuleList(bns)
49
+ self.conv3 = nn.Conv1d(width*scale, planes, kernel_size=1)
50
+ self.bn3 = nn.BatchNorm1d(planes)
51
+ self.relu = nn.ReLU()
52
+ self.width = width
53
+ self.se = SEModule(planes)
54
+
55
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
56
+ residual = x
57
+ out = self.conv1(x)
58
+ out = self.relu(out)
59
+ out = self.bn1(out)
60
+
61
+ spx = torch.split(out, self.width, 1)
62
+ for i in range(self.nums):
63
+ if i==0:
64
+ sp = spx[i]
65
+ else:
66
+ sp = sp + spx[i]
67
+ sp = self.convs[i](sp)
68
+ sp = self.relu(sp)
69
+ sp = self.bns[i](sp)
70
+ if i==0:
71
+ out = sp
72
+ else:
73
+ out = torch.cat((out, sp), 1)
74
+ out = torch.cat((out, spx[self.nums]),1)
75
+
76
+ out = self.conv3(out)
77
+ out = self.relu(out)
78
+ out = self.bn3(out)
79
+
80
+ out = self.se(out)
81
+ out += residual
82
+ return out
83
+
84
+
85
+ class ECAPA_gender(nn.Module, PyTorchModelHubMixin):
86
+ def __init__(self, C : int = 1024):
87
+ super(ECAPA_gender, self).__init__()
88
+ self.C = C
89
+ self.conv1 = nn.Conv1d(80, C, kernel_size=5, stride=1, padding=2)
90
+ self.relu = nn.ReLU()
91
+ self.bn1 = nn.BatchNorm1d(C)
92
+ self.layer1 = Bottle2neck(C, C, kernel_size=3, dilation=2, scale=8)
93
+ self.layer2 = Bottle2neck(C, C, kernel_size=3, dilation=3, scale=8)
94
+ self.layer3 = Bottle2neck(C, C, kernel_size=3, dilation=4, scale=8)
95
+ # I fixed the shape of the output from MFA layer, that is close to the setting from ECAPA paper.
96
+ self.layer4 = nn.Conv1d(3*C, 1536, kernel_size=1)
97
+ self.attention = nn.Sequential(
98
+ nn.Conv1d(4608, 256, kernel_size=1),
99
+ nn.ReLU(),
100
+ nn.BatchNorm1d(256),
101
+ nn.Tanh(), # I add this layer
102
+ nn.Conv1d(256, 1536, kernel_size=1),
103
+ nn.Softmax(dim=2),
104
+ )
105
+ self.bn5 = nn.BatchNorm1d(3072)
106
+ self.fc6 = nn.Linear(3072, 192)
107
+ self.bn6 = nn.BatchNorm1d(192)
108
+ self.fc7 = nn.Linear(192, 2)
109
+ self.pred2gender = {0 : 'male', 1 : 'female'}
110
+
111
+ def logtorchfbank(self, x : torch.Tensor) -> torch.Tensor:
112
+ # Preemphasis
113
+ flipped_filter = torch.FloatTensor([-0.97, 1.]).unsqueeze(0).unsqueeze(0).to(x.device)
114
+ x = x.unsqueeze(1)
115
+ x = F.pad(x, (1, 0), 'reflect')
116
+ x = F.conv1d(x, flipped_filter).squeeze(1)
117
+
118
+ # Melspectrogram
119
+ x = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_fft=512, win_length=400, hop_length=160, \
120
+ f_min = 20, f_max = 7600, window_fn=torch.hamming_window, n_mels=80).to(x.device)(x) + 1e-6
121
+
122
+ # Log and normalize
123
+ x = x.log()
124
+ x = x - torch.mean(x, dim=-1, keepdim=True)
125
+ return x
126
+
127
+ def forward(self, x : torch.Tensor) -> torch.Tensor:
128
+ x = self.logtorchfbank(x)
129
+
130
+ x = self.conv1(x)
131
+ x = self.relu(x)
132
+ x = self.bn1(x)
133
+
134
+ x1 = self.layer1(x)
135
+ x2 = self.layer2(x+x1)
136
+ x3 = self.layer3(x+x1+x2)
137
+
138
+ x = self.layer4(torch.cat((x1,x2,x3),dim=1))
139
+ x = self.relu(x)
140
+
141
+ t = x.size()[-1]
142
+
143
+ global_x = torch.cat((x,torch.mean(x,dim=2,keepdim=True).repeat(1,1,t), torch.sqrt(torch.var(x,dim=2,keepdim=True).clamp(min=1e-4)).repeat(1,1,t)), dim=1)
144
+
145
+ w = self.attention(global_x)
146
+
147
+ mu = torch.sum(x * w, dim=2)
148
+ sg = torch.sqrt( ( torch.sum((x**2) * w, dim=2) - mu**2 ).clamp(min=1e-4) )
149
+
150
+ x = torch.cat((mu,sg),1)
151
+ x = self.bn5(x)
152
+ x = self.fc6(x)
153
+ x = self.bn6(x)
154
+ x = self.relu(x)
155
+ x = self.fc7(x)
156
+
157
+ return x
158
+
159
+ def load_audio(self, path: str) -> torch.Tensor:
160
+ audio, sr = torchaudio.load(path)
161
+ if sr != 16000:
162
+ resampler = Resample(orig_freq=sr, new_freq=16000)
163
+ audio = resampler(audio)
164
+ return audio.mean(dim=0, keepdim=True) # Convert to mono if stereo
165
+
166
+ def predict(self, audio : torch.Tensor, device: torch.device) -> torch.Tensor:
167
+ audio = self.load_audio(audio)
168
+ audio = audio.to(device)
169
+ self.eval()
170
+
171
+ with torch.no_grad():
172
+ output = self.forward(audio)
173
+ _, pred = output.max(1)
174
+ return self.pred2gender[pred.item()]