File size: 5,144 Bytes
cb47fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90b7a5f
9686f9f
cb47fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6ff7503
cb47fe9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05e019e
 
 
 
 
 
 
 
 
6cc416d
05e019e
 
 
 
6cc416d
05e019e
cb47fe9
05e019e
cb47fe9
acc98a7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from transformers import SequenceFeatureExtractor
from transformers.tokenization_utils_base import BatchEncoding
from transformers.feature_extraction_utils import BatchFeature
from torchaudio.compliance.kaldi import fbank
import torch
import numpy as np
import torch.nn.functional as F

from typing import Union, List
from transformers.utils import PaddingStrategy


class BirdMAEFeatureExtractor(SequenceFeatureExtractor):
    _auto_class = "AutoFeatureExtractor"
    model_input_names = ["input_values"]

    def __init__(
            self,
            # process waveform
            feature_size: int = 1,
            sampling_rate: int = 32_000,
            padding_value: float = 0.0,
            return_attention_mask: bool = True,

            # fbank
            htk_compat: bool = True,
            use_energy: bool = False,
            window_type: str = "hanning",
            num_mel_bins: int = 128,
            dither: float = 0.0,
            frame_shift: int = 10,

            # pad and normalize
            target_length: int = 512,
            mean: float = -7.2,
            std: float = 4.43,

            **kwargs
    ):
        super().__init__(feature_size, sampling_rate, padding_value, **kwargs)
        # squence FE
        self.feature_size = feature_size
        self.sampling_rate = sampling_rate
        self.padding_value = padding_value
        self.return_attention_mask = return_attention_mask

        # fbank
        self.htk_compat = htk_compat
        self.use_energy = use_energy
        self.window_type = window_type
        self.num_mel_bins = num_mel_bins
        self.dither = dither
        self.frame_shift = frame_shift

        # pad and normalize
        self.target_length = target_length
        self.mean = mean
        self.std = std

    def __call__(self,
                 waveform_batch: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
                 padding: Union[bool, str, PaddingStrategy] = "max_length",
                 max_length: int | None = None,
                 truncation: bool = True,
                 return_tensors: str = "pt"
                 ):

        if not torch.is_tensor(waveform_batch):
            #waveform_batch = torch.tensor((waveform_batch))
            waveform_batch = torch.from_numpy(np.array(waveform_batch))

        if len(waveform_batch.shape) == 1:
            waveform_batch = waveform_batch.unsqueeze(0)

        if len(waveform_batch.shape) != 2:
            raise ValueError("waveform_batch must be have 1 or 2 dimensions")

        waveform_batch = self._process_waveforms(waveform_batch, padding, truncation)

        fbank_features = self._compute_fbank_features(waveform_batch["input_values"])

        fbank_features = self._pad_and_normalize(fbank_features)

        return fbank_features.unsqueeze(1)

    def _process_waveforms(self,
                           waveforms,
                           padding: bool | str,
                           truncation: bool):
        clip_duration = 5 # TODO this is the clip duration used in training
        max_length = int(int(self.sampling_rate) * clip_duration)
        waveform_encoded = BatchFeature({"input_values": waveforms})

        waveform_batch = self.pad(
            waveform_encoded,
            padding=padding,
            max_length=max_length,
            truncation=truncation,
            return_attention_mask=self.return_attention_mask
        )

        attention_mask = waveform_batch.get("attention_mask")

        if attention_mask is not None:
            waveform_batch["attention_mask"] = attention_mask

        #waveform_batch["input_values"] = waveform_batch["input_values"] - waveform_batch["input_values"].mean(axis=1, keepdims=True)
        return waveform_batch

    def _compute_fbank_features(self, waveforms):
        fbank_features = [
            fbank(
                waveform.unsqueeze(0),
                htk_compat=self.htk_compat,
                sample_frequency=self.sampling_rate,
                use_energy=self.use_energy,
                window_type=self.window_type,
                num_mel_bins=self.num_mel_bins,
                dither=self.dither,
                frame_shift=self.frame_shift
            )
            for waveform in waveforms
        ]
        return torch.stack(fbank_features)

    def _pad_and_normalize(self, fbank_features):
        # fbank_features shape: [batch, time_frames, mel_bins]

        # Truncate if too long
        if fbank_features.shape[1] > self.target_length:
            fbank_features = fbank_features[:, :self.target_length, :]

        # Pad if too short
        elif fbank_features.shape[1] < self.target_length:
            difference = self.target_length - fbank_features.shape[1]
            min_value = fbank_features.min()
            padding = (0, 0, 0, difference)  # (left, right, top, bottom)
            fbank_features = F.pad(
                fbank_features,
                padding,
                value=min_value
            )

        # Normalize
        fbank_features = (fbank_features - self.mean) / (self.std * 2)
        return fbank_features