File size: 4,795 Bytes
fe17ce1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import torch
import librosa
import torch.nn as nn
import random
from torch import Tensor
from typing import Optional
from torchaudio.transforms import Spectrogram
from torchaudio.transforms import Spectrogram, MelScale
def soxnorm(wav: torch.Tensor, gain, factor=None):
"""sox norm, used in Vocos codes;
"""
wav = torch.clip(wav, max=1, min=-1).float()
if factor is None:
linear_gain = 10 ** (gain / 20)
factor = linear_gain / torch.abs(wav).max().item()
wav = wav * factor
else:
# for clean speech, normed by the noisy factor
wav = wav * factor
assert torch.all(wav.abs() <= 1), f"out wavform is not in [-1, 1], {wav.abs().max()}"
return wav, factor
class InputSTFT(nn.Module):
"""
The STFT of the input signal of CleanMel (STFT coefficients);
In online mode, the recursive normalization is used.
"""
def __init__(
self,
n_fft: int,
n_win: int,
n_hop: int,
center: bool,
normalize: bool,
onesided: bool,
online: bool = False):
super().__init__()
self.online = online
self.stft=Spectrogram(
n_fft=n_fft,
win_length=n_win,
hop_length=n_hop,
normalized=normalize,
center=center,
onesided=onesided,
power=None
)
def forward(self, x):
if self.online:
# recursive normalization
x = self.stft(x)
x_mag = x.abs()
x_norm = recursive_normalization(x_mag)
x = x / x_norm.clamp(min=1e-8)
x = torch.view_as_real(x)
else:
# vocos dBFS normalization
x, x_norm = soxnorm(x, random.randint(-6, -1) if self.training else -3)
x = self.stft(x)
x = torch.view_as_real(x)
return x, x_norm
class LibrosaMelScale(nn.Module):
r"""Pytorch implementation of librosa mel scale to align with common ESPNet ASR models;
You might need to define .
"""
def __init__(self, n_mels, sample_rate, f_min, f_max, n_stft, norm=None, mel_scale="slaney"):
super(LibrosaMelScale, self).__init__()
_mel_options = dict(
sr=sample_rate,
n_fft=(n_stft - 1) * 2,
n_mels=n_mels,
fmin=f_min,
fmax=f_max if f_max is not None else float(sample_rate // 2),
htk=mel_scale=="htk",
norm=norm
)
fb = torch.from_numpy(librosa.filters.mel(**_mel_options).T).float()
self.register_buffer("fb", fb)
def forward(self, specgram):
mel_specgram = torch.matmul(specgram.transpose(-1, -2), self.fb).transpose(-1, -2)
return mel_specgram
class TargetMel(nn.Module):
"""
This class generates the enhancement TARGET mel spectrogram;
"""
def __init__(
self,
sample_rate: int,
n_fft: int,
n_win: int,
n_hop: int,
n_mels: int,
f_min: int,
f_max: int,
power: int,
center: bool,
normalize: bool,
onesided: bool,
mel_norm: str | None,
mel_scale: str,
librosa_mel: bool = True,
online: bool = False,
):
super().__init__()
# This implementation vs torchaudio.transforms.MelSpectrogram: Add librosa melscale
# librosa melscale is numerically different from the torchaudio melscale (x_diff > 1e-5)
self.sample_rate = sample_rate
self.n_fft = n_fft
self.online = online
self.stft = Spectrogram(
n_fft=n_fft,
win_length=n_win,
hop_length=n_hop,
power=None if online else power,
normalized=normalize,
center=center,
onesided=onesided,
)
mel_method = LibrosaMelScale if librosa_mel else MelScale
self.mel_scale = mel_method(
n_mels=n_mels,
sample_rate=sample_rate,
f_min=f_min,
f_max=f_max,
n_stft=n_fft // 2 + 1,
norm=mel_norm,
mel_scale=mel_scale,
)
def forward(self, x: Tensor, x_norm=None):
if self.online:
# apply recursive normalization to target waveform
spectrogram = self.stft(x)
spectrogram = spectrogram / (x_norm + 1e-8)
spectrogram = spectrogram.abs().pow(2) # to power spectrogram
else:
# apply vocos dBFS normalization to target waveform
x, _ = soxnorm(x, None, x_norm)
spectrogram = self.stft(x)
# mel spectrogram
mel_specgram = self.mel_scale(spectrogram)
return mel_specgram
|