| """Preprocessing and normalization to prepare audio for Kintsugi Depression and Anxiety model.""" |
| from typing import Union, BinaryIO |
| import numpy as np |
| import os |
| import torch |
| import torchaudio |
| from transformers import AutoFeatureExtractor |
|
|
| from config import EXPECTED_SAMPLE_RATE, logmel_energies |
|
|
|
|
| def load_audio(source: Union[BinaryIO, str, os.PathLike]) -> torch.Tensor: |
| """Load audio file, verify mono channel count, and resample if necessary. |
| |
| Parameters |
| ---------- |
| source: open file or path to file |
| |
| Returns |
| ------- |
| Time domain audio samples as a 1 x num_samples float tensor sampled at 16 kHz. |
| |
| """ |
| audio, fs = torchaudio.load(source) |
| if audio.shape[0] != 1: |
| raise ValueError(f"Provided audio has {audio.shape[0]} != 1 channels.") |
| if fs != EXPECTED_SAMPLE_RATE: |
| audio = torchaudio.functional.resample(audio, fs, EXPECTED_SAMPLE_RATE) |
| return audio |
|
|
|
|
| class Preprocessor: |
| def __init__(self, |
| normalize_features: bool = True, |
| chunk_seconds: int = 30, |
| max_overlap_frac: float = 0.0, |
| pad_last_chunk_to_full: bool = True, |
| ): |
| """Create preprocessor object. |
| |
| Parameters |
| ---------- |
| normalize_features: Whether the Whisper preprocessor should normalize features |
| chunk_seconds: Size of model's receptive field in seconds |
| max_overlap_frac: Fraction of each chunk allowed to overlap previous chunk for inputs longer than chunk_seconds |
| pad_last_chunk_to_full: Whether to pad audio to an integer multiple of chunk_seconds |
| |
| """ |
| self.preprocessor = AutoFeatureExtractor.from_pretrained("openai/whisper-small.en") |
| self.normalize_features = normalize_features |
| self.chunk_seconds = chunk_seconds |
| self.max_overlap_frac = max_overlap_frac |
| self.pad_last_chunk_to_full = pad_last_chunk_to_full |
|
|
| def preprocess_with_audio_normalization( |
| self, |
| audio: torch.Tensor, |
| ) -> torch.Tensor: |
| """Run Whisper preprocessor and normalization expected by the model. |
| |
| Note: some normalization steps can be avoided, but are included to match |
| feature extraction used during training. |
| |
| Parameters |
| ---------- |
| audio: Raw audio samples as a 1 x num_samples float tensor sampled at 16 kHz |
| |
| Returns |
| ------- |
| Normalized mel filter bank features as a float tensor of shape |
| num_chunks x 80 mel filter bands x 3000 time frames |
| |
| """ |
| |
| audio = torch.squeeze(audio, 0) |
| audio = audio - torch.mean(audio) |
| audio = audio / torch.max(torch.abs(audio)) |
|
|
| chunk_samples = EXPECTED_SAMPLE_RATE * self.chunk_seconds |
|
|
| if self.pad_last_chunk_to_full: |
| |
| if self.max_overlap_frac > 0: |
| raise ValueError( |
| f"pad_last_chunk_to_full is only supported for non-overlapping windows" |
| ) |
| num_chunks = np.ceil(len(audio) / chunk_samples) |
| pad_size = int(num_chunks * chunk_samples - len(audio)) |
| audio = torch.nn.functional.pad(audio, (0, pad_size)) |
|
|
| overflow_len = len(audio) - chunk_samples |
|
|
| min_hop_samples = int( |
| (1 - self.max_overlap_frac) * chunk_samples |
| ) |
|
|
| n_windows = 1 + overflow_len // min_hop_samples |
| window_starts = np.linspace(0, overflow_len, max(n_windows, 1)).astype(int) |
|
|
| features = self.preprocessor( |
| [ |
| audio[start : start + chunk_samples].numpy(force=True) |
| for start in window_starts |
| ], |
| return_tensors="pt", |
| sampling_rate=EXPECTED_SAMPLE_RATE, |
| do_normalize=self.normalize_features, |
| ) |
| for key in ("input_features", "input_values"): |
| if hasattr(features, key): |
| features = getattr(features, key) |
| break |
|
|
| mean_features = torch.mean(features, dim=-1) |
| |
| rescale_factor = logmel_energies.unsqueeze(0) - mean_features |
| rescale_factor = rescale_factor.unsqueeze(2) |
| features += rescale_factor |
| return features |
|
|