File size: 3,336 Bytes
df0ae2d
 
 
13b9d6a
df0ae2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
from torch import nn
import torch.nn.functional as F
from typing import List


class LogSpectralCentroid(nn.Module):
    def forward(self, spec):
        # assume spec is of shape (..., freq, time)
        freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device)
        spec_T = spec.transpose(-1, -2)
        normalised_spec = spec_T / spec_T.sum(-1, keepdim=True).clamp_min(1e-8)
        return torch.log(normalised_spec @ freqs + 1e-8).unsqueeze(-2)


class LogSpectralFlatness(nn.Module):
    def forward(self, spec):
        # assume spec is of shape (..., freq, time)
        spec_pow = spec.clamp(1e-8).square()
        log_gmean = spec_pow.log().mean(-2, keepdim=True)
        log_amean = spec_pow.mean(-2, keepdim=True).log()
        return log_gmean - log_amean


class LogSpectralBandwidth(nn.Module):
    def __init__(self):
        super().__init__()
        self.centroid = LogSpectralCentroid()

    def forward(self, spec):
        # assume spec is of shape (..., freq, time)
        freqs = torch.linspace(0, 1, spec.size(-2), device=spec.device)
        centroid = self.centroid(spec).exp()
        normalised_spec = spec / spec.sum(-2, keepdim=True).clamp_min(1e-8)
        return (
            torch.log(
                (normalised_spec * (freqs[:, None] - centroid).square()).sum(
                    -2, keepdim=True
                )
                + 1e-8
            )
            * 0.5
        )


class LogRMS(nn.Module):
    def forward(self, frame):
        return torch.log(frame.square().mean(-2, keepdim=True).sqrt() + 1e-8)


class LogCrest(nn.Module):
    def __init__(self):
        super().__init__()
        self.rms = LogRMS()

    def forward(self, frame):
        log_rms = self.rms(frame)
        return frame.abs().amax(-2, keepdim=True).add(1e-8).log() - log_rms


class LogSpread(nn.Module):
    def __init__(self):
        super().__init__()
        self.rms = LogRMS()

    def forward(self, frame):
        log_rms = self.rms(frame)
        return (frame.abs().add(1e-8).log() - log_rms).mean(-2, keepdim=True)


class MapAndMerge(nn.Module):
    def __init__(self, funcs: List[nn.Module], dim=-1):
        super().__init__()
        self.funcs = nn.ModuleList(funcs)
        self.dim = dim

    def forward(self, frame):
        return torch.cat([f(frame) for f in self.funcs], dim=self.dim)


class Frame(nn.Module):
    def __init__(self, frame_length, hop_length, center=False):
        super().__init__()
        self.frame_length = frame_length
        self.hop_length = hop_length
        self.center = center

    def forward(self, waveform):
        if self.center:
            waveform = F.pad(waveform, (self.frame_length // 2, self.frame_length // 2))
        return waveform.unfold(-1, self.frame_length, self.hop_length).transpose(-1, -2)


class StatisticReduction(nn.Module):
    def __init__(self, dim=-1):
        super().__init__()
        self.dim = dim

    def forward(self, x):
        mu = x.mean(self.dim, keepdim=True)
        diffs = x - mu
        std = diffs.square().mean(self.dim, keepdim=True).sqrt()
        zscores = diffs / std.clamp_min(1e-8)
        skews = zscores.pow(3).mean(self.dim, keepdim=True)
        kurts = zscores.pow(4).mean(self.dim, keepdim=True) - 3
        return torch.cat([mu, std, skews, kurts], dim=self.dim)