import torch import torch.nn as nn import torchaudio.transforms as T import numpy as np class MultiViewSpectrogram(nn.Module): def __init__(self, sample_rate=16000, n_mels=80, hop_length=160): super().__init__() # Window sizes: 23ms, 46ms, 93ms self.win_lengths = [368, 736, 1488] self.transforms = nn.ModuleList() for win_len in self.win_lengths: n_fft = 2 ** int(np.ceil(np.log2(win_len))) mel = T.MelSpectrogram( sample_rate=sample_rate, n_fft=n_fft, win_length=win_len, hop_length=hop_length, f_min=27.5, f_max=16000.0, n_mels=n_mels, power=1.0, center=True, ) self.transforms.append(mel) def forward(self, waveform): specs = [] for transform in self.transforms: # Scale magnitudes logarithmically s = transform(waveform) s = torch.log(s + 1e-9) specs.append(s) return torch.stack(specs, dim=1) def extract_context(spec, center_frame, context=7): # Context of +/- 70ms (7 frames) channels, n_mels, total_time = spec.shape start = center_frame - context end = center_frame + context + 1 pad_left = max(0, -start) pad_right = max(0, end - total_time) if pad_left > 0 or pad_right > 0: spec = torch.nn.functional.pad(spec, (pad_left, pad_right)) start += pad_left end += pad_left return spec[:, :, start:end]