Spaces:
Sleeping
Sleeping
| import torch | |
| from audiotools import AudioSignal | |
| ################################################################################ | |
| # Utilities for extracting rhythm feature representations | |
| ################################################################################ | |
| def _moving_average(x: torch.Tensor, window_length: int): | |
| """ | |
| Smooth features with moving average over frames. | |
| Parameters | |
| ---------- | |
| x : torch.Tensor | |
| Shape (n_batch, n_feats, n_frames) | |
| window_length : int | |
| Smoothing window length | |
| """ | |
| if window_length <= 1: | |
| return x | |
| n_feats = x.shape[1] | |
| kernel = torch.ones( | |
| (n_feats, 1, window_length), | |
| device=x.device, dtype=x.dtype | |
| ) / window_length | |
| pad_left = (window_length - 1) // 2 | |
| pad_right = window_length // 2 | |
| x_pad = torch.nn.functional.pad(x, (pad_left, pad_right), mode="reflect") | |
| # Smooth separately over feature channels | |
| return torch.nn.functional.conv1d(x_pad, kernel, groups=n_feats) | |
| # The 'original' TRIA features can be recovered using: | |
| # * `slow_ma_ms` = None | |
| # * `post_smooth_ms` = None | |
| # * `legacy_normalize` = True | |
| def rhythm_features( | |
| signal: AudioSignal, | |
| sample_rate: int = 44_100, | |
| n_bands: int = 2, | |
| n_mels: int = 80, | |
| window_length: int = 1024, | |
| hop_length: int = 512, | |
| normalize_quantile: float = 0.98, | |
| quantization_levels: int = 33, | |
| clamp_max: float = 50.0, | |
| eps: float = 1e-8, | |
| slow_ma_ms: float = 100.0, | |
| post_smooth_ms: float = 10.0, | |
| legacy_normalize: bool = False, | |
| ): | |
| """ | |
| Extract multi-band 'rhythm' features from audio by adaptively splitting | |
| spectrogram along frequency axis and applying normalization, quantization, | |
| and smoothing / sparsity filtering. | |
| Parameters | |
| ---------- | |
| signal : AudioSignal | |
| Audio from which to extract features | |
| sample_rate : int | |
| Sample rate at which to extract features | |
| n_bands : int | |
| Number of frequency bands into which to adaptively divide spectrogram | |
| n_mels : int | |
| Number of base mel frequency bins in spectrogram | |
| window_length : int | |
| Spectrogram window length | |
| hop_length : int | |
| Spectrogram hop length | |
| normalize_quantile : float | |
| Optionally normalize each band relative to top-p largest magnitude | |
| rather than absolute max | |
| quantization_levels : int | |
| Number of bins into which feature magnitudes are quantized | |
| clamp_max : float | |
| Maximum allowed spectrogram magnitude | |
| eps : float | |
| For numerical stability | |
| slow_ma_ms : float | |
| Smoothing filter length in milliseconds for transient emphasis (smoothed | |
| features are subtracted) | |
| post_smooth_ms : float | |
| Smoothing filter length in milliseconds for transient smoothing | |
| legacy_normalize : bool | |
| If `True`, use mean/std and sigmoid normalization as described in | |
| original TRIA paper | |
| """ | |
| assert n_bands >= 1 | |
| assert quantization_levels >= 2 | |
| # Loudness normalization | |
| signal = signal.clone().to_mono().resample(sample_rate).normalize(-16.) | |
| signal.ensure_max_of_audio() | |
| # Clamped mel spectrogram | |
| mel = signal.mel_spectrogram( | |
| n_mels=n_mels, | |
| hop_length=hop_length, | |
| window_length=window_length, | |
| ).mean(1) # (n_batch, n_mels, n_frames) | |
| mel = torch.clamp(mel, 0.0, clamp_max) | |
| n_batch, _, n_frames = mel.shape | |
| if legacy_normalize: | |
| # Original normalization: divide by number of mels | |
| mel = mel / n_mels | |
| else: | |
| # Compress logarithmically | |
| mel = torch.log1p(mel) / torch.log1p(torch.tensor(clamp_max, device=mel.device, dtype=mel.dtype)) | |
| # Split spectrogram into bands adaptively | |
| energy_per_bin = mel.mean(dim=-1) # (n_batch, n_mels) | |
| cum = energy_per_bin.cumsum(dim=1) # (n_batch, n_mels) | |
| total = cum[:, -1:] # (n_batch, 1) | |
| if n_bands == 1: | |
| bands = mel.sum(dim=1, keepdim=True) # (n_batch, 1, n_frames) | |
| else: | |
| targets = torch.linspace( | |
| 1.0 / n_bands, (n_bands - 1) / n_bands, n_bands - 1, | |
| device=mel.device, dtype=mel.dtype | |
| )[None, :] * total # (n_batch, n_bands-1) | |
| edges = torch.searchsorted(cum, targets, right=False) # (n_batch, n_bands-1) | |
| cuts = torch.cat( | |
| [ | |
| torch.zeros(n_batch, 1, dtype=torch.long, device=mel.device), | |
| edges + 1, | |
| torch.full((n_batch, 1), mel.size(1), dtype=torch.long, device=mel.device), | |
| ], | |
| dim=1 | |
| ) # (n_batch, n_bands+1) | |
| prefix = mel.cumsum(dim=1) # (n_batch, n_mels, n_frames) | |
| prefix_pad = torch.cat( | |
| [torch.zeros(n_batch, 1, n_frames, device=mel.device, dtype=mel.dtype), prefix], | |
| dim=1 | |
| ) | |
| a_idx = cuts[:, :-1].unsqueeze(-1).expand(n_batch, n_bands, n_frames) | |
| b_idx = cuts[:, 1: ].unsqueeze(-1).expand(n_batch, n_bands, n_frames) | |
| bands = prefix_pad.gather(1, b_idx) - prefix_pad.gather(1, a_idx) # (n_batch, n_bands, n_frames) | |
| # Emphasize transients by subtracting smoothed features | |
| transient = bands.clone() | |
| to_frames = lambda ms: max(1, int(round((ms / 1000.0) * sample_rate / hop_length))) | |
| if slow_ma_ms is not None: | |
| slow_win = to_frames(slow_ma_ms) | |
| bands_slow = _moving_average(bands, slow_win) # (n_batch, n_bands, n_frames) | |
| transient = torch.relu(bands - bands_slow) | |
| # Apply additional smoothing to transients | |
| if post_smooth_ms is not None: | |
| ps_win = to_frames(post_smooth_ms) | |
| if ps_win > 1: | |
| transient = _moving_average(transient, ps_win) | |
| # Normalize features across time per band | |
| if legacy_normalize: | |
| # Original normalization (mean/std with sigmoid compression) | |
| mean = transient.mean(dim=-1, keepdim=True) | |
| std = transient.std(dim=-1, keepdim=True).clamp_min(eps) | |
| transient = torch.sigmoid((transient - mean) / std) | |
| else: | |
| # Quantile-based normalization | |
| q = torch.quantile( | |
| transient.clamp_min(0.0), | |
| q=normalize_quantile, | |
| dim=-1, | |
| keepdim=True | |
| ).clamp_min(eps) | |
| transient = (transient / q).clamp(0.0, 1.0) | |
| # Quantize feature intensities into bins to ensure a tight information | |
| # bottleneck | |
| steps = quantization_levels - 1 | |
| return torch.round(transient * steps) / steps | |