File size: 6,536 Bytes
c9f87fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import torch
from audiotools import AudioSignal


################################################################################
# Utilities for extracting rhythm feature representations
################################################################################


def _moving_average(x: torch.Tensor, window_length: int):
    """
    Smooth features with moving average over frames.

    Parameters
    ----------
    x : torch.Tensor
        Shape (n_batch, n_feats, n_frames)
    window_length : int
        Smoothing window length
    """
    if window_length <= 1:
        return x
    n_feats = x.shape[1]
    kernel = torch.ones(
        (n_feats, 1, window_length), 
        device=x.device, dtype=x.dtype
    ) / window_length

    pad_left = (window_length - 1) // 2
    pad_right = window_length // 2
    x_pad = torch.nn.functional.pad(x, (pad_left, pad_right), mode="reflect")

    # Smooth separately over feature channels
    return torch.nn.functional.conv1d(x_pad, kernel, groups=n_feats)


# The 'original' TRIA features can be recovered using:
#  * `slow_ma_ms` = None
#  * `post_smooth_ms` = None
#  * `legacy_normalize` = True
def rhythm_features(
    signal: AudioSignal,
    sample_rate: int = 44_100,
    n_bands: int = 2,
    n_mels: int = 80,
    window_length: int = 1024,
    hop_length: int = 512,
    normalize_quantile: float = 0.98,
    quantization_levels: int = 33,
    clamp_max: float = 50.0,
    eps: float = 1e-8,
    slow_ma_ms: float = 100.0,
    post_smooth_ms: float = 10.0,
    legacy_normalize: bool = False,
):
    """
    Extract multi-band 'rhythm' features from audio by adaptively splitting 
    spectrogram along frequency axis and applying normalization, quantization,
    and smoothing / sparsity filtering.
    
    Parameters
    ----------
    signal : AudioSignal
        Audio from which to extract features
    sample_rate : int
        Sample rate at which to extract features
    n_bands : int
        Number of frequency bands into which to adaptively divide spectrogram
    n_mels : int
        Number of base mel frequency bins in spectrogram
    window_length : int
        Spectrogram window length
    hop_length : int
        Spectrogram hop length
    normalize_quantile : float
        Optionally normalize each band relative to top-p largest magnitude 
        rather than absolute max
    quantization_levels : int
        Number of bins into which feature magnitudes are quantized
    clamp_max : float
        Maximum allowed spectrogram magnitude
    eps : float
        For numerical stability
    slow_ma_ms : float
        Smoothing filter length in milliseconds for transient emphasis (smoothed
        features are subtracted)
    post_smooth_ms : float
        Smoothing filter length in milliseconds for transient smoothing
    legacy_normalize : bool
        If `True`, use mean/std and sigmoid normalization as described in 
        original TRIA paper
    """

    assert n_bands >= 1
    assert quantization_levels >= 2

    # Loudness normalization
    signal = signal.clone().to_mono().resample(sample_rate).normalize(-16.)
    signal.ensure_max_of_audio()

    # Clamped mel spectrogram
    mel = signal.mel_spectrogram(
        n_mels=n_mels,
        hop_length=hop_length,
        window_length=window_length,
    ).mean(1)  # (n_batch, n_mels, n_frames)
    mel = torch.clamp(mel, 0.0, clamp_max)

    n_batch, _, n_frames = mel.shape

    if legacy_normalize:
        # Original normalization: divide by number of mels
        mel = mel / n_mels
    else:
        # Compress logarithmically
        mel = torch.log1p(mel) / torch.log1p(torch.tensor(clamp_max, device=mel.device, dtype=mel.dtype))

    # Split spectrogram into bands adaptively
    energy_per_bin = mel.mean(dim=-1)          # (n_batch, n_mels)
    cum = energy_per_bin.cumsum(dim=1)         # (n_batch, n_mels)
    total = cum[:, -1:]                        # (n_batch, 1)

    if n_bands == 1:
        bands = mel.sum(dim=1, keepdim=True)                   # (n_batch, 1, n_frames)
    else:
        targets = torch.linspace(
            1.0 / n_bands, (n_bands - 1) / n_bands, n_bands - 1,
            device=mel.device, dtype=mel.dtype
        )[None, :] * total                                     # (n_batch, n_bands-1)

        edges = torch.searchsorted(cum, targets, right=False)  # (n_batch, n_bands-1)

        cuts = torch.cat(
            [
                torch.zeros(n_batch, 1, dtype=torch.long, device=mel.device),
                edges + 1,
                torch.full((n_batch, 1), mel.size(1), dtype=torch.long, device=mel.device),
            ],
            dim=1
        )  # (n_batch, n_bands+1)

        prefix = mel.cumsum(dim=1)  # (n_batch, n_mels, n_frames)
        prefix_pad = torch.cat(
            [torch.zeros(n_batch, 1, n_frames, device=mel.device, dtype=mel.dtype), prefix],
            dim=1
        )

        a_idx = cuts[:, :-1].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
        b_idx = cuts[:, 1: ].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
        bands = prefix_pad.gather(1, b_idx) - prefix_pad.gather(1, a_idx)  # (n_batch, n_bands, n_frames)

    # Emphasize transients by subtracting smoothed features
    transient = bands.clone()
    to_frames = lambda ms: max(1, int(round((ms / 1000.0) * sample_rate / hop_length)))

    if slow_ma_ms is not None:
        slow_win = to_frames(slow_ma_ms)
        bands_slow = _moving_average(bands, slow_win)  # (n_batch, n_bands, n_frames)
        transient = torch.relu(bands - bands_slow)

    # Apply additional smoothing to transients
    if post_smooth_ms is not None:
        ps_win = to_frames(post_smooth_ms)
        if ps_win > 1:
            transient = _moving_average(transient, ps_win)

    # Normalize features across time per band
    if legacy_normalize:
        # Original normalization (mean/std with sigmoid compression)
        mean = transient.mean(dim=-1, keepdim=True)
        std = transient.std(dim=-1, keepdim=True).clamp_min(eps)
        transient = torch.sigmoid((transient - mean) / std)
    
    else:
        # Quantile-based normalization
        q = torch.quantile(
            transient.clamp_min(0.0),
            q=normalize_quantile,
            dim=-1,
            keepdim=True
        ).clamp_min(eps)
        transient = (transient / q).clamp(0.0, 1.0)

    # Quantize feature intensities into bins to ensure a tight information
    # bottleneck
    steps = quantization_levels - 1
    return torch.round(transient * steps) / steps