1NEYRON1 commited on
Commit
e2389da
·
verified ·
1 Parent(s): 1700325

Create modeling_mosnet.py

Browse files
Files changed (1) hide show
  1. modeling_mosnet.py +317 -0
modeling_mosnet.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple
2
+
3
+ from einops import rearrange
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as f
8
+ from torch import nn
9
+ from transformers import BertModel, BertTokenizer, PreTrainedModel
10
+ from .configuration_mosnet import MosNetConfig
11
+ from transformers import AutoConfig, AutoModel
12
+
13
+
14
+ class TimeDistributed(nn.Module):
15
+ def __init__(self, module: nn.Module, batch_first: bool) -> None:
16
+ super().__init__()
17
+ self.module = module
18
+ self.batch_first = batch_first
19
+
20
+ def forward(self, input_seq: torch.Tensor) -> torch.Tensor:
21
+ assert len(input_seq.size()) > 2
22
+ reshaped_input = input_seq.contiguous().view(-1, input_seq.size(-1))
23
+ output = self.module(reshaped_input)
24
+ if self.batch_first:
25
+ output = output.contiguous().view(input_seq.size(0), -1, output.size(-1))
26
+ else:
27
+ output = output.contiguous().view(-1, input_seq.size(1), output.size(-1))
28
+ return output
29
+
30
+
31
+ class CnnBlstmMbnet2(nn.Module):
32
+ def __init__(self, dropout: float = 0.3) -> None:
33
+ super().__init__()
34
+ self.conv1 = nn.Sequential(
35
+ nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
36
+ nn.ReLU(),
37
+ nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
38
+ nn.ReLU(),
39
+ nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
40
+ nn.ReLU(),
41
+ nn.BatchNorm2d(16),
42
+ nn.Dropout(dropout),
43
+ )
44
+ self.conv2 = nn.Sequential(
45
+ nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
46
+ nn.ReLU(),
47
+ nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
48
+ nn.ReLU(),
49
+ nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
50
+ nn.ReLU(),
51
+ nn.BatchNorm2d(32),
52
+ nn.Dropout(dropout),
53
+ )
54
+ self.conv3 = nn.Sequential(
55
+ nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
56
+ nn.ReLU(),
57
+ nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
58
+ nn.ReLU(),
59
+ nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
60
+ nn.ReLU(),
61
+ nn.BatchNorm2d(64),
62
+ nn.Dropout(dropout),
63
+ )
64
+ self.conv4 = nn.Sequential(
65
+ nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
66
+ nn.ReLU(),
67
+ nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
68
+ nn.ReLU(),
69
+ nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
70
+ nn.ReLU(),
71
+ nn.BatchNorm2d(128),
72
+ nn.Dropout(dropout),
73
+ )
74
+ self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
75
+ self.droupout = nn.Dropout(dropout)
76
+ self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
77
+ self.dense1 = nn.Sequential(
78
+ TimeDistributed(
79
+ nn.Sequential(
80
+ nn.Linear(256, 128),
81
+ nn.ReLU(),
82
+ ),
83
+ batch_first=True,
84
+ ),
85
+ nn.Dropout(dropout),
86
+ )
87
+ self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
88
+ self.average_layer = nn.AdaptiveAvgPool1d(1)
89
+
90
+ def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91
+ conv1_output = self.conv1(forward_input)
92
+ conv2_output = self.conv2(conv1_output)
93
+ conv3_output = self.conv3(conv2_output)
94
+ conv4_output = self.conv4(conv3_output)
95
+ conv4_output = conv4_output.permute(0, 2, 1, 3)
96
+ conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
97
+ blstm_output, _ = self.blstm1(conv4_output)
98
+ blstm_output = self.droupout(blstm_output)
99
+ flatten_output = self.flatten(blstm_output)
100
+ fc_output = self.dense1(flatten_output)
101
+ frame_score = self.frame_layer(fc_output)
102
+ frame_score = frame_score.squeeze(-1) * mask
103
+ valid_sum = torch.sum(frame_score, dim=1)
104
+ valid_count = torch.sum(mask, dim=1)
105
+ avg_score = valid_sum / (valid_count + 1e-8)
106
+ return avg_score.unsqueeze(-1), frame_score
107
+
108
+
109
+ class SwiGLU(nn.Module):
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+ x_, gate = x.chunk(2, dim=-1)
112
+ return f.silu(gate) * x_
113
+
114
+
115
+ class RotaryEmbedding(nn.Module):
116
+ def __init__(self, dim: int, scale_base: int = 512, use_xpos: bool = True) -> None:
117
+ super().__init__()
118
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
119
+ self.register_buffer("inv_freq", inv_freq)
120
+ self.use_xpos = use_xpos
121
+ self.scale_base = scale_base
122
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
123
+ self.register_buffer('scale', scale)
124
+
125
+ def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
127
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
128
+ freqs = torch.cat((freqs, freqs), dim=-1)
129
+ if not self.use_xpos:
130
+ return freqs, torch.ones(1, device=device)
131
+ power = (t - (seq_len // 2)) / self.scale_base
132
+ scale = self.scale ** rearrange(power, 'n -> n 1')
133
+ scale = torch.cat((scale, scale), dim=-1)
134
+ return freqs, scale
135
+
136
+
137
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
138
+ x1, x2 = x.chunk(2, dim=-1)
139
+ return torch.cat((-x2, x1), dim=-1)
140
+
141
+
142
+ def apply_rotary_pos_emb(pos: torch.Tensor, t: torch.Tensor, scale: float = 1.) -> torch.Tensor:
143
+ return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
144
+
145
+
146
+ def l2norm(t: torch.Tensor) -> torch.Tensor:
147
+ return f.normalize(t, dim=-1)
148
+
149
+
150
+ class TransformerBlock(nn.Module):
151
+ def __init__(self, dim_head: int = 64, heads: int = 8, dropout: float = 0.2, forward_expansion: int = 2, device: str = "cpu") -> None:
152
+ super().__init__()
153
+ self.heads = heads
154
+ self.dim_head = dim_head
155
+ self.embed_dim = heads * dim_head
156
+ self.device = device
157
+
158
+ self.qkv = nn.Linear(dim_head * heads, dim_head * heads * 3)
159
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
160
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
161
+ self.rotary_emb = RotaryEmbedding(dim_head)
162
+ self.norm = nn.LayerNorm(dim_head * heads)
163
+ self.feed_forward = nn.Sequential(
164
+ nn.Linear(dim_head * heads, forward_expansion * dim_head * heads * 2), # *2 для SwiGLU
165
+ SwiGLU(),
166
+ nn.Dropout(dropout),
167
+ nn.Linear(forward_expansion * dim_head * heads, dim_head * heads),
168
+ )
169
+
170
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
171
+ n, seq_length, _ = q.shape
172
+ qkv_proj = self.qkv(q)
173
+ qkv_proj = qkv_proj.reshape(n, seq_length, self.heads, 3 * self.dim_head)
174
+ qkv = qkv_proj.permute(0, 2, 1, 3)
175
+ q_, k_, v_ = qkv.chunk(3, dim=-1)
176
+ q_, k_ = map(l2norm, (q_, k_))
177
+ q_ = q_ * self.q_scale
178
+ k_ = k_ * self.k_scale
179
+ positions, scale = self.rotary_emb(seq_length, self.device)
180
+ q_ = apply_rotary_pos_emb(positions, q_, scale)
181
+ k_ = apply_rotary_pos_emb(positions, k_, scale ** -1)
182
+ attn_output = f.scaled_dot_product_attention(q_, k_, v_)
183
+ attn_output = attn_output.permute(0, 2, 1, 3).reshape(n, seq_length, self.embed_dim)
184
+ attn_output = self.norm(attn_output)
185
+ forward_output = self.feed_forward(attn_output)
186
+ return attn_output + forward_output
187
+
188
+
189
+ class AudioFeatureExtractor(nn.Module):
190
+ def __init__(self) -> None:
191
+ super().__init__()
192
+ self.conv1 = nn.Sequential(
193
+ nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1), nn.ReLU(),
194
+ nn.Conv2d(16, 16, (3, 3), (1, 1), padding=1), nn.ReLU(),
195
+ nn.Conv2d(16, 16, (3, 3), (1, 3), padding=1), nn.ReLU()
196
+ )
197
+ self.conv2 = nn.Sequential(
198
+ nn.Conv2d(16, 32, (3, 3), (1, 1), padding=1), nn.ReLU(),
199
+ nn.Conv2d(32, 32, (3, 3), (1, 1), padding=1), nn.ReLU(),
200
+ nn.Conv2d(32, 32, (3, 3), (1, 3), padding=1), nn.ReLU()
201
+ )
202
+ self.conv3 = nn.Sequential(
203
+ nn.Conv2d(32, 64, (3, 3), (1, 1), padding=1), nn.ReLU(),
204
+ nn.Conv2d(64, 64, (3, 3), (1, 1), padding=1), nn.ReLU(),
205
+ nn.Conv2d(64, 64, (3, 3), (1, 3), padding=1), nn.ReLU()
206
+ )
207
+ self.conv4 = nn.Sequential(
208
+ nn.Conv2d(64, 128, (3, 3), (1, 1), padding=1), nn.ReLU(),
209
+ nn.Conv2d(128, 128, (3, 3), (1, 1), padding=1), nn.ReLU(),
210
+ nn.Conv2d(128, 128, (3, 3), (1, 3), padding=1), nn.ReLU()
211
+ )
212
+
213
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
214
+ x = self.conv1(x)
215
+ x = self.conv2(x)
216
+ x = self.conv3(x)
217
+ x = self.conv4(x)
218
+ x = x.permute(0, 2, 1, 3)
219
+ x = torch.reshape(x, (x.shape[0], x.shape[1], -1))
220
+ return x
221
+
222
+
223
+ class CrossAttentionModel(nn.Module):
224
+ def __init__(self, device: str = "cpu") -> None:
225
+ super().__init__()
226
+ self.audio_extractor = AudioFeatureExtractor()
227
+
228
+ self.text_projection = nn.Linear(768, 512)
229
+ # передаём device внутрь TransformerBlock
230
+ self.cross_attention = TransformerBlock(dim_head=64, heads=8, device=device)
231
+
232
+ self.fc1 = nn.Sequential(
233
+ nn.Linear(512, 128),
234
+ nn.ReLU(),
235
+ nn.Dropout(0.3),
236
+ )
237
+ self.frame_layer = nn.Linear(128, 1)
238
+ self.average_layer = nn.AdaptiveAvgPool1d(1)
239
+
240
+ def forward(
241
+ self,
242
+ audio_input: torch.Tensor,
243
+ text_embeddings: torch.Tensor,
244
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
245
+ """audio_input shape: (B, 1, T, F)
246
+ text_embeddings shape: (B, 768)
247
+ """
248
+ # ↳ Audio branch
249
+ audio_features = self.audio_extractor(audio_input) # (B, T, 512)
250
+
251
+ # ↳ Text branch
252
+ text_proj = self.text_projection(text_embeddings) # (B, 512)
253
+ text_proj = text_proj.unsqueeze(1) # (B, 1, 512)
254
+
255
+ # Cross-attention
256
+ cross_out = self.cross_attention(audio_features, text_proj, text_proj) # (B, T, 512)
257
+
258
+ # Head
259
+ fc_out = self.fc1(cross_out) # (B, T, 128)
260
+ frame_score = self.frame_layer(fc_out) # (B, T, 1)
261
+
262
+ # aggregate
263
+ avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
264
+ return avg_score.reshape(avg_score.size(0), -1), frame_score.squeeze()
265
+
266
+
267
+ class MosNet(PreTrainedModel):
268
+ config_class = MosNetConfig
269
+
270
+ def __init__(self, config: MosNetConfig) -> None:
271
+ # self.device = device
272
+ # self.model = CnnBlstmMbnet2()
273
+ # self.sample_rate = 16000
274
+ # self.fft_size = 512
275
+ # self.hop_length = 256
276
+ # self.win_length = 512
277
+ self.model = CnnBlstmMbnet2()
278
+ self.sample_rate = self.config.sample_rate
279
+ self.fft_size = self.config.fft_size
280
+ self.hop_length = self.config.hop_length
281
+ self.win_length = self.config.win_length
282
+
283
+ def preprocess_audios(self, audios: List[Any]) -> Tuple[torch.Tensor, torch.Tensor]:
284
+ spectrograms = []
285
+ for audio in audios:
286
+ if isinstance(audio, np.ndarray):
287
+ audio_tensor = torch.from_numpy(audio).float().to(self.device)
288
+ else:
289
+ audio_tensor = audio.float().to(self.device)
290
+ audio_np = audio_tensor.cpu().numpy()
291
+ spec = librosa.stft(audio_np, n_fft=self.fft_size, hop_length=self.hop_length, win_length=self.win_length)
292
+ mag = np.abs(spec).astype(np.float32).T
293
+ mag_tensor = torch.tensor(mag, device=self.device).unsqueeze(0)
294
+ spectrograms.append(mag_tensor)
295
+ max_len = max(spec.shape[1] for spec in spectrograms)
296
+ batch_size, feat_dim = len(spectrograms), spectrograms[0].shape[2]
297
+ padded = torch.zeros(batch_size, 1, max_len, feat_dim, device=self.device)
298
+ masks = torch.zeros(batch_size, max_len, device=self.device)
299
+ for i, spec in enumerate(spectrograms):
300
+ valid_len = spec.shape[1]
301
+ padded[i, :, :valid_len, :] = spec
302
+ masks[i, :valid_len] = 1.0
303
+ return padded, masks
304
+
305
+ def forward(self, audios: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
306
+ outputs, _ = self.model(audios.to(self.device), masks.to(self.device))
307
+ return outputs
308
+
309
+ def predict(self, audios: List[Any]) -> List[float]:
310
+ with torch.no_grad():
311
+ padded, masks = self.preprocess_audios(audios)
312
+ scores = self.forward(padded, masks)
313
+ return scores.squeeze(-1).cpu().tolist()
314
+
315
+
316
+ AutoConfig.register("mosnet", MosNetConfig)
317
+ AutoModel.register(MosNetConfig, MosNet)