cwitkowitz commited on
Commit
5ff96b2
·
1 Parent(s): 7f1748e

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ model/checkpoint.data-00000-of-00001 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *__pycache__
2
+ _outputs
3
+ src
app.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pyharp import *
2
+
3
+ import gradio as gr
4
+ import os
5
+
6
+
7
+ model_card = ModelCard(
8
+ name='Denoising U-Net',
9
+ description='A two-stage U-Net for high-fidelity denoising of historical gramophone recordings.',
10
+ author='Eloi Moliner and Vesa Välimäki',
11
+ tags=['Music', 'Denoising', 'U-Net', 'High-Fidelity', 'Historical']
12
+ )
13
+
14
+
15
+ def process_fn(input_audio_path):
16
+ """
17
+ This function defines the audio processing steps
18
+
19
+ Args:
20
+ input_audio_path (str): the audio filepath to be processed.
21
+
22
+ <YOUR_KWARGS>: additional keyword arguments necessary for processing.
23
+ NOTE: These should correspond to and match order of UI elements defined below.
24
+
25
+ Returns:
26
+ output_audio_path (str): the filepath of the processed audio.
27
+ output_labels (LabelList): any labels to display.
28
+ """
29
+
30
+ os.system("python inference.py inference.audio=" + input_audio_path)
31
+
32
+ output_audio_path = input_audio_path[:-4] + "_denoised.wav"
33
+
34
+ # No output labels
35
+ output_labels = LabelList()
36
+
37
+ return output_audio_path, output_labels
38
+
39
+
40
+ # Build Gradio endpoint
41
+ with gr.Blocks() as demo:
42
+ # Define Gradio Components
43
+ components = []
44
+
45
+ app = build_endpoint(model_card=model_card,
46
+ components=components,
47
+ process_fn=process_fn)
48
+
49
+ demo.queue()
50
+ demo.launch(share=True, show_error=True)
conf.yaml ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ path_experiment: "model" #there should be a better way to do this
2
+ tensorboard_logs: "/scratch/work/molinee2/tensorboard_logs/unet_historical" #path with tensorboard
3
+ # Dataset related
4
+ fs: 44100 #default is 44100, better NEVER change
5
+ seg_len_s_train: 5 #length of the train (and val) segments in seconds
6
+ freq_inference: 10 #we do inference after * epochs
7
+ seg_len_s_test: 15 #lenum_test_segments: 10 #number of test segments (inferenced every epoch)
8
+ num_real_test_segments: 5 #number of real recordings inferenced every epoch
9
+ num_test_segments: 15
10
+ buffer_size: 1000 # buffer size for shuffling datasets (train and val)
11
+ # Dataset Augmentation
12
+ overlap: 0 #overlap when extracting audio segments, default is 0, augment if more data is needed
13
+
14
+ # Logging and printing, and does not impact training
15
+ #device: cuda
16
+ verbose: 0
17
+ use_tensorboard: True
18
+ use_soft_denoising: False
19
+
20
+
21
+ num_workers: 10
22
+
23
+
24
+ # Checkpointing, by default automatically load last checkpoint
25
+ checkpoint: true
26
+ continue_from: '' # Path the a checkpoint.th file to start from.
27
+ # this is not used in the name of the experiment!
28
+ # so use a dummy=something not to mixup experiments.
29
+ continue_best: false # continue from best, not last state if continue_from is set.
30
+ only_inference: false
31
+
32
+ # Optimization related
33
+ optim: adam
34
+ lr: 1e-4 #used
35
+ variable_lr: True
36
+ beta1: 0.5 #used
37
+ beta2: 0.9 #used
38
+
39
+ loss: "mae" #choose loss:
40
+
41
+ epochs: 73
42
+ batch_size: 16
43
+ val_take: -1
44
+ steps_per_epoch: 1000
45
+
46
+
47
+ sp:
48
+ method: "wiener"
49
+
50
+ #STFT parameteres
51
+ stft:
52
+ win_size: 2048 #STFT window size
53
+ hop_size: 512
54
+
55
+ #inference param
56
+ inference:
57
+ audio: None
58
+ # Models
59
+ model: unet # either demucs or dwave
60
+ unet:
61
+ activation: "elu"
62
+ use_csff: False
63
+ use_SAM: True
64
+ use_cam: False
65
+ use_fam: False
66
+ use_fencoding: True
67
+ use_tdf: False
68
+ use_alttdfs: False
69
+ num_tfc: 3
70
+ num_stages: 3
71
+ depth: 6
72
+ f_dim: 1025 #hardcoded, depends on the stft window
73
+
74
+
75
+
76
+ # Hydra config
77
+ hydra:
78
+ job:
79
+ config:
80
+ # configuration for the ${hydra.job.override_dirname} runtime variable
81
+ override_dirname:
82
+ kv_sep: '='
83
+ item_sep: ','
84
+ # Remove all paths, as the / in them would mess up things
85
+ exclude_keys: ['path_experiment',
86
+ 'hydra.job_logging.handles.file.filename']
inference.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hydra
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def run(args):
8
+ import unet
9
+ import tensorflow as tf
10
+ import soundfile as sf
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ import scipy.signal
14
+
15
+ path_experiment=str(args.path_experiment)
16
+
17
+ unet_model = unet.build_model_denoise(unet_args=args.unet)
18
+
19
+ ckpt=os.path.join(os.path.dirname(os.path.abspath(__file__)),path_experiment, 'checkpoint')
20
+ unet_model.load_weights(ckpt)
21
+
22
+ def do_stft(noisy):
23
+
24
+ window_fn = tf.signal.hamming_window
25
+
26
+ win_size=args.stft.win_size
27
+ hop_size=args.stft.hop_size
28
+
29
+
30
+ stft_signal_noisy=tf.signal.stft(noisy,frame_length=win_size, window_fn=window_fn, frame_step=hop_size, pad_end=True)
31
+ stft_noisy_stacked=tf.stack( values=[tf.math.real(stft_signal_noisy), tf.math.imag(stft_signal_noisy)], axis=-1)
32
+
33
+ return stft_noisy_stacked
34
+
35
+ def do_istft(data):
36
+
37
+ window_fn = tf.signal.hamming_window
38
+
39
+ win_size=args.stft.win_size
40
+ hop_size=args.stft.hop_size
41
+
42
+ inv_window_fn=tf.signal.inverse_stft_window_fn(hop_size, forward_window_fn=window_fn)
43
+
44
+ pred_cpx=data[...,0] + 1j * data[...,1]
45
+ pred_time=tf.signal.inverse_stft(pred_cpx, win_size, hop_size, window_fn=inv_window_fn)
46
+ return pred_time
47
+
48
+ audio=str(args.inference.audio)
49
+ data, samplerate = sf.read(audio)
50
+ print(data.dtype)
51
+ #Stereo to mono
52
+ if len(data.shape)>1:
53
+ data=np.mean(data,axis=1)
54
+
55
+ if samplerate!=44100:
56
+ print("Resampling")
57
+
58
+ data=scipy.signal.resample(data, int((44100 / samplerate )*len(data))+1)
59
+
60
+
61
+
62
+ segment_size=44100*5 #20s segments
63
+
64
+ length_data=len(data)
65
+ overlapsize=2048 #samples (46 ms)
66
+ window=np.hanning(2*overlapsize)
67
+ window_right=window[overlapsize::]
68
+ window_left=window[0:overlapsize]
69
+ audio_finished=False
70
+ pointer=0
71
+ denoised_data=np.zeros(shape=(len(data),))
72
+ residual_noise=np.zeros(shape=(len(data),))
73
+ numchunks=int(np.ceil(length_data/segment_size))
74
+
75
+ for i in tqdm(range(numchunks)):
76
+ if pointer+segment_size<length_data:
77
+ segment=data[pointer:pointer+segment_size]
78
+ #dostft
79
+ segment_TF=do_stft(segment)
80
+ segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
81
+ pred = unet_model.predict(segment_TF_ds.batch(1))
82
+ pred=pred[0]
83
+ residual=segment_TF-pred[0]
84
+ residual=np.array(residual)
85
+ pred_time=do_istft(pred[0])
86
+ residual_time=do_istft(residual)
87
+ residual_time=np.array(residual_time)
88
+
89
+ if pointer==0:
90
+ pred_time=np.concatenate((pred_time[0:int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
91
+ residual_time=np.concatenate((residual_time[0:int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):segment_size],window_right)), axis=0)
92
+ else:
93
+ pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(pred_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
94
+ residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size-overlapsize)], np.multiply(residual_time[int(segment_size-overlapsize):int(segment_size)],window_right)), axis=0)
95
+
96
+ denoised_data[pointer:pointer+segment_size]=denoised_data[pointer:pointer+segment_size]+pred_time
97
+ residual_noise[pointer:pointer+segment_size]=residual_noise[pointer:pointer+segment_size]+residual_time
98
+
99
+ pointer=pointer+segment_size-overlapsize
100
+ else:
101
+ segment=data[pointer::]
102
+ lensegment=len(segment)
103
+ segment=np.concatenate((segment, np.zeros(shape=(int(segment_size-len(segment)),))), axis=0)
104
+ audio_finished=True
105
+ #dostft
106
+ segment_TF=do_stft(segment)
107
+
108
+ segment_TF_ds=tf.data.Dataset.from_tensors(segment_TF)
109
+
110
+ pred = unet_model.predict(segment_TF_ds.batch(1))
111
+ pred=pred[0]
112
+ residual=segment_TF-pred[0]
113
+ residual=np.array(residual)
114
+ pred_time=do_istft(pred[0])
115
+ pred_time=np.array(pred_time)
116
+ pred_time=pred_time[0:segment_size]
117
+ residual_time=do_istft(residual)
118
+ residual_time=np.array(residual_time)
119
+ residual_time=residual_time[0:segment_size]
120
+ if pointer==0:
121
+ pred_time=pred_time
122
+ residual_time=residual_time
123
+ else:
124
+ pred_time=np.concatenate((np.multiply(pred_time[0:int(overlapsize)], window_left), pred_time[int(overlapsize):int(segment_size)]),axis=0)
125
+ residual_time=np.concatenate((np.multiply(residual_time[0:int(overlapsize)], window_left), residual_time[int(overlapsize):int(segment_size)]),axis=0)
126
+
127
+ denoised_data[pointer::]=denoised_data[pointer::]+pred_time[0:lensegment]
128
+ residual_noise[pointer::]=residual_noise[pointer::]+residual_time[0:lensegment]
129
+
130
+ basename=os.path.splitext(audio)[0]
131
+ wav_noisy_name=basename+"_noisy_input"+".wav"
132
+ sf.write(wav_noisy_name, data, 44100)
133
+ wav_output_name=basename+"_denoised"+".wav"
134
+ sf.write(wav_output_name, denoised_data, 44100)
135
+ wav_output_name=basename+"_residual"+".wav"
136
+ sf.write(wav_output_name, residual_noise, 44100)
137
+
138
+
139
+ def _main(args):
140
+ global __file__
141
+
142
+ __file__ = hydra.utils.to_absolute_path(__file__)
143
+
144
+ run(args)
145
+
146
+
147
+ @hydra.main(config_path=".", config_name="conf")
148
+ def main(args):
149
+ try:
150
+ _main(args)
151
+ except Exception:
152
+ logger.exception("Some error happened")
153
+ # Hydra intercepts exit code, fixed in beta but I could not get the beta to work
154
+ os._exit(1)
155
+
156
+
157
+ if __name__ == "__main__":
158
+ main()
159
+
160
+
161
+
162
+
163
+
164
+
165
+
model/checkpoint ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ model_checkpoint_path: "checkpoint"
2
+ all_model_checkpoint_paths: "checkpoint"
model/checkpoint.data-00000-of-00001 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:40385267bb050426ed8a1384f983f19b1de18333fb6143689dd6f3bc5420aeaa
3
+ size 285671561
model/checkpoint.index ADDED
Binary file (20.3 kB). View file
 
