|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
|
|
|
import numpy as np |
|
|
import torch as th |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as tf |
|
|
import librosa.filters as filters |
|
|
|
|
|
from typing import Optional, Tuple |
|
|
from distutils.version import LooseVersion |
|
|
|
|
|
EPSILON = float(np.finfo(np.float32).eps) |
|
|
TORCH_VERSION = th.__version__ |
|
|
|
|
|
if TORCH_VERSION >= LooseVersion("1.7"): |
|
|
from torch.fft import fft as fft_func |
|
|
else: |
|
|
pass |
|
|
|
|
|
|
|
|
def export_jit(transform: nn.Module) -> nn.Module: |
|
|
""" |
|
|
Export transform module for inference |
|
|
""" |
|
|
export_out = [module for module in transform if module.exportable()] |
|
|
return nn.Sequential(*export_out) |
|
|
|
|
|
|
|
|
def init_window(wnd: str, frame_len: int, device: th.device = "cpu") -> th.Tensor: |
|
|
""" |
|
|
Return window coefficient |
|
|
Args: |
|
|
wnd: window name |
|
|
frame_len: length of the frame |
|
|
""" |
|
|
|
|
|
def sqrthann(frame_len, periodic=True): |
|
|
return th.hann_window(frame_len, periodic=periodic) ** 0.5 |
|
|
|
|
|
if wnd not in ["bartlett", "hann", "hamm", "blackman", "rect", "sqrthann"]: |
|
|
raise RuntimeError(f"Unknown window type: {wnd}") |
|
|
|
|
|
wnd_tpl = { |
|
|
"sqrthann": sqrthann, |
|
|
"hann": th.hann_window, |
|
|
"hamm": th.hamming_window, |
|
|
"blackman": th.blackman_window, |
|
|
"bartlett": th.bartlett_window, |
|
|
"rect": th.ones, |
|
|
} |
|
|
if wnd != "rect": |
|
|
|
|
|
c = wnd_tpl[wnd](frame_len, periodic=True) |
|
|
else: |
|
|
c = wnd_tpl[wnd](frame_len) |
|
|
return c.to(device) |
|
|
|
|
|
|
|
|
def init_kernel( |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
window: th.Tensor, |
|
|
round_pow_of_two: bool = True, |
|
|
normalized: bool = False, |
|
|
inverse: bool = False, |
|
|
mode: str = "librosa", |
|
|
) -> Tuple[th.Tensor, th.Tensor]: |
|
|
""" |
|
|
Return STFT kernels |
|
|
Args: |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
window: window tensor |
|
|
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size |
|
|
normalized: return normalized DFT matrix |
|
|
inverse: return iDFT matrix |
|
|
mode: framing mode (librosa or kaldi) |
|
|
""" |
|
|
if mode not in ["librosa", "kaldi"]: |
|
|
raise ValueError(f"Unsupported mode: {mode}") |
|
|
|
|
|
if round_pow_of_two or mode == "kaldi": |
|
|
fft_size = 2 ** math.ceil(math.log2(frame_len)) |
|
|
else: |
|
|
fft_size = frame_len |
|
|
|
|
|
if mode == "librosa" and fft_size != frame_len: |
|
|
lpad = (fft_size - frame_len) // 2 |
|
|
window = tf.pad(window, (lpad, fft_size - frame_len - lpad)) |
|
|
if normalized: |
|
|
|
|
|
S = fft_size ** 0.5 |
|
|
else: |
|
|
S = 1 |
|
|
|
|
|
if TORCH_VERSION >= LooseVersion("1.7"): |
|
|
K = fft_func(th.eye(fft_size) / S, dim=-1) |
|
|
K = th.stack([K.real, K.imag], dim=-1) |
|
|
else: |
|
|
I = th.stack([th.eye(fft_size), th.zeros(fft_size, fft_size)], dim=-1) |
|
|
K = th.fft(I / S, 1) |
|
|
if mode == "kaldi": |
|
|
K = K[:frame_len] |
|
|
if inverse and not normalized: |
|
|
|
|
|
K = K / fft_size |
|
|
|
|
|
K = th.transpose(K, 0, 2) |
|
|
|
|
|
K = th.reshape(K, (fft_size * 2, 1, K.shape[-1])) |
|
|
return K.to(window.device), window |
|
|
|
|
|
|
|
|
def mel_filter( |
|
|
frame_len: int, |
|
|
round_pow_of_two: bool = True, |
|
|
num_bins: Optional[int] = None, |
|
|
sr: int = 16000, |
|
|
num_mels: int = 80, |
|
|
fmin: float = 0.0, |
|
|
fmax: Optional[float] = None, |
|
|
norm: bool = False, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Return mel filter coefficients |
|
|
Args: |
|
|
frame_len: length of the frame |
|
|
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size |
|
|
num_bins: number of the frequency bins produced by STFT |
|
|
num_mels: number of the mel bands |
|
|
fmin: lowest frequency (in Hz) |
|
|
fmax: highest frequency (in Hz) |
|
|
norm: normalize the mel filter coefficients |
|
|
""" |
|
|
|
|
|
if num_bins is None: |
|
|
N = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len |
|
|
else: |
|
|
N = (num_bins - 1) * 2 |
|
|
|
|
|
freq_upper = sr // 2 |
|
|
if fmax is None: |
|
|
fmax = freq_upper |
|
|
else: |
|
|
fmax = min(fmax + freq_upper if fmax < 0 else fmax, freq_upper) |
|
|
fmin = max(0, fmin) |
|
|
|
|
|
mel = filters.mel( |
|
|
sr, |
|
|
N, |
|
|
n_mels=num_mels, |
|
|
fmax=fmax, |
|
|
fmin=fmin, |
|
|
htk=True, |
|
|
norm="slaney" if norm else None, |
|
|
) |
|
|
|
|
|
return th.tensor(mel, dtype=th.float32) |
|
|
|
|
|
|
|
|
def speed_perturb_filter( |
|
|
src_sr: int, dst_sr: int, cutoff_ratio: float = 0.95, num_zeros: int = 64 |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Return speed perturb filters, reference: |
|
|
https://github.com/danpovey/filtering/blob/master/lilfilter/resampler.py |
|
|
Args: |
|
|
src_sr: sample rate of the source signal |
|
|
dst_sr: sample rate of the target signal |
|
|
Return: |
|
|
weight (Tensor): coefficients of the filter |
|
|
""" |
|
|
if src_sr == dst_sr: |
|
|
raise ValueError(f"src_sr should not be equal to dst_sr: {src_sr}/{dst_sr}") |
|
|
gcd = math.gcd(src_sr, dst_sr) |
|
|
src_sr = src_sr // gcd |
|
|
dst_sr = dst_sr // gcd |
|
|
if src_sr == 1 or dst_sr == 1: |
|
|
raise ValueError("do not support integer downsample/upsample") |
|
|
zeros_per_block = min(src_sr, dst_sr) * cutoff_ratio |
|
|
padding = 1 + int(num_zeros / zeros_per_block) |
|
|
|
|
|
times = ( |
|
|
np.arange(dst_sr)[:, None, None] / float(dst_sr) |
|
|
- np.arange(src_sr)[None, :, None] / float(src_sr) |
|
|
- np.arange(2 * padding + 1)[None, None, :] |
|
|
+ padding |
|
|
) |
|
|
window = np.heaviside(1 - np.abs(times / padding), 0.0) * ( |
|
|
0.5 + 0.5 * np.cos(times / padding * math.pi) |
|
|
) |
|
|
weight = np.sinc(times * zeros_per_block) * window * zeros_per_block / float(src_sr) |
|
|
return th.tensor(weight, dtype=th.float32) |
|
|
|
|
|
|
|
|
def splice_feature( |
|
|
feats: th.Tensor, lctx: int = 1, rctx: int = 1, op: str = "cat" |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Splice feature |
|
|
Args: |
|
|
feats (Tensor): N x ... x T x F, original feature |
|
|
lctx: left context |
|
|
rctx: right context |
|
|
op: operator on feature context |
|
|
Return: |
|
|
splice (Tensor): feature with context padded |
|
|
""" |
|
|
if lctx + rctx == 0: |
|
|
return feats |
|
|
if op not in ["cat", "stack"]: |
|
|
raise ValueError(f"Unknown op for feature splicing: {op}") |
|
|
|
|
|
ctx = [] |
|
|
T = feats.shape[-2] |
|
|
for c in range(-lctx, rctx + 1): |
|
|
idx = th.arange(c, c + T, device=feats.device, dtype=th.int64) |
|
|
idx = th.clamp(idx, min=0, max=T - 1) |
|
|
ctx.append(th.index_select(feats, -2, idx)) |
|
|
if op == "cat": |
|
|
|
|
|
splice = th.cat(ctx, -1) |
|
|
else: |
|
|
|
|
|
splice = th.stack(ctx, -1) |
|
|
return splice |
|
|
|
|
|
|
|
|
def _forward_stft( |
|
|
wav: th.Tensor, |
|
|
kernel: th.Tensor, |
|
|
window: th.Tensor, |
|
|
return_polar: bool = False, |
|
|
pre_emphasis: float = 0, |
|
|
frame_hop: int = 256, |
|
|
onesided: bool = False, |
|
|
center: bool = False, |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
STFT function implemented by conv1d (not efficient, but we don't care during training) |
|
|
Args: |
|
|
wav (Tensor): N x (C) x S |
|
|
kernel (Tensor): STFT transform kernels, from init_kernel(...) |
|
|
return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor |
|
|
pre_emphasis: factor of preemphasis |
|
|
frame_hop: frame hop size in number samples |
|
|
onesided: return half FFT bins |
|
|
center: if true, we assumed to have centered frames |
|
|
Return: |
|
|
transform (Tensor): STFT transform results |
|
|
""" |
|
|
wav_dim = wav.dim() |
|
|
if wav_dim not in [2, 3]: |
|
|
raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") |
|
|
|
|
|
|
|
|
N, S = wav.shape[0], wav.shape[-1] |
|
|
wav = wav.view(-1, 1, S) |
|
|
|
|
|
if center: |
|
|
pad = kernel.shape[-1] // 2 |
|
|
|
|
|
wav = tf.pad(wav, (pad, pad), mode="reflect") |
|
|
|
|
|
kernel = kernel * window |
|
|
if pre_emphasis > 0: |
|
|
|
|
|
frames = tf.unfold( |
|
|
wav[:, None], (1, kernel.shape[-1]), stride=frame_hop, padding=0 |
|
|
) |
|
|
|
|
|
frames[:, 1:] = frames[:, 1:] - pre_emphasis * frames[:, :-1] |
|
|
frames[:, 0] *= 1 - pre_emphasis |
|
|
|
|
|
packed = th.matmul(kernel[:, 0][None, ...], frames) |
|
|
else: |
|
|
packed = tf.conv1d(wav, kernel, stride=frame_hop, padding=0) |
|
|
|
|
|
if wav_dim == 3: |
|
|
packed = packed.view(N, -1, packed.shape[-2], packed.shape[-1]) |
|
|
|
|
|
real, imag = th.chunk(packed, 2, dim=-2) |
|
|
|
|
|
if onesided: |
|
|
num_bins = kernel.shape[0] // 4 + 1 |
|
|
real = real[..., :num_bins, :] |
|
|
imag = imag[..., :num_bins, :] |
|
|
if return_polar: |
|
|
mag = (real ** 2 + imag ** 2 + eps) ** 0.5 |
|
|
pha = th.atan2(imag, real) |
|
|
return th.stack([mag, pha], dim=-1) |
|
|
else: |
|
|
return th.stack([real, imag], dim=-1) |
|
|
|
|
|
|
|
|
def _inverse_stft( |
|
|
transform: th.Tensor, |
|
|
kernel: th.Tensor, |
|
|
window: th.Tensor, |
|
|
return_polar: bool = False, |
|
|
frame_hop: int = 256, |
|
|
onesided: bool = False, |
|
|
center: bool = False, |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
iSTFT function implemented by conv1d |
|
|
Args: |
|
|
transform (Tensor): STFT transform results |
|
|
kernel (Tensor): STFT transform kernels, from init_kernel(...) |
|
|
return_polar (bool): keep same with the one in _forward_stft |
|
|
frame_hop: frame hop size in number samples |
|
|
onesided: return half FFT bins |
|
|
center: used in _forward_stft |
|
|
Return: |
|
|
wav (Tensor), N x S |
|
|
""" |
|
|
|
|
|
transform_dim = transform.dim() |
|
|
|
|
|
if transform_dim == 3: |
|
|
transform = th.unsqueeze(transform, 0) |
|
|
if transform_dim != 4: |
|
|
raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") |
|
|
|
|
|
if return_polar: |
|
|
real = transform[..., 0] * th.cos(transform[..., 1]) |
|
|
imag = transform[..., 0] * th.sin(transform[..., 1]) |
|
|
else: |
|
|
real, imag = transform[..., 0], transform[..., 1] |
|
|
|
|
|
if onesided: |
|
|
|
|
|
reverse = range(kernel.shape[0] // 4 - 1, 0, -1) |
|
|
|
|
|
real = th.cat([real, real[:, reverse]], 1) |
|
|
imag = th.cat([imag, -imag[:, reverse]], 1) |
|
|
|
|
|
packed = th.cat([real, imag], dim=1) |
|
|
|
|
|
wav = tf.conv_transpose1d(packed, kernel * window, stride=frame_hop, padding=0) |
|
|
|
|
|
|
|
|
num_frames = packed.shape[-1] |
|
|
win_length = window.shape[0] |
|
|
|
|
|
win = th.repeat_interleave(window[..., None] ** 2, num_frames, dim=-1) |
|
|
|
|
|
|
|
|
I = th.eye(win_length, device=win.device)[:, None] |
|
|
denorm = tf.conv_transpose1d(win[None, ...], I, stride=frame_hop, padding=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if center: |
|
|
pad = kernel.shape[-1] // 2 |
|
|
wav = wav[..., pad:-pad] |
|
|
denorm = denorm[..., pad:-pad] |
|
|
wav = wav / (denorm + eps) |
|
|
|
|
|
return wav.squeeze(1) |
|
|
|
|
|
|
|
|
def _pytorch_stft( |
|
|
wav: th.Tensor, |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
n_fft: int = 512, |
|
|
return_polar: bool = False, |
|
|
window: str = "sqrthann", |
|
|
normalized: bool = False, |
|
|
onesided: bool = True, |
|
|
center: bool = False, |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Wrapper of PyTorch STFT function |
|
|
Args: |
|
|
wav (Tensor): source audio signal |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
n_fft: number of the FFT size |
|
|
return_polar: return the results in polar coordinate |
|
|
window: window tensor |
|
|
center: same definition with the parameter in librosa.stft |
|
|
normalized: use normalized DFT kernel |
|
|
onesided: output onesided STFT |
|
|
Return: |
|
|
transform (Tensor), STFT transform results |
|
|
""" |
|
|
if TORCH_VERSION < LooseVersion("1.7"): |
|
|
raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") |
|
|
wav_dim = wav.dim() |
|
|
if wav_dim not in [2, 3]: |
|
|
raise RuntimeError(f"STFT expect 2D/3D tensor, but got {wav_dim:d}D") |
|
|
|
|
|
wav = wav.view(-1, wav.shape[-1]) |
|
|
|
|
|
stft = th.stft( |
|
|
wav, |
|
|
n_fft, |
|
|
hop_length=frame_hop, |
|
|
win_length=window.shape[-1], |
|
|
window=window, |
|
|
center=center, |
|
|
normalized=normalized, |
|
|
onesided=onesided, |
|
|
return_complex=False, |
|
|
) |
|
|
if wav_dim == 3: |
|
|
N, F, T, _ = stft.shape |
|
|
stft = stft.view(N, -1, F, T, 2) |
|
|
|
|
|
if not return_polar: |
|
|
return stft |
|
|
|
|
|
real, imag = stft[..., 0], stft[..., 1] |
|
|
mag = (real ** 2 + imag ** 2 + eps) ** 0.5 |
|
|
pha = th.atan2(imag, real) |
|
|
return th.stack([mag, pha], dim=-1) |
|
|
|
|
|
|
|
|
def _pytorch_istft( |
|
|
transform: th.Tensor, |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
window: th.Tensor, |
|
|
n_fft: int = 512, |
|
|
return_polar: bool = False, |
|
|
normalized: bool = False, |
|
|
onesided: bool = True, |
|
|
center: bool = False, |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Wrapper of PyTorch iSTFT function |
|
|
Args: |
|
|
transform (Tensor): results of STFT |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
window: window tensor |
|
|
n_fft: number of the FFT size |
|
|
return_polar: keep same with _pytorch_stft |
|
|
center: same definition with the parameter in librosa.stft |
|
|
normalized: use normalized DFT kernel |
|
|
onesided: output onesided STFT |
|
|
Return: |
|
|
wav (Tensor): synthetic audio |
|
|
""" |
|
|
if TORCH_VERSION < LooseVersion("1.7"): |
|
|
raise RuntimeError("Can not use this function as TORCH_VERSION < 1.7") |
|
|
|
|
|
transform_dim = transform.dim() |
|
|
|
|
|
if transform_dim == 3: |
|
|
transform = th.unsqueeze(transform, 0) |
|
|
if transform_dim != 4: |
|
|
raise RuntimeError(f"Expect 4D tensor, but got {transform_dim}D") |
|
|
|
|
|
if return_polar: |
|
|
real = transform[..., 0] * th.cos(transform[..., 1]) |
|
|
imag = transform[..., 0] * th.sin(transform[..., 1]) |
|
|
transform = th.stack([real, imag], -1) |
|
|
|
|
|
stft = th.view_as_complex(transform) |
|
|
|
|
|
wav = th.istft( |
|
|
stft, |
|
|
n_fft, |
|
|
hop_length=frame_hop, |
|
|
win_length=window.shape[-1], |
|
|
window=window, |
|
|
center=center, |
|
|
normalized=normalized, |
|
|
onesided=onesided, |
|
|
return_complex=False, |
|
|
) |
|
|
return wav |
|
|
|
|
|
|
|
|
def forward_stft( |
|
|
wav: th.Tensor, |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
window: str = "sqrthann", |
|
|
round_pow_of_two: bool = True, |
|
|
return_polar: bool = False, |
|
|
pre_emphasis: float = 0, |
|
|
normalized: bool = False, |
|
|
onesided: bool = True, |
|
|
center: bool = False, |
|
|
mode: str = "librosa", |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
STFT function implementation, equals to STFT layer |
|
|
Args: |
|
|
wav: source audio signal |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
return_polar: return [magnitude; phase] Tensor or [real; imag] Tensor |
|
|
window: window name |
|
|
center: center flag (similar with that in librosa.stft) |
|
|
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size |
|
|
pre_emphasis: factor of preemphasis |
|
|
normalized: use normalized DFT kernel |
|
|
onesided: output onesided STFT |
|
|
inverse: using iDFT kernel (for iSTFT) |
|
|
mode: STFT mode, "kaldi" or "librosa" or "torch" |
|
|
Return: |
|
|
transform: results of STFT |
|
|
""" |
|
|
window = init_window(window, frame_len, device=wav.device) |
|
|
if mode == "torch": |
|
|
n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len |
|
|
return _pytorch_stft( |
|
|
wav, |
|
|
frame_len, |
|
|
frame_hop, |
|
|
n_fft=n_fft, |
|
|
return_polar=return_polar, |
|
|
window=window, |
|
|
normalized=normalized, |
|
|
onesided=onesided, |
|
|
center=center, |
|
|
eps=eps, |
|
|
) |
|
|
else: |
|
|
kernel, window = init_kernel( |
|
|
frame_len, |
|
|
frame_hop, |
|
|
window=window, |
|
|
round_pow_of_two=round_pow_of_two, |
|
|
normalized=normalized, |
|
|
inverse=False, |
|
|
mode=mode, |
|
|
) |
|
|
return _forward_stft( |
|
|
wav, |
|
|
kernel, |
|
|
window, |
|
|
return_polar=return_polar, |
|
|
frame_hop=frame_hop, |
|
|
pre_emphasis=pre_emphasis, |
|
|
onesided=onesided, |
|
|
center=center, |
|
|
eps=eps, |
|
|
) |
|
|
|
|
|
|
|
|
def inverse_stft( |
|
|
transform: th.Tensor, |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
return_polar: bool = False, |
|
|
window: str = "sqrthann", |
|
|
round_pow_of_two: bool = True, |
|
|
normalized: bool = False, |
|
|
onesided: bool = True, |
|
|
center: bool = False, |
|
|
mode: str = "librosa", |
|
|
eps: float = EPSILON, |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
iSTFT function implementation, equals to iSTFT layer |
|
|
Args: |
|
|
transform: results of STFT |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
return_polar: keep same with function forward_stft(...) |
|
|
window: window name |
|
|
center: center flag (similar with that in librosa.stft) |
|
|
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size |
|
|
normalized: use normalized DFT kernel |
|
|
onesided: output onesided STFT |
|
|
mode: STFT mode, "kaldi" or "librosa" or "torch" |
|
|
Return: |
|
|
wav: synthetic signals |
|
|
""" |
|
|
window = init_window(window, frame_len, device=transform.device) |
|
|
if mode == "torch": |
|
|
n_fft = 2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len |
|
|
return _pytorch_istft( |
|
|
transform, |
|
|
frame_len, |
|
|
frame_hop, |
|
|
n_fft=n_fft, |
|
|
return_polar=return_polar, |
|
|
window=window, |
|
|
normalized=normalized, |
|
|
onesided=onesided, |
|
|
center=center, |
|
|
eps=eps, |
|
|
) |
|
|
else: |
|
|
kernel, window = init_kernel( |
|
|
frame_len, |
|
|
frame_hop, |
|
|
window, |
|
|
round_pow_of_two=round_pow_of_two, |
|
|
normalized=normalized, |
|
|
inverse=True, |
|
|
mode=mode, |
|
|
) |
|
|
return _inverse_stft( |
|
|
transform, |
|
|
kernel, |
|
|
window, |
|
|
return_polar=return_polar, |
|
|
frame_hop=frame_hop, |
|
|
onesided=onesided, |
|
|
center=center, |
|
|
eps=eps, |
|
|
) |
|
|
|
|
|
|
|
|
class STFTBase(nn.Module): |
|
|
""" |
|
|
Base layer for (i)STFT |
|
|
Args: |
|
|
frame_len: length of the frame |
|
|
frame_hop: hop size between frames |
|
|
window: window name |
|
|
center: center flag (similar with that in librosa.stft) |
|
|
round_pow_of_two: if true, choose round(#power_of_two) as the FFT size |
|
|
normalized: use normalized DFT kernel |
|
|
pre_emphasis: factor of preemphasis |
|
|
mode: STFT mode, "kaldi" or "librosa" or "torch" |
|
|
onesided: output onesided STFT |
|
|
inverse: using iDFT kernel (for iSTFT) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
frame_len: int, |
|
|
frame_hop: int, |
|
|
window: str = "sqrthann", |
|
|
round_pow_of_two: bool = True, |
|
|
normalized: bool = False, |
|
|
pre_emphasis: float = 0, |
|
|
onesided: bool = True, |
|
|
inverse: bool = False, |
|
|
center: bool = False, |
|
|
mode: str = "librosa", |
|
|
) -> None: |
|
|
super(STFTBase, self).__init__() |
|
|
if mode != "torch": |
|
|
K, w = init_kernel( |
|
|
frame_len, |
|
|
frame_hop, |
|
|
init_window(window, frame_len), |
|
|
round_pow_of_two=round_pow_of_two, |
|
|
normalized=normalized, |
|
|
inverse=inverse, |
|
|
mode=mode, |
|
|
) |
|
|
self.K = nn.Parameter(K, requires_grad=False) |
|
|
self.w = nn.Parameter(w, requires_grad=False) |
|
|
self.num_bins = self.K.shape[0] // 4 + 1 |
|
|
self.pre_emphasis = pre_emphasis |
|
|
self.win_length = self.K.shape[2] |
|
|
else: |
|
|
self.K = None |
|
|
w = init_window(window, frame_len) |
|
|
self.w = nn.Parameter(w, requires_grad=False) |
|
|
fft_size = ( |
|
|
2 ** math.ceil(math.log2(frame_len)) if round_pow_of_two else frame_len |
|
|
) |
|
|
self.num_bins = fft_size // 2 + 1 |
|
|
self.pre_emphasis = 0 |
|
|
self.win_length = fft_size |
|
|
self.frame_len = frame_len |
|
|
self.frame_hop = frame_hop |
|
|
self.window = window |
|
|
self.normalized = normalized |
|
|
self.onesided = onesided |
|
|
self.center = center |
|
|
self.mode = mode |
|
|
|
|
|
def num_frames(self, wav_len: th.Tensor) -> th.Tensor: |
|
|
""" |
|
|
Compute number of the frames |
|
|
""" |
|
|
assert th.sum(wav_len <= self.win_length) == 0 |
|
|
if self.center: |
|
|
wav_len += self.win_length |
|
|
return ( |
|
|
th.div(wav_len - self.win_length, self.frame_hop, rounding_mode="trunc") + 1 |
|
|
) |
|
|
|
|
|
def extra_repr(self) -> str: |
|
|
str_repr = ( |
|
|
f"num_bins={self.num_bins}, win_length={self.win_length}, " |
|
|
+ f"stride={self.frame_hop}, window={self.window}, " |
|
|
+ f"center={self.center}, mode={self.mode}" |
|
|
) |
|
|
if not self.onesided: |
|
|
str_repr += f", onesided={self.onesided}" |
|
|
if self.pre_emphasis > 0: |
|
|
str_repr += f", pre_emphasis={self.pre_emphasis}" |
|
|
if self.normalized: |
|
|
str_repr += f", normalized={self.normalized}" |
|
|
return str_repr |
|
|
|
|
|
|
|
|
class STFT(STFTBase): |
|
|
""" |
|
|
Short-time Fourier Transform as a Layer |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super(STFT, self).__init__(*args, inverse=False, **kwargs) |
|
|
|
|
|
def forward( |
|
|
self, wav: th.Tensor, return_polar: bool = False, eps: float = EPSILON |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Accept (single or multiple channel) raw waveform and output magnitude and phase |
|
|
Args |
|
|
wav (Tensor) input signal, N x (C) x S |
|
|
Return |
|
|
transform (Tensor), N x (C) x F x T x 2 |
|
|
""" |
|
|
if self.mode == "torch": |
|
|
return _pytorch_stft( |
|
|
wav, |
|
|
self.frame_len, |
|
|
self.frame_hop, |
|
|
n_fft=(self.num_bins - 1) * 2, |
|
|
return_polar=return_polar, |
|
|
window=self.w, |
|
|
normalized=self.normalized, |
|
|
onesided=self.onesided, |
|
|
center=self.center, |
|
|
eps=eps, |
|
|
) |
|
|
else: |
|
|
return _forward_stft( |
|
|
wav, |
|
|
self.K, |
|
|
self.w, |
|
|
return_polar=return_polar, |
|
|
frame_hop=self.frame_hop, |
|
|
pre_emphasis=self.pre_emphasis, |
|
|
onesided=self.onesided, |
|
|
center=self.center, |
|
|
eps=eps, |
|
|
) |
|
|
|
|
|
|
|
|
class iSTFT(STFTBase): |
|
|
""" |
|
|
Inverse Short-time Fourier Transform as a Layer |
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super(iSTFT, self).__init__(*args, inverse=True, **kwargs) |
|
|
|
|
|
def forward( |
|
|
self, transform: th.Tensor, return_polar: bool = False, eps: float = EPSILON |
|
|
) -> th.Tensor: |
|
|
""" |
|
|
Accept phase & magnitude and output raw waveform |
|
|
Args |
|
|
transform (Tensor): STFT output, N x F x T x 2 |
|
|
Return |
|
|
s (Tensor): N x S |
|
|
""" |
|
|
if self.mode == "torch": |
|
|
return _pytorch_istft( |
|
|
transform, |
|
|
self.frame_len, |
|
|
self.frame_hop, |
|
|
n_fft=(self.num_bins - 1) * 2, |
|
|
return_polar=return_polar, |
|
|
window=self.w, |
|
|
normalized=self.normalized, |
|
|
onesided=self.onesided, |
|
|
center=self.center, |
|
|
eps=eps, |
|
|
) |
|
|
else: |
|
|
return _inverse_stft( |
|
|
transform, |
|
|
self.K, |
|
|
self.w, |
|
|
return_polar=return_polar, |
|
|
frame_hop=self.frame_hop, |
|
|
onesided=self.onesided, |
|
|
center=self.center, |
|
|
eps=eps, |
|
|
) |
|
|
|