| from tensorflow.keras.layers import ( |
| Input, |
| Lambda, |
| Concatenate, |
| Conv2D, |
| Conv2DTranspose, |
| MaxPooling2D, |
| BatchNormalization, |
| Activation, |
| Add, |
| AveragePooling2D, |
| UpSampling2D, |
| SeparableConv2D, |
| SpatialDropout2D, |
| ) |
| from tensorflow.keras.models import Model |
| from tensorflow.keras.layers imorport ConvLSTM2D |
| from tensorflow.keras import callbacks |
| import tensorflow.keras.optimizers |
| from tensorflow.keras.regularizers import l2 |
| from tensorflow.python import pywrap_tensorflow |
| import tensorflow as tf |
|
|
|
|
| class Thundernet: |
|
|
| def __init__( |
| self, |
| input_shape=(512, 1024, 3), |
| resnet_trainable=False, |
| kernel_regularizer=0, |
| n_classes=38, |
| ): |
| self.input_shape = input_shape |
| self.resnet_trainable = resnet_trainable |
| self.n_classes = n_classes |
| self.model = self.thundernet(input_shape, resnet_trainable, kernel_regularizer) |
| self.load_resnet_weights() |
|
|
| def resnet_layer( |
| self, |
| inp, |
| downsample_first=True, |
| filters=64, |
| first=False, |
| number=0, |
| resnet_trainable=False, |
| kernel_regularizer=0, |
| ): |
| if downsample_first: |
| conv_1 = Conv2D( |
| filters, |
| kernel_size=3, |
| strides=2, |
| padding="same", |
| name="conv2d_" + str(2 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inp) |
| else: |
| conv_1 = Conv2D( |
| filters, |
| kernel_size=3, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(2 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inp) |
| bn_1 = BatchNormalization( |
| axis=3, |
| name="batch_normalization_" + str(1 + (number - 1) * 4), |
| trainable=resnet_trainable, |
| )(conv_1) |
| relu_1 = Activation("relu")(bn_1) |
| conv_2 = Conv2D( |
| filters, |
| kernel_size=3, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(3 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(relu_1) |
| bn_2 = BatchNormalization( |
| axis=3, |
| name="batch_normalization_" + str(2 + (number - 1) * 4), |
| trainable=resnet_trainable, |
| )(conv_2) |
| if downsample_first: |
| shortcut_1 = Conv2D( |
| filters, |
| kernel_size=1, |
| strides=2, |
| padding="same", |
| name="conv2d_" + str(1 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inp) |
| |
| joint = Add()([shortcut_1, bn_2]) |
| elif first: |
| shortcut_1 = Conv2D( |
| filters, |
| kernel_size=1, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(1 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inp) |
| |
| joint = Add()([shortcut_1, bn_2]) |
| else: |
| joint = Add()([inp, bn_2]) |
| block_1 = Activation("relu")(joint) |
| conv_3 = Conv2D( |
| filters, |
| kernel_size=3, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(4 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(block_1) |
| bn_3 = BatchNormalization( |
| axis=3, |
| name="batch_normalization_" + str(3 + (number - 1) * 4), |
| trainable=resnet_trainable, |
| )(conv_3) |
| relu_3 = Activation("relu")(bn_3) |
| conv_4 = Conv2D( |
| filters, |
| kernel_size=3, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(5 + (number - 1) * 5), |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(relu_3) |
| bn_4 = BatchNormalization( |
| axis=3, |
| name="batch_normalization_" + str(4 + (number - 1) * 4), |
| trainable=resnet_trainable, |
| )(conv_4) |
| joint_2 = Add()([block_1, bn_4]) |
| out = Activation("relu")(joint_2) |
| return out |
|
|
| def pyramid_pooling_block(self, input_tensor, number=0, kernel_regularizer=0): |
| concat_list = [] |
|
|
| w = input_tensor.shape[1] |
| h = input_tensor.shape[2] |
|
|
| if w == None: |
| w = 45 |
| if h == None: |
| h = 45 |
|
|
| k = 0 |
| import tensorflow as tf |
|
|
| for bin_size in [6, 12]: |
| x = AveragePooling2D( |
| pool_size=(w // bin_size, h // bin_size), |
| strides=(w // bin_size, h // bin_size), |
| )(input_tensor) |
| x = Conv2D( |
| 512, |
| kernel_size=1, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(number + k), |
| kernel_regularizer=l2(kernel_regularizer), |
| )(x) |
| x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x) |
| concat_list.append(x) |
| k += 1 |
|
|
| for bin_size in [18, 24]: |
| x = AveragePooling2D( |
| pool_size=(w // bin_size, h // bin_size), |
| strides=(w // bin_size, h // bin_size), |
| )(input_tensor) |
| x = Conv2D( |
| 256, |
| kernel_size=1, |
| strides=1, |
| padding="same", |
| name="conv2d_" + str(number + k), |
| kernel_regularizer=l2(kernel_regularizer), |
| )(x) |
| x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x) |
| concat_list.append(x) |
| k += 1 |
|
|
| ppm = Concatenate()(concat_list) |
| conv = Conv2D( |
| 256, |
| kernel_size=1, |
| name="conv2d_" + str(number + k), |
| kernel_regularizer=l2(kernel_regularizer), |
| )(ppm) |
| out = Activation("relu")(conv) |
|
|
| return out |
|
|
| def decoder_block(self, inp, filters, number=0, kernel_regularizer=0): |
| |
| conv_1 = Conv2D( |
| filters, |
| kernel_size=1, |
| name="conv2d_" + str(number), |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inp) |
| |
| deconv = Conv2DTranspose(filters, kernel_size=3, strides=2, padding="same")( |
| conv_1 |
| ) |
| bn_1 = BatchNormalization(axis=3, name="batch_normalization_" + str(number))( |
| deconv |
| ) |
| conv_2 = Conv2D( |
| filters // 2, |
| kernel_size=1, |
| name="conv2d_" + str(number + 1), |
| kernel_regularizer=l2(kernel_regularizer), |
| )(bn_1) |
| |
| bn_2 = BatchNormalization( |
| axis=3, name="batch_normalization_" + str(number + 1) |
| )(conv_2) |
|
|
| inp_deconv = Conv2DTranspose( |
| filters // 2, kernel_size=3, strides=2, padding="same" |
| )(inp) |
| inp_bn = BatchNormalization( |
| axis=3, name="batch_normalization_" + str(number + 2) |
| )(inp_deconv) |
|
|
| joint = Add()([inp_bn, bn_2]) |
| out = Activation("relu")(joint) |
| return out |
|
|
| def thundernet( |
| self, input_shape=(512, 1024, 3), resnet_trainable=False, kernel_regularizer=0 |
| ): |
| |
| inputs = Input(shape=(input_shape)) |
|
|
| |
| conv_1 = Conv2D( |
| 64, |
| kernel_size=3, |
| strides=2, |
| padding="same", |
| name="conv2d", |
| use_bias=False, |
| trainable=resnet_trainable, |
| kernel_regularizer=l2(kernel_regularizer), |
| )(inputs) |
| bn_1 = BatchNormalization( |
| axis=3, name="batch_normalization", trainable=resnet_trainable |
| )(conv_1) |
| relu_1 = Activation("relu")(bn_1) |
| maxp_1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(relu_1) |
|
|
| res1 = self.resnet_layer( |
| maxp_1, |
| downsample_first=False, |
| filters=64, |
| first=True, |
| number=1, |
| resnet_trainable=resnet_trainable, |
| kernel_regularizer=kernel_regularizer, |
| ) |
| |
| res2 = self.resnet_layer( |
| res1, |
| downsample_first=True, |
| filters=128, |
| first=False, |
| number=2, |
| resnet_trainable=resnet_trainable, |
| kernel_regularizer=kernel_regularizer, |
| ) |
| |
| res3 = self.resnet_layer( |
| res2, |
| downsample_first=True, |
| filters=256, |
| first=False, |
| number=3, |
| resnet_trainable=resnet_trainable, |
| kernel_regularizer=kernel_regularizer, |
| ) |
|
|
| ppm = self.pyramid_pooling_block( |
| res3, number=16, kernel_regularizer=kernel_regularizer |
| ) |
| |
| ppm = Concatenate()([ppm, res3]) |
|
|
| dec_1 = self.decoder_block( |
| ppm, 256, number=21, kernel_regularizer=kernel_regularizer |
| ) |
| |
| dec_1 = Concatenate()([dec_1, res2]) |
|
|
| dec_2 = self.decoder_block( |
| dec_1, 128, number=24, kernel_regularizer=kernel_regularizer |
| ) |
| |
| dec_2 = Concatenate()([dec_2, res1]) |
|
|
| |
|
|
| ups = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2) |
| |
|
|
| out = Conv2D( |
| filters=int(self.n_classes), |
| kernel_size=1, |
| activation="softmax", |
| name="conv2d_out", |
| )(ups) |
|
|
| model = Model(inputs=inputs, outputs=out) |
| return model |
|
|
| def load_resnet_weights(self): |
| print("Loading weights for resnet18 backbone") |
| checkpoint_path = "./resnet/resnet18/checkpoints/model/model.ckpt-5865" |
| reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) |
| var_to_shape_map = reader.get_variable_to_shape_map() |
|
|
| |
| |
| |
|
|
| for k in range(0, 16): |
| layer_name = "conv2d" |
| if k != 0: |
| layer_name += "_" + str(k) |
| weights_key = layer_name + "/kernel" |
| weights = reader.get_tensor(weights_key) |
| keras_weights = self.model.get_layer(layer_name).get_weights() |
| self.model.get_layer(layer_name).set_weights([weights]) |
|
|
| layer_name = "batch_normalization" |
| if k != 0: |
| layer_name += "_" + str(k) |
| if k < 13: |
| beta_key = layer_name + "/beta" |
| beta = reader.get_tensor(beta_key) |
| gamma_key = layer_name + "/gamma" |
| gamma = reader.get_tensor(gamma_key) |
| mean_key = layer_name + "/moving_mean" |
| mean = reader.get_tensor(mean_key) |
| var_key = layer_name + "/moving_variance" |
| var = reader.get_tensor(var_key) |
| keras_weights = self.model.get_layer(layer_name).get_weights() |
| self.model.get_layer(layer_name).set_weights([gamma, beta, mean, var]) |
| print("Weights for resnet18 backbone loaded!") |
|
|