File size: 9,648 Bytes
141bc61 340b6f7 141bc61 340b6f7 141bc61 340b6f7 141bc61 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoConfig, AutoFeatureExtractor
import torchaudio
from safetensors import safe_open
from typing import List, Dict
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
class WavLMForMusicDetection(nn.Module):
"""
Music detection model based on WavLM.
Uses attention pooling + classification head.
Outputs probability that input audio contains music.
Supports batched inference with automatic batching and preprocessing.
EER - 2.5-3 %
"""
def __init__(
self,
base_model_name: str = 'microsoft/wavlm-base-plus',
batch_size: int = 32,
device: str = 'cuda'
) -> None:
super().__init__()
self.config = AutoConfig.from_pretrained(base_model_name)
self.wavlm = AutoModel.from_pretrained(base_model_name, config=self.config)
self.processor = AutoFeatureExtractor.from_pretrained(base_model_name)
self.batch_size = batch_size
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.target_sample_rate = self.processor.sampling_rate
# Attention-based pooling head
self.pool_attention = nn.Sequential(
nn.Linear(self.config.hidden_size, 256),
nn.Tanh(),
nn.Linear(256, 1)
)
# Classification head
self.classifier = nn.Sequential(
nn.Linear(self.config.hidden_size, 256),
nn.LayerNorm(256),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(256, 64),
nn.LayerNorm(64),
nn.GELU(),
nn.Linear(64, 1)
)
# to device
self.to(self.device)
def _attention_pool(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Apply attention-based pooling over time dimension.
Args:
hidden_states (torch.Tensor): [batch_size, seq_len, hidden_size]
attention_mask (torch.Tensor): [batch_size, seq_len] β mask to ignore padding
Returns:
torch.Tensor: [batch_size, hidden_size] β context vector
"""
attention_weights = self.pool_attention(hidden_states) # [B, T, 1]
# Mask out padded positions
attention_weights = attention_weights + (
(1.0 - attention_mask.unsqueeze(-1).to(attention_weights.dtype)) * -1e9
)
attention_weights = F.softmax(attention_weights, dim=1) # [B, T, 1]
# Weighted sum over time
weighted_sum = torch.sum(hidden_states * attention_weights, dim=1) # [B, D]
return weighted_sum
def forward(
self,
input_values: torch.Tensor,
attention_mask: torch.Tensor
) -> torch.Tensor:
"""
Forward pass for inference.
Args:
input_values (torch.Tensor): [batch_size, audio_seq_len] β raw audio waveform
attention_mask (torch.Tensor): [batch_size, audio_seq_len] β input mask (1 = real, 0 = pad)
Returns:
torch.Tensor: [batch_size, 1] β probability that audio contains music
"""
assert isinstance(input_values, torch.Tensor), f"Expected torch.Tensor, got {type(input_values)}"
assert isinstance(attention_mask, torch.Tensor), f"Expected torch.Tensor, got {type(attention_mask)}"
input_values = input_values.to(dtype=self.dtype, device=self.device)
attention_mask = attention_mask.to(device=self.device, dtype=self.dtype)
outputs = self.wavlm(input_values, attention_mask=attention_mask)
hidden_states = outputs.last_hidden_state # [B, T', D]
input_length = attention_mask.size(1)
hidden_length = hidden_states.size(1)
ratio = input_length / hidden_length
indices = (torch.arange(hidden_length, device=attention_mask.device) * ratio).long()
attention_mask = attention_mask[:, indices] # [B, T']
attention_mask = attention_mask.bool()
pooled = self._attention_pool(hidden_states, attention_mask)
logits = self.classifier(pooled) # [B, 1]
probs = torch.sigmoid(logits) # [B, 1] β probability of MUSIC
return probs
def _prepare_batches(self, audio_paths: List[str]) -> List[List[str]]:
"""
Split list of audio paths into batches of size `self.batch_size`.
Args:
audio_paths (List[str]): List of paths to audio files.
Returns:
List[List[str]]: List of batches, each batch is a list of paths.
"""
batches = []
current_batch = []
counter = 0
while counter < len(audio_paths):
if len(current_batch) == self.batch_size:
batches.append(current_batch)
current_batch = []
current_batch.append(audio_paths[counter])
counter += 1
if current_batch:
batches.append(current_batch)
return batches
def _preprocess_audio_batch(self, audio_paths: List[str]) -> Dict[str, torch.Tensor]:
"""
Load and preprocess a batch of audio files.
Args:
audio_paths (List[str]): List of file paths.
Returns:
Dict with keys:
"input_values": tensor [B, T]
"attention_mask": tensor [B, T]
"""
waveforms = []
for audio_path in audio_paths:
waveform, sample_rate = torchaudio.load(audio_path)
# Resample if needed
if sample_rate != self.target_sample_rate:
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=self.target_sample_rate)
waveform = resampler(waveform)
# Convert to mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
waveforms.append(waveform.squeeze())
# Extract features
inputs = self.processor(
[w.numpy() for w in waveforms],
sampling_rate=self.target_sample_rate,
return_tensors="pt",
padding=True,
truncation=False
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
return inputs
def predict_proba(self, audio_paths: List[str]) -> torch.Tensor:
"""
Predict music probability for a list of audio files.
Args:
audio_paths (List[str]): List of audio file paths.
Returns:
torch.Tensor: [N] β probabilities for each audio file.
"""
all_probs = []
batches = self._prepare_batches(audio_paths)
for batch in batches:
inputs = self._preprocess_audio_batch(batch)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
probs = self.forward(**inputs).squeeze(-1) # [B]
all_probs.append(probs)
return torch.cat(all_probs, dim=0)
def convert_to_bf16(self):
self.wavlm = self.wavlm.to(torch.bfloat16)
self.pool_attention = self.pool_attention.to(torch.bfloat16)
self.classifier = self.classifier.to(torch.bfloat16)
self.dtype = torch.bfloat16
return self
def predict_proba_smart_batching(
self,
audio_paths: List[str],
audio_lengths: List[float]
) -> torch.Tensor:
assert len(audio_paths) == len(audio_lengths), \
f"Mismatch: {len(audio_paths)} paths vs {len(audio_lengths)} lengths"
was_training = self.training
self.eval()
try:
indexed_audios = [
(i, path, length)
for i, (path, length) in enumerate(zip(audio_paths, audio_lengths))
]
sorted_audios = sorted(indexed_audios, key=lambda x: x[2])
batches = []
for i in range(0, len(sorted_audios), self.batch_size):
batch = sorted_audios[i:i + self.batch_size]
batches.append(batch)
results = {}
for batch in batches:
batch_paths = [item[1] for item in batch]
batch_indices = [item[0] for item in batch]
inputs = self._preprocess_audio_batch(batch_paths)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
with torch.no_grad():
probs = self.forward(**inputs).squeeze(-1)
if probs.dim() == 0:
probs = probs.unsqueeze(0)
for idx, prob in zip(batch_indices, probs):
results[idx] = prob.cpu()
all_probs = [results[i] for i in range(len(audio_paths))]
return torch.stack(all_probs)
finally:
if was_training:
self.train()
if __name__ == "__main__":
device = 'cuda:0'
checkpoint_path = './music_detection.safetensors'
model = WavLMForMusicDetection('microsoft/wavlm-base-plus', batch_size=8, device=device)
model.convert_to_bf16()
model.eval()
with safe_open(checkpoint_path, framework="pt", device=device) as f:
state_dict = {key: f.get_tensor(key) for key in f.keys()}
model.load_state_dict(state_dict)
|