from tensorflow.keras.layers import ( Input, Lambda, Concatenate, Conv2D, Conv2DTranspose, MaxPooling2D, BatchNormalization, Activation, Add, AveragePooling2D, UpSampling2D, SeparableConv2D, SpatialDropout2D, ) from tensorflow.keras.models import Model from tensorflow.keras.layers imorport ConvLSTM2D from tensorflow.keras import callbacks import tensorflow.keras.optimizers from tensorflow.keras.regularizers import l2 from tensorflow.python import pywrap_tensorflow import tensorflow as tf class Thundernet: def __init__( self, input_shape=(512, 1024, 3), resnet_trainable=False, kernel_regularizer=0, n_classes=38, ): self.input_shape = input_shape self.resnet_trainable = resnet_trainable self.n_classes = n_classes self.model = self.thundernet(input_shape, resnet_trainable, kernel_regularizer) self.load_resnet_weights() def resnet_layer( self, inp, downsample_first=True, filters=64, first=False, number=0, resnet_trainable=False, kernel_regularizer=0, ): if downsample_first: conv_1 = Conv2D( filters, kernel_size=3, strides=2, padding="same", name="conv2d_" + str(2 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(inp) else: conv_1 = Conv2D( filters, kernel_size=3, strides=1, padding="same", name="conv2d_" + str(2 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(inp) bn_1 = BatchNormalization( axis=3, name="batch_normalization_" + str(1 + (number - 1) * 4), trainable=resnet_trainable, )(conv_1) relu_1 = Activation("relu")(bn_1) conv_2 = Conv2D( filters, kernel_size=3, strides=1, padding="same", name="conv2d_" + str(3 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(relu_1) bn_2 = BatchNormalization( axis=3, name="batch_normalization_" + str(2 + (number - 1) * 4), trainable=resnet_trainable, )(conv_2) if downsample_first: shortcut_1 = Conv2D( filters, kernel_size=1, strides=2, padding="same", name="conv2d_" + str(1 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(inp) # bn_short = BatchNormalization(axis = 3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1) joint = Add()([shortcut_1, bn_2]) elif first: shortcut_1 = Conv2D( filters, kernel_size=1, strides=1, padding="same", name="conv2d_" + str(1 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(inp) # bn_short = BatchNormalization(axis=3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1) joint = Add()([shortcut_1, bn_2]) else: joint = Add()([inp, bn_2]) block_1 = Activation("relu")(joint) conv_3 = Conv2D( filters, kernel_size=3, strides=1, padding="same", name="conv2d_" + str(4 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(block_1) bn_3 = BatchNormalization( axis=3, name="batch_normalization_" + str(3 + (number - 1) * 4), trainable=resnet_trainable, )(conv_3) relu_3 = Activation("relu")(bn_3) conv_4 = Conv2D( filters, kernel_size=3, strides=1, padding="same", name="conv2d_" + str(5 + (number - 1) * 5), use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(relu_3) bn_4 = BatchNormalization( axis=3, name="batch_normalization_" + str(4 + (number - 1) * 4), trainable=resnet_trainable, )(conv_4) joint_2 = Add()([block_1, bn_4]) out = Activation("relu")(joint_2) return out def pyramid_pooling_block(self, input_tensor, number=0, kernel_regularizer=0): concat_list = [] w = input_tensor.shape[1] h = input_tensor.shape[2] if w == None: w = 45 if h == None: h = 45 k = 0 import tensorflow as tf for bin_size in [6, 12]: x = AveragePooling2D( pool_size=(w // bin_size, h // bin_size), strides=(w // bin_size, h // bin_size), )(input_tensor) x = Conv2D( 512, kernel_size=1, strides=1, padding="same", name="conv2d_" + str(number + k), kernel_regularizer=l2(kernel_regularizer), )(x) x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x) concat_list.append(x) k += 1 for bin_size in [18, 24]: x = AveragePooling2D( pool_size=(w // bin_size, h // bin_size), strides=(w // bin_size, h // bin_size), )(input_tensor) x = Conv2D( 256, kernel_size=1, strides=1, padding="same", name="conv2d_" + str(number + k), kernel_regularizer=l2(kernel_regularizer), )(x) x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x) concat_list.append(x) k += 1 ppm = Concatenate()(concat_list) conv = Conv2D( 256, kernel_size=1, name="conv2d_" + str(number + k), kernel_regularizer=l2(kernel_regularizer), )(ppm) out = Activation("relu")(conv) return out def decoder_block(self, inp, filters, number=0, kernel_regularizer=0): # filters = inp.shape[3] conv_1 = Conv2D( filters, kernel_size=1, name="conv2d_" + str(number), kernel_regularizer=l2(kernel_regularizer), )(inp) # conv_1 = SeparableConv2D(filters, kernel_size=1, name='conv2d_' + str(number), kernel_regularizer=l2(kernel_regularizer))(inp) deconv = Conv2DTranspose(filters, kernel_size=3, strides=2, padding="same")( conv_1 ) bn_1 = BatchNormalization(axis=3, name="batch_normalization_" + str(number))( deconv ) conv_2 = Conv2D( filters // 2, kernel_size=1, name="conv2d_" + str(number + 1), kernel_regularizer=l2(kernel_regularizer), )(bn_1) # conv_2 = SeparableConv2D(filters // 2, kernel_size=1, name='conv2d_' + str(number + 1), kernel_regularizer=l2(kernel_regularizer))(bn_1) bn_2 = BatchNormalization( axis=3, name="batch_normalization_" + str(number + 1) )(conv_2) inp_deconv = Conv2DTranspose( filters // 2, kernel_size=3, strides=2, padding="same" )(inp) inp_bn = BatchNormalization( axis=3, name="batch_normalization_" + str(number + 2) )(inp_deconv) joint = Add()([inp_bn, bn_2]) out = Activation("relu")(joint) return out def thundernet( self, input_shape=(512, 1024, 3), resnet_trainable=False, kernel_regularizer=0 ): # This returns a tensor inputs = Input(shape=(input_shape)) # a layer instance is callable on a tensor, and returns a tensor conv_1 = Conv2D( 64, kernel_size=3, strides=2, padding="same", name="conv2d", use_bias=False, trainable=resnet_trainable, kernel_regularizer=l2(kernel_regularizer), )(inputs) bn_1 = BatchNormalization( axis=3, name="batch_normalization", trainable=resnet_trainable )(conv_1) relu_1 = Activation("relu")(bn_1) maxp_1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(relu_1) res1 = self.resnet_layer( maxp_1, downsample_first=False, filters=64, first=True, number=1, resnet_trainable=resnet_trainable, kernel_regularizer=kernel_regularizer, ) # res1 = SpatialDropout2D(0.25)(res1) res2 = self.resnet_layer( res1, downsample_first=True, filters=128, first=False, number=2, resnet_trainable=resnet_trainable, kernel_regularizer=kernel_regularizer, ) # res2 = SpatialDropout2D(0.25)(res2) res3 = self.resnet_layer( res2, downsample_first=True, filters=256, first=False, number=3, resnet_trainable=resnet_trainable, kernel_regularizer=kernel_regularizer, ) ppm = self.pyramid_pooling_block( res3, number=16, kernel_regularizer=kernel_regularizer ) # ppm = Add()([ppm,res3]) ppm = Concatenate()([ppm, res3]) dec_1 = self.decoder_block( ppm, 256, number=21, kernel_regularizer=kernel_regularizer ) # dec_1 = Add()([dec_1, res2]) dec_1 = Concatenate()([dec_1, res2]) dec_2 = self.decoder_block( dec_1, 128, number=24, kernel_regularizer=kernel_regularizer ) # dec_2 = Add()([dec_2, res1]) dec_2 = Concatenate()([dec_2, res1]) # dec_3 = self.decoder_block(dec_2, 128, number=27) ups = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2) # ups = UpSampling2D(size=(2, 2), interpolation='bilinear')(dec_3) out = Conv2D( filters=int(self.n_classes), kernel_size=1, activation="softmax", name="conv2d_out", )(ups) model = Model(inputs=inputs, outputs=out) return model def load_resnet_weights(self): print("Loading weights for resnet18 backbone") checkpoint_path = "./resnet/resnet18/checkpoints/model/model.ckpt-5865" reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) var_to_shape_map = reader.get_variable_to_shape_map() # for key in var_to_shape_map: # print("tensor_name: ", key) # print(reader.get_tensor(key).shape) # Remove this is you want to print only variable names for k in range(0, 16): layer_name = "conv2d" if k != 0: layer_name += "_" + str(k) weights_key = layer_name + "/kernel" weights = reader.get_tensor(weights_key) keras_weights = self.model.get_layer(layer_name).get_weights() self.model.get_layer(layer_name).set_weights([weights]) layer_name = "batch_normalization" if k != 0: layer_name += "_" + str(k) if k < 13: beta_key = layer_name + "/beta" beta = reader.get_tensor(beta_key) gamma_key = layer_name + "/gamma" gamma = reader.get_tensor(gamma_key) mean_key = layer_name + "/moving_mean" mean = reader.get_tensor(mean_key) var_key = layer_name + "/moving_variance" var = reader.get_tensor(var_key) keras_weights = self.model.get_layer(layer_name).get_weights() self.model.get_layer(layer_name).set_weights([gamma, beta, mean, var]) print("Weights for resnet18 backbone loaded!")