File size: 2,564 Bytes
dfd1909 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 | from torch import Tensor
from numpy import ndarray
import librosa
import soundfile as sf
import torch
import torch.nn as nn
class ConvReverb(nn.Module):
def __init__(self):
super(ConvReverb,self).__init__()
def conv_reverb_by_one_ir(
self,
input_signal:Tensor, #(batch,sampletime_length)
input_ir:Tensor #(batch,sampletime_length)
) -> Tensor:
zero_padded_input_signal = nn.functional.pad(input_signal, (0, input_ir.shape[-1] - 1))
input_signal_fft = torch.fft.rfft(zero_padded_input_signal, dim=1) #torch.rfft(zero_padded_input_signal, 1)
zero_pad_final_fir = nn.functional.pad(input_ir, (0, input_signal.shape[-1] - 1))
fir_fft = torch.fft.rfft(zero_pad_final_fir, dim=1) #torch.rfft(zero_pad_final_fir, 1)
output_signal_fft:Tensor = fir_fft * input_signal_fft
output_signal = torch.fft.irfft(output_signal_fft, dim=1) #torch.irfft(output_signal_fft, 1)
return output_signal
def forward(
self,
input_signal:Tensor, #(batch,sampletime_length)
input_ir:Tensor #(batch,sampletime_length)
) -> Tensor:
assert ((len(input_signal.shape) == 2) or (len(input_signal.shape) == 3)), "input shape is wrong"
if len(input_signal.shape) == 2:
return self.conv_reverb_by_one_ir(input_signal,input_ir)
else:
left_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,0,:],input_ir[:,0,:]).unsqueeze(1)
right_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,1,:],input_ir[:,1,:]).unsqueeze(1)
reverb_audio:Tensor = torch.cat([left_reverb_audio,right_reverb_audio],axis=1)
return reverb_audio
if __name__ == "__main__":
vocal_audio_dir:str = "/home/jakeoneijk/220101_data/MusDBMainVocal/train/A Classic Education - NightOwl/A Classic Education - NightOwl_Main Vocal.wav"
ir_dir:str = "/home/jakeoneijk/220101_data/DetmoldSRIRStereo/SetB_LSandWFSOrchestra/Data/OpenArray/wfs_R1/S1.wav"
vocal_audio,sr = librosa.load(vocal_audio_dir,sr=None,mono=False)
ir_audio,sr = librosa.load(ir_dir,sr=sr,mono=False)
vocal_tensor:Tensor = torch.from_numpy(vocal_audio).unsqueeze(0)
ir_tensor:Tensor = torch.from_numpy(ir_audio).unsqueeze(0)
conv_reverb = ConvReverb()
reverberated_audio:Tensor = conv_reverb(vocal_tensor,ir_tensor)
reverberated_audio_numpy:ndarray = reverberated_audio.squeeze().numpy()
sf.write("./reverb_audio.wav", data=reverberated_audio_numpy.T, samplerate=sr) |