| |
| from typing import Union,Dict |
| from numpy import ndarray |
| from torch import Tensor |
|
|
| import matplotlib.pyplot as plt |
| import numpy as np |
| import torch |
| import librosa |
| import librosa.display |
|
|
| from TorchJaekwon.Util.UtilAudio import UtilAudio |
|
|
| class UtilAudioSTFT(UtilAudio): |
| def __init__(self,nfft:int, hop_size:int): |
| super().__init__() |
| self.nfft = nfft |
| self.hop_size = hop_size |
| self.hann_window = torch.hann_window(self.nfft) |
| |
| def get_mag_phase_stft_np(self,audio): |
| stft = librosa.stft(audio,n_fft=self.nfft, hop_length=self.hop_size) |
| mag = abs(stft) |
| phase = np.exp(1.j * np.angle(stft)) |
| return {"mag":mag,"phase":phase} |
| |
| def get_mag_phase_stft_np_mono(self,audio): |
| if audio.shape[0] == 2: |
| return self.get_mag_phase_stft_np(np.mean(audio,axis=0)) |
| else: |
| return self.get_mag_phase_stft_np(audio) |
|
|
| |
| def stft_torch(self, |
| audio:Union[ndarray,Tensor] |
| ) -> Dict[str,Tensor]: |
| |
| audio_torch:Tensor = torch.from_numpy(audio) if type(audio) == np.ndarray else audio |
| |
| assert(len(audio_torch.shape) <= 3), f'Error: stft_torch() audio torch shape is {audio_torch.shape}' |
|
|
| if (len(audio_torch.shape) == 1): audio_torch = audio_torch.unsqueeze(0) |
|
|
| shape_is_three = True if len(audio_torch.shape) == 3 else False |
| if shape_is_three: |
| batch_size, channels_num, segment_samples = audio_torch.shape |
| audio_torch = audio_torch.reshape(batch_size * channels_num, segment_samples) |
| |
| spec_dict:Dict[str,Tensor] = dict() |
|
|
| audio_torch = torch.nn.functional.pad(audio_torch.unsqueeze(1), (int((self.nfft-self.hop_size)/2), int((self.nfft-self.hop_size)/2)), mode='reflect').squeeze(1) |
| spec_dict['stft'] = torch.stft(audio_torch, |
| self.nfft, |
| hop_length=self.hop_size, |
| window=self.hann_window.to(audio_torch.device), |
| center=False, |
| pad_mode='reflect', |
| normalized=False, |
| onesided=True, |
| return_complex=True) |
| ''' |
| spec_dict['stft'] = torch.stft(audio_torch, |
| n_fft=self.nfft, |
| hop_length=self.hop_size, |
| window=self.hann_window.to(audio_torch.device), |
| return_complex=True) |
| ''' |
| spec_dict['mag'] = spec_dict['stft'].abs() |
| spec_dict['angle'] = spec_dict['stft'].angle() |
|
|
| if shape_is_three: |
| _, time_steps, freq_bins = spec_dict['stft'].shape |
| for feature_name in spec_dict: |
| spec_dict[feature_name] = spec_dict[feature_name].reshape(batch_size, channels_num, time_steps, freq_bins) |
|
|
| return spec_dict |
| |
| def istft_torch_from_mag_and_angle(self, |
| mag:Tensor, |
| angel:Tensor): |
| stft_complex:Tensor = torch.polar(abs = mag, angle = angel) |
| return torch.istft(stft_complex, self.nfft, hop_length=self.hop_size,window=self.hann_window.to(stft_complex.device), |
| center=True, onesided=True) |
| |
| def get_pred_accom_by_subtract_pred_vocal_audio(self,pred_vocal,mix_audio): |
| pred_vocal_mag = self.get_mag_phase_stft_np_mono(pred_vocal)["mag"] |
| mix_stft = self.get_mag_phase_stft_np_mono(mix_audio) |
| mix_mag = mix_stft["mag"] |
| mix_phase = mix_stft["phase"] |
| pred_accom_mag = mix_mag - pred_vocal_mag |
| pred_accom_mag[pred_accom_mag < 0] = 0 |
| pred_accom = librosa.istft(pred_accom_mag*mix_phase,hop_length=self.hop_size,length=mix_audio.shape[-1]) |
| return pred_accom |
| |
| def stft_plot_from_audio_path(self,audio_path:str,save_path:str = None, dpi:int = 500) -> None: |
| audio, sr = librosa.load(audio_path) |
| stft_audio:ndarray = librosa.stft(audio) |
| spectrogram_db_scale:ndarray = librosa.amplitude_to_db(np.abs(stft_audio), ref=np.max) |
| plt.figure(dpi=dpi) |
| librosa.display.specshow(spectrogram_db_scale) |
| plt.colorbar() |
| if save_path is not None: |
| plt.savefig(save_path,dpi=dpi) |
|
|
| @staticmethod |
| def spec_to_figure(spec, |
| vmin:float = -6.0, |
| vmax:float = 1.5, |
| fig_size:tuple = (12,6), |
| dpi = 400, |
| transposed=False, |
| save_path=None): |
| if isinstance(spec, torch.Tensor): |
| spec = spec.squeeze().cpu().numpy() |
| spec = spec.squeeze() |
| fig = plt.figure(figsize=fig_size, dpi = dpi) |
| plt.pcolor(spec.T if transposed else spec, vmin=vmin, vmax=vmax) |
| if save_path is not None: |
| plt.savefig(save_path,dpi=dpi) |
| plt.close() |
| return fig |