| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | import torchaudio |
| | import numpy as np |
| | from librosa.filters import mel as librosa_mel_fn |
| |
|
| | from .base import BaseModule |
| |
|
| |
|
| | def mse_loss(x, y, mask, n_feats): |
| | loss = torch.sum(((x - y)**2) * mask) |
| | return loss / (torch.sum(mask) * n_feats) |
| |
|
| |
|
| | def sequence_mask(length, max_length=None): |
| | if max_length is None: |
| | max_length = length.max() |
| | x = torch.arange(int(max_length), dtype=length.dtype, device=length.device) |
| | return x.unsqueeze(0) < length.unsqueeze(1) |
| |
|
| |
|
| | def convert_pad_shape(pad_shape): |
| | l = pad_shape[::-1] |
| | pad_shape = [item for sublist in l for item in sublist] |
| | return pad_shape |
| |
|
| |
|
| | def fix_len_compatibility(length, num_downsamplings_in_unet=2): |
| | while True: |
| | if length % (2**num_downsamplings_in_unet) == 0: |
| | return length |
| | length += 1 |
| |
|
| |
|
| | class PseudoInversion(BaseModule): |
| | def __init__(self, n_mels, sampling_rate, n_fft): |
| | super(PseudoInversion, self).__init__() |
| | self.n_mels = n_mels |
| | self.sampling_rate = sampling_rate |
| | self.n_fft = n_fft |
| | mel_basis = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=n_mels, fmin=0, fmax=8000) |
| | mel_basis_inverse = np.linalg.pinv(mel_basis) |
| | mel_basis_inverse = torch.from_numpy(mel_basis_inverse).float() |
| | self.register_buffer("mel_basis_inverse", mel_basis_inverse) |
| |
|
| | def forward(self, log_mel_spectrogram): |
| | mel_spectrogram = torch.exp(log_mel_spectrogram) |
| | stftm = torch.matmul(self.mel_basis_inverse, mel_spectrogram) |
| | return stftm |
| |
|
| |
|
| | class InitialReconstruction(BaseModule): |
| | def __init__(self, n_fft, hop_size): |
| | super(InitialReconstruction, self).__init__() |
| | self.n_fft = n_fft |
| | self.hop_size = hop_size |
| | window = torch.hann_window(n_fft).float() |
| | self.register_buffer("window", window) |
| |
|
| | def forward(self, stftm): |
| | real_part = torch.ones_like(stftm, device=stftm.device) |
| | imag_part = torch.zeros_like(stftm, device=stftm.device) |
| | stft = torch.stack([real_part, imag_part], -1)*stftm.unsqueeze(-1) |
| | istft = torch.istft(stft, n_fft=self.n_fft, |
| | hop_length=self.hop_size, win_length=self.n_fft, |
| | window=self.window, center=True) |
| | return istft.unsqueeze(1) |
| |
|
| |
|
| | |
| | class FastGL(BaseModule): |
| | def __init__(self, n_mels, sampling_rate, n_fft, hop_size, momentum=0.99): |
| | super(FastGL, self).__init__() |
| | self.n_mels = n_mels |
| | self.sampling_rate = sampling_rate |
| | self.n_fft = n_fft |
| | self.hop_size = hop_size |
| | self.momentum = momentum |
| | self.pi = PseudoInversion(n_mels, sampling_rate, n_fft) |
| | self.ir = InitialReconstruction(n_fft, hop_size) |
| | window = torch.hann_window(n_fft).float() |
| | self.register_buffer("window", window) |
| |
|
| | @torch.no_grad() |
| | def forward(self, s, n_iters=32): |
| | c = self.pi(s) |
| | x = self.ir(c) |
| | x = x.squeeze(1) |
| | c = c.unsqueeze(-1) |
| | prev_angles = torch.zeros_like(c, device=c.device) |
| | for _ in range(n_iters): |
| | s = torch.stft(x, n_fft=self.n_fft, hop_length=self.hop_size, |
| | win_length=self.n_fft, window=self.window, |
| | center=True) |
| | real_part, imag_part = s.unbind(-1) |
| | stftm = torch.sqrt(torch.clamp(real_part**2 + imag_part**2, min=1e-8)) |
| | angles = s / stftm.unsqueeze(-1) |
| | s = c * (angles + self.momentum * (angles - prev_angles)) |
| | x = torch.istft(s, n_fft=self.n_fft, hop_length=self.hop_size, |
| | win_length=self.n_fft, window=self.window, |
| | center=True) |
| | prev_angles = angles |
| | return x.unsqueeze(1) |
| |
|