File size: 3,665 Bytes
7b5f1b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from transformers import SequenceFeatureExtractor
from transformers.utils import PaddingStrategy
from transformers.feature_extraction_utils import BatchFeature
from torchaudio import transforms
from typing import Union
import numpy as np
import torch


class AudioProtoNetFeatureExtractor(SequenceFeatureExtractor):
    _auto_class = "AutoFeatureExtractor"
    model_input_names = ["input_values"]

    def __init__(self,

                 # spectrogram

                 n_fft: int = 2048,

                 feature_size: int = 1,

                 hop_length: int = 256,

                 power: float = 2.0,



                 # mel scale

                 n_mels: int = 256,

                 sampling_rate: int = 32_000,

                 n_stft: int = 1025,



                 # power to db

                 stype: str = "power",

                 top_db: int = 80,



                 # normalization

                 mean: float = -13.369,

                 std: float = 13.162,

                 padding_value: float = 0.0,



                 return_attention_mask: bool = True,

                 **kwargs,

                 ):
        super().__init__(feature_size=feature_size, sampling_rate=sampling_rate, padding_value=padding_value, **kwargs)

        # Store parameters for serialization
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.power = power
        self.n_mels = n_mels
        self.sampling_rate = sampling_rate
        self.n_stft = n_stft
        self.stype = stype
        self.top_db = top_db
        self.mean = mean
        self.std = std
        self.padding_value = padding_value
        self.return_attention_mask = return_attention_mask
        self.spec_transform = None
        self.mel_scale = None
        self.db_scale = None

    def _init_transforms(self): # TODO post init method?
        self.spec_transform = transforms.Spectrogram(n_fft=self.n_fft, hop_length=self.hop_length, power=self.power)
        self.mel_scale = transforms.MelScale(n_mels=self.n_mels, sample_rate=self.sampling_rate, n_stft=self.n_stft)
        self.db_scale = transforms.AmplitudeToDB(stype=self.stype, top_db=self.top_db)

    def __call__(self,

                 waveform_batch: Union[np.ndarray, list[float], list[np.ndarray], list[list[float]]],

                 padding: Union[bool, str, PaddingStrategy] = "longest",

                 max_length: int | None = None,

                 truncation: bool = True,

                 return_tensors: str = "pt"

                 ):
        if self.spec_transform is None:
            self._init_transforms()
        clip_duration = 5 # TODO this is the clip duration used in training
        max_length = max_length or int(int(self.sampling_rate) * clip_duration)

        if isinstance(waveform_batch, (list, np.ndarray)) and not isinstance(waveform_batch[0], (list, np.ndarray)):
            waveform_batch = [waveform_batch]

        waveform_batch = BatchFeature({"input_values": waveform_batch})

        waveform_batch = self.pad(
            waveform_batch,
            padding=padding,
            max_length=max_length,
            truncation=truncation,
            return_attention_mask=self.return_attention_mask
        )
        waveform_batch = waveform_batch["input_values"]
        audio_tensor = torch.as_tensor(waveform_batch)
        spec_gram = self.spec_transform(audio_tensor)
        mel_spec = self.mel_scale(spec_gram)
        mel_spec = self.db_scale(mel_spec)
        mel_spec_norm = (mel_spec - self.mean) / self.std

        return mel_spec_norm.unsqueeze(1)