| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
| from nnAudio import features |
|
|
|
|
| class MelSpectrogram(nn.Module): |
| def __init__( |
| self, |
| sample_rate: int = 16000, |
| n_ftt: int = 2048, |
| n_mels: int = 512, |
| hop_length: int = 128, |
| ): |
| """ |
| Melspectrogram transformation layer, supports on-the-fly processing on GPU. |
| |
| Attributes: |
| sample_rate: The sampling rate for the input audio. |
| n_ftt: The window size for the STFT. |
| n_mels: The number of Mel filter banks. |
| hop_length: The hop (or stride) size. |
| """ |
| super().__init__() |
| self.transform = features.MelSpectrogram( |
| sr=sample_rate, |
| n_fft=n_ftt, |
| n_mels=n_mels, |
| hop_length=hop_length, |
| center=True, |
| fmin=0, |
| fmax=sample_rate // 2, |
| pad_mode="constant", |
| ) |
|
|
| def forward(self, samples: torch.tensor) -> torch.tensor: |
| """ |
| Convert a batch of audio frames into a batch of Mel spectrogram frames. |
| |
| For each item in the batch: |
| 1. pad left and right ends of audio by n_fft // 2. |
| 2. run STFT with window size of |n_ftt| and stride of |hop_length|. |
| 3. convert result into mel-scale. |
| 4. therefore, n_frames = n_samples // hop_length + 1. |
| |
| Args: |
| samples: Audio time-series (batch size, n_samples). |
| |
| Returns: |
| A batch of Mel spectrograms of size (batch size, n_frames, n_mels). |
| """ |
| spectrogram = self.transform(samples) |
| spectrogram = spectrogram.permute(0, 2, 1) |
| return spectrogram |
|
|