Voice Activity Detection
English
a_sad_model / logMelSpectrogram.py
hypersunflower's picture
Upload 3 files
d1124fa verified
import torch
import numpy as np
from matplotlib import pyplot as plt
from typing import Optional
class logMelSpectrogram:
def __init__(
self,
frame_rate_s: int = 30,
stride_s: int = 10,
n_fft: Optional[int] = None,
n_mels: Optional[int] = 40,
top_db: int = 80,
pre_emph_coef: float = 0.95,
device: Optional[str] = None
):
self.frame_rate_s = frame_rate_s
self.stride_s = stride_s
self.n_fft = n_fft
self.n_mels = n_mels
self.log_mel_spec_is_computed = False
self.top_db = top_db
self.pre_emph_coef = pre_emph_coef
if not device:
self.device = "cuda" if torch.cuda.is_available() else (
"mps" if torch.mps.is_available() else "cpu"
)
self.device = device
torch.set_default_device(device)
torch.set_default_dtype(torch.float32)
def transform(
self,
samples: np.array,
sr: int,
):
self.samples = torch.from_numpy(samples)
self.sr = sr
if self.samples.shape[0] < 2:
raise ValueError("Samples should be longer than two")
# pre emphasis
# it's necessary to compensate the audio roll off
# meaning it amplifies the difference between current signal
# and previous one
pre_emph_samples = torch.cat([
self.samples[0:1],
self.samples[1:] - self.pre_emph_coef * self.samples[:-1]
], dim=0)
# framing
# it's needed to turn the audio into descrete overlapping chunks
stride = self.sr * self.stride_s // 1000
frame_rate = self.sr * self.frame_rate_s // 1000
chunks = pre_emph_samples.unfold(0, frame_rate, stride).contiguous()
num_of_frames = chunks.shape[0]
# hann window to smooth out the edges
# as i understand, it is necessary to
# smooth out the edges of chunks to avoid
# sudden drops and rises in volume
n = torch.arange(frame_rate)
hanning_weights = 0.5 - 0.5 * torch.cos(2 * torch.pi * n / (frame_rate - 1))
weighted_chunks = chunks * hanning_weights
# applying fast fourier transform
# to decompose "raw" audio into underlying frequencies
# only positive frequencies are taken, because negative freqs
# dont bring new information
# so there are about "half" (n_fft / 2 + 1) extracted
if not self.n_fft:
self.n_fft = 2 ** torch.ceil(torch.log2(torch.tensor(frame_rate, dtype=torch.float32))).to(torch.int32)
fft_chunks = torch.fft.rfft(weighted_chunks, n=self.n_fft)
power_spec = (2 / self.n_fft ** 2) * torch.abs(fft_chunks) ** 2
# herz to mels converter and vice versa
def hz_to_mel(hz):
return 2595 * torch.log10(1 + hz / 700)
def mel_to_hz(m):
return 700 * (10 ** (m / 2595) - 1)
fmax = self.sr / 2
fmin = 0
# here we create mels scale
mels = torch.linspace(
hz_to_mel(torch.tensor(fmin)),
hz_to_mel(torch.tensor(fmax)),
self.n_mels + 2
)
# converting linear mels to hz thus
# introducing non-linearity
hz_points = mel_to_hz(mels)
bins = torch.floor((self.n_fft + 1) * hz_points / self.sr).to(torch.int32)
# building triangular filters
# that are overlapping and gain "energy" with the increase of hz
# simulating human hearing that is better at distinguishing between lower
# freqs than higher ones
# so as the hz rises the filter becomes bigger
# and, if one might say, less sensitive
k = torch.arange(self.n_fft // 2 + 1).unsqueeze(0)
f_left = bins[:-2].unsqueeze(1)
f_center = bins[1:-1].unsqueeze(1)
f_right = bins[2:].unsqueeze(1)
up = (k - f_left) / torch.clamp(f_center - f_left, min=1e-8) # (n_mels, bins)
down = (f_right - k) / torch.clamp(f_right - f_center, min=1e-8) # (n_mels, bins)
filters = torch.clamp(torch.minimum(up, down), min=0.0)
mel_spec = torch.matmul(filters, power_spec.T)
# converting mel spectogram to log scale
mel_spec = torch.clamp(mel_spec, min=1e-10)
log_mel_spec = 10 * torch.log10(mel_spec)
# normalising
log_mel_spec = torch.clamp(
log_mel_spec,
min=torch.max(log_mel_spec) - self.top_db
)
self.log_mel_spec = log_mel_spec
self.log_mel_spec_is_computed = True
return log_mel_spec
def plot_waveform(self):
plt.figure(figsize=(10, 4))
cpu_samples = self.samples.cpu().numpy()
plt.plot(np.arange(cpu_samples.shape[0]) / self.sr, cpu_samples)
plt.title("Waveform")
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.show()
def plot_log_mel_spec(self, cmap="magma_r"):
if not self.log_mel_spec_is_computed:
raise ValueError("run compute() before plotting log mel spectogram")
plt.figure(figsize=(10, 4))
spec_to_plot = self.log_mel_spec.cpu().numpy()
plt.imshow(spec_to_plot, origin="lower", aspect="auto", cmap=cmap)
plt.title("Log-Mel Spectrogram (dB)")
plt.xlabel("Time frames")
plt.ylabel("Mel bins")
plt.colorbar()
plt.show()