File size: 4,354 Bytes
58d3955
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Preprocessing and normalization to prepare audio for Kintsugi Depression and Anxiety model."""
from typing import Union, BinaryIO
import numpy as np
import os
import torch
import torchaudio
from transformers import AutoFeatureExtractor

from config import EXPECTED_SAMPLE_RATE, logmel_energies


def load_audio(source: Union[BinaryIO, str, os.PathLike]) -> torch.Tensor:
    """Load audio file, verify mono channel count, and resample if necessary.

    Parameters
    ----------
    source: open file or path to file

    Returns
    -------
    Time domain audio samples as a 1 x num_samples float tensor sampled at 16 kHz.

    """
    audio, fs = torchaudio.load(source)
    if audio.shape[0] != 1:
        raise ValueError(f"Provided audio has {audio.shape[0]} != 1 channels.")
    if fs != EXPECTED_SAMPLE_RATE:
        audio = torchaudio.functional.resample(audio, fs, EXPECTED_SAMPLE_RATE)
    return audio


class Preprocessor:
    def __init__(self,
                 normalize_features: bool = True,
                 chunk_seconds: int = 30,
                 max_overlap_frac: float = 0.0,
                 pad_last_chunk_to_full: bool = True,
    ):
        """Create preprocessor object.

        Parameters
        ----------
        normalize_features: Whether the Whisper preprocessor should normalize features
        chunk_seconds: Size of model's receptive field in seconds
        max_overlap_frac: Fraction of each chunk allowed to overlap previous chunk for inputs longer than chunk_seconds
        pad_last_chunk_to_full: Whether to pad audio to an integer multiple of chunk_seconds

        """
        self.preprocessor = AutoFeatureExtractor.from_pretrained("openai/whisper-small.en")
        self.normalize_features = normalize_features
        self.chunk_seconds = chunk_seconds
        self.max_overlap_frac = max_overlap_frac
        self.pad_last_chunk_to_full = pad_last_chunk_to_full

    def preprocess_with_audio_normalization(
        self,
        audio: torch.Tensor,
    ) -> torch.Tensor:
        """Run Whisper preprocessor and normalization expected by the model.

        Note: some normalization steps can be avoided, but are included to match
        feature extraction used during training.

        Parameters
        ----------
        audio: Raw audio samples as a 1 x num_samples float tensor sampled at 16 kHz

        Returns
        -------
        Normalized mel filter bank features as a float tensor of shape
        num_chunks x 80 mel filter bands x 3000 time frames

        """
        # Remove DC offset and scale amplitude to [-1, 1]
        audio = torch.squeeze(audio, 0)
        audio = audio - torch.mean(audio)
        audio = audio / torch.max(torch.abs(audio))

        chunk_samples = EXPECTED_SAMPLE_RATE * self.chunk_seconds

        if self.pad_last_chunk_to_full:
            # pad audio so that the last chunk is not dropped
            if self.max_overlap_frac > 0:
                raise ValueError(
                    f"pad_last_chunk_to_full is only supported for non-overlapping windows"
                )
            num_chunks = np.ceil(len(audio) / chunk_samples)
            pad_size = int(num_chunks * chunk_samples - len(audio))
            audio = torch.nn.functional.pad(audio, (0, pad_size))

        overflow_len = len(audio) - chunk_samples

        min_hop_samples = int(
            (1 - self.max_overlap_frac) * chunk_samples
        )

        n_windows = 1 + overflow_len // min_hop_samples
        window_starts = np.linspace(0, overflow_len, max(n_windows, 1)).astype(int)

        features = self.preprocessor(
            [
                audio[start : start + chunk_samples].numpy(force=True)
                for start in window_starts
            ],
            return_tensors="pt",
            sampling_rate=EXPECTED_SAMPLE_RATE,
            do_normalize=self.normalize_features,
        )
        for key in ("input_features", "input_values"):
            if hasattr(features, key):
                features = getattr(features, key)
                break

        mean_features = torch.mean(features, dim=-1)
        # features are [batch, n_logmel_bins, n_frames]
        rescale_factor = logmel_energies.unsqueeze(0) - mean_features
        rescale_factor = rescale_factor.unsqueeze(2)
        features += rescale_factor
        return features