1NEYRON1 commited on
Commit
5c52273
·
verified ·
1 Parent(s): bd9de4f

Create modeling_mosnet.py

Browse files
Files changed (1) hide show
  1. modeling_mosnet.py +318 -0
modeling_mosnet.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple
2
+
3
+ from einops import rearrange
4
+ import librosa
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as f
8
+ from torch import nn
9
+
10
+ from transformers import BertModel, BertTokenizer, PreTrainedModel
11
+ from .configuration_mosnet import MosNetConfig
12
+ from transformers import AutoConfig, AutoModel
13
+
14
+
15
+ class TimeDistributed(nn.Module):
16
+ def __init__(self, module: nn.Module, batch_first: bool) -> None:
17
+ super().__init__()
18
+ self.module = module
19
+ self.batch_first = batch_first
20
+
21
+ def forward(self, input_seq: torch.Tensor) -> torch.Tensor:
22
+ assert len(input_seq.size()) > 2
23
+ reshaped_input = input_seq.contiguous().view(-1, input_seq.size(-1))
24
+ output = self.module(reshaped_input)
25
+ if self.batch_first:
26
+ output = output.contiguous().view(input_seq.size(0), -1, output.size(-1))
27
+ else:
28
+ output = output.contiguous().view(-1, input_seq.size(1), output.size(-1))
29
+ return output
30
+
31
+
32
+ class CnnBlstmMbnet2(nn.Module):
33
+ def __init__(self, dropout: float = 0.3) -> None:
34
+ super().__init__()
35
+ self.conv1 = nn.Sequential(
36
+ nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1),
37
+ nn.ReLU(),
38
+ nn.Conv2d(16, 16, (3, 3), (1, 1), 1),
39
+ nn.ReLU(),
40
+ nn.Conv2d(16, 16, (3, 3), (1, 3), 1),
41
+ nn.ReLU(),
42
+ nn.BatchNorm2d(16),
43
+ nn.Dropout(dropout),
44
+ )
45
+ self.conv2 = nn.Sequential(
46
+ nn.Conv2d(16, 32, (3, 3), (1, 1), 1),
47
+ nn.ReLU(),
48
+ nn.Conv2d(32, 32, (3, 3), (1, 1), 1),
49
+ nn.ReLU(),
50
+ nn.Conv2d(32, 32, (3, 3), (1, 3), 1),
51
+ nn.ReLU(),
52
+ nn.BatchNorm2d(32),
53
+ nn.Dropout(dropout),
54
+ )
55
+ self.conv3 = nn.Sequential(
56
+ nn.Conv2d(32, 64, (3, 3), (1, 1), 1),
57
+ nn.ReLU(),
58
+ nn.Conv2d(64, 64, (3, 3), (1, 1), 1),
59
+ nn.ReLU(),
60
+ nn.Conv2d(64, 64, (3, 3), (1, 3), 1),
61
+ nn.ReLU(),
62
+ nn.BatchNorm2d(64),
63
+ nn.Dropout(dropout),
64
+ )
65
+ self.conv4 = nn.Sequential(
66
+ nn.Conv2d(64, 128, (3, 3), (1, 1), 1),
67
+ nn.ReLU(),
68
+ nn.Conv2d(128, 128, (3, 3), (1, 1), 1),
69
+ nn.ReLU(),
70
+ nn.Conv2d(128, 128, (3, 3), (1, 3), 1),
71
+ nn.ReLU(),
72
+ nn.BatchNorm2d(128),
73
+ nn.Dropout(dropout),
74
+ )
75
+ self.blstm1 = nn.LSTM(512, 128, bidirectional=True, batch_first=True)
76
+ self.droupout = nn.Dropout(dropout)
77
+ self.flatten = TimeDistributed(nn.Flatten(), batch_first=True)
78
+ self.dense1 = nn.Sequential(
79
+ TimeDistributed(
80
+ nn.Sequential(
81
+ nn.Linear(256, 128),
82
+ nn.ReLU(),
83
+ ),
84
+ batch_first=True,
85
+ ),
86
+ nn.Dropout(dropout),
87
+ )
88
+ self.frame_layer = TimeDistributed(nn.Linear(128, 1), batch_first=True)
89
+ self.average_layer = nn.AdaptiveAvgPool1d(1)
90
+
91
+ def forward(self, forward_input: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ conv1_output = self.conv1(forward_input)
93
+ conv2_output = self.conv2(conv1_output)
94
+ conv3_output = self.conv3(conv2_output)
95
+ conv4_output = self.conv4(conv3_output)
96
+ conv4_output = conv4_output.permute(0, 2, 1, 3)
97
+ conv4_output = torch.reshape(conv4_output, (conv4_output.shape[0], conv4_output.shape[1], 4 * 128))
98
+ blstm_output, _ = self.blstm1(conv4_output)
99
+ blstm_output = self.droupout(blstm_output)
100
+ flatten_output = self.flatten(blstm_output)
101
+ fc_output = self.dense1(flatten_output)
102
+ frame_score = self.frame_layer(fc_output)
103
+ frame_score = frame_score.squeeze(-1) * mask
104
+ valid_sum = torch.sum(frame_score, dim=1)
105
+ valid_count = torch.sum(mask, dim=1)
106
+ avg_score = valid_sum / (valid_count + 1e-8)
107
+ return avg_score.unsqueeze(-1), frame_score
108
+
109
+
110
+ class SwiGLU(nn.Module):
111
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
112
+ x_, gate = x.chunk(2, dim=-1)
113
+ return f.silu(gate) * x_
114
+
115
+
116
+ class RotaryEmbedding(nn.Module):
117
+ def __init__(self, dim: int, scale_base: int = 512, use_xpos: bool = True) -> None:
118
+ super().__init__()
119
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
120
+ self.register_buffer("inv_freq", inv_freq)
121
+ self.use_xpos = use_xpos
122
+ self.scale_base = scale_base
123
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
124
+ self.register_buffer('scale', scale)
125
+
126
+ def forward(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
127
+ t = torch.arange(seq_len, device=device).type_as(self.inv_freq)
128
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
129
+ freqs = torch.cat((freqs, freqs), dim=-1)
130
+ if not self.use_xpos:
131
+ return freqs, torch.ones(1, device=device)
132
+ power = (t - (seq_len // 2)) / self.scale_base
133
+ scale = self.scale ** rearrange(power, 'n -> n 1')
134
+ scale = torch.cat((scale, scale), dim=-1)
135
+ return freqs, scale
136
+
137
+
138
+ def rotate_half(x: torch.Tensor) -> torch.Tensor:
139
+ x1, x2 = x.chunk(2, dim=-1)
140
+ return torch.cat((-x2, x1), dim=-1)
141
+
142
+
143
+ def apply_rotary_pos_emb(pos: torch.Tensor, t: torch.Tensor, scale: float = 1.) -> torch.Tensor:
144
+ return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
145
+
146
+
147
+ def l2norm(t: torch.Tensor) -> torch.Tensor:
148
+ return f.normalize(t, dim=-1)
149
+
150
+
151
+ class TransformerBlock(nn.Module):
152
+ def __init__(self, dim_head: int = 64, heads: int = 8, dropout: float = 0.2, forward_expansion: int = 2, device: str = "cpu") -> None:
153
+ super().__init__()
154
+ self.heads = heads
155
+ self.dim_head = dim_head
156
+ self.embed_dim = heads * dim_head
157
+ self.device = device
158
+
159
+ self.qkv = nn.Linear(dim_head * heads, dim_head * heads * 3)
160
+ self.q_scale = nn.Parameter(torch.ones(dim_head))
161
+ self.k_scale = nn.Parameter(torch.ones(dim_head))
162
+ self.rotary_emb = RotaryEmbedding(dim_head)
163
+ self.norm = nn.LayerNorm(dim_head * heads)
164
+ self.feed_forward = nn.Sequential(
165
+ nn.Linear(dim_head * heads, forward_expansion * dim_head * heads * 2), # *2 для SwiGLU
166
+ SwiGLU(),
167
+ nn.Dropout(dropout),
168
+ nn.Linear(forward_expansion * dim_head * heads, dim_head * heads),
169
+ )
170
+
171
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
172
+ n, seq_length, _ = q.shape
173
+ qkv_proj = self.qkv(q)
174
+ qkv_proj = qkv_proj.reshape(n, seq_length, self.heads, 3 * self.dim_head)
175
+ qkv = qkv_proj.permute(0, 2, 1, 3)
176
+ q_, k_, v_ = qkv.chunk(3, dim=-1)
177
+ q_, k_ = map(l2norm, (q_, k_))
178
+ q_ = q_ * self.q_scale
179
+ k_ = k_ * self.k_scale
180
+ positions, scale = self.rotary_emb(seq_length, self.device)
181
+ q_ = apply_rotary_pos_emb(positions, q_, scale)
182
+ k_ = apply_rotary_pos_emb(positions, k_, scale ** -1)
183
+ attn_output = f.scaled_dot_product_attention(q_, k_, v_)
184
+ attn_output = attn_output.permute(0, 2, 1, 3).reshape(n, seq_length, self.embed_dim)
185
+ attn_output = self.norm(attn_output)
186
+ forward_output = self.feed_forward(attn_output)
187
+ return attn_output + forward_output
188
+
189
+
190
+ class AudioFeatureExtractor(nn.Module):
191
+ def __init__(self) -> None:
192
+ super().__init__()
193
+ self.conv1 = nn.Sequential(
194
+ nn.Conv2d(1, 16, (3, 3), (1, 1), padding=1), nn.ReLU(),
195
+ nn.Conv2d(16, 16, (3, 3), (1, 1), padding=1), nn.ReLU(),
196
+ nn.Conv2d(16, 16, (3, 3), (1, 3), padding=1), nn.ReLU()
197
+ )
198
+ self.conv2 = nn.Sequential(
199
+ nn.Conv2d(16, 32, (3, 3), (1, 1), padding=1), nn.ReLU(),
200
+ nn.Conv2d(32, 32, (3, 3), (1, 1), padding=1), nn.ReLU(),
201
+ nn.Conv2d(32, 32, (3, 3), (1, 3), padding=1), nn.ReLU()
202
+ )
203
+ self.conv3 = nn.Sequential(
204
+ nn.Conv2d(32, 64, (3, 3), (1, 1), padding=1), nn.ReLU(),
205
+ nn.Conv2d(64, 64, (3, 3), (1, 1), padding=1), nn.ReLU(),
206
+ nn.Conv2d(64, 64, (3, 3), (1, 3), padding=1), nn.ReLU()
207
+ )
208
+ self.conv4 = nn.Sequential(
209
+ nn.Conv2d(64, 128, (3, 3), (1, 1), padding=1), nn.ReLU(),
210
+ nn.Conv2d(128, 128, (3, 3), (1, 1), padding=1), nn.ReLU(),
211
+ nn.Conv2d(128, 128, (3, 3), (1, 3), padding=1), nn.ReLU()
212
+ )
213
+
214
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
215
+ x = self.conv1(x)
216
+ x = self.conv2(x)
217
+ x = self.conv3(x)
218
+ x = self.conv4(x)
219
+ x = x.permute(0, 2, 1, 3)
220
+ x = torch.reshape(x, (x.shape[0], x.shape[1], -1))
221
+ return x
222
+
223
+
224
+ class MultiModalMosNet(PreTrainedModel):
225
+ config_class = MosNetConfig
226
+
227
+ def __init__(self, config: MosNetConfig) -> None:
228
+ # self.device = device
229
+ # self.model = CnnBlstmMbnet2()
230
+ # self.sample_rate = 16000
231
+ # self.fft_size = 512
232
+ # self.hop_length = 256
233
+ # self.win_length = 512
234
+ super().__init__(config)
235
+ self.config = config
236
+ self.sample_rate = self.config.sample_rate
237
+ self.fft_size = self.config.fft_size
238
+ self.hop_length = self.config.hop_length
239
+ self.win_length = self.config.win_length
240
+ self.dropout = self.config.dropout
241
+
242
+ self.audio_extractor = AudioFeatureExtractor()
243
+
244
+ self.text_projection = nn.Linear(768, self.win_length)
245
+ # передаём device внутрь TransformerBlock
246
+ self.cross_attention = TransformerBlock(dim_head=64, heads=8, device=device)
247
+
248
+ self.fc1 = nn.Sequential(
249
+ nn.Linear(self.fft_size, 128),
250
+ nn.ReLU(),
251
+ nn.Dropout(self.dropout),
252
+ )
253
+ self.frame_layer = nn.Linear(128, 1)
254
+ self.average_layer = nn.AdaptiveAvgPool1d(1)
255
+
256
+ def forward(
257
+ self,
258
+ audio_input: torch.Tensor,
259
+ text_embeddings: torch.Tensor,
260
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
261
+ """audio_input shape: (B, 1, T, F)
262
+ text_embeddings shape: (B, 768)
263
+ """
264
+ # ↳ Audio branch
265
+ audio_features = self.audio_extractor(audio_input) # (B, T, 512)
266
+
267
+ # ↳ Text branch
268
+ text_proj = self.text_projection(text_embeddings) # (B, 512)
269
+ text_proj = text_proj.unsqueeze(1) # (B, 1, 512)
270
+
271
+ # Cross-attention
272
+ cross_out = self.cross_attention(audio_features, text_proj, text_proj) # (B, T, 512)
273
+
274
+ # Head
275
+ fc_out = self.fc1(cross_out) # (B, T, 128)
276
+ frame_score = self.frame_layer(fc_out) # (B, T, 1)
277
+
278
+ # aggregate
279
+ avg_score = self.average_layer(frame_score.permute(0, 2, 1)) # (B, 1, 1)
280
+ return avg_score.reshape(avg_score.size(0), -1)
281
+
282
+ def preprocess_audio(self, audios: List[np.ndarray]) -> torch.Tensor:
283
+ tensors = []
284
+ for audio in audios:
285
+ y = torch.tensor(audio, dtype=torch.float32, device=self.device)
286
+ spec = torch.stft(y, n_fft=512, hop_length=256, win_length=512, return_complex=False)
287
+ mag = torch.sqrt(spec[..., 0] ** 2 + spec[..., 1] ** 2)
288
+ mag = mag.permute(1, 0).unsqueeze(0)
289
+ tensors.append(mag)
290
+ max_len = max(t.shape[1] for t in tensors)
291
+ padded = torch.zeros(len(tensors), 1, max_len, tensors[0].shape[2], device=self.device)
292
+ for i, t in enumerate(tensors):
293
+ padded[i, :, :t.shape[1], :] = t
294
+ return padded
295
+
296
+ def preprocess_text(self, texts: List[str]) -> dict:
297
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
298
+ with torch.no_grad():
299
+ inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
300
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
301
+ return inputs
302
+
303
+ def predict(self, audios: List[np.ndarray], texts: List[str] = None) -> List[float]:
304
+ with torch.no_grad():
305
+ audios_tensor = self.preprocess_audio(audios)
306
+ inputs = self.preprocess_text(texts)
307
+ model = BertModel.from_pretrained("bert-base-uncased").to(self.device)
308
+ model.eval()
309
+ outputs = model(**inputs)
310
+ texts_tensor = outputs.last_hidden_state[:, 0, :]
311
+ preds = self.forward(audios_tensor, texts_tensor)
312
+ result = preds.squeeze().cpu().tolist()
313
+ if isinstance(result, float):
314
+ return [result]
315
+ return [float(x) for x in result]
316
+
317
+ AutoConfig.register("mosnet", MosNetConfig)
318
+ AutoModel.register(MosNetConfig, MultiModalMosNet)