NeoPy's picture
EXP
30f8290 verified
import os
import sys
import torch
import numpy as np
import torch.nn as nn
sys.path.append(os.getcwd())
class Spectrogram(nn.Module):
def __init__(
self,
hop_length,
win_length,
n_fft=None,
clamp=1e-10
):
super(Spectrogram, self).__init__()
self.n_fft = win_length if n_fft is None else n_fft
self.hop_length = hop_length
self.win_length = win_length
self.clamp = clamp
self.register_buffer("window", torch.hann_window(win_length), persistent=False)
def forward(self, audio, center=True):
bs, c, segment_samples = audio.shape
audio = audio.reshape(bs * c, segment_samples)
if str(audio.device).startswith(("ocl", "privateuseone")):
if not hasattr(self, "stft"):
from main.library.backends.utils import STFT
self.stft = STFT(
filter_length=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length
).to(audio.device)
magnitude = self.stft.transform(audio, 1e-9)
else:
fft = torch.stft(
audio,
n_fft=self.n_fft,
hop_length=self.hop_length,
win_length=self.win_length,
window=self.window,
center=center,
pad_mode="reflect",
return_complex=True
)
magnitude = (fft.real.pow(2) + fft.imag.pow(2)).sqrt()
mag = magnitude.transpose(1, 2).clamp(self.clamp, np.inf)
mag = mag.reshape(bs, c, mag.shape[1], mag.shape[2])
return mag