1NEYRON1 commited on
Commit
c8282f5
·
verified ·
1 Parent(s): c5de25b

Update modeling_mosnet.py

Browse files
Files changed (1) hide show
  1. modeling_mosnet.py +75 -83
modeling_mosnet.py CHANGED
@@ -28,84 +28,6 @@ class TimeDistributed(nn.Module):
28
  return output
29
 
30
 
31
- class CnnBlstmMbnet2(nn.Module):
32
- def __init__(self, dropout: float = 0.3) -> None:
33
- super().__init__()
34
- self.conv1 = nn.Sequential(
35
- nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
36
- nn.ReLU(),
37
- nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
38
- nn.ReLU(),
39
- nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
40
- nn.ReLU(),
41
- nn.BatchNorm2d(16),
42
- nn.Dropout(dropout),
43
- )
44
- self.conv2 = nn.Sequential(
45
- nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
46
- nn.ReLU(),
47
- nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
48
- nn.ReLU(),
49
- nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
50
- nn.ReLU(),
51
- nn.BatchNorm2d(32),
52
- nn.Dropout(dropout),
53
- )
54
- self.conv3 = nn.Sequential(
55
- nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
56
- nn.ReLU(),
57
- nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
58
- nn.ReLU(),
59
- nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
60
- nn.ReLU(),
61
- nn.BatchNorm2d(64),
62
- nn.Dropout(dropout),
63
- )
64
- self.conv4 = nn.Sequential(
65
- nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
66
- nn.ReLU(),
67
- nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
68
- nn.ReLU(),
69
- nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
70
- nn.ReLU(),
71
- nn.BatchNorm2d(128),
72
- nn.Dropout(dropout),
73
- )
74
- self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
75
- self.droupout = nn.Dropout(dropout)
76
- self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
77
- self.dense1 = nn.Sequential(
78
- TimeDistributed(
79
- nn.Sequential(
80
- nn.Linear(256, 128),
81
- nn.ReLU(),
82
- ),
83
- batch_first=True,
84
- ),
85
- nn.Dropout(dropout),
86
- )
87
- self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
88
- self.average_layer = nn.AdaptiveAvgPool1d(1)
89
-
90
- def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91
- conv1_output = self.conv1(forward_input)
92
- conv2_output = self.conv2(conv1_output)
93
- conv3_output = self.conv3(conv2_output)
94
- conv4_output = self.conv4(conv3_output)
95
- conv4_output = conv4_output.permute(0, 2, 1, 3)
96
- conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
97
- blstm_output, _ = self.blstm1(conv4_output)
98
- blstm_output = self.droupout(blstm_output)
99
- flatten_output = self.flatten(blstm_output)
100
- fc_output = self.dense1(flatten_output)
101
- frame_score = self.frame_layer(fc_output)
102
- frame_score = frame_score.squeeze(-1) * mask
103
- valid_sum = torch.sum(frame_score, dim=1)
104
- valid_count = torch.sum(mask, dim=1)
105
- avg_score = valid_sum / (valid_count + 1e-8)
106
- return avg_score.unsqueeze(-1), frame_score
107
-
108
-
109
  class SwiGLU(nn.Module):
110
  def forward(self, x: torch.Tensor) -> torch.Tensor:
111
  x_, gate = x.chunk(2, dim=-1)
@@ -263,7 +185,6 @@ class CrossAttentionModel(nn.Module):
263
  avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
264
  return avg_score.reshape(avg_score.size(0), -1), frame_score.squeeze()
265
 
266
-
267
  class MosNet(PreTrainedModel):
268
  config_class = MosNetConfig
269
 
@@ -281,6 +202,81 @@ class MosNet(PreTrainedModel):
281
  self.fft_size = self.config.fft_size
282
  self.hop_length = self.config.hop_length
283
  self.win_length = self.config.win_length
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284
 
285
  def preprocess_audios(self, audios: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]:
286
  spectrograms = []
@@ -304,10 +300,6 @@ class MosNet(PreTrainedModel):
304
  masks[i, :valid_len] = 1.0
305
  return padded, masks
