|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from matplotlib import pyplot as plt |
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
class logMelSpectrogram: |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
frame_rate_s: int = 30, |
|
|
stride_s: int = 10, |
|
|
n_fft: Optional[int] = None, |
|
|
n_mels: Optional[int] = 40, |
|
|
top_db: int = 80, |
|
|
pre_emph_coef: float = 0.95, |
|
|
device: Optional[str] = None |
|
|
): |
|
|
|
|
|
self.frame_rate_s = frame_rate_s |
|
|
self.stride_s = stride_s |
|
|
self.n_fft = n_fft |
|
|
self.n_mels = n_mels |
|
|
self.log_mel_spec_is_computed = False |
|
|
self.top_db = top_db |
|
|
self.pre_emph_coef = pre_emph_coef |
|
|
|
|
|
if not device: |
|
|
self.device = "cuda" if torch.cuda.is_available() else ( |
|
|
"mps" if torch.mps.is_available() else "cpu" |
|
|
) |
|
|
self.device = device |
|
|
torch.set_default_device(device) |
|
|
torch.set_default_dtype(torch.float32) |
|
|
|
|
|
def transform( |
|
|
self, |
|
|
samples: np.array, |
|
|
sr: int, |
|
|
): |
|
|
|
|
|
self.samples = torch.from_numpy(samples) |
|
|
self.sr = sr |
|
|
|
|
|
if self.samples.shape[0] < 2: |
|
|
raise ValueError("Samples should be longer than two") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pre_emph_samples = torch.cat([ |
|
|
self.samples[0:1], |
|
|
self.samples[1:] - self.pre_emph_coef * self.samples[:-1] |
|
|
], dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stride = self.sr * self.stride_s // 1000 |
|
|
frame_rate = self.sr * self.frame_rate_s // 1000 |
|
|
|
|
|
|
|
|
chunks = pre_emph_samples.unfold(0, frame_rate, stride).contiguous() |
|
|
num_of_frames = chunks.shape[0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n = torch.arange(frame_rate) |
|
|
hanning_weights = 0.5 - 0.5 * torch.cos(2 * torch.pi * n / (frame_rate - 1)) |
|
|
|
|
|
weighted_chunks = chunks * hanning_weights |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not self.n_fft: |
|
|
self.n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(frame_rate, dtype=torch.float32))).to(torch.int32) |
|
|
|
|
|
fft_chunks = torch.fft.rfft(weighted_chunks, n=self.n_fft) |
|
|
power_spec = (2 / self.n_fft ** 2) * torch.abs(fft_chunks) ** 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def hz_to_mel(hz): |
|
|
return 2595 * torch.log10(1 + hz / 700) |
|
|
def mel_to_hz(m): |
|
|
return 700 * (10 ** (m / 2595) - 1) |
|
|
|
|
|
fmax = self.sr / 2 |
|
|
fmin = 0 |
|
|
|
|
|
|
|
|
mels = torch.linspace( |
|
|
hz_to_mel(torch.tensor(fmin)), |
|
|
hz_to_mel(torch.tensor(fmax)), |
|
|
self.n_mels + 2 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
hz_points = mel_to_hz(mels) |
|
|
bins = torch.floor((self.n_fft + 1) * hz_points / self.sr).to(torch.int32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
k = torch.arange(self.n_fft // 2 + 1).unsqueeze(0) |
|
|
|
|
|
f_left = bins[:-2].unsqueeze(1) |
|
|
f_center = bins[1:-1].unsqueeze(1) |
|
|
f_right = bins[2:].unsqueeze(1) |
|
|
|
|
|
up = (k - f_left) / torch.clamp(f_center - f_left, min=1e-8) |
|
|
down = (f_right - k) / torch.clamp(f_right - f_center, min=1e-8) |
|
|
|
|
|
filters = torch.clamp(torch.minimum(up, down), min=0.0) |
|
|
|
|
|
|
|
|
mel_spec = torch.matmul(filters, power_spec.T) |
|
|
|
|
|
|
|
|
|
|
|
mel_spec = torch.clamp(mel_spec, min=1e-10) |
|
|
log_mel_spec = 10 * torch.log10(mel_spec) |
|
|
|
|
|
|
|
|
|
|
|
log_mel_spec = torch.clamp( |
|
|
log_mel_spec, |
|
|
min=torch.max(log_mel_spec) - self.top_db |
|
|
) |
|
|
|
|
|
self.log_mel_spec = log_mel_spec |
|
|
|
|
|
self.log_mel_spec_is_computed = True |
|
|
|
|
|
return log_mel_spec |
|
|
|
|
|
def plot_waveform(self): |
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
|
cpu_samples = self.samples.cpu().numpy() |
|
|
plt.plot(np.arange(cpu_samples.shape[0]) / self.sr, cpu_samples) |
|
|
plt.title("Waveform") |
|
|
plt.xlabel("Time (s)") |
|
|
plt.ylabel("Amplitude") |
|
|
plt.show() |
|
|
|
|
|
def plot_log_mel_spec(self, cmap="magma_r"): |
|
|
|
|
|
if not self.log_mel_spec_is_computed: |
|
|
raise ValueError("run compute() before plotting log mel spectogram") |
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
|
spec_to_plot = self.log_mel_spec.cpu().numpy() |
|
|
plt.imshow(spec_to_plot, origin="lower", aspect="auto", cmap=cmap) |
|
|
plt.title("Log-Mel Spectrogram (dB)") |
|
|
plt.xlabel("Time frames") |
|
|
plt.ylabel("Mel bins") |
|
|
plt.colorbar() |
|
|
plt.show() |
|
|
|