from __future__ import annotations import torch import torch.nn as nn from nnAudio import features class MelSpectrogram(nn.Module): def __init__( self, sample_rate: int = 16000, n_ftt: int = 2048, n_mels: int = 512, hop_length: int = 128, ): """ Melspectrogram transformation layer, supports on-the-fly processing on GPU. Attributes: sample_rate: The sampling rate for the input audio. n_ftt: The window size for the STFT. n_mels: The number of Mel filter banks. hop_length: The hop (or stride) size. """ super().__init__() self.transform = features.MelSpectrogram( sr=sample_rate, n_fft=n_ftt, n_mels=n_mels, hop_length=hop_length, center=True, fmin=0, fmax=sample_rate // 2, pad_mode="constant", ) def forward(self, samples: torch.tensor) -> torch.tensor: """ Convert a batch of audio frames into a batch of Mel spectrogram frames. For each item in the batch: 1. pad left and right ends of audio by n_fft // 2. 2. run STFT with window size of |n_ftt| and stride of |hop_length|. 3. convert result into mel-scale. 4. therefore, n_frames = n_samples // hop_length + 1. Args: samples: Audio time-series (batch size, n_samples). Returns: A batch of Mel spectrograms of size (batch size, n_frames, n_mels). """ spectrogram = self.transform(samples) spectrogram = spectrogram.permute(0, 2, 1) return spectrogram