Spaces:
Build error
Build error
| import tensorflow as tf | |
| from tensorflow.keras import Model, Input | |
| from tensorflow.keras import layers | |
| from tensorflow.keras.initializers import TruncatedNormal | |
| import math as m | |
| def build_model_denoise(unet_args=None): | |
| inputs=Input(shape=(None, None,2)) | |
| outputs_stage_2,outputs_stage_1=MultiStage_denoise(unet_args=unet_args)(inputs) | |
| #Encapsulating MultiStage_denoise in a keras.Model object | |
| model= tf.keras.Model(inputs=inputs,outputs=[outputs_stage_2, outputs_stage_1]) | |
| return model | |
| class DenseBlock(layers.Layer): | |
| ''' | |
| [B, T, F, N] => [B, T, F, N] | |
| DenseNet Block consisting of "num_layers" densely connected convolutional layers | |
| ''' | |
| def __init__(self, num_layers, N, ksize,activation): | |
| ''' | |
| num_layers: number of densely connected conv. layers | |
| N: Number of filters (same in each layer) | |
| ksize: Kernel size (same in each layer) | |
| ''' | |
| super(DenseBlock, self).__init__() | |
| self.activation=activation | |
| self.paddings_1=get_paddings(ksize) | |
| self.H=[] | |
| self.num_layers=num_layers | |
| for i in range(num_layers): | |
| self.H.append(layers.Conv2D(filters=N, | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=self.activation)) | |
| def call(self, x): | |
| x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') | |
| x_ = self.H[0](x_) | |
| if self.num_layers>1: | |
| for h in self.H[1:]: | |
| x = tf.concat([x_, x], axis=-1) | |
| x_=tf.pad(x, self.paddings_1, mode='SYMMETRIC') | |
| x_ = h(x_) | |
| return x_ | |
| class FinalBlock(layers.Layer): | |
| ''' | |
| [B, T, F, N] => [B, T, F, 2] | |
| Final block. Basically, a 3x3 conv. layer to map the output features to the output complex spectrogram. | |
| ''' | |
| def __init__(self): | |
| super(FinalBlock, self).__init__() | |
| ksize=(3,3) | |
| self.paddings_2=get_paddings(ksize) | |
| self.conv2=layers.Conv2D(filters=2, | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=None) | |
| def call(self, inputs ): | |
| x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') | |
| pred=self.conv2(x) | |
| return pred | |
| class SAM(layers.Layer): | |
| ''' | |
| [B, T, F, N] => [B, T, F, N] , [B, T, F, N] | |
| Supervised Attention Module: | |
| The purpose of SAM is to make the network only propagate the most relevant features to the second stage, discarding the less useful ones. | |
| The estimated residual noise signal is generated from the U-Net output features by means of a 3x3 convolutional layer. | |
| The first stage output is then calculated adding the original input spectrogram to the residual noise. | |
| The attention-guided features are computed using the attention masks M, which are directly calculated from the first stage output with a 1x1 convolution and a sigmoid function. | |
| ''' | |
| def __init__(self, n_feat): | |
| super(SAM, self).__init__() | |
| ksize=(3,3) | |
| self.paddings_1=get_paddings(ksize) | |
| self.conv1 = layers.Conv2D(filters=n_feat, | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=None) | |
| ksize=(3,3) | |
| self.paddings_2=get_paddings(ksize) | |
| self.conv2=layers.Conv2D(filters=2, | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=None) | |
| ksize=(3,3) | |
| self.paddings_3=get_paddings(ksize) | |
| self.conv3 = layers.Conv2D(filters=n_feat, | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=None) | |
| self.cropadd=CropAddBlock() | |
| def call(self, inputs, input_spectrogram): | |
| x1=tf.pad(inputs, self.paddings_1, mode='SYMMETRIC') | |
| x1 = self.conv1(x1) | |
| x=tf.pad(inputs, self.paddings_2, mode='SYMMETRIC') | |
| x=self.conv2(x) | |
| #residual prediction | |
| pred = layers.Add()([x, input_spectrogram]) #features to next stage | |
| x3=tf.pad(pred, self.paddings_3, mode='SYMMETRIC') | |
| M=self.conv3(x3) | |
| M= tf.keras.activations.sigmoid(M) | |
| x1=layers.Multiply()([x1, M]) | |
| x1 = layers.Add()([x1, inputs]) #features to next stage | |
| return x1, pred | |
| class AddFreqEncoding(layers.Layer): | |
| ''' | |
| [B, T, F, 2] => [B, T, F, 12] | |
| Generates frequency positional embeddings and concatenates them as 10 extra channels | |
| This function is optimized for F=1025 | |
| ''' | |
| def __init__(self, f_dim): | |
| super(AddFreqEncoding, self).__init__() | |
| pi = tf.constant(m.pi) | |
| pi=tf.cast(pi,'float32') | |
| self.f_dim=f_dim #f_dim is fixed | |
| n=tf.cast(tf.range(f_dim)/(f_dim-1),'float32') | |
| coss=tf.math.cos(pi*n) | |
| f_channel = tf.expand_dims(coss, -1) #(1025,1) | |
| self.fembeddings= f_channel | |
| for k in range(1,10): | |
| coss=tf.math.cos(2**k*pi*n) | |
| f_channel = tf.expand_dims(coss, -1) #(1025,1) | |
| self.fembeddings=tf.concat([self.fembeddings,f_channel],axis=-1) #(1025,10) | |
| def call(self, input_tensor): | |
| batch_size_tensor = tf.shape(input_tensor)[0] # get batch size | |
| time_dim = tf.shape(input_tensor)[1] # get time dimension | |
| fembeddings_2 = tf.broadcast_to(self.fembeddings, [batch_size_tensor, time_dim, self.f_dim, 10]) | |
| return tf.concat([input_tensor,fembeddings_2],axis=-1) #(batch,427,1025,12) | |
| def get_paddings(K): | |
| return tf.constant([[0,0],[K[0]//2, K[0]//2 -(1- K[0]%2) ], [ K[1]//2, K[1]//2 -(1- K[1]%2) ],[0,0]]) | |
| class Decoder(layers.Layer): | |
| ''' | |
| [B, T, F, N] , skip connections => [B, T, F, N] | |
| Decoder side of the U-Net subnetwork. | |
| ''' | |
| def __init__(self, Ns, Ss, unet_args): | |
| super(Decoder, self).__init__() | |
| self.Ns=Ns | |
| self.Ss=Ss | |
| self.activation=unet_args.activation | |
| self.depth=unet_args.depth | |
| ksize=(3,3) | |
| self.paddings_3=get_paddings(ksize) | |
| self.conv2d_3=layers.Conv2D(filters=self.Ns[self.depth], | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=self.activation) | |
| self.cropadd=CropAddBlock() | |
| self.dblocks=[] | |
| for i in range(self.depth): | |
| self.dblocks.append(D_Block(layer_idx=i,N=self.Ns[i], S=self.Ss[i], activation=self.activation,num_tfc=unet_args.num_tfc)) | |
| def call(self,inputs, contracting_layers): | |
| x=inputs | |
| for i in range(self.depth,0,-1): | |
| x=self.dblocks[i-1](x, contracting_layers[i-1]) | |
| return x | |
| class Encoder(tf.keras.Model): | |
| ''' | |
| [B, T, F, N] => skip connections , [B, T, F, N_4] | |
| Encoder side of the U-Net subnetwork. | |
| ''' | |
| def __init__(self, Ns, Ss, unet_args): | |
| super(Encoder, self).__init__() | |
| self.Ns=Ns | |
| self.Ss=Ss | |
| self.activation=unet_args.activation | |
| self.depth=unet_args.depth | |
| self.contracting_layers = {} | |
| self.eblocks=[] | |
| for i in range(self.depth): | |
| self.eblocks.append(E_Block(layer_idx=i,N0=self.Ns[i],N=self.Ns[i+1],S=self.Ss[i], activation=self.activation , num_tfc=unet_args.num_tfc)) | |
| self.i_block=I_Block(self.Ns[self.depth],self.activation,unet_args.num_tfc) | |
| def call(self, inputs): | |
| x=inputs | |
| for i in range(self.depth): | |
| x, x_contract=self.eblocks[i](x) | |
| self.contracting_layers[i] = x_contract #if remove 0, correct this | |
| x=self.i_block(x) | |
| return x, self.contracting_layers | |
| class MultiStage_denoise(tf.keras.Model): | |
| def __init__(self, unet_args=None): | |
| super(MultiStage_denoise, self).__init__() | |
| self.activation=unet_args.activation | |
| self.depth=unet_args.depth | |
| if unet_args.use_fencoding: | |
| self.freq_encoding=AddFreqEncoding(unet_args.f_dim) | |
| self.use_sam=unet_args.use_SAM | |
| self.use_fencoding=unet_args.use_fencoding | |
| self.num_stages=unet_args.num_stages | |
| #Encoder | |
| self.Ns= [32,64,64,128,128,256,512] | |
| self.Ss= [(2,2),(2,2),(2,2),(2,2),(2,2),(2,2)] | |
| #initial feature extractor | |
| ksize=(7,7) | |
| self.paddings_1=get_paddings(ksize) | |
| self.conv2d_1 = layers.Conv2D(filters=self.Ns[0], | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=self.activation) | |
| self.encoder_s1=Encoder(self.Ns, self.Ss, unet_args) | |
| self.decoder_s1=Decoder(self.Ns, self.Ss, unet_args) | |
| self.cropconcat = CropConcatBlock() | |
| self.cropadd = CropAddBlock() | |
| self.finalblock=FinalBlock() | |
| if self.num_stages>1: | |
| self.sam_1=SAM(self.Ns[0]) | |
| #initial feature extractor | |
| ksize=(7,7) | |
| self.paddings_2=get_paddings(ksize) | |
| self.conv2d_2 = layers.Conv2D(filters=self.Ns[0], | |
| kernel_size=ksize, | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID', | |
| activation=self.activation) | |
| self.encoder_s2=Encoder(self.Ns, self.Ss, unet_args) | |
| self.decoder_s2=Decoder(self.Ns, self.Ss, unet_args) | |
| def call(self, inputs): | |
| if self.use_fencoding: | |
| x_w_freq=self.freq_encoding(inputs) #None, None, 1025, 12 | |
| else: | |
| x_w_freq=inputs | |
| #intitial feature extractor | |
| x=tf.pad(x_w_freq, self.paddings_1, mode='SYMMETRIC') | |
| x=self.conv2d_1(x) #None, None, 1025, 32 | |
| x, contracting_layers_s1= self.encoder_s1(x) | |
| #decoder | |
| feats_s1 =self.decoder_s1(x, contracting_layers_s1) #None, None, 1025, 32 features | |
| if self.num_stages>1: | |
| #SAM module | |
| Fout, pred_stage_1=self.sam_1(feats_s1,inputs) | |
| #intitial feature extractor | |
| x=tf.pad(x_w_freq, self.paddings_2, mode='SYMMETRIC') | |
| x=self.conv2d_2(x) | |
| if self.use_sam: | |
| x = tf.concat([x, Fout], axis=-1) | |
| else: | |
| x = tf.concat([x,feats_s1], axis=-1) | |
| x, contracting_layers_s2= self.encoder_s2(x) | |
| feats_s2=self.decoder_s2(x, contracting_layers_s2) #None, None, 1025, 32 features | |
| #consider implementing a third stage? | |
| pred_stage_2=self.finalblock(feats_s2) | |
| return pred_stage_2, pred_stage_1 | |
| else: | |
| pred_stage_1=self.finalblock(feats_s1) | |
| return pred_stage_1 | |
| class I_Block(layers.Layer): | |
| ''' | |
| [B, T, F, N] => [B, T, F, N] | |
| Intermediate block: | |
| Basically, a densenet block with a residual connection | |
| ''' | |
| def __init__(self,N,activation, num_tfc, **kwargs): | |
| super(I_Block, self).__init__(**kwargs) | |
| ksize=(3,3) | |
| self.tfc=DenseBlock(num_tfc,N,ksize, activation) | |
| self.conv2d_res= layers.Conv2D(filters=N, | |
| kernel_size=(1,1), | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| padding='VALID') | |
| def call(self,inputs): | |
| x=self.tfc(inputs) | |
| inputs_proj=self.conv2d_res(inputs) | |
| return layers.Add()([x,inputs_proj]) | |
| class E_Block(layers.Layer): | |
| def __init__(self, layer_idx,N0, N, S,activation, num_tfc, **kwargs): | |
| super(E_Block, self).__init__(**kwargs) | |
| self.layer_idx=layer_idx | |
| self.N0=N0 | |
| self.N=N | |
| self.S=S | |
| self.activation=activation | |
| self.i_block=I_Block(N0,activation,num_tfc) | |
| ksize=(S[0]+2,S[1]+2) | |
| self.paddings_2=get_paddings(ksize) | |
| self.conv2d_2 = layers.Conv2D(filters=N, | |
| kernel_size=(S[0]+2,S[1]+2), | |
| kernel_initializer=TruncatedNormal(), | |
| strides=S, | |
| padding='VALID', | |
| activation=self.activation) | |
| def call(self, inputs, training=None, **kwargs): | |
| x=self.i_block(inputs) | |
| x_down=tf.pad(x, self.paddings_2, mode='SYMMETRIC') | |
| x_down = self.conv2d_2(x_down) | |
| return x_down, x | |
| def get_config(self): | |
| return dict(layer_idx=self.layer_idx, | |
| N=self.N, | |
| S=self.S, | |
| **super(E_Block, self).get_config() | |
| ) | |
| class D_Block(layers.Layer): | |
| def __init__(self, layer_idx, N, S,activation, num_tfc, **kwargs): | |
| super(D_Block, self).__init__(**kwargs) | |
| self.layer_idx=layer_idx | |
| self.N=N | |
| self.S=S | |
| self.activation=activation | |
| ksize=(S[0]+2, S[1]+2) | |
| self.paddings_1=get_paddings(ksize) | |
| self.tconv_1= layers.Conv2DTranspose(filters=N, | |
| kernel_size=(S[0]+2, S[1]+2), | |
| kernel_initializer=TruncatedNormal(), | |
| strides=S, | |
| activation=self.activation, | |
| padding='VALID') | |
| self.upsampling = layers.UpSampling2D(size=S, interpolation='nearest') | |
| self.projection = layers.Conv2D(filters=N, | |
| kernel_size=(1,1), | |
| kernel_initializer=TruncatedNormal(), | |
| strides=1, | |
| activation=self.activation, | |
| padding='VALID') | |
| self.cropadd=CropAddBlock() | |
| self.cropconcat=CropConcatBlock() | |
| self.i_block=I_Block(N,activation,num_tfc) | |
| def call(self, inputs, bridge, previous_encoder=None, previous_decoder=None,**kwargs): | |
| x = inputs | |
| x=tf.pad(x, self.paddings_1, mode='SYMMETRIC') | |
| x = self.tconv_1(inputs) | |
| x2= self.upsampling(inputs) | |
| if x2.shape[-1]!=x.shape[-1]: | |
| x2= self.projection(x2) | |
| x= self.cropadd(x,x2) | |
| x=self.cropconcat(x,bridge) | |
| x=self.i_block(x) | |
| return x | |
| def get_config(self): | |
| return dict(layer_idx=self.layer_idx, | |
| N=self.N, | |
| S=self.S, | |
| **super(D_Block, self).get_config() | |
| ) | |
| class CropAddBlock(layers.Layer): | |
| def call(self,down_layer, x, **kwargs): | |
| x1_shape = tf.shape(down_layer) | |
| x2_shape = tf.shape(x) | |
| height_diff = (x1_shape[1] - x2_shape[1]) // 2 | |
| width_diff = (x1_shape[2] - x2_shape[2]) // 2 | |
| down_layer_cropped = down_layer[:, | |
| height_diff: (x2_shape[1] + height_diff), | |
| width_diff: (x2_shape[2] + width_diff), | |
| :] | |
| x = layers.Add()([down_layer_cropped, x]) | |
| return x | |
| class CropConcatBlock(layers.Layer): | |
| def call(self, down_layer, x, **kwargs): | |
| x1_shape = tf.shape(down_layer) | |
| x2_shape = tf.shape(x) | |
| height_diff = (x1_shape[1] - x2_shape[1]) // 2 | |
| width_diff = (x1_shape[2] - x2_shape[2]) // 2 | |
| down_layer_cropped = down_layer[:, | |
| height_diff: (x2_shape[1] + height_diff), | |
| width_diff: (x2_shape[2] + width_diff), | |
| :] | |
| x = tf.concat([down_layer_cropped, x], axis=-1) | |
| return x | |