Audio Classification
Russian
Files changed (1) hide show
  1. model.py +63 -24
model.py CHANGED
@@ -5,7 +5,6 @@ from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
5
  import torchaudio
6
  from safetensors import safe_open
7
  from typing import List, Dict
8
- import time
9
 
10
  torch.backends.cuda.matmul.allow_tf32 = True
11
  torch.backends.cuda.enable_flash_sdp(True)
@@ -66,11 +65,9 @@ class WavLMForMusicDetection(nn.Module):
66
  ) -> torch.Tensor:
67
  """
68
  Apply attention-based pooling over time dimension.
69
-
70
  Args:
71
  hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size]
72
  attention_mask (torch.Tensor): [batch_size, seq_len] β€” mask to ignore padding
73
-
74
  Returns:
75
  torch.Tensor: [batch_size, hidden_size] β€” context vector
76
  """
@@ -94,21 +91,22 @@ class WavLMForMusicDetection(nn.Module):
94
  ) -> torch.Tensor:
95
  """
96
  Forward pass for inference.
97
-
98
  Args:
99
  input_values (torch.Tensor): [batch_size, audio_seq_len] β€” raw audio waveform
100
  attention_mask (torch.Tensor): [batch_size, audio_seq_len] β€” input mask (1 = real, 0 = pad)
101
-
102
  Returns:
103
  torch.Tensor: [batch_size, 1] β€” probability that audio contains music
104
  """
105
  assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}"
106
  assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}"
107
 
108
- outputs = self.wavlm(input_values.to(self.device), attention_mask=attention_mask.to(self.device))
 
 
 
 
109
  hidden_states = outputs.last_hidden_state # [B, T', D]
110
 
111
- # Align attention mask with downsampled hidden states
112
  input_length = attention_mask.size(1)
113
  hidden_length = hidden_states.size(1)
114
  ratio = input_length / hidden_length
@@ -125,10 +123,8 @@ class WavLMForMusicDetection(nn.Module):
125
  def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]:
126
  """
127
  Split list of audio paths into batches of size `self.batch_size`.
128
-
129
  Args:
130
  audio_paths (List[str]): List of paths to audio files.
131
-
132
  Returns:
133
  List[List[str]]: List of batches, each batch is a list of paths.
134
  """
@@ -151,10 +147,8 @@ class WavLMForMusicDetection(nn.Module):
151
  def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]:
152
  """
153
  Load and preprocess a batch of audio files.
154
-
155
  Args:
156
  audio_paths (List[str]): List of file paths.
157
-
158
  Returns:
159
  Dict with keys:
160
  "input_values": tensor [B, T]
@@ -191,10 +185,8 @@ class WavLMForMusicDetection(nn.Module):
191
  def predict_proba(self, audio_paths: List[str]) -> torch.Tensor:
192
  """
193
  Predict music probability for a list of audio files.
194
-
195
  Args:
196
  audio_paths (List[str]): List of audio file paths.
197
-
198
  Returns:
199
  torch.Tensor: [N] β€” probabilities for each audio file.
200
  """
@@ -212,21 +204,68 @@ class WavLMForMusicDetection(nn.Module):
212
  all_probs.append(probs)
213
 
214
  return torch.cat(all_probs, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  if __name__ == "__main__":
217
  device = 'cuda:0'
218
  checkpoint_path = './music_detection.safetensors'
219
- model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=32, device=device)
220
-
 
221
  with safe_open(checkpoint_path, framework="pt", device=device) as f:
222
  state_dict = {key: f.get_tensor(key) for key in f.keys()}
223
  model.load_state_dict(state_dict)
224
- global_start = time.time()
225
- paths = [
226
- '/92.mp3',
227
- '133.mp3',
228
- '113.mp3',
229
- '30.mp3'
230
- ]
231
- print(model.predict_proba(paths))
232
-
 
5
  import torchaudio
6
  from safetensors import safe_open
7
  from typing import List, Dict
 
8
 
9
  torch.backends.cuda.matmul.allow_tf32 = True
10
  torch.backends.cuda.enable_flash_sdp(True)
 
65
  ) -> torch.Tensor:
66
  """
67
  Apply attention-based pooling over time dimension.
 
68
  Args:
69
  hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size]
70
  attention_mask (torch.Tensor): [batch_size, seq_len] β€” mask to ignore padding
 
71
  Returns:
72
  torch.Tensor: [batch_size, hidden_size] β€” context vector
73
  """
 
91
  ) -> torch.Tensor:
92
  """
