Audio Classification
Russian
NikiPshg commited on
Commit
141bc61
·
verified ·
1 Parent(s): 419429d

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +232 -0
model.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
5
+ import torchaudio
6
+ from safetensors import safe_open
7
+ from typing import List, Dict
8
+ import time
9
+
10
+ torch.backends.cuda.matmul.allow_tf32 = True
11
+ torch.backends.cuda.enable_flash_sdp(True)
12
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
13
+ torch.backends.cuda.enable_math_sdp(False)
14
+
15
+
16
+ class WavLMForMusicDetection(nn.Module):
17
+ """
18
+ Music detection model based on WavLM.
19
+ Uses attention pooling + classification head.
20
+ Outputs probability that input audio contains music.
21
+ Supports batched inference with automatic batching and preprocessing.
22
+ EER - 2.5-3 %
23
+ """
24
+ def __init__(
25
+ self,
26
+ base_model_name: str = 'microsoft/wavlm-base-plus',
27
+ batch_size: int = 32,
28
+ device: str = 'cuda'
29
+ ) -> None:
30
+ super().__init__()
31
+ self.config = AutoConfig.from_pretrained(base_model_name)
32
+ self.wavlm = AutoModel.from_pretrained(base_model_name, config=self.config)
33
+ self.processor = AutoFeatureExtractor.from_pretrained(base_model_name)
34
+
35
+ self.batch_size = batch_size
36
+ self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
37
+
38
+ self.target_sample_rate = self.processor.sampling_rate
39
+
40
+ # Attention-based pooling head
41
+ self.pool_attention = nn.Sequential(
42
+ nn.Linear(self.config.hidden_size, 256),
43
+ nn.Tanh(),
44
+ nn.Linear(256, 1)
45
+ )
46
+
47
+ # Classification head
48
+ self.classifier = nn.Sequential(
49
+ nn.Linear(self.config.hidden_size, 256),
50
+ nn.LayerNorm(256),
51
+ nn.GELU(),
52
+ nn.Dropout(0.1),
53
+ nn.Linear(256, 64),
54
+ nn.LayerNorm(64),
55
+ nn.GELU(),
56
+ nn.Linear(64, 1)
57
+ )
58
+
59
+ # to device
60
+ self.to(self.device)
61
+
62
+ def _attention_pool(
63
+ self,
64
+ hidden_states: torch.Tensor,
65
+ attention_mask: torch.Tensor
66
+ ) -> torch.Tensor:
67
+ """
68
+ Apply attention-based pooling over time dimension.
69
+
70
+ Args:
71
+ hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size]
72
+ attention_mask (torch.Tensor): [batch_size, seq_len] — mask to ignore padding
73
+
74
+ Returns:
75
+ torch.Tensor: [batch_size, hidden_size] — context vector
76
+ """
77
+
78
+ attention_weights = self.pool_attention(hidden_states) # [B, T, 1]
79
+ # Mask out padded positions
80
+ attention_weights = attention_weights + (
81
+ (1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
82
+ )
83
+
84
+ attention_weights = F.softmax(attention_weights, dim=1) # [B, T, 1]
85
+
86
+ # Weighted sum over time
87
+ weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) # [B, D]
88
+ return weighted_sum
89
+
90
+ def forward(
91
+ self,
92
+ input_values: torch.Tensor,
93
+ attention_mask: torch.Tensor
94
+ ) -> torch.Tensor:
95
+ """
96
+ Forward pass for inference.
97
+
98
+ Args:
99
+ input_values (torch.Tensor): [batch_size, audio_seq_len] — raw audio waveform
100
+ attention_mask (torch.Tensor): [batch_size, audio_seq_len] — input mask (1 = real, 0 = pad)
101
+
102
+ Returns:
103
+ torch.Tensor: [batch_size, 1] — probability that audio contains music
104
+ """
105
+ assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}"
106
+ assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}"
107
+
108
+ outputs = self.wavlm(input_values.to(self.device), attention_mask=attention_mask.to(self.device))
109
+ hidden_states = outputs.last_hidden_state # [B, T', D]
110
+
111
+ # Align attention mask with downsampled hidden states
112
+ input_length = attention_mask.size(1)
113
+ hidden_length = hidden_states.size(1)
114
+ ratio = input_length / hidden_length
115
+ indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
116
+ attention_mask = attention_mask[:, indices] # [B, T']
117
+ attention_mask = attention_mask.bool()
118
+
119
+ pooled = self._attention_pool(hidden_states, attention_mask)
120
+ logits = self.classifier(pooled) # [B, 1]
121
+
122
+ probs = torch.sigmoid(logits) # [B, 1] → probability of MUSIC
123
+ return probs
124
+
125
+ def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]:
126
+ """
127
+ Split list of audio paths into batches of size `self.batch_size`.
128
+
129
+ Args:
130
+ audio_paths (List[str]): List of paths to audio files.
131
+
132
+ Returns:
133
+ List[List[str]]: List of batches, each batch is a list of paths.
134
+ """
135
+ batches = []
136
+ current_batch = []
137
+ counter = 0
138
+
139
+ while counter < len(audio_paths):
140
+ if len(current_batch) == self.batch_size:
141
+ batches.append(current_batch)
142
+ current_batch = []
143
+ current_batch.append(audio_paths[counter])
144
+ counter += 1
145
+
146
+ if current_batch:
147
+ batches.append(current_batch)
148
+
149
+ return batches
150
+
151
+ def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]:
152
+ """
153
+ Load and preprocess a batch of audio files.
154
+
155
+ Args:
156
+ audio_paths (List[str]): List of file paths.
157
+
158
+ Returns:
159
+ Dict with keys:
160
+ "input_values": tensor [B, T]
161
+ "attention_mask": tensor [B, T]
162
+ """
163
+ waveforms = []
164
+
165
+ for audio_path in audio_paths:
166
+ waveform, sample_rate = torchaudio.load(audio_path)
167
+
168
+ # Resample if needed
169
+ if sample_rate != self.target_sample_rate:
170
+ resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
171
+ waveform = resampler(waveform)
172
+
173
+ # Convert to mono
174
+ if waveform.shape[0] > 1:
175
+ waveform = waveform.mean(dim=0, keepdim=True)
176
+
177
+ waveforms.append(waveform.squeeze())
178
+
179
+ # Extract features
180
+ inputs = self.processor(
181
+ [w.numpy() for w in waveforms],
182
+ sampling_rate=self.target_sample_rate,
183
+ return_tensors="pt",
184
+ padding=True,
185
+ truncation=False
186
+ )
187
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
188
+
189
+ return inputs
190
+
191
+ def predict_proba(self, audio_paths: List[str]) -> torch.Tensor:
192
+ """
193
+ Predict music probability for a list of audio files.
194
+
195
+ Args:
196
+ audio_paths (List[str]): List of audio file paths.
197
+
198
+ Returns:
199
+ torch.Tensor: [N] — probabilities for each audio file.
200
+ """
201
+
202
+ all_probs = []
203
+
204
+ batches = self._prepare_batches(audio_paths)
205
+
206
+ for batch in batches:
207
+ inputs = self._preprocess_audio_batch(batch)
208
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
209
+
210
+ with torch.no_grad():
211
+ probs = self.forward(**inputs).squeeze(-1) # [B]
212
+ all_probs.append(probs)
213
+
214
+ return torch.cat(all_probs, dim=0)
215
+
216
+ if __name__ == "__main__":
217
+ device = 'cuda:0'
218
+ checkpoint_path = './music_detection.safetensors'
219
+ model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=32, device=device)
220
+
221
+ with safe_open(checkpoint_path, framework="pt", device=device) as f:
222
+ state_dict = {key: f.get_tensor(key) for key in f.keys()}
223
+ model.load_state_dict(state_dict)
224
+ global_start = time.time()
225
+ paths = [
226
+ '/92.mp3',
227
+ '133.mp3',
228
+ '113.mp3',
229
+ '30.mp3'
230
+ ]
231
+ print(model.predict_proba(paths))
232
+