Spaces:
Build error
Build error
File size: 6,536 Bytes
c9f87fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 | import torch
from audiotools import AudioSignal
################################################################################
# Utilities for extracting rhythm feature representations
################################################################################
def _moving_average(x: torch.Tensor, window_length: int):
"""
Smooth features with moving average over frames.
Parameters
----------
x : torch.Tensor
Shape (n_batch, n_feats, n_frames)
window_length : int
Smoothing window length
"""
if window_length <= 1:
return x
n_feats = x.shape[1]
kernel = torch.ones(
(n_feats, 1, window_length),
device=x.device, dtype=x.dtype
) / window_length
pad_left = (window_length - 1) // 2
pad_right = window_length // 2
x_pad = torch.nn.functional.pad(x, (pad_left, pad_right), mode="reflect")
# Smooth separately over feature channels
return torch.nn.functional.conv1d(x_pad, kernel, groups=n_feats)
# The 'original' TRIA features can be recovered using:
# * `slow_ma_ms` = None
# * `post_smooth_ms` = None
# * `legacy_normalize` = True
def rhythm_features(
signal: AudioSignal,
sample_rate: int = 44_100,
n_bands: int = 2,
n_mels: int = 80,
window_length: int = 1024,
hop_length: int = 512,
normalize_quantile: float = 0.98,
quantization_levels: int = 33,
clamp_max: float = 50.0,
eps: float = 1e-8,
slow_ma_ms: float = 100.0,
post_smooth_ms: float = 10.0,
legacy_normalize: bool = False,
):
"""
Extract multi-band 'rhythm' features from audio by adaptively splitting
spectrogram along frequency axis and applying normalization, quantization,
and smoothing / sparsity filtering.
Parameters
----------
signal : AudioSignal
Audio from which to extract features
sample_rate : int
Sample rate at which to extract features
n_bands : int
Number of frequency bands into which to adaptively divide spectrogram
n_mels : int
Number of base mel frequency bins in spectrogram
window_length : int
Spectrogram window length
hop_length : int
Spectrogram hop length
normalize_quantile : float
Optionally normalize each band relative to top-p largest magnitude
rather than absolute max
quantization_levels : int
Number of bins into which feature magnitudes are quantized
clamp_max : float
Maximum allowed spectrogram magnitude
eps : float
For numerical stability
slow_ma_ms : float
Smoothing filter length in milliseconds for transient emphasis (smoothed
features are subtracted)
post_smooth_ms : float
Smoothing filter length in milliseconds for transient smoothing
legacy_normalize : bool
If `True`, use mean/std and sigmoid normalization as described in
original TRIA paper
"""
assert n_bands >= 1
assert quantization_levels >= 2
# Loudness normalization
signal = signal.clone().to_mono().resample(sample_rate).normalize(-16.)
signal.ensure_max_of_audio()
# Clamped mel spectrogram
mel = signal.mel_spectrogram(
n_mels=n_mels,
hop_length=hop_length,
window_length=window_length,
).mean(1) # (n_batch, n_mels, n_frames)
mel = torch.clamp(mel, 0.0, clamp_max)
n_batch, _, n_frames = mel.shape
if legacy_normalize:
# Original normalization: divide by number of mels
mel = mel / n_mels
else:
# Compress logarithmically
mel = torch.log1p(mel) / torch.log1p(torch.tensor(clamp_max, device=mel.device, dtype=mel.dtype))
# Split spectrogram into bands adaptively
energy_per_bin = mel.mean(dim=-1) # (n_batch, n_mels)
cum = energy_per_bin.cumsum(dim=1) # (n_batch, n_mels)
total = cum[:, -1:] # (n_batch, 1)
if n_bands == 1:
bands = mel.sum(dim=1, keepdim=True) # (n_batch, 1, n_frames)
else:
targets = torch.linspace(
1.0 / n_bands, (n_bands - 1) / n_bands, n_bands - 1,
device=mel.device, dtype=mel.dtype
)[None, :] * total # (n_batch, n_bands-1)
edges = torch.searchsorted(cum, targets, right=False) # (n_batch, n_bands-1)
cuts = torch.cat(
[
torch.zeros(n_batch, 1, dtype=torch.long, device=mel.device),
edges + 1,
torch.full((n_batch, 1), mel.size(1), dtype=torch.long, device=mel.device),
],
dim=1
) # (n_batch, n_bands+1)
prefix = mel.cumsum(dim=1) # (n_batch, n_mels, n_frames)
prefix_pad = torch.cat(
[torch.zeros(n_batch, 1, n_frames, device=mel.device, dtype=mel.dtype), prefix],
dim=1
)
a_idx = cuts[:, :-1].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
b_idx = cuts[:, 1: ].unsqueeze(-1).expand(n_batch, n_bands, n_frames)
bands = prefix_pad.gather(1, b_idx) - prefix_pad.gather(1, a_idx) # (n_batch, n_bands, n_frames)
# Emphasize transients by subtracting smoothed features
transient = bands.clone()
to_frames = lambda ms: max(1, int(round((ms / 1000.0) * sample_rate / hop_length)))
if slow_ma_ms is not None:
slow_win = to_frames(slow_ma_ms)
bands_slow = _moving_average(bands, slow_win) # (n_batch, n_bands, n_frames)
transient = torch.relu(bands - bands_slow)
# Apply additional smoothing to transients
if post_smooth_ms is not None:
ps_win = to_frames(post_smooth_ms)
if ps_win > 1:
transient = _moving_average(transient, ps_win)
# Normalize features across time per band
if legacy_normalize:
# Original normalization (mean/std with sigmoid compression)
mean = transient.mean(dim=-1, keepdim=True)
std = transient.std(dim=-1, keepdim=True).clamp_min(eps)
transient = torch.sigmoid((transient - mean) / std)
else:
# Quantile-based normalization
q = torch.quantile(
transient.clamp_min(0.0),
q=normalize_quantile,
dim=-1,
keepdim=True
).clamp_min(eps)
transient = (transient / q).clamp(0.0, 1.0)
# Quantize feature intensities into bins to ensure a tight information
# bottleneck
steps = quantization_levels - 1
return torch.round(transient * steps) / steps
|