osu_mapper2 / osuT5 /model /spectrogram.py
Tiger14n's picture
Upload folder using huggingface_hub
7ef7abb verified
from __future__ import annotations
import torch
import torch.nn as nn
from nnAudio import features
class MelSpectrogram(nn.Module):
def __init__(
self,
sample_rate: int = 16000,
n_ftt: int = 2048,
n_mels: int = 512,
hop_length: int = 128,
):
"""
Melspectrogram transformation layer, supports on-the-fly processing on GPU.
Attributes:
sample_rate: The sampling rate for the input audio.
n_ftt: The window size for the STFT.
n_mels: The number of Mel filter banks.
hop_length: The hop (or stride) size.
"""
super().__init__()
self.transform = features.MelSpectrogram(
sr=sample_rate,
n_fft=n_ftt,
n_mels=n_mels,
hop_length=hop_length,
center=True,
fmin=0,
fmax=sample_rate // 2,
pad_mode="constant",
)
def forward(self, samples: torch.tensor) -> torch.tensor:
"""
Convert a batch of audio frames into a batch of Mel spectrogram frames.
For each item in the batch:
1. pad left and right ends of audio by n_fft // 2.
2. run STFT with window size of |n_ftt| and stride of |hop_length|.
3. convert result into mel-scale.
4. therefore, n_frames = n_samples // hop_length + 1.
Args:
samples: Audio time-series (batch size, n_samples).
Returns:
A batch of Mel spectrograms of size (batch size, n_frames, n_mels).
"""
spectrogram = self.transform(samples)
spectrogram = spectrogram.permute(0, 2, 1)
return spectrogram