File size: 6,085 Bytes
5ff96b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
import os
import hydra
import logging

logger = logging.getLogger(__name__)

def run(args):
    import unet
    import tensorflow as tf
    import soundfile as sf
    import numpy as np
    from tqdm import tqdm
    import scipy.signal
    
    path_experiment=str(args.path_experiment)

    unet_model = unet.build_model_denoise(unet_args=args.unet)

    ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
    unet_model.load_weights(ckpt)

    def do_stft(noisy):
        
        window_fn = tf.signal.hamming_window

        win_size=args.stft.win_size
        hop_size=args.stft.hop_size

        
        stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
        stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
    
        return stft_noisy_stacked

    def do_istft(data):
        
        window_fn = tf.signal.hamming_window

        win_size=args.stft.win_size
        hop_size=args.stft.hop_size

        inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)

        pred_cpx=data[...,0] + 1j * data[...,1]
        pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
        return pred_time

    audio=str(args.inference.audio)
    data, samplerate = sf.read(audio)
    print(data.dtype)
    #Stereo to mono
    if len(data.shape)>1:
        data=np.mean(data,axis=1)
    
    if samplerate!=44100: 
        print("Resampling")
   
        data=scipy.signal.resample(data, int((44100  / samplerate )*len(data))+1)  
 
    
    
    segment_size=44100*5  #20s segments

    length_data=len(data)
    overlapsize=2048 #samples (46 ms)
    window=np.hanning(2*overlapsize)
    window_right=window[overlapsize::]
    window_left=window[0:overlapsize]
    audio_finished=False
    pointer=0
    denoised_data=np.zeros(shape=(len(data),))
    residual_noise=np.zeros(shape=(len(data),))
    numchunks=int(np.ceil(length_data/segment_size))
     
    for i in tqdm(range(numchunks)):
        if pointer+segment_size<length_data:
            segment=data[pointer:pointer+segment_size]
            #dostft
            segment_TF=do_stft(segment)
            segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
            pred = unet_model.predict(segment_TF_ds.batch(1))
            pred=pred[0]
            residual=segment_TF-pred[0]
            residual=np.array(residual)
            pred_time=do_istft(pred[0])
            residual_time=do_istft(residual)
            residual_time=np.array(residual_time)

            if pointer==0:
                pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
                residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
            else:
                pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
                residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
                
            denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
            residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time

            pointer=pointer+segment_size-overlapsize
        else: 
            segment=data[pointer::]
            lensegment=len(segment)
            segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
            audio_finished=True
            #dostft
            segment_TF=do_stft(segment)

            segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)

            pred = unet_model.predict(segment_TF_ds.batch(1))
            pred=pred[0]
            residual=segment_TF-pred[0]
            residual=np.array(residual)
            pred_time=do_istft(pred[0])
            pred_time=np.array(pred_time)
            pred_time=pred_time[0:segment_size]
            residual_time=do_istft(residual)
            residual_time=np.array(residual_time)
            residual_time=residual_time[0:segment_size]
            if pointer==0:
                pred_time=pred_time
                residual_time=residual_time
            else:
                pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
                residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)

            denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
            residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]

    basename=os.path.splitext(audio)[0]
    wav_noisy_name=basename+"_noisy_input"+".wav"
    sf.write(wav_noisy_name, data, 44100)
    wav_output_name=basename+"_denoised"+".wav"
    sf.write(wav_output_name, denoised_data, 44100)
    wav_output_name=basename+"_residual"+".wav"
    sf.write(wav_output_name, residual_noise, 44100)
    

def _main(args):
    global __file__

    __file__ = hydra.utils.to_absolute_path(__file__)

    run(args)


@hydra.main(config_path=".", config_name="conf")
def main(args):
    try:
        _main(args)
    except Exception:
        logger.exception("Some error happened")
        # Hydra intercepts exit code, fixed in beta but I could not get the beta to work
        os._exit(1)


if __name__ == "__main__":
    main()