File size: 1,588 Bytes
707cbac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torchaudio.transforms as T
import numpy as np


class MultiViewSpectrogram(nn.Module):
    def __init__(self, sample_rate=16000, n_mels=80, hop_length=160):
        super().__init__()
        #  Window sizes: 23ms, 46ms, 93ms
        self.win_lengths = [368, 736, 1488]
        self.transforms = nn.ModuleList()

        for win_len in self.win_lengths:
            n_fft = 2 ** int(np.ceil(np.log2(win_len)))
            mel = T.MelSpectrogram(
                sample_rate=sample_rate,
                n_fft=n_fft,
                win_length=win_len,
                hop_length=hop_length,
                f_min=27.5,
                f_max=16000.0,
                n_mels=n_mels,
                power=1.0,
                center=True,
            )
            self.transforms.append(mel)

    def forward(self, waveform):
        specs = []
        for transform in self.transforms:
            # Scale magnitudes logarithmically
            s = transform(waveform)
            s = torch.log(s + 1e-9)
            specs.append(s)
        return torch.stack(specs, dim=1)


def extract_context(spec, center_frame, context=7):
    # Context of +/- 70ms (7 frames)
    channels, n_mels, total_time = spec.shape
    start = center_frame - context
    end = center_frame + context + 1

    pad_left = max(0, -start)
    pad_right = max(0, end - total_time)

    if pad_left > 0 or pad_right > 0:
        spec = torch.nn.functional.pad(spec, (pad_left, pad_right))
        start += pad_left
        end += pad_left

    return spec[:, :, start:end]