306
 
307
- def forward(self, audios: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
308
- outputs, _ = self.model(audios.to(self.device), masks.to(self.device))
309
- return outputs
310
-
311
  def predict(self, audios: List[Any]) -> List[float]:
312
  with torch.no_grad():
313
  padded, masks = self.preprocess_audios(audios)
 
28
  return output
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class SwiGLU(nn.Module):
32
  def forward(self, x: torch.Tensor) -> torch.Tensor:
33
  x_, gate = x.chunk(2, dim=-1)
 
185
  avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
186
  return avg_score.reshape(avg_score.size(0), -1), frame_score.squeeze()
187
 
 
188
  class MosNet(PreTrainedModel):
189
  config_class = MosNetConfig
190
 
 
202
  self.fft_size = self.config.fft_size
203
  self.hop_length = self.config.hop_length
204
  self.win_length = self.config.win_length
205
+ self.dropout = self.config.dropout
206
+
207
+ self.conv1 = nn.Sequential(
208
+ nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
209
+ nn.ReLU(),
210
+ nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
211
+ nn.ReLU(),
212
+ nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
213
+ nn.ReLU(),
214
+ nn.BatchNorm2d(16),
215
+ nn.Dropout(self.dropout),
216
+ )
217
+ self.conv2 = nn.Sequential(
218
+ nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
219
+ nn.ReLU(),
220
+ nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
221
+ nn.ReLU(),
222
+ nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
223
+ nn.ReLU(),
224
+ nn.BatchNorm2d(32),
225
+ nn.Dropout(self.dropout),
226
+ )
227
+ self.conv3 = nn.Sequential(
228
+ nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
229
+ nn.ReLU(),
230
+ nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
231
+ nn.ReLU(),
232
+ nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
233
+ nn.ReLU(),
234
+ nn.BatchNorm2d(64),
235
+ nn.Dropout(self.dropout),
236
+ )
237
+ self.conv4 = nn.Sequential(
238
+ nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
239
+ nn.ReLU(),
240
+ nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
241
+ nn.ReLU(),
242
+ nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
243
+ nn.ReLU(),
244
+ nn.BatchNorm2d(128),
245
+ nn.Dropout(self.dropout),
246
+ )
247
+ self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
248
+ self.droupout = nn.Dropout(self.dropout)
249
+ self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
250
+ self.dense1 = nn.Sequential(
251
+ TimeDistributed(
252
+ nn.Sequential(
253
+ nn.Linear(256, 128),
254
+ nn.ReLU(),
255
+ ),
256
+ batch_first=True,
257
+ ),
258
+ nn.Dropout(self.dropout),
259
+ )
260
+ self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
261
+ self.average_layer = nn.AdaptiveAvgPool1d(1)
262
+
263
+ def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
264
+ conv1_output = self.conv1(forward_input)
265
+ conv2_output = self.conv2(conv1_output)
266
+ conv3_output = self.conv3(conv2_output)
267
+ conv4_output = self.conv4(conv3_output)
268
+ conv4_output = conv4_output.permute(0, 2, 1, 3)
269
+ conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
270
+ blstm_output, _ = self.blstm1(conv4_output)
271
+ blstm_output = self.droupout(blstm_output)
272
+ flatten_output = self.flatten(blstm_output)
273
+ fc_output = self.dense1(flatten_output)
274
+ frame_score = self.frame_layer(fc_output)
275
+ frame_score = frame_score.squeeze(-1) * mask
276
+ valid_sum = torch.sum(frame_score, dim=1)
277
+ valid_count = torch.sum(mask, dim=1)
278
+ avg_score = valid_sum / (valid_count + 1e-8)
279
+ return avg_score.unsqueeze(-1)
280
 
281
  def preprocess_audios(self, audios: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]:
282
  spectrograms = []
 
300
  masks[i, :valid_len] = 1.0
301
  return padded, masks
302
 
 
 
 
 
303
  def predict(self, audios: List[Any]) -> List[float]:
304
  with torch.no_grad():
305
  padded, masks = self.preprocess_audios(audios)