File size: 5,144 Bytes
cb47fe9 90b7a5f 9686f9f cb47fe9 6ff7503 cb47fe9 05e019e 6cc416d 05e019e 6cc416d 05e019e cb47fe9 05e019e cb47fe9 acc98a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | from transformers import SequenceFeatureExtractor
from transformers.tokenization_utils_base import BatchEncoding
from transformers.feature_extraction_utils import BatchFeature
from torchaudio.compliance.kaldi import fbank
import torch
import numpy as np
import torch.nn.functional as F
from typing import Union, List
from transformers.utils import PaddingStrategy
class BirdMAEFeatureExtractor(SequenceFeatureExtractor):
_auto_class = "AutoFeatureExtractor"
model_input_names = ["input_values"]
def __init__(
self,
# process waveform
feature_size: int = 1,
sampling_rate: int = 32_000,
padding_value: float = 0.0,
return_attention_mask: bool = True,
# fbank
htk_compat: bool = True,
use_energy: bool = False,
window_type: str = "hanning",
num_mel_bins: int = 128,
dither: float = 0.0,
frame_shift: int = 10,
# pad and normalize
target_length: int = 512,
mean: float = -7.2,
std: float = 4.43,
**kwargs
):
super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
# squence FE
self.feature_size = feature_size
self.sampling_rate = sampling_rate
self.padding_value = padding_value
self.return_attention_mask = return_attention_mask
# fbank
self.htk_compat = htk_compat
self.use_energy = use_energy
self.window_type = window_type
self.num_mel_bins = num_mel_bins
self.dither = dither
self.frame_shift = frame_shift
# pad and normalize
self.target_length = target_length
self.mean = mean
self.std = std
def __call__(self,
waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
padding: Union[bool, str, PaddingStrategy] = "max_length",
max_length: int | None = None,
truncation: bool = True,
return_tensors: str = "pt"
):
if not torch.is_tensor(waveform_batch):
#waveform_batch = torch.tensor((waveform_batch))
waveform_batch = torch.from_numpy(np.array(waveform_batch))
if len(waveform_batch.shape) == 1:
waveform_batch = waveform_batch.unsqueeze(0)
if len(waveform_batch.shape) != 2:
raise ValueError("waveform_batch must be have 1 or 2 dimensions")
waveform_batch = self._process_waveforms(waveform_batch, padding, truncation)
fbank_features = self._compute_fbank_features(waveform_batch["input_values"])
fbank_features = self._pad_and_normalize(fbank_features)
return fbank_features.unsqueeze(1)
def _process_waveforms(self,
waveforms,
padding: bool | str,
truncation: bool):
clip_duration = 5 # TODO this is the clip duration used in training
max_length = int(int(self.sampling_rate) * clip_duration)
waveform_encoded = BatchFeature({"input_values": waveforms})
waveform_batch = self.pad(
waveform_encoded,
padding=padding,
max_length=max_length,
truncation=truncation,
return_attention_mask=self.return_attention_mask
)
attention_mask = waveform_batch.get("attention_mask")
if attention_mask is not None:
waveform_batch["attention_mask"] = attention_mask
#waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)
return waveform_batch
def _compute_fbank_features(self, waveforms):
fbank_features = [
fbank(
waveform.unsqueeze(0),
htk_compat=self.htk_compat,
sample_frequency=self.sampling_rate,
use_energy=self.use_energy,
window_type=self.window_type,
num_mel_bins=self.num_mel_bins,
dither=self.dither,
frame_shift=self.frame_shift
)
for waveform in waveforms
]
return torch.stack(fbank_features)
def _pad_and_normalize(self, fbank_features):
# fbank_features shape: [batch, time_frames, mel_bins]
# Truncate if too long
if fbank_features.shape[1] > self.target_length:
fbank_features = fbank_features[:, :self.target_length, :]
# Pad if too short
elif fbank_features.shape[1] < self.target_length:
difference = self.target_length - fbank_features.shape[1]
min_value = fbank_features.min()
padding = (0, 0, 0, difference) # (left, right, top, bottom)
fbank_features = F.pad(
fbank_features,
padding,
value=min_value
)
# Normalize
fbank_features = (fbank_features - self.mean) / (self.std * 2)
return fbank_features |