| |
| import librosa |
| import numpy as np |
| import torch |
|
|
| from TorchJaekwon.Util.UtilData import UtilData |
| from TorchJaekwon.Util.UtilTorch import UtilTorch |
|
|
| class UtilAudioSR: |
|
|
| @staticmethod |
| def mel_replace_ops(predict_mel:torch.Tensor, |
| gt_low_pass_mel:torch.Tensor, |
| debug_message:bool=False |
| ) -> torch.Tensor: |
| batch_size = predict_mel.size(0) |
| for i in range(batch_size): |
| cutoff_melbin = UtilAudioSR.locate_cutoff_freq(torch.exp(gt_low_pass_mel[i].squeeze())) |
|
|
| if debug_message: |
| ratio = predict_mel[i][...,:cutoff_melbin]/gt_low_pass_mel[i][...,:cutoff_melbin] |
| print(torch.mean(ratio), torch.max(ratio), torch.min(ratio)) |
|
|
| predict_mel[i][..., :cutoff_melbin] = gt_low_pass_mel[i][..., :cutoff_melbin] |
| return predict_mel |
| |
| @staticmethod |
| def locate_cutoff_freq(stft, percentile=0.985): |
| magnitude = torch.abs(stft) |
| energy = torch.cumsum(torch.sum(magnitude, dim=0), dim=0) |
| return UtilAudioSR.find_cutoff(energy, percentile) |
| |
| @staticmethod |
| def find_cutoff(x, percentile=0.95): |
| percentile = x[-1] * percentile |
| for i in range(1, x.shape[0]): |
| if x[-i] < percentile: |
| return x.shape[0] - i |
| return 0 |
| |
| @staticmethod |
| def wav_replace_ops(pred_wav:torch.Tensor, |
| gt_low_pass_wav:torch.Tensor |
| ) -> torch.Tensor: |
| device = pred_wav.device |
| pred_wav = UtilData.fit_shape_length(pred_wav, 2).cpu().detach().numpy() |
| gt_low_pass_wav = UtilData.fit_shape_length(gt_low_pass_wav, 2).cpu().detach().numpy() |
| for i in range(pred_wav.shape[0]): |
|
|
| out = pred_wav[i] |
| x = gt_low_pass_wav[i] |
| cutoffratio = UtilAudioSR.get_cutoff_index_np(x) |
|
|
| length = out.shape[0] |
| stft_gt = librosa.stft(x) |
|
|
| stft_out = librosa.stft(out) |
| energy_ratio = np.mean( |
| np.sum(np.abs(stft_gt[cutoffratio])) |
| / np.sum(np.abs(stft_out[cutoffratio, ...])) |
| ) |
| energy_ratio = min(max(energy_ratio, 0.8), 1.2) |
| stft_out[:cutoffratio, ...] = stft_gt[:cutoffratio, ...] / energy_ratio |
|
|
| out_renewed = librosa.istft(stft_out, length=length) |
| pred_wav[i] = out_renewed |
| |
| return torch.FloatTensor(pred_wav).to(device) |
| |
| @staticmethod |
| def get_cutoff_index_np(x): |
| stft_x = np.abs(librosa.stft(x)) |
| energy = np.cumsum(np.sum(stft_x, axis=-1)) |
| return UtilAudioSR.find_cutoff(energy, 0.985) |
| |
| @staticmethod |
| def find_cutoff_freq(audio:torch.Tensor) -> int: |
| stft_spec = torch.stft( |
| input = audio, |
| n_fft = 2048, |
| hop_length=480, |
| win_length=2048, |
| window=torch.hann_window(2048).to(audio.device), |
| center=False, |
| pad_mode="reflect", |
| normalized=False, |
| onesided=True, |
| return_complex=True, |
| ) |
|
|
| stft_spec = stft_spec[0].T.float() |
| cutoff_freq = (UtilAudioSR.locate_cutoff_freq(stft_spec, percentile=0.983) / 1024) * 24000 |
| if(cutoff_freq < 1000): |
| cutoff_freq = 24000 |
| return cutoff_freq |