|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio.transforms as T |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
class MultiViewSpectrogram(nn.Module): |
|
|
def __init__(self, sample_rate=16000, n_mels=80, hop_length=160): |
|
|
super().__init__() |
|
|
|
|
|
self.win_lengths = [368, 736, 1488] |
|
|
self.transforms = nn.ModuleList() |
|
|
|
|
|
for win_len in self.win_lengths: |
|
|
n_fft = 2 ** int(np.ceil(np.log2(win_len))) |
|
|
mel = T.MelSpectrogram( |
|
|
sample_rate=sample_rate, |
|
|
n_fft=n_fft, |
|
|
win_length=win_len, |
|
|
hop_length=hop_length, |
|
|
f_min=27.5, |
|
|
f_max=16000.0, |
|
|
n_mels=n_mels, |
|
|
power=1.0, |
|
|
center=True, |
|
|
) |
|
|
self.transforms.append(mel) |
|
|
|
|
|
def forward(self, waveform): |
|
|
specs = [] |
|
|
for transform in self.transforms: |
|
|
|
|
|
s = transform(waveform) |
|
|
s = torch.log(s + 1e-9) |
|
|
specs.append(s) |
|
|
return torch.stack(specs, dim=1) |
|
|
|
|
|
|
|
|
def extract_context(spec, center_frame, context=7): |
|
|
|
|
|
channels, n_mels, total_time = spec.shape |
|
|
start = center_frame - context |
|
|
end = center_frame + context + 1 |
|
|
|
|
|
pad_left = max(0, -start) |
|
|
pad_right = max(0, end - total_time) |
|
|
|
|
|
if pad_left > 0 or pad_right > 0: |
|
|
spec = torch.nn.functional.pad(spec, (pad_left, pad_right)) |
|
|
start += pad_left |
|
|
end += pad_left |
|
|
|
|
|
return spec[:, :, start:end] |
|
|
|