| from torch import Tensor |
| from numpy import ndarray |
|
|
| import librosa |
| import soundfile as sf |
| import torch |
| import torch.nn as nn |
|
|
| class ConvReverb(nn.Module): |
| def __init__(self): |
| super(ConvReverb,self).__init__() |
| |
| def conv_reverb_by_one_ir( |
| self, |
| input_signal:Tensor, |
| input_ir:Tensor |
| ) -> Tensor: |
|
|
| zero_padded_input_signal = nn.functional.pad(input_signal, (0, input_ir.shape[-1] - 1)) |
| input_signal_fft = torch.fft.rfft(zero_padded_input_signal, dim=1) |
|
|
| zero_pad_final_fir = nn.functional.pad(input_ir, (0, input_signal.shape[-1] - 1)) |
|
|
| fir_fft = torch.fft.rfft(zero_pad_final_fir, dim=1) |
| output_signal_fft:Tensor = fir_fft * input_signal_fft |
|
|
| output_signal = torch.fft.irfft(output_signal_fft, dim=1) |
|
|
| return output_signal |
|
|
| def forward( |
| self, |
| input_signal:Tensor, |
| input_ir:Tensor |
| ) -> Tensor: |
| assert ((len(input_signal.shape) == 2) or (len(input_signal.shape) == 3)), "input shape is wrong" |
| if len(input_signal.shape) == 2: |
| return self.conv_reverb_by_one_ir(input_signal,input_ir) |
| else: |
| left_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,0,:],input_ir[:,0,:]).unsqueeze(1) |
| right_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,1,:],input_ir[:,1,:]).unsqueeze(1) |
| reverb_audio:Tensor = torch.cat([left_reverb_audio,right_reverb_audio],axis=1) |
| return reverb_audio |
|
|
| if __name__ == "__main__": |
| vocal_audio_dir:str = "/home/jakeoneijk/220101_data/MusDBMainVocal/train/A Classic Education - NightOwl/A Classic Education - NightOwl_Main Vocal.wav" |
| ir_dir:str = "/home/jakeoneijk/220101_data/DetmoldSRIRStereo/SetB_LSandWFSOrchestra/Data/OpenArray/wfs_R1/S1.wav" |
|
|
| vocal_audio,sr = librosa.load(vocal_audio_dir,sr=None,mono=False) |
| ir_audio,sr = librosa.load(ir_dir,sr=sr,mono=False) |
|
|
| vocal_tensor:Tensor = torch.from_numpy(vocal_audio).unsqueeze(0) |
| ir_tensor:Tensor = torch.from_numpy(ir_audio).unsqueeze(0) |
| |
| conv_reverb = ConvReverb() |
| reverberated_audio:Tensor = conv_reverb(vocal_tensor,ir_tensor) |
| reverberated_audio_numpy:ndarray = reverberated_audio.squeeze().numpy() |
| sf.write("./reverb_audio.wav", data=reverberated_audio_numpy.T, samplerate=sr) |