Audio Classification
Russian
MusicDetection / model.py
NikiPshg's picture
new batching (#4)
340b6f7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
import torchaudio
from safetensors import safe_open
from typing import List, Dict
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
class WavLMForMusicDetection(nn.Module):
"""
Music detection model based on WavLM.
Uses attention pooling + classification head.
Outputs probability that input audio contains music.
Supports batched inference with automatic batching and preprocessing.
EER - 2.5-3 %
"""
def __init__(
self,
base_model_name: str = 'microsoft/wavlm-base-plus',
batch_size: int = 32,
device: str = 'cuda'
) -> None:
super().__init__()
self.config = AutoConfig.from_pretrained(base_model_name)
self.wavlm = AutoModel.from_pretrained(base_model_name, config=self.config)
self.processor = AutoFeatureExtractor.from_pretrained(base_model_name)
self.batch_size = batch_size
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.target_sample_rate = self.processor.sampling_rate
# Attention-based pooling head
self.pool_attention = nn.Sequential(
nn.Linear(self.config.hidden_size, 256),
nn.Tanh(),
nn.Linear(256, 1)
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(self.config.hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 64),
nn.LayerNorm(64),
nn.GELU(),
nn.Linear(64, 1)
)
# to device
self.to(self.device)
def _attention_pool(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Apply attention-based pooling over time dimension.
Args:
hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size]
attention_mask (torch.Tensor): [batch_size, seq_len] β€” mask to ignore padding
Returns:
torch.Tensor: [batch_size, hidden_size] β€” context vector
"""
attention_weights = self.pool_attention(hidden_states) # [B, T, 1]
# Mask out padded positions
attention_weights = attention_weights + (
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
)
attention_weights = F.softmax(attention_weights, dim=1) # [B, T, 1]
# Weighted sum over time
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) # [B, D]
return weighted_sum
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Forward pass for inference.
Args:
input_values (torch.Tensor): [batch_size, audio_seq_len] β€” raw audio waveform
attention_mask (torch.Tensor): [batch_size, audio_seq_len] β€” input mask (1 = real, 0 = pad)
Returns:
torch.Tensor: [batch_size, 1] β€” probability that audio contains music
"""
assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}"
assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}"
input_values = input_values.to(dtype=self.dtype, device=self.device)
attention_mask = attention_mask.to(device=self.device, dtype=self.dtype)
outputs = self.wavlm(input_values, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # [B, T', D]
input_length = attention_mask.size(1)
hidden_length = hidden_states.size(1)
ratio = input_length / hidden_length
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
attention_mask = attention_mask[:, indices] # [B, T']
attention_mask = attention_mask.bool()
pooled = self._attention_pool(hidden_states, attention_mask)
logits = self.classifier(pooled) # [B, 1]
probs = torch.sigmoid(logits) # [B, 1] β†’ probability of MUSIC
return probs
def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]:
"""
Split list of audio paths into batches of size `self.batch_size`.
Args:
audio_paths (List[str]): List of paths to audio files.
Returns:
List[List[str]]: List of batches, each batch is a list of paths.
"""
batches = []
current_batch = []
counter = 0
while counter < len(audio_paths):
if len(current_batch) == self.batch_size:
batches.append(current_batch)
current_batch = []
current_batch.append(audio_paths[counter])
counter += 1
if current_batch:
batches.append(current_batch)
return batches
def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]:
"""
Load and preprocess a batch of audio files.
Args:
audio_paths (List[str]): List of file paths.
Returns:
Dict with keys:
"input_values": tensor [B, T]
"attention_mask": tensor [B, T]
"""
waveforms = []
for audio_path in audio_paths:
waveform, sample_rate = torchaudio.load(audio_path)
# Resample if needed
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
waveform = resampler(waveform)
# Convert to mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveforms.append(waveform.squeeze())
# Extract features
inputs = self.processor(
[w.numpy() for w in waveforms],
sampling_rate=self.target_sample_rate,
return_tensors="pt",
padding=True,
truncation=False
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs
def predict_proba(self, audio_paths: List[str]) -> torch.Tensor:
"""
Predict music probability for a list of audio files.
Args:
audio_paths (List[str]): List of audio file paths.
Returns:
torch.Tensor: [N] β€” probabilities for each audio file.
"""
all_probs = []
batches = self._prepare_batches(audio_paths)
for batch in batches:
inputs = self._preprocess_audio_batch(batch)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
probs = self.forward(**inputs).squeeze(-1) # [B]
all_probs.append(probs)
return torch.cat(all_probs, dim=0)
def convert_to_bf16(self):
self.wavlm = self.wavlm.to(torch.bfloat16)
self.pool_attention = self.pool_attention.to(torch.bfloat16)
self.classifier = self.classifier.to(torch.bfloat16)
self.dtype = torch.bfloat16
return self
def predict_proba_smart_batching(
self,
audio_paths: List[str],
audio_lengths: List[float]
) -> torch.Tensor:
assert len(audio_paths) == len(audio_lengths), \
f"Mismatch: {len(audio_paths)} paths vs {len(audio_lengths)} lengths"
was_training = self.training
self.eval()
try:
indexed_audios = [
(i, path, length)
for i, (path, length) in enumerate(zip(audio_paths, audio_lengths))
]
sorted_audios = sorted(indexed_audios, key=lambda x: x[2])
batches = []
for i in range(0, len(sorted_audios), self.batch_size):
batch = sorted_audios[i:i + self.batch_size]
batches.append(batch)
results = {}
for batch in batches:
batch_paths = [item[1] for item in batch]
batch_indices = [item[0] for item in batch]
inputs = self._preprocess_audio_batch(batch_paths)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
probs = self.forward(**inputs).squeeze(-1)
if probs.dim() == 0:
probs = probs.unsqueeze(0)
for idx, prob in zip(batch_indices, probs):
results[idx] = prob.cpu()
all_probs = [results[i] for i in range(len(audio_paths))]
return torch.stack(all_probs)
finally:
if was_training:
self.train()
if __name__ == "__main__":
device = 'cuda:0'
checkpoint_path = './music_detection.safetensors'
model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=8, device=device)
model.convert_to_bf16()
model.eval()
with safe_open(checkpoint_path, framework="pt", device=device) as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
model.load_state_dict(state_dict)