Update modeling_mosnet.py
Browse files- modeling_mosnet.py +75 -83
modeling_mosnet.py
CHANGED
|
@@ -28,84 +28,6 @@ class TimeDistributed(nn.Module):
|
|
| 28 |
return output
|
| 29 |
|
| 30 |
|
| 31 |
-
class CnnBlstmMbnet2(nn.Module):
|
| 32 |
-
def __init__(self, dropout: float = 0.3) -> None:
|
| 33 |
-
super().__init__()
|
| 34 |
-
self.conv1 = nn.Sequential(
|
| 35 |
-
nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
|
| 36 |
-
nn.ReLU(),
|
| 37 |
-
nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
|
| 38 |
-
nn.ReLU(),
|
| 39 |
-
nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
|
| 40 |
-
nn.ReLU(),
|
| 41 |
-
nn.BatchNorm2d(16),
|
| 42 |
-
nn.Dropout(dropout),
|
| 43 |
-
)
|
| 44 |
-
self.conv2 = nn.Sequential(
|
| 45 |
-
nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
|
| 46 |
-
nn.ReLU(),
|
| 47 |
-
nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
|
| 48 |
-
nn.ReLU(),
|
| 49 |
-
nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
|
| 50 |
-
nn.ReLU(),
|
| 51 |
-
nn.BatchNorm2d(32),
|
| 52 |
-
nn.Dropout(dropout),
|
| 53 |
-
)
|
| 54 |
-
self.conv3 = nn.Sequential(
|
| 55 |
-
nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
|
| 56 |
-
nn.ReLU(),
|
| 57 |
-
nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
|
| 58 |
-
nn.ReLU(),
|
| 59 |
-
nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
|
| 60 |
-
nn.ReLU(),
|
| 61 |
-
nn.BatchNorm2d(64),
|
| 62 |
-
nn.Dropout(dropout),
|
| 63 |
-
)
|
| 64 |
-
self.conv4 = nn.Sequential(
|
| 65 |
-
nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
|
| 66 |
-
nn.ReLU(),
|
| 67 |
-
nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
|
| 68 |
-
nn.ReLU(),
|
| 69 |
-
nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
|
| 70 |
-
nn.ReLU(),
|
| 71 |
-
nn.BatchNorm2d(128),
|
| 72 |
-
nn.Dropout(dropout),
|
| 73 |
-
)
|
| 74 |
-
self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
|
| 75 |
-
self.droupout = nn.Dropout(dropout)
|
| 76 |
-
self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
|
| 77 |
-
self.dense1 = nn.Sequential(
|
| 78 |
-
TimeDistributed(
|
| 79 |
-
nn.Sequential(
|
| 80 |
-
nn.Linear(256, 128),
|
| 81 |
-
nn.ReLU(),
|
| 82 |
-
),
|
| 83 |
-
batch_first=True,
|
| 84 |
-
),
|
| 85 |
-
nn.Dropout(dropout),
|
| 86 |
-
)
|
| 87 |
-
self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
|
| 88 |
-
self.average_layer = nn.AdaptiveAvgPool1d(1)
|
| 89 |
-
|
| 90 |
-
def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 91 |
-
conv1_output = self.conv1(forward_input)
|
| 92 |
-
conv2_output = self.conv2(conv1_output)
|
| 93 |
-
conv3_output = self.conv3(conv2_output)
|
| 94 |
-
conv4_output = self.conv4(conv3_output)
|
| 95 |
-
conv4_output = conv4_output.permute(0, 2, 1, 3)
|
| 96 |
-
conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
|
| 97 |
-
blstm_output, _ = self.blstm1(conv4_output)
|
| 98 |
-
blstm_output = self.droupout(blstm_output)
|
| 99 |
-
flatten_output = self.flatten(blstm_output)
|
| 100 |
-
fc_output = self.dense1(flatten_output)
|
| 101 |
-
frame_score = self.frame_layer(fc_output)
|
| 102 |
-
frame_score = frame_score.squeeze(-1) * mask
|
| 103 |
-
valid_sum = torch.sum(frame_score, dim=1)
|
| 104 |
-
valid_count = torch.sum(mask, dim=1)
|
| 105 |
-
avg_score = valid_sum / (valid_count + 1e-8)
|
| 106 |
-
return avg_score.unsqueeze(-1), frame_score
|
| 107 |
-
|
| 108 |
-
|
| 109 |
class SwiGLU(nn.Module):
|
| 110 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 111 |
x_, gate = x.chunk(2, dim=-1)
|
|
@@ -263,7 +185,6 @@ class CrossAttentionModel(nn.Module):
|
|
| 263 |
avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
|
| 264 |
return avg_score.reshape(avg_score.size(0), -1), frame_score.squeeze()
|
| 265 |
|
| 266 |
-
|
| 267 |
class MosNet(PreTrainedModel):
|
| 268 |
config_class = MosNetConfig
|
| 269 |
|
|
@@ -281,6 +202,81 @@ class MosNet(PreTrainedModel):
|
|
| 281 |
self.fft_size = self.config.fft_size
|
| 282 |
self.hop_length = self.config.hop_length
|
| 283 |
self.win_length = self.config.win_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
|
| 285 |
def preprocess_audios(self, audios: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 286 |
spectrograms = []
|
|
@@ -304,10 +300,6 @@ class MosNet(PreTrainedModel):
|
|
| 304 |
masks[i, :valid_len] = 1.0
|
| 305 |
return padded, masks
|
| 306 |
|
| 307 |
-
def forward(self, audios: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
|
| 308 |
-
outputs, _ = self.model(audios.to(self.device), masks.to(self.device))
|
| 309 |
-
return outputs
|
| 310 |
-
|
| 311 |
def predict(self, audios: List[Any]) -> List[float]:
|
| 312 |
with torch.no_grad():
|
| 313 |
padded, masks = self.preprocess_audios(audios)
|
|
|
|
| 28 |
return output
|
| 29 |
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
class SwiGLU(nn.Module):
|
| 32 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 33 |
x_, gate = x.chunk(2, dim=-1)
|
|
|
|
| 185 |
avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
|
| 186 |
return avg_score.reshape(avg_score.size(0), -1), frame_score.squeeze()
|
| 187 |
|
|
|
|
| 188 |
class MosNet(PreTrainedModel):
|
| 189 |
config_class = MosNetConfig
|
| 190 |
|
|
|
|
| 202 |
self.fft_size = self.config.fft_size
|
| 203 |
self.hop_length = self.config.hop_length
|
| 204 |
self.win_length = self.config.win_length
|
| 205 |
+
self.dropout = self.config.dropout
|
| 206 |
+
|
| 207 |
+
self.conv1 = nn.Sequential(
|
| 208 |
+
nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
|
| 209 |
+
nn.ReLU(),
|
| 210 |
+
nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
|
| 211 |
+
nn.ReLU(),
|
| 212 |
+
nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
|
| 213 |
+
nn.ReLU(),
|
| 214 |
+
nn.BatchNorm2d(16),
|
| 215 |
+
nn.Dropout(self.dropout),
|
| 216 |
+
)
|
| 217 |
+
self.conv2 = nn.Sequential(
|
| 218 |
+
nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
|
| 219 |
+
nn.ReLU(),
|
| 220 |
+
nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
|
| 221 |
+
nn.ReLU(),
|
| 222 |
+
nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
|
| 223 |
+
nn.ReLU(),
|
| 224 |
+
nn.BatchNorm2d(32),
|
| 225 |
+
nn.Dropout(self.dropout),
|
| 226 |
+
)
|
| 227 |
+
self.conv3 = nn.Sequential(
|
| 228 |
+
nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
|
| 229 |
+
nn.ReLU(),
|
| 230 |
+
nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
|
| 231 |
+
nn.ReLU(),
|
| 232 |
+
nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
|
| 233 |
+
nn.ReLU(),
|
| 234 |
+
nn.BatchNorm2d(64),
|
| 235 |
+
nn.Dropout(self.dropout),
|
| 236 |
+
)
|
| 237 |
+
self.conv4 = nn.Sequential(
|
| 238 |
+
nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
|
| 239 |
+
nn.ReLU(),
|
| 240 |
+
nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
|
| 241 |
+
nn.ReLU(),
|
| 242 |
+
nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
|
| 243 |
+
nn.ReLU(),
|
| 244 |
+
nn.BatchNorm2d(128),
|
| 245 |
+
nn.Dropout(self.dropout),
|
| 246 |
+
)
|
| 247 |
+
self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
|
| 248 |
+
self.droupout = nn.Dropout(self.dropout)
|
| 249 |
+
self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
|
| 250 |
+
self.dense1 = nn.Sequential(
|
| 251 |
+
TimeDistributed(
|
| 252 |
+
nn.Sequential(
|
| 253 |
+
nn.Linear(256, 128),
|
| 254 |
+
nn.ReLU(),
|
| 255 |
+
),
|
| 256 |
+
batch_first=True,
|
| 257 |
+
),
|
| 258 |
+
nn.Dropout(self.dropout),
|
| 259 |
+
)
|
| 260 |
+
self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
|
| 261 |
+
self.average_layer = nn.AdaptiveAvgPool1d(1)
|
| 262 |
+
|
| 263 |
+
def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 264 |
+
conv1_output = self.conv1(forward_input)
|
| 265 |
+
conv2_output = self.conv2(conv1_output)
|
| 266 |
+
conv3_output = self.conv3(conv2_output)
|
| 267 |
+
conv4_output = self.conv4(conv3_output)
|
| 268 |
+
conv4_output = conv4_output.permute(0, 2, 1, 3)
|
| 269 |
+
conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
|
| 270 |
+
blstm_output, _ = self.blstm1(conv4_output)
|
| 271 |
+
blstm_output = self.droupout(blstm_output)
|
| 272 |
+
flatten_output = self.flatten(blstm_output)
|
| 273 |
+
fc_output = self.dense1(flatten_output)
|
| 274 |
+
frame_score = self.frame_layer(fc_output)
|
| 275 |
+
frame_score = frame_score.squeeze(-1) * mask
|
| 276 |
+
valid_sum = torch.sum(frame_score, dim=1)
|
| 277 |
+
valid_count = torch.sum(mask, dim=1)
|
| 278 |
+
avg_score = valid_sum / (valid_count + 1e-8)
|
| 279 |
+
return avg_score.unsqueeze(-1)
|
| 280 |
|
| 281 |
def preprocess_audios(self, audios: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 282 |
spectrograms = []
|
|
|
|
| 300 |
masks[i, :valid_len] = 1.0
|
| 301 |
return padded, masks
|
| 302 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 303 |
def predict(self, audios: List[Any]) -> List[float]:
|
| 304 |
with torch.no_grad():
|
| 305 |
padded, masks = self.preprocess_audios(audios)
|