Spaces:
Paused
Paused
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| from einops.layers.torch import Rearrange | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from pesq import pesq | |
| from joblib import Parallel, delayed | |
| def phase_losses(phase_r, phase_g): | |
| ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) | |
| gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) | |
| iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) | |
| return ip_loss, gd_loss, iaf_loss | |
| def anti_wrapping_function(x): | |
| return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) | |
| def pesq_score(utts_r, utts_g, h): | |
| pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)( | |
| utts_r[i].squeeze().cpu().numpy(), | |
| utts_g[i].squeeze().cpu().numpy(), | |
| h.sample_rate) | |
| for i in range(len(utts_r))) | |
| pesq_score = np.mean(pesq_score) | |
| return pesq_score | |
| def eval_pesq(clean_utt, esti_utt, sr): | |
| try: | |
| pesq_score = pesq(sr, clean_utt, esti_utt) | |
| except: | |
| pesq_score = -1 | |
| return pesq_score | |
| def mag_pha_stft(y, n_fft, hop_size, win_size, compress_factor=1.0, center=True): | |
| hann_window = torch.hann_window(win_size).to(y.device) | |
| stft_spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, | |
| center=center, pad_mode='reflect', normalized=False, return_complex=True) | |
| stft_spec = torch.view_as_real(stft_spec) | |
| mag = torch.sqrt(stft_spec.pow(2).sum(-1) + 1e-9) | |
| pha = torch.atan2(stft_spec[:, :, :, 1] + 1e-10, stft_spec[:, :, :, 0] + 1e-5) | |
| # Magnitude Compression | |
| mag = torch.pow(mag, compress_factor) | |
| com = torch.stack((mag*torch.cos(pha), mag*torch.sin(pha)), dim=-1) | |
| return mag, pha, com | |
| def mag_pha_istft(mag, pha, n_fft, hop_size, win_size, compress_factor=1.0, center=True): | |
| # Magnitude Decompression | |
| mag = torch.pow(mag, (1.0/compress_factor)) | |
| com = torch.complex(mag*torch.cos(pha), mag*torch.sin(pha)) | |
| hann_window = torch.hann_window(win_size).to(com.device) | |
| wav = torch.istft(com, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window, center=center) | |
| return wav | |
| class LearnableSigmoid1d(nn.Module): | |
| def __init__(self, in_features, beta=1): | |
| super().__init__() | |
| self.beta = beta | |
| self.slope = nn.Parameter(torch.ones(in_features)) | |
| self.slope.requiresGrad = True | |
| def forward(self, x): | |
| # x shape: [batch_size, time_steps, spec_bins] | |
| return self.beta * torch.sigmoid(self.slope * x) | |
| class LearnableSigmoid2d(nn.Module): | |
| def __init__(self, in_features, beta=1): | |
| super().__init__() | |
| self.beta = beta | |
| self.slope = nn.Parameter(torch.ones(in_features, 1)) | |
| self.slope.requiresGrad = True | |
| def forward(self, x): | |
| return self.beta * torch.sigmoid(self.slope * x) | |
| def main(): | |
| learnable_sigmoid = LearnableSigmoid1d(201) | |
| a = torch.randn(4, 100, 201) | |
| result = learnable_sigmoid.forward(a) | |
| print(result.shape) | |
| return | |
| if __name__ == '__main__': | |
| main() | |