|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor |
|
|
import torchaudio |
|
|
from safetensors import safe_open |
|
|
from typing import List, Dict |
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True |
|
|
torch.backends.cuda.enable_flash_sdp(True) |
|
|
torch.backends.cuda.enable_mem_efficient_sdp(True) |
|
|
torch.backends.cuda.enable_math_sdp(False) |
|
|
|
|
|
|
|
|
class WavLMForMusicDetection(nn.Module): |
|
|
""" |
|
|
Music detection model based on WavLM. |
|
|
Uses attention pooling + classification head. |
|
|
Outputs probability that input audio contains music. |
|
|
Supports batched inference with automatic batching and preprocessing. |
|
|
EER - 2.5-3 % |
|
|
""" |
|
|
def __init__( |
|
|
self, |
|
|
base_model_name: str = 'microsoft/wavlm-base-plus', |
|
|
batch_size: int = 32, |
|
|
device: str = 'cuda' |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.config = AutoConfig.from_pretrained(base_model_name) |
|
|
self.wavlm = AutoModel.from_pretrained(base_model_name, config=self.config) |
|
|
self.processor = AutoFeatureExtractor.from_pretrained(base_model_name) |
|
|
|
|
|
self.batch_size = batch_size |
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
self.target_sample_rate = self.processor.sampling_rate |
|
|
|
|
|
|
|
|
self.pool_attention = nn.Sequential( |
|
|
nn.Linear(self.config.hidden_size, 256), |
|
|
nn.Tanh(), |
|
|
nn.Linear(256, 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.classifier = nn.Sequential( |
|
|
nn.Linear(self.config.hidden_size, 256), |
|
|
nn.LayerNorm(256), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(256, 64), |
|
|
nn.LayerNorm(64), |
|
|
nn.GELU(), |
|
|
nn.Linear(64, 1) |
|
|
) |
|
|
|
|
|
|
|
|
self.to(self.device) |
|
|
|
|
|
def _attention_pool( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Apply attention-based pooling over time dimension. |
|
|
Args: |
|
|
hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size] |
|
|
attention_mask (torch.Tensor): [batch_size, seq_len] β mask to ignore padding |
|
|
Returns: |
|
|
torch.Tensor: [batch_size, hidden_size] β context vector |
|
|
""" |
|
|
|
|
|
attention_weights = self.pool_attention(hidden_states) |
|
|
|
|
|
attention_weights = attention_weights + ( |
|
|
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9 |
|
|
) |
|
|
|
|
|
attention_weights = F.softmax(attention_weights, dim=1) |
|
|
|
|
|
|
|
|
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) |
|
|
return weighted_sum |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: torch.Tensor, |
|
|
attention_mask: torch.Tensor |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass for inference. |
|
|
Args: |
|
|
input_values (torch.Tensor): [batch_size, audio_seq_len] β raw audio waveform |
|
|
attention_mask (torch.Tensor): [batch_size, audio_seq_len] β input mask (1 = real, 0 = pad) |
|
|
Returns: |
|
|
torch.Tensor: [batch_size, 1] β probability that audio contains music |
|
|
""" |
|
|
assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}" |
|
|
assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}" |
|
|
|
|
|
|
|
|
input_values = input_values.to(dtype=self.dtype, device=self.device) |
|
|
attention_mask = attention_mask.to(device=self.device, dtype=self.dtype) |
|
|
|
|
|
outputs = self.wavlm(input_values, attention_mask=attention_mask) |
|
|
hidden_states = outputs.last_hidden_state |
|
|
|
|
|
input_length = attention_mask.size(1) |
|
|
hidden_length = hidden_states.size(1) |
|
|
ratio = input_length / hidden_length |
|
|
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long() |
|
|
attention_mask = attention_mask[:, indices] |
|
|
attention_mask = attention_mask.bool() |
|
|
|
|
|
pooled = self._attention_pool(hidden_states, attention_mask) |
|
|
logits = self.classifier(pooled) |
|
|
|
|
|
probs = torch.sigmoid(logits) |
|
|
return probs |
|
|
|
|
|
def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]: |
|
|
""" |
|
|
Split list of audio paths into batches of size `self.batch_size`. |
|
|
Args: |
|
|
audio_paths (List[str]): List of paths to audio files. |
|
|
Returns: |
|
|
List[List[str]]: List of batches, each batch is a list of paths. |
|
|
""" |
|
|
batches = [] |
|
|
current_batch = [] |
|
|
counter = 0 |
|
|
|
|
|
while counter < len(audio_paths): |
|
|
if len(current_batch) == self.batch_size: |
|
|
batches.append(current_batch) |
|
|
current_batch = [] |
|
|
current_batch.append(audio_paths[counter]) |
|
|
counter += 1 |
|
|
|
|
|
if current_batch: |
|
|
batches.append(current_batch) |
|
|
|
|
|
return batches |
|
|
|
|
|
def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Load and preprocess a batch of audio files. |
|
|
Args: |
|
|
audio_paths (List[str]): List of file paths. |
|
|
Returns: |
|
|
Dict with keys: |
|
|
"input_values": tensor [B, T] |
|
|
"attention_mask": tensor [B, T] |
|
|
""" |
|
|
waveforms = [] |
|
|
|
|
|
for audio_path in audio_paths: |
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
|
|
|
|
|
|
|
if sample_rate != self.target_sample_rate: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate) |
|
|
waveform = resampler(waveform) |
|
|
|
|
|
|
|
|
if waveform.shape[0] > 1: |
|
|
waveform = waveform.mean(dim=0, keepdim=True) |
|
|
|
|
|
waveforms.append(waveform.squeeze()) |
|
|
|
|
|
|
|
|
inputs = self.processor( |
|
|
[w.numpy() for w in waveforms], |
|
|
sampling_rate=self.target_sample_rate, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
truncation=False |
|
|
) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
return inputs |
|
|
|
|
|
def predict_proba(self, audio_paths: List[str]) -> torch.Tensor: |
|
|
""" |
|
|
Predict music probability for a list of audio files. |
|
|
Args: |
|
|
audio_paths (List[str]): List of audio file paths. |
|
|
Returns: |
|
|
torch.Tensor: [N] β probabilities for each audio file. |
|
|
""" |
|
|
|
|
|
all_probs = [] |
|
|
|
|
|
batches = self._prepare_batches(audio_paths) |
|
|
|
|
|
for batch in batches: |
|
|
inputs = self._preprocess_audio_batch(batch) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
probs = self.forward(**inputs).squeeze(-1) |
|
|
all_probs.append(probs) |
|
|
|
|
|
return torch.cat(all_probs, dim=0) |
|
|
|
|
|
def convert_to_bf16(self): |
|
|
self.wavlm = self.wavlm.to(torch.bfloat16) |
|
|
self.pool_attention = self.pool_attention.to(torch.bfloat16) |
|
|
self.classifier = self.classifier.to(torch.bfloat16) |
|
|
self.dtype = torch.bfloat16 |
|
|
return self |
|
|
|
|
|
def predict_proba_smart_batching( |
|
|
self, |
|
|
audio_paths: List[str], |
|
|
audio_lengths: List[float] |
|
|
) -> torch.Tensor: |
|
|
|
|
|
assert len(audio_paths) == len(audio_lengths), \ |
|
|
f"Mismatch: {len(audio_paths)} paths vs {len(audio_lengths)} lengths" |
|
|
|
|
|
was_training = self.training |
|
|
self.eval() |
|
|
|
|
|
try: |
|
|
indexed_audios = [ |
|
|
(i, path, length) |
|
|
for i, (path, length) in enumerate(zip(audio_paths, audio_lengths)) |
|
|
] |
|
|
|
|
|
sorted_audios = sorted(indexed_audios, key=lambda x: x[2]) |
|
|
batches = [] |
|
|
for i in range(0, len(sorted_audios), self.batch_size): |
|
|
batch = sorted_audios[i:i + self.batch_size] |
|
|
batches.append(batch) |
|
|
|
|
|
results = {} |
|
|
|
|
|
for batch in batches: |
|
|
batch_paths = [item[1] for item in batch] |
|
|
batch_indices = [item[0] for item in batch] |
|
|
|
|
|
inputs = self._preprocess_audio_batch(batch_paths) |
|
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
|
|
with torch.no_grad(): |
|
|
probs = self.forward(**inputs).squeeze(-1) |
|
|
|
|
|
if probs.dim() == 0: |
|
|
probs = probs.unsqueeze(0) |
|
|
|
|
|
for idx, prob in zip(batch_indices, probs): |
|
|
results[idx] = prob.cpu() |
|
|
|
|
|
all_probs = [results[i] for i in range(len(audio_paths))] |
|
|
return torch.stack(all_probs) |
|
|
finally: |
|
|
if was_training: |
|
|
self.train() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
device = 'cuda:0' |
|
|
checkpoint_path = './music_detection.safetensors' |
|
|
model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=8, device=device) |
|
|
model.convert_to_bf16() |
|
|
model.eval() |
|
|
with safe_open(checkpoint_path, framework="pt", device=device) as f: |
|
|
state_dict = {key: f.get_tensor(key) for key in f.keys()} |
|
|
model.load_state_dict(state_dict) |
|
|
|