ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from torch import Tensor
from numpy import ndarray
import librosa
import soundfile as sf
import torch
import torch.nn as nn
class ConvReverb(nn.Module):
def __init__(self):
super(ConvReverb,self).__init__()
def conv_reverb_by_one_ir(
self,
input_signal:Tensor, #(batch,sampletime_length)
input_ir:Tensor #(batch,sampletime_length)
) -> Tensor:
zero_padded_input_signal = nn.functional.pad(input_signal, (0, input_ir.shape[-1] - 1))
input_signal_fft = torch.fft.rfft(zero_padded_input_signal, dim=1) #torch.rfft(zero_padded_input_signal, 1)
zero_pad_final_fir = nn.functional.pad(input_ir, (0, input_signal.shape[-1] - 1))
fir_fft = torch.fft.rfft(zero_pad_final_fir, dim=1) #torch.rfft(zero_pad_final_fir, 1)
output_signal_fft:Tensor = fir_fft * input_signal_fft
output_signal = torch.fft.irfft(output_signal_fft, dim=1) #torch.irfft(output_signal_fft, 1)
return output_signal
def forward(
self,
input_signal:Tensor, #(batch,sampletime_length)
input_ir:Tensor #(batch,sampletime_length)
) -> Tensor:
assert ((len(input_signal.shape) == 2) or (len(input_signal.shape) == 3)), "input shape is wrong"
if len(input_signal.shape) == 2:
return self.conv_reverb_by_one_ir(input_signal,input_ir)
else:
left_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,0,:],input_ir[:,0,:]).unsqueeze(1)
right_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,1,:],input_ir[:,1,:]).unsqueeze(1)
reverb_audio:Tensor = torch.cat([left_reverb_audio,right_reverb_audio],axis=1)
return reverb_audio
if __name__ == "__main__":
vocal_audio_dir:str = "/home/jakeoneijk/220101_data/MusDBMainVocal/train/A Classic Education - NightOwl/A Classic Education - NightOwl_Main Vocal.wav"
ir_dir:str = "/home/jakeoneijk/220101_data/DetmoldSRIRStereo/SetB_LSandWFSOrchestra/Data/OpenArray/wfs_R1/S1.wav"
vocal_audio,sr = librosa.load(vocal_audio_dir,sr=None,mono=False)
ir_audio,sr = librosa.load(ir_dir,sr=sr,mono=False)
vocal_tensor:Tensor = torch.from_numpy(vocal_audio).unsqueeze(0)
ir_tensor:Tensor = torch.from_numpy(ir_audio).unsqueeze(0)
conv_reverb = ConvReverb()
reverberated_audio:Tensor = conv_reverb(vocal_tensor,ir_tensor)
reverberated_audio_numpy:ndarray = reverberated_audio.squeeze().numpy()
sf.write("./reverb_audio.wav", data=reverberated_audio_numpy.T, samplerate=sr)