File size: 2,564 Bytes
dfd1909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torch import Tensor
from numpy import ndarray

import librosa
import soundfile as sf
import torch
import torch.nn as nn

class ConvReverb(nn.Module):
    def __init__(self):
        super(ConvReverb,self).__init__()
    
    def conv_reverb_by_one_ir(
        self, 
        input_signal:Tensor, #(batch,sampletime_length)
        input_ir:Tensor #(batch,sampletime_length)
        ) -> Tensor:

        zero_padded_input_signal = nn.functional.pad(input_signal, (0, input_ir.shape[-1] - 1))
        input_signal_fft = torch.fft.rfft(zero_padded_input_signal, dim=1) #torch.rfft(zero_padded_input_signal, 1)

        zero_pad_final_fir = nn.functional.pad(input_ir, (0, input_signal.shape[-1] - 1))

        fir_fft = torch.fft.rfft(zero_pad_final_fir, dim=1) #torch.rfft(zero_pad_final_fir, 1)
        output_signal_fft:Tensor = fir_fft * input_signal_fft

        output_signal = torch.fft.irfft(output_signal_fft, dim=1) #torch.irfft(output_signal_fft, 1)

        return output_signal

    def forward(
        self, 
        input_signal:Tensor, #(batch,sampletime_length)
        input_ir:Tensor #(batch,sampletime_length)
        ) -> Tensor:
        assert ((len(input_signal.shape) == 2) or (len(input_signal.shape) == 3)), "input shape is wrong"
        if len(input_signal.shape) == 2:
            return self.conv_reverb_by_one_ir(input_signal,input_ir)
        else:
            left_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,0,:],input_ir[:,0,:]).unsqueeze(1)
            right_reverb_audio:Tensor = self.conv_reverb_by_one_ir(input_signal[:,1,:],input_ir[:,1,:]).unsqueeze(1)
            reverb_audio:Tensor = torch.cat([left_reverb_audio,right_reverb_audio],axis=1)
            return reverb_audio

if __name__ == "__main__":
    vocal_audio_dir:str = "/home/jakeoneijk/220101_data/MusDBMainVocal/train/A Classic Education - NightOwl/A Classic Education - NightOwl_Main Vocal.wav"
    ir_dir:str = "/home/jakeoneijk/220101_data/DetmoldSRIRStereo/SetB_LSandWFSOrchestra/Data/OpenArray/wfs_R1/S1.wav"

    vocal_audio,sr = librosa.load(vocal_audio_dir,sr=None,mono=False)
    ir_audio,sr = librosa.load(ir_dir,sr=sr,mono=False)

    vocal_tensor:Tensor = torch.from_numpy(vocal_audio).unsqueeze(0)
    ir_tensor:Tensor = torch.from_numpy(ir_audio).unsqueeze(0)
    
    conv_reverb = ConvReverb()
    reverberated_audio:Tensor = conv_reverb(vocal_tensor,ir_tensor)
    reverberated_audio_numpy:ndarray = reverberated_audio.squeeze().numpy()
    sf.write("./reverb_audio.wav", data=reverberated_audio_numpy.T, samplerate=sr)