Bird-MAE-Large / feature_extractor.py
mwirth7's picture
Update feature_extractor.py
38e413a verified
from transformers import SequenceFeatureExtractor
from transformers.tokenization_utils_base import BatchEncoding
from transformers.feature_extraction_utils import BatchFeature
from torchaudio.compliance.kaldi import fbank
import torch
import numpy as np
import torch.nn.functional as F
from typing import Union, List
from transformers.utils import PaddingStrategy
class BirdMAEFeatureExtractor(SequenceFeatureExtractor):
_auto_class = "AutoFeatureExtractor"
model_input_names = ["input_values"]
def __init__(
self,
# process waveform
feature_size: int = 1,
sampling_rate: int = 32_000,
padding_value: float = 0.0,
return_attention_mask: bool = True,
# fbank
htk_compat: bool = True,
use_energy: bool = False,
window_type: str = "hanning",
num_mel_bins: int = 128,
dither: float = 0.0,
frame_shift: int = 10,
# pad and normalize
target_length: int = 512,
mean: float = -7.2,
std: float = 4.43,
**kwargs
):
super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
# squence FE
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.return_attention_mask = return_attention_mask
# fbank
self.htk_compat = htk_compat
self.use_energy = use_energy
self.window_type = window_type
self.num_mel_bins = num_mel_bins
self.dither = dither
self.frame_shift = frame_shift
# pad and normalize
self.target_length = target_length
self.mean = mean
self.std = std
def __call__(self,
waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = "max_length",
max_length: int | None = None,
truncation: bool = True,
return_tensors: str = "pt"
):
if not torch.is_tensor(waveform_batch):
waveform_batch = torch.from_numpy(np.array(waveform_batch))
if len(waveform_batch.shape) == 1:
waveform_batch = waveform_batch.unsqueeze(0)
if len(waveform_batch.shape) != 2:
raise ValueError("waveform_batch must be have 1 or 2 dimensions")
waveform_batch = self._process_waveforms(waveform_batch, padding, truncation)
fbank_features = self._compute_fbank_features(waveform_batch["input_values"])
fbank_features = self._pad_and_normalize(fbank_features)
return fbank_features.unsqueeze(1)
def _process_waveforms(self,
waveforms,
padding: bool | str,
truncation: bool):
clip_duration = 5 # TODO this is the clip duration used in training
max_length = int(int(self.sampling_rate) * clip_duration)
waveform_encoded = BatchFeature({"input_values": waveforms})
waveform_batch = self.pad(
waveform_encoded,
padding=padding,
max_length=max_length,
truncation=truncation,
return_attention_mask=self.return_attention_mask
)
attention_mask = waveform_batch.get("attention_mask")
if attention_mask is not None:
waveform_batch["attention_mask"] = attention_mask
waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)
return waveform_batch
def _compute_fbank_features(self, waveforms):
fbank_features = [
fbank(
waveform.unsqueeze(0),
htk_compat=self.htk_compat,
sample_frequency=self.sampling_rate,
use_energy=self.use_energy,
window_type=self.window_type,
num_mel_bins=self.num_mel_bins,
dither=self.dither,
frame_shift=self.frame_shift
)
for waveform in waveforms
]
return torch.stack(fbank_features)
def _pad_and_normalize(self, fbank_features):
difference = self.target_length - fbank_features[0].shape[0]
min_value = fbank_features.min()
if self.target_length > fbank_features.shape[0]:
padding = (0, 0, 0, difference)
fbank_features = F.pad(fbank_features, padding, value=min_value.item())
fbank_features = (fbank_features - self.mean) / (self.std * 2)
return fbank_features