|
|
from transformers import SequenceFeatureExtractor |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
from transformers.feature_extraction_utils import BatchFeature |
|
|
from torchaudio.compliance.kaldi import fbank |
|
|
import torch |
|
|
import numpy as np |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from typing import Union, List |
|
|
from transformers.utils import PaddingStrategy |
|
|
|
|
|
|
|
|
class BirdMAEFeatureExtractor(SequenceFeatureExtractor): |
|
|
_auto_class = "AutoFeatureExtractor" |
|
|
model_input_names = ["input_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
|
|
|
feature_size: int = 1, |
|
|
sampling_rate: int = 32_000, |
|
|
padding_value: float = 0.0, |
|
|
return_attention_mask: bool = True, |
|
|
|
|
|
|
|
|
htk_compat: bool = True, |
|
|
use_energy: bool = False, |
|
|
window_type: str = "hanning", |
|
|
num_mel_bins: int = 128, |
|
|
dither: float = 0.0, |
|
|
frame_shift: int = 10, |
|
|
|
|
|
|
|
|
target_length: int = 512, |
|
|
mean: float = -7.2, |
|
|
std: float = 4.43, |
|
|
|
|
|
**kwargs |
|
|
): |
|
|
super().__init__(feature_size, sampling_rate, padding_value, **kwargs) |
|
|
|
|
|
self.feature_size = feature_size |
|
|
self.sampling_rate = sampling_rate |
|
|
self.padding_value = padding_value |
|
|
self.return_attention_mask = return_attention_mask |
|
|
|
|
|
|
|
|
self.htk_compat = htk_compat |
|
|
self.use_energy = use_energy |
|
|
self.window_type = window_type |
|
|
self.num_mel_bins = num_mel_bins |
|
|
self.dither = dither |
|
|
self.frame_shift = frame_shift |
|
|
|
|
|
|
|
|
self.target_length = target_length |
|
|
self.mean = mean |
|
|
self.std = std |
|
|
|
|
|
def __call__(self, |
|
|
waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]], |
|
|
padding: Union[bool, str, PaddingStrategy] = "max_length", |
|
|
max_length: int | None = None, |
|
|
truncation: bool = True, |
|
|
return_tensors: str = "pt" |
|
|
): |
|
|
|
|
|
if not torch.is_tensor(waveform_batch): |
|
|
waveform_batch = torch.from_numpy(np.array(waveform_batch)) |
|
|
|
|
|
if len(waveform_batch.shape) == 1: |
|
|
waveform_batch = waveform_batch.unsqueeze(0) |
|
|
|
|
|
if len(waveform_batch.shape) != 2: |
|
|
raise ValueError("waveform_batch must be have 1 or 2 dimensions") |
|
|
|
|
|
waveform_batch = self._process_waveforms(waveform_batch, padding, truncation) |
|
|
|
|
|
fbank_features = self._compute_fbank_features(waveform_batch["input_values"]) |
|
|
|
|
|
fbank_features = self._pad_and_normalize(fbank_features) |
|
|
|
|
|
return fbank_features.unsqueeze(1) |
|
|
|
|
|
def _process_waveforms(self, |
|
|
waveforms, |
|
|
padding: bool | str, |
|
|
truncation: bool): |
|
|
clip_duration = 5 |
|
|
max_length = int(int(self.sampling_rate) * clip_duration) |
|
|
waveform_encoded = BatchFeature({"input_values": waveforms}) |
|
|
|
|
|
waveform_batch = self.pad( |
|
|
waveform_encoded, |
|
|
padding=padding, |
|
|
max_length=max_length, |
|
|
truncation=truncation, |
|
|
return_attention_mask=self.return_attention_mask |
|
|
) |
|
|
|
|
|
attention_mask = waveform_batch.get("attention_mask") |
|
|
|
|
|
if attention_mask is not None: |
|
|
waveform_batch["attention_mask"] = attention_mask |
|
|
|
|
|
waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True) |
|
|
return waveform_batch |
|
|
|
|
|
def _compute_fbank_features(self, waveforms): |
|
|
fbank_features = [ |
|
|
fbank( |
|
|
waveform.unsqueeze(0), |
|
|
htk_compat=self.htk_compat, |
|
|
sample_frequency=self.sampling_rate, |
|
|
use_energy=self.use_energy, |
|
|
window_type=self.window_type, |
|
|
num_mel_bins=self.num_mel_bins, |
|
|
dither=self.dither, |
|
|
frame_shift=self.frame_shift |
|
|
) |
|
|
for waveform in waveforms |
|
|
] |
|
|
return torch.stack(fbank_features) |
|
|
|
|
|
def _pad_and_normalize(self, fbank_features): |
|
|
difference = self.target_length - fbank_features[0].shape[0] |
|
|
min_value = fbank_features.min() |
|
|
|
|
|
if self.target_length > fbank_features.shape[0]: |
|
|
padding = (0, 0, 0, difference) |
|
|
fbank_features = F.pad(fbank_features, padding, value=min_value.item()) |
|
|
|
|
|
fbank_features = (fbank_features - self.mean) / (self.std * 2) |
|
|
return fbank_features |