Spaces:
Running on Zero
Running on Zero
File size: 4,429 Bytes
c7f3ffb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | import torch
import math
import numpy as np
from librosa.filters import mel as librosa_mel_fn
import torch.nn as nn
from typing import Any, Dict, Optional
def dynamic_range_compression(x, C=1, clip_val=1e-5):
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
def dynamic_range_decompression(x, C=1):
return np.exp(x) / C
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return torch.log(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.exp(x) / C
def spectral_normalize_torch(magnitudes):
output = dynamic_range_compression_torch(magnitudes)
return output
def spectral_de_normalize_torch(magnitudes):
output = dynamic_range_decompression_torch(magnitudes)
return output
class MelSpectrogram(nn.Module):
def __init__(
self,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
center=False,
):
super(MelSpectrogram, self).__init__()
self.n_fft = n_fft
self.hop_size = hop_size
self.win_size = win_size
self.sampling_rate = sampling_rate
self.num_mels = num_mels
self.fmin = fmin
self.fmax = fmax
self.center = center
mel_basis = {}
hann_window = {}
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis = torch.from_numpy(mel).float()
hann_window = torch.hann_window(win_size)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("hann_window", hann_window)
def forward(self, y):
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window,
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
spec = torch.view_as_real(spec)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
spec = torch.matmul(self.mel_basis, spec)
spec = spectral_normalize_torch(spec)
return spec
def load_mel_spectrogram():
return load_mel_spectrogram_from_cfg(None)
def _get_from_mapping(cfg: Any, key: str, default: Any = None) -> Any:
"""Safely read a field from a dict/OmegaConf-like object."""
if cfg is None:
return default
if isinstance(cfg, dict):
return cfg.get(key, default)
return getattr(cfg, key, default)
def load_mel_spectrogram_from_cfg(audio_cfg: Optional[Any] = None) -> MelSpectrogram:
"""Build MelSpectrogram from `audio_config`-like config.
Expected keys (either in dict or Hydra/OmegaConf object):
- hop_size, sample_rate (or sampling_rate), n_fft, num_mels, win_size, fmin, fmax
"""
# Defaults keep current behavior.
mel_cfg: Dict[str, Any] = {
"hop_size": _get_from_mapping(audio_cfg, "hop_size", 480),
"sampling_rate": _get_from_mapping(
audio_cfg,
"sampling_rate",
_get_from_mapping(audio_cfg, "sample_rate", 24000),
),
"n_fft": _get_from_mapping(audio_cfg, "n_fft", 1920),
"num_mels": _get_from_mapping(audio_cfg, "num_mels", 128),
"win_size": _get_from_mapping(audio_cfg, "win_size", 1920),
"fmin": _get_from_mapping(audio_cfg, "fmin", 0),
"fmax": _get_from_mapping(audio_cfg, "fmax", 12000),
}
mel_model = MelSpectrogram(**mel_cfg)
mel_model.eval()
return mel_model
class MelSpectrogramEncoder(nn.Module):
def __init__(self, audio_config: dict | None = None):
super(MelSpectrogramEncoder, self).__init__()
self.model = load_mel_spectrogram_from_cfg(audio_config)
audio_config = audio_config or {}
self.mel_mean = audio_config.get("mel_mean", -4.92)
self.mel_var = audio_config.get("mel_var", 8.14)
def forward(self, x):
x = self.model(x).transpose(1, 2)
x = (x - self.mel_mean) / math.sqrt(self.mel_var)
return x |