| | from transformers import SequenceFeatureExtractor
|
| | from transformers.tokenization_utils_base import BatchEncoding
|
| | from transformers.feature_extraction_utils import BatchFeature
|
| | from torchaudio.compliance.kaldi import fbank
|
| | import torch
|
| | import numpy as np
|
| | import torch.nn.functional as F
|
| |
|
| | from typing import Union, List
|
| | from transformers.utils import PaddingStrategy
|
| |
|
| |
|
| | class BirdMAEFeatureExtractor(SequenceFeatureExtractor):
|
| | _auto_class = "AutoFeatureExtractor"
|
| | model_input_names = ["input_values"]
|
| |
|
| | def __init__(
|
| | self,
|
| |
|
| | feature_size: int = 1,
|
| | sampling_rate: int = 32_000,
|
| | padding_value: float = 0.0,
|
| | return_attention_mask: bool = True,
|
| |
|
| |
|
| | htk_compat: bool = True,
|
| | use_energy: bool = False,
|
| | window_type: str = "hanning",
|
| | num_mel_bins: int = 128,
|
| | dither: float = 0.0,
|
| | frame_shift: int = 10,
|
| |
|
| |
|
| | target_length: int = 512,
|
| | mean: float = -7.2,
|
| | std: float = 4.43,
|
| |
|
| | **kwargs
|
| | ):
|
| | super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
|
| |
|
| | self.feature_size = feature_size
|
| | self.sampling_rate = sampling_rate
|
| | self.padding_value = padding_value
|
| | self.return_attention_mask = return_attention_mask
|
| |
|
| |
|
| | self.htk_compat = htk_compat
|
| | self.use_energy = use_energy
|
| | self.window_type = window_type
|
| | self.num_mel_bins = num_mel_bins
|
| | self.dither = dither
|
| | self.frame_shift = frame_shift
|
| |
|
| |
|
| | self.target_length = target_length
|
| | self.mean = mean
|
| | self.std = std
|
| |
|
| | def __call__(self,
|
| | waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
|
| | padding: Union[bool, str, PaddingStrategy] = "max_length",
|
| | max_length: int | None = None,
|
| | truncation: bool = True,
|
| | return_tensors: str = "pt"
|
| | ):
|
| |
|
| | if not torch.is_tensor(waveform_batch):
|
| | waveform_batch = torch.tensor(waveform_batch)
|
| |
|
| | if len(waveform_batch.shape) == 1:
|
| | waveform_batch = waveform_batch.unsqueeze(0)
|
| |
|
| | if len(waveform_batch.shape) != 2:
|
| | raise ValueError("waveform_batch must be have 1 or 2 dimensions")
|
| |
|
| | waveform_batch = self._process_waveforms(waveform_batch, padding, truncation)
|
| |
|
| | fbank_features = self._compute_fbank_features(waveform_batch["input_values"])
|
| |
|
| | fbank_features = self._pad_and_normalize(fbank_features)
|
| |
|
| | return fbank_features.unsqueeze(1)
|
| |
|
| | def _process_waveforms(self,
|
| | waveforms,
|
| | padding: bool | str,
|
| | truncation: bool):
|
| | clip_duration = 5
|
| | max_length = int(int(self.sampling_rate) * clip_duration)
|
| | waveform_encoded = BatchFeature({"input_values": waveforms})
|
| |
|
| | waveform_batch = self.pad(
|
| | waveform_encoded,
|
| | padding=padding,
|
| | max_length=max_length,
|
| | truncation=truncation,
|
| | return_attention_mask=self.return_attention_mask
|
| | )
|
| |
|
| | attention_mask = waveform_batch.get("attention_mask")
|
| |
|
| | if attention_mask is not None:
|
| | waveform_batch["attention_mask"] = attention_mask
|
| |
|
| | waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)
|
| | return waveform_batch
|
| |
|
| | def _compute_fbank_features(self, waveforms):
|
| | fbank_features = [
|
| | fbank(
|
| | waveform.unsqueeze(0),
|
| | htk_compat=self.htk_compat,
|
| | sample_frequency=self.sampling_rate,
|
| | use_energy=self.use_energy,
|
| | window_type=self.window_type,
|
| | num_mel_bins=self.num_mel_bins,
|
| | dither=self.dither,
|
| | frame_shift=self.frame_shift
|
| | )
|
| | for waveform in waveforms
|
| | ]
|
| | return torch.stack(fbank_features)
|
| |
|
| | def _pad_and_normalize(self, fbank_features):
|
| | difference = self.target_length - fbank_features[0].shape[0]
|
| | min_value = fbank_features.min()
|
| |
|
| | if self.target_length > fbank_features.shape[0]:
|
| | padding = (0, 0, 0, difference)
|
| | fbank_features = F.pad(fbank_features, padding, value=min_value.item())
|
| |
|
| | fbank_features = (fbank_features - self.mean) / (self.std * 2)
|
| | return fbank_features |