model/copy_checkpoint_here ADDED
File without changes
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ -e git+https://github.com/TEAMuP-dev/pyharp.git#egg=pyharp
2
+ hydra-core
3
+ numpy==1.26.4
4
+ scipy
5
+ soundfile
6
+ tensorflow
7
+ tqdm
unet.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras import Model, Input
3
+ from tensorflow.keras import layers
4
+ from tensorflow.keras.initializers import TruncatedNormal
5
+ import math as m
6
+
7
+ def build_model_denoise(unet_args=None):
8
+
9
+ inputs=Input(shape=(None, None,2))
10
+
11
+ outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs)
12
+
13
+ #Encapsulating MultiStage_denoise in a keras.Model object
14
+ model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1])
15
+
16
+ return model
17
+ class DenseBlock(layers.Layer):
18
+ '''
19
+ [B, T, F, N] => [B, T, F, N]
20
+ DenseNet Block consisting of "num_layers" densely connected convolutional layers
21
+ '''
22
+ def __init__(self, num_layers, N, ksize,activation):
23
+ '''
24
+ num_layers: number of densely connected conv. layers
25
+ N: Number of filters (same in each layer)
26
+ ksize: Kernel size (same in each layer)
27
+ '''
28
+ super(DenseBlock, self).__init__()
29
+ self.activation=activation
30
+
31
+ self.paddings_1=get_paddings(ksize)
32
+ self.H=[]
33
+ self.num_layers=num_layers
34
+
35
+ for i in range(num_layers):
36
+ self.H.append(layers.Conv2D(filters=N,
37
+ kernel_size=ksize,
38
+ kernel_initializer=TruncatedNormal(),
39
+ strides=1,
40
+ padding='VALID',
41
+ activation=self.activation))
42
+
43
+ def call(self, x):
44
+
45
+ x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
46
+ x_ = self.H[0](x_)
47
+ if self.num_layers>1:
48
+ for h in self.H[1:]:
49
+ x = tf.concat([x_, x], axis=-1)
50
+ x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
51
+ x_ = h(x_)
52
+
53
+ return x_
54
+
55
+
56
+ class FinalBlock(layers.Layer):
57
+ '''
58
+ [B, T, F, N] => [B, T, F, 2]
59
+ Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram.
60
+
61
+ '''
62
+ def __init__(self):
63
+ super(FinalBlock, self).__init__()
64
+ ksize=(3,3)
65
+ self.paddings_2=get_paddings(ksize)
66
+ self.conv2=layers.Conv2D(filters=2,
67
+ kernel_size=ksize,
68
+ kernel_initializer=TruncatedNormal(),
69
+ strides=1,
70
+ padding='VALID',
71
+ activation=None)
72
+
73
+
74
+ def call(self, inputs ):
75
+
76
+ x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
77
+ pred=self.conv2(x)
78
+
79
+ return pred
80
+ class SAM(layers.Layer):
81
+ '''
82
+ [B, T, F, N] => [B, T, F, N] , [B, T, F, N]
83
+ Supervised Attention Module:
84
+ The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones.
85
+ The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer.
86
+ The first stage output is then calculated adding the original input spectrogram to the residual noise.
87
+ The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function.
88
+
89
+ '''
90
+ def __init__(self, n_feat):
91
+ super(SAM, self).__init__()
92
+
93
+ ksize=(3,3)
94
+ self.paddings_1=get_paddings(ksize)
95
+ self.conv1 = layers.Conv2D(filters=n_feat,
96
+ kernel_size=ksize,
97
+ kernel_initializer=TruncatedNormal(),
98
+ strides=1,
99
+ padding='VALID',
100
+ activation=None)
101
+ ksize=(3,3)
102
+ self.paddings_2=get_paddings(ksize)
103
+ self.conv2=layers.Conv2D(filters=2,
104
+ kernel_size=ksize,
105
+ kernel_initializer=TruncatedNormal(),
106
+ strides=1,
107
+ padding='VALID',
108
+ activation=None)
109
+
110
+ ksize=(3,3)
111
+ self.paddings_3=get_paddings(ksize)
112
+ self.conv3 = layers.Conv2D(filters=n_feat,
113
+ kernel_size=ksize,
114
+ kernel_initializer=TruncatedNormal(),
115
+ strides=1,
116
+ padding='VALID',
117
+ activation=None)
118
+ self.cropadd=CropAddBlock()
119
+
120
+ def call(self, inputs, input_spectrogram):
121
+ x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC')
122
+ x1 = self.conv1(x1)
123
+
124
+ x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC')
125
+ x=self.conv2(x)
126
+
127
+ #residual prediction
128
+ pred = layers.Add()([x, input_spectrogram]) #features to next stage
129
+
130
+ x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC')
131
+ M=self.conv3(x3)
132
+
133
+ M= tf.keras.activations.sigmoid(M)
134
+ x1=layers.Multiply()([x1, M])
135
+ x1 = layers.Add()([x1, inputs]) #features to next stage
136
+
137
+ return x1, pred
138
+
139
+
140
+ class AddFreqEncoding(layers.Layer):
141
+ '''
142
+ [B, T, F, 2] => [B, T, F, 12]
143
+ Generates frequency positional embeddings and concatenates them as 10 extra channels
144
+ This function is optimized for F=1025
145
+ '''
146
+ def __init__(self, f_dim):
147
+ super(AddFreqEncoding, self).__init__()
148
+ pi = tf.constant(m.pi)
149
+ pi=tf.cast(pi,'float32')
150
+ self.f_dim=f_dim #f_dim is fixed
151
+ n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32')
152
+ coss=tf.math.cos(pi*n)
153
+ f_channel = tf.expand_dims(coss, -1) #(1025,1)
154
+ self.fembeddings= f_channel
155
+
156
+ for k in range(1,10):
157
+ coss=tf.math.cos(2**k*pi*n)
158
+ f_channel = tf.expand_dims(coss, -1) #(1025,1)
159
+ self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10)
160
+
161
+
162
+ def call(self, input_tensor):
163
+
164
+ batch_size_tensor = tf.shape(input_tensor)[0] # get batch size
165
+ time_dim = tf.shape(input_tensor)[1] # get time dimension
166
+
167
+ fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10])
168
+
169
+
170
+ return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12)
171
+
172
+
173
+ def get_paddings(K):
174
+ return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]])
175
+
176
+ class Decoder(layers.Layer):
177
+ '''
178
+ [B, T, F, N] , skip connections => [B, T, F, N]
179
+ Decoder side of the U-Net subnetwork.
180
+ '''
181
+ def __init__(self, Ns, Ss, unet_args):
182
+ super(Decoder, self).__init__()
183
+
184
+ self.Ns=Ns
185
+ self.Ss=Ss
186
+ self.activation=unet_args.activation
187
+ self.depth=unet_args.depth
188
+
189
+
190
+ ksize=(3,3)
191
+ self.paddings_3=get_paddings(ksize)
192
+ self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth],
193
+ kernel_size=ksize,
194
+ kernel_initializer=TruncatedNormal(),
195
+ strides=1,
196
+ padding='VALID',
197
+ activation=self.activation)
198
+
199
+ self.cropadd=CropAddBlock()
200
+
201
+ self.dblocks=[]
202
+ for i in range(self.depth):
203
+ self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc))
204
+
205
+ def call(self,inputs, contracting_layers):
206
+ x=inputs
207
+ for i in range(self.depth,0,-1):
208
+ x=self.dblocks[i-1](x, contracting_layers[i-1])
209
+ return x
210
+
211
+ class Encoder(tf.keras.Model):
212
+
213
+ '''
214
+ [B, T, F, N] => skip connections , [B, T, F, N_4]
215
+ Encoder side of the U-Net subnetwork.
216
+ '''
217
+ def __init__(self, Ns, Ss, unet_args):
218
+ super(Encoder, self).__init__()
219
+ self.Ns=Ns
220
+ self.Ss=Ss
221
+ self.activation=unet_args.activation
222
+ self.depth=unet_args.depth
223
+
224
+ self.contracting_layers = {}
225
+
226
+ self.eblocks=[]
227
+ for i in range(self.depth):
228
+ self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc))
229
+
230
+ self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc)
231
+
232
+ def call(self, inputs):
233
+ x=inputs
234
+ for i in range(self.depth):
235
+
236
+ x, x_contract=self.eblocks[i](x)
237
+
238
+ self.contracting_layers[i] = x_contract #if remove 0, correct this
239
+ x=self.i_block(x)
240
+
241
+ return x, self.contracting_layers
242
+
243
+ class MultiStage_denoise(tf.keras.Model):
244
+
245
+ def __init__(self, unet_args=None):
246
+ super(MultiStage_denoise, self).__init__()
247
+
248
+ self.activation=unet_args.activation
249
+ self.depth=unet_args.depth
250
+ if unet_args.use_fencoding:
251
+ self.freq_encoding=AddFreqEncoding(unet_args.f_dim)
252
+ self.use_sam=unet_args.use_SAM
253
+ self.use_fencoding=unet_args.use_fencoding
254
+ self.num_stages=unet_args.num_stages
255
+ #Encoder
256
+ self.Ns= [32,64,64,128,128,256,512]
257
+ self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)]
258
+
259
+ #initial feature extractor
260
+ ksize=(7,7)
261
+ self.paddings_1=get_paddings(ksize)
262
+ self.conv2d_1 = layers.Conv2D(filters=self.Ns[0],
263
+ kernel_size=ksize,
264
+ kernel_initializer=TruncatedNormal(),
265
+ strides=1,
266
+ padding='VALID',
267
+ activation=self.activation)
268
+
269
+
270
+ self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args)
271
+ self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args)
272
+
273
+ self.cropconcat = CropConcatBlock()
274
+ self.cropadd = CropAddBlock()
275
+
276
+ self.finalblock=FinalBlock()
277
+
278
+ if self.num_stages>1:
279
+ self.sam_1=SAM(self.Ns[0])
280
+
281
+ #initial feature extractor
282
+ ksize=(7,7)
283
+ self.paddings_2=get_paddings(ksize)
284
+ self.conv2d_2 = layers.Conv2D(filters=self.Ns[0],
285
+ kernel_size=ksize,
286
+ kernel_initializer=TruncatedNormal(),
287
+ strides=1,
288
+ padding='VALID',
289
+ activation=self.activation)
290
+
291
+
292
+ self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args)
293
+ self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args)
294
+
295
+ @tf.function()
296
+ def call(self, inputs):
297
+
298
+ if self.use_fencoding:
299
+ x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12
300
+ else:
301
+ x_w_freq=inputs
302
+
303
+ #intitial feature extractor
304
+ x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC')
305
+ x=self.conv2d_1(x) #None, None, 1025, 32
306
+
307
+ x, contracting_layers_s1= self.encoder_s1(x)
308
+ #decoder
309
+ feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features
310
+
311
+ if self.num_stages>1:
312
+ #SAM module
313
+ Fout, pred_stage_1=self.sam_1(feats_s1,inputs)
314
+
315
+ #intitial feature extractor
316
+ x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC')
317
+ x=self.conv2d_2(x)
318
+
319
+ if self.use_sam:
320
+ x = tf.concat([x, Fout], axis=-1)
321
+ else:
322
+ x = tf.concat([x,feats_s1], axis=-1)
323
+
324
+ x, contracting_layers_s2= self.encoder_s2(x)
325
+
326
+ feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features
327
+
328
+ #consider implementing a third stage?
329
+
330
+ pred_stage_2=self.finalblock(feats_s2)
331
+ return pred_stage_2, pred_stage_1
332
+ else:
333
+ pred_stage_1=self.finalblock(feats_s1)
334
+ return pred_stage_1
335
+
336
+ class I_Block(layers.Layer):
337
+ '''
338
+ [B, T, F, N] => [B, T, F, N]
339
+ Intermediate block:
340
+ Basically, a densenet block with a residual connection
341
+ '''
342
+ def __init__(self,N,activation, num_tfc, **kwargs):
343
+ super(I_Block, self).__init__(**kwargs)
344
+
345
+ ksize=(3,3)
346
+ self.tfc=DenseBlock(num_tfc,N,ksize, activation)
347
+
348
+ self.conv2d_res= layers.Conv2D(filters=N,
349
+ kernel_size=(1,1),
350
+ kernel_initializer=TruncatedNormal(),
351
+ strides=1,
352
+ padding='VALID')
353
+
354
+ def call(self,inputs):
355
+ x=self.tfc(inputs)
356
+
357
+ inputs_proj=self.conv2d_res(inputs)
358
+ return layers.Add()([x,inputs_proj])
359
+
360
+
361
+ class E_Block(layers.Layer):
362
+
363
+ def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs):
364
+ super(E_Block, self).__init__(**kwargs)
365
+ self.layer_idx=layer_idx
366
+ self.N0=N0
367
+ self.N=N
368
+ self.S=S
369
+ self.activation=activation
370
+ self.i_block=I_Block(N0,activation,num_tfc)
371
+
372
+ ksize=(S[0]+2,S[1]+2)
373
+ self.paddings_2=get_paddings(ksize)
374
+ self.conv2d_2 = layers.Conv2D(filters=N,
375
+ kernel_size=(S[0]+2,S[1]+2),
376
+ kernel_initializer=TruncatedNormal(),
377
+ strides=S,
378
+ padding='VALID',
379
+ activation=self.activation)
380
+
381
+
382
+ def call(self, inputs, training=None, **kwargs):
383
+ x=self.i_block(inputs)
384
+
385
+ x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC')
386
+ x_down = self.conv2d_2(x_down)
387
+
388
+ return x_down, x
389
+
390
+
391
+ def get_config(self):
392
+ return dict(layer_idx=self.layer_idx,
393
+ N=self.N,
394
+ S=self.S,
395
+ **super(E_Block, self).get_config()
396
+ )
397
+ class D_Block(layers.Layer):
398
+
399
+ def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs):
400
+ super(D_Block, self).__init__(**kwargs)
401
+ self.layer_idx=layer_idx
402
+ self.N=N
403
+ self.S=S
404
+ self.activation=activation
405
+ ksize=(S[0]+2, S[1]+2)
406
+ self.paddings_1=get_paddings(ksize)
407
+
408
+ self.tconv_1= layers.Conv2DTranspose(filters=N,
409
+ kernel_size=(S[0]+2, S[1]+2),
410
+ kernel_initializer=TruncatedNormal(),
411
+ strides=S,
412
+ activation=self.activation,
413
+ padding='VALID')
414
+
415
+ self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest')
416
+
417
+ self.projection = layers.Conv2D(filters=N,
418
+ kernel_size=(1,1),
419
+ kernel_initializer=TruncatedNormal(),
420
+ strides=1,
421
+ activation=self.activation,
422
+ padding='VALID')
423
+ self.cropadd=CropAddBlock()
424
+ self.cropconcat=CropConcatBlock()
425
+
426
+ self.i_block=I_Block(N,activation,num_tfc)
427
+
428
+ def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs):
429
+ x = inputs
430
+ x=tf.pad(x, self.paddings_1, mode='SYMMETRIC')
431
+ x = self.tconv_1(inputs)
432
+
433
+ x2= self.upsampling(inputs)
434
+
435
+ if x2.shape[-1]!=x.shape[-1]:
436
+ x2= self.projection(x2)
437
+
438
+ x= self.cropadd(x,x2)
439
+
440
+
441
+ x=self.cropconcat(x,bridge)
442
+
443
+ x=self.i_block(x)
444
+ return x
445
+
446
+ def get_config(self):
447
+ return dict(layer_idx=self.layer_idx,
448
+ N=self.N,
449
+ S=self.S,
450
+ **super(D_Block, self).get_config()
451
+ )
452
+
453
+ class CropAddBlock(layers.Layer):
454
+
455
+ def call(self,down_layer, x, **kwargs):
456
+ x1_shape = tf.shape(down_layer)
457
+ x2_shape = tf.shape(x)
458
+
459
+
460
+ height_diff = (x1_shape[1] - x2_shape[1]) // 2
461
+ width_diff = (x1_shape[2] - x2_shape[2]) // 2
462
+
463
+ down_layer_cropped = down_layer[:,
464
+ height_diff: (x2_shape[1] + height_diff),
465
+ width_diff: (x2_shape[2] + width_diff),
466
+ :]
467
+
468
+ x = layers.Add()([down_layer_cropped, x])
469
+ return x
470
+
471
+ class CropConcatBlock(layers.Layer):
472
+
473
+ def call(self, down_layer, x, **kwargs):
474
+ x1_shape = tf.shape(down_layer)
475
+ x2_shape = tf.shape(x)
476
+
477
+ height_diff = (x1_shape[1] - x2_shape[1]) // 2
478
+ width_diff = (x1_shape[2] - x2_shape[2]) // 2
479
+
480
+ down_layer_cropped = down_layer[:,
481
+ height_diff: (x2_shape[1] + height_diff),
482
+ width_diff: (x2_shape[2] + width_diff),
483
+ :]
484
+
485
+ x = tf.concat([down_layer_cropped, x], axis=-1)
486
+ return x