AudioProtoPNet-5-BirdSet-XCL / processing_protonet.py
mwirth7's picture
upload
14de4a2 verified
from transformers import SequenceFeatureExtractor
from transformers.utils import PaddingStrategy
from transformers.feature_extraction_utils import BatchFeature
from torchaudio import transforms
from typing import Union
import numpy as np
import torch
class AudioProtoNetFeatureExtractor(SequenceFeatureExtractor):
_auto_class = "AutoFeatureExtractor"
model_input_names = ["input_values"]
def __init__(self,
# spectrogram
n_fft: int = 2048,
feature_size: int = 1,
hop_length: int = 256,
power: float = 2.0,
# mel scale
n_mels: int = 256,
sampling_rate: int = 32_000,
n_stft: int = 1025,
# power to db
stype: str = "power",
top_db: int = 80,
# normalization
mean: float = -13.369,
std: float = 13.162,
padding_value: float = 0.0,
return_attention_mask: bool = True,
**kwargs,
):
super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)
# Store parameters for serialization
self.n_fft = n_fft
self.hop_length = hop_length
self.power = power
self.n_mels = n_mels
self.sampling_rate = sampling_rate
self.n_stft = n_stft
self.stype = stype
self.top_db = top_db
self.mean = mean
self.std = std
self.padding_value = padding_value
self.return_attention_mask = return_attention_mask
self.spec_transform = None
self.mel_scale = None
self.db_scale = None
def _init_transforms(self): # TODO post init method?
self.spec_transform = transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=self.power)
self.mel_scale = transforms.MelScale(n_mels=self.n_mels, sample_rate=self.sampling_rate, n_stft=self.n_stft)
self.db_scale = transforms.AmplitudeToDB(stype=self.stype, top_db=self.top_db)
def __call__(self,
waveform_batch: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],
padding: Union[bool, str, PaddingStrategy] = "longest",
max_length: int | None = None,
truncation: bool = True,
return_tensors: str = "pt"
):
if self.spec_transform is None:
self._init_transforms()
clip_duration = 5 # TODO this is the clip duration used in training
max_length = max_length or int(int(self.sampling_rate) * clip_duration)
if isinstance(waveform_batch, (list, np.ndarray)) and not isinstance(waveform_batch[0], (list, np.ndarray)):
waveform_batch = [waveform_batch]
waveform_batch = BatchFeature({"input_values": waveform_batch})
waveform_batch = self.pad(
waveform_batch,
padding=padding,
max_length=max_length,
truncation=truncation,
return_attention_mask=self.return_attention_mask
)
waveform_batch = waveform_batch["input_values"]
audio_tensor = torch.as_tensor(waveform_batch)
spec_gram = self.spec_transform(audio_tensor)
mel_spec = self.mel_scale(spec_gram)
mel_spec = self.db_scale(mel_spec)
mel_spec_norm = (mel_spec - self.mean) / self.std
return mel_spec_norm.unsqueeze(1)