93
  Forward pass for inference.
 
94
  Args:
95
  input_values (torch.Tensor): [batch_size, audio_seq_len] β€” raw audio waveform
96
  attention_mask (torch.Tensor): [batch_size, audio_seq_len] β€” input mask (1 = real, 0 = pad)
 
97
  Returns:
98
  torch.Tensor: [batch_size, 1] β€” probability that audio contains music
99
  """
100
  assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}"
101
  assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}"
102
 
103
+
104
+ input_values = input_values.to(dtype=self.dtype, device=self.device)
105
+ attention_mask = attention_mask.to(device=self.device, dtype=self.dtype)
106
+
107
+ outputs = self.wavlm(input_values, attention_mask=attention_mask)
108
  hidden_states = outputs.last_hidden_state # [B, T', D]
109
 
 
110
  input_length = attention_mask.size(1)
111
  hidden_length = hidden_states.size(1)
112
  ratio = input_length / hidden_length
 
123
  def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]:
124
  """
125
  Split list of audio paths into batches of size `self.batch_size`.
 
126
  Args:
127
  audio_paths (List[str]): List of paths to audio files.
 
128
  Returns:
129
  List[List[str]]: List of batches, each batch is a list of paths.
130
  """
 
147
  def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]:
148
  """
149
  Load and preprocess a batch of audio files.
 
150
  Args:
151
  audio_paths (List[str]): List of file paths.
 
152
  Returns:
153
  Dict with keys:
154
  "input_values": tensor [B, T]
 
185
  def predict_proba(self, audio_paths: List[str]) -> torch.Tensor:
186
  """
187
  Predict music probability for a list of audio files.
 
188
  Args:
189
  audio_paths (List[str]): List of audio file paths.
 
190
  Returns:
191
  torch.Tensor: [N] β€” probabilities for each audio file.
192
  """
 
204
  all_probs.append(probs)
205
 
206
  return torch.cat(all_probs, dim=0)
207
+
208
+ def convert_to_bf16(self):
209
+ self.wavlm = self.wavlm.to(torch.bfloat16)
210
+ self.pool_attention = self.pool_attention.to(torch.bfloat16)
211
+ self.classifier = self.classifier.to(torch.bfloat16)
212
+ self.dtype = torch.bfloat16
213
+ return self
214
+
215
+ def predict_proba_smart_batching(
216
+ self,
217
+ audio_paths: List[str],
218
+ audio_lengths: List[float]
219
+ ) -> torch.Tensor:
220
+
221
+ assert len(audio_paths) == len(audio_lengths), \
222
+ f"Mismatch: {len(audio_paths)} paths vs {len(audio_lengths)} lengths"
223
+
224
+ was_training = self.training
225
+ self.eval()
226
+
227
+ try:
228
+ indexed_audios = [
229
+ (i, path, length)
230
+ for i, (path, length) in enumerate(zip(audio_paths, audio_lengths))
231
+ ]
232
+
233
+ sorted_audios = sorted(indexed_audios, key=lambda x: x[2])
234
+ batches = []
235
+ for i in range(0, len(sorted_audios), self.batch_size):
236
+ batch = sorted_audios[i:i + self.batch_size]
237
+ batches.append(batch)
238
+
239
+ results = {}
240
+
241
+ for batch in batches:
242
+ batch_paths = [item[1] for item in batch]
243
+ batch_indices = [item[0] for item in batch]
244
+
245
+ inputs = self._preprocess_audio_batch(batch_paths)
246
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
247
+
248
+ with torch.no_grad():
249
+ probs = self.forward(**inputs).squeeze(-1)
250
+
251
+ if probs.dim() == 0:
252
+ probs = probs.unsqueeze(0)
253
+
254
+ for idx, prob in zip(batch_indices, probs):
255
+ results[idx] = prob.cpu()
256
+
257
+ all_probs = [results[i] for i in range(len(audio_paths))]
258
+ return torch.stack(all_probs)
259
+ finally:
260
+ if was_training:
261
+ self.train()
262
 
263
  if __name__ == "__main__":
264
  device = 'cuda:0'
265
  checkpoint_path = './music_detection.safetensors'
266
+ model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=8, device=device)
267
+ model.convert_to_bf16()
268
+ model.eval()
269
  with safe_open(checkpoint_path, framework="pt", device=device) as f:
270
  state_dict = {key: f.get_tensor(key) for key in f.keys()}
271
  model.load_state_dict(state_dict)