JacobLinCool's picture
Upload folder using huggingface_hub
707cbac unverified
import torch
import torch.nn as nn
import torchaudio.transforms as T
import numpy as np
class MultiViewSpectrogram(nn.Module):
def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
super().__init__()
# Window sizes: 23ms, 46ms, 93ms
self.win_lengths = [368, 736, 1488]
self.transforms = nn.ModuleList()
for win_len in self.win_lengths:
n_fft = 2 ** int(np.ceil(np.log2(win_len)))
mel = T.MelSpectrogram(
sample_rate=sample_rate,
n_fft=n_fft,
win_length=win_len,
hop_length=hop_length,
f_min=27.5,
f_max=16000.0,
n_mels=n_mels,
power=1.0,
center=True,
)
self.transforms.append(mel)
def forward(self, waveform):
specs = []
for transform in self.transforms:
# Scale magnitudes logarithmically
s = transform(waveform)
s = torch.log(s + 1e-9)
specs.append(s)
return torch.stack(specs, dim=1)
def extract_context(spec, center_frame, context=7):
# Context of +/- 70ms (7 frames)
channels, n_mels, total_time = spec.shape
start = center_frame - context
end = center_frame + context + 1
pad_left = max(0, -start)
pad_right = max(0, end - total_time)
if pad_left > 0 or pad_right > 0:
spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
start += pad_left
end += pad_left
return spec[:, :, start:end]