Thundernet / model /model_ppm_factors.py
ExtendedRealityLab's picture
Add files using upload-large-folder tool
ae29340 verified
from tensorflow.keras.layers import (
Input,
Lambda,
Concatenate,
Conv2D,
Conv2DTranspose,
MaxPooling2D,
BatchNormalization,
Activation,
Add,
AveragePooling2D,
UpSampling2D,
SeparableConv2D,
SpatialDropout2D,
)
from tensorflow.keras.models import Model
from keras import callbacks
import keras.optimizers
from tensorflow.keras.regularizers import l2
import tensorflow as tf
from tensorflow.python import pywrap_tensorflow
class Thundernet:
def __init__(
self,
input_shape=(512, 1024, 3),
resnet_trainable=False,
kernel_regularizer=0,
n_classes=2,
add_2x1up_layer=False,
add_2up_layer=False,
resize_first=False,
):
self.input_shape = input_shape
self.resnet_trainable = resnet_trainable
self.n_classes = n_classes
self.model = self.thundernet(
input_shape,
resnet_trainable,
kernel_regularizer,
add_2x1up_layer,
add_2up_layer,
resize_first,
)
self.load_resnet_weights()
self.add_2x1up_layer = add_2x1up_layer
self.add_2up_layer = add_2up_layer
self.resize_first = resize_first
def resnet_layer(
self,
inp,
downsample_first=True,
filters=64,
first=False,
number=0,
resnet_trainable=False,
kernel_regularizer=0,
):
if downsample_first:
conv_1 = Conv2D(
filters,
kernel_size=3,
strides=2,
padding="same",
name="conv2d_" + str(2 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(inp)
else:
conv_1 = Conv2D(
filters,
kernel_size=3,
strides=1,
padding="same",
name="conv2d_" + str(2 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(inp)
bn_1 = BatchNormalization(
axis=3,
name="batch_normalization_" + str(1 + (number - 1) * 4),
trainable=resnet_trainable,
)(conv_1)
relu_1 = Activation("relu")(bn_1)
conv_2 = Conv2D(
filters,
kernel_size=3,
strides=1,
padding="same",
name="conv2d_" + str(3 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(relu_1)
bn_2 = BatchNormalization(
axis=3,
name="batch_normalization_" + str(2 + (number - 1) * 4),
trainable=resnet_trainable,
)(conv_2)
if downsample_first:
shortcut_1 = Conv2D(
filters,
kernel_size=1,
strides=2,
padding="same",
name="conv2d_" + str(1 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(inp)
# bn_short = BatchNormalization(axis = 3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
joint = Add()([shortcut_1, bn_2])
elif first:
shortcut_1 = Conv2D(
filters,
kernel_size=1,
strides=1,
padding="same",
name="conv2d_" + str(1 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(inp)
# bn_short = BatchNormalization(axis=3, name = 'batch_normalization_' + str(1+(number-1)*5))(shortcut_1)
joint = Add()([shortcut_1, bn_2])
else:
joint = Add()([inp, bn_2])
block_1 = Activation("relu")(joint)
conv_3 = Conv2D(
filters,
kernel_size=3,
strides=1,
padding="same",
name="conv2d_" + str(4 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(block_1)
bn_3 = BatchNormalization(
axis=3,
name="batch_normalization_" + str(3 + (number - 1) * 4),
trainable=resnet_trainable,
)(conv_3)
relu_3 = Activation("relu")(bn_3)
conv_4 = Conv2D(
filters,
kernel_size=3,
strides=1,
padding="same",
name="conv2d_" + str(5 + (number - 1) * 5),
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(relu_3)
bn_4 = BatchNormalization(
axis=3,
name="batch_normalization_" + str(4 + (number - 1) * 4),
trainable=resnet_trainable,
)(conv_4)
joint_2 = Add()([block_1, bn_4])
out = Activation("relu")(joint_2)
return out
def pyramid_pooling_block(self, input_tensor, number=0, kernel_regularizer=0):
concat_list = []
# w = input_tensor.shape[1].value
# h = input_tensor.shape[2].value
w = input_tensor.shape[1]
h = input_tensor.shape[2]
if w == None:
w = 45
if h == None:
h = 45
k = 0
for bin_size in [1, 3, 6]:
x = AveragePooling2D(
pool_size=(w // bin_size, h // bin_size),
strides=(w // bin_size, h // bin_size),
)(input_tensor)
x = Conv2D(
512,
kernel_size=1,
strides=1,
padding="same",
name="conv2d_" + str(number + k),
kernel_regularizer=l2(kernel_regularizer),
)(x)
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
concat_list.append(x)
k += 1
for bin_size in [12, 18, 24]:
x = AveragePooling2D(
pool_size=(w // bin_size, h // bin_size),
strides=(w // bin_size, h // bin_size),
)(input_tensor)
x = Conv2D(
256,
kernel_size=1,
strides=1,
padding="same",
name="conv2d_" + str(number + k),
kernel_regularizer=l2(kernel_regularizer),
)(x)
x = Lambda(lambda x: tf.image.resize(x, (w, h)))(x)
concat_list.append(x)
k += 1
ppm = Concatenate()(concat_list)
conv = Conv2D(
256,
kernel_size=1,
name="conv2d_" + str(number + k),
kernel_regularizer=l2(kernel_regularizer),
)(ppm)
out = Activation("relu")(conv)
return out
def decoder_block(self, inp, filters, number=0, kernel_regularizer=0):
# filters = inp.shape[3]
conv_1 = Conv2D(
filters,
kernel_size=1,
name="conv2d_" + str(number),
kernel_regularizer=l2(kernel_regularizer),
)(inp)
# conv_1 = SeparableConv2D(filters, kernel_size=1, name='conv2d_' + str(number), kernel_regularizer=l2(kernel_regularizer))(inp)
deconv = Conv2DTranspose(filters, kernel_size=3, strides=2, padding="same")(
conv_1
)
bn_1 = BatchNormalization(axis=3, name="batch_normalization_" + str(number))(
deconv
)
conv_2 = Conv2D(
filters // 2,
kernel_size=1,
name="conv2d_" + str(number + 1),
kernel_regularizer=l2(kernel_regularizer),
)(bn_1)
# conv_2 = SeparableConv2D(filters // 2, kernel_size=1, name='conv2d_' + str(number + 1), kernel_regularizer=l2(kernel_regularizer))(bn_1)
bn_2 = BatchNormalization(
axis=3, name="batch_normalization_" + str(number + 1)
)(conv_2)
inp_deconv = Conv2DTranspose(
filters // 2, kernel_size=3, strides=2, padding="same"
)(inp)
inp_bn = BatchNormalization(
axis=3, name="batch_normalization_" + str(number + 2)
)(inp_deconv)
joint = Add()([inp_bn, bn_2])
out = Activation("relu")(joint)
return out
def thundernet(
self,
input_shape=(512, 1024, 3),
resnet_trainable=False,
kernel_regularizer=0,
add_2x1up_layer=False,
add_2up_layer=False,
resize_first=False,
):
# This returns a tensor
inputs = Input(shape=(input_shape))
if resize_first:
# Lambda are needed so that you can have
# aux = Lambda(lambda x: tf.image.resize_images(x, (480, 640)))(inputs)
aux = Lambda(
lambda x: tf.image.resize(
x, (inputs.shape[0] // 2, inputs.shape[1] // 2)
)
)(inputs)
else:
aux = inputs
# a layer instance is callable on a tensor, and returns a tensor
conv_1 = Conv2D(
64,
kernel_size=3,
strides=2,
padding="same",
name="conv2d",
use_bias=False,
trainable=resnet_trainable,
kernel_regularizer=l2(kernel_regularizer),
)(aux)
bn_1 = BatchNormalization(
axis=3, name="batch_normalization", trainable=resnet_trainable
)(conv_1)
relu_1 = Activation("relu")(bn_1)
maxp_1 = MaxPooling2D(pool_size=(3, 3), strides=2, padding="same")(relu_1)
res1 = self.resnet_layer(
maxp_1,
downsample_first=False,
filters=64,
first=True,
number=1,
resnet_trainable=resnet_trainable,
kernel_regularizer=kernel_regularizer,
)
# res1 = SpatialDropout2D(0.25)(res1)
res2 = self.resnet_layer(
res1,
downsample_first=True,
filters=128,
first=False,
number=2,
resnet_trainable=resnet_trainable,
kernel_regularizer=kernel_regularizer,
)
# res2 = SpatialDropout2D(0.25)(res2)
res3 = self.resnet_layer(
res2,
downsample_first=True,
filters=256,
first=False,
number=3,
resnet_trainable=resnet_trainable,
kernel_regularizer=kernel_regularizer,
)
ppm = self.pyramid_pooling_block(
res3, number=16, kernel_regularizer=kernel_regularizer
)
# ppm = Add()([ppm,res3])
ppm = Concatenate()([ppm, res3])
0
dec_1 = self.decoder_block(
ppm, 256, number=30, kernel_regularizer=kernel_regularizer
)
# dec_1 = Add()([dec_1, res2])
dec_1 = Concatenate()([dec_1, res2])
dec_2 = self.decoder_block(
dec_1, 128, number=33, kernel_regularizer=kernel_regularizer
)
# dec_2 = Add()([dec_2, res1])
dec_2 = Concatenate()([dec_2, res1])
# dec_3 = self.decoder_block(dec_2, 128, number=27)
if add_2x1up_layer:
if add_2up_layer:
dec_3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_2)
ups = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_3)
else:
ups = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2)
print("adding the new upsampling")
ups_2 = UpSampling2D(size=(1, 2), interpolation="bilinear")(ups)
else:
if add_2up_layer:
dec_3 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_2)
ups_2 = UpSampling2D(size=(2, 2), interpolation="bilinear")(dec_3)
else:
ups_2 = UpSampling2D(size=(4, 4), interpolation="bilinear")(dec_2)
out = Conv2D(
filters=int(self.n_classes),
kernel_size=1,
activation="softmax",
name="conv2d_out",
)(ups_2)
model = Model(inputs=inputs, outputs=out)
return model
def load_resnet_weights(self):
print("Loading weights for resnet18 backbone")
checkpoint_path = "./resnet/resnet18/checkpoints/model/model.ckpt-5865"
# reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
reader = tf.compat.v1.train.NewCheckpointReader(checkpoint_path) # for tf 2.0
var_to_shape_map = reader.get_variable_to_shape_map()
# for key in var_to_shape_map:
# print("tensor_name: ", key)
# print(reader.get_tensor(key).shape) # Remove this is you want to print only variable names
for k in range(0, 16):
layer_name = "conv2d"
if k != 0:
layer_name += "_" + str(k)
weights_key = layer_name + "/kernel"
weights = reader.get_tensor(weights_key)
# print(weights.shape)
keras_weights = self.model.get_layer(layer_name).get_weights()
# print(keras_weights[0].shape)
self.model.get_layer(layer_name).set_weights([weights])
layer_name = "batch_normalization"
if k != 0:
layer_name += "_" + str(k)
if k < 13:
beta_key = layer_name + "/beta"
beta = reader.get_tensor(beta_key)
gamma_key = layer_name + "/gamma"
gamma = reader.get_tensor(gamma_key)
mean_key = layer_name + "/moving_mean"
mean = reader.get_tensor(mean_key)
var_key = layer_name + "/moving_variance"
var = reader.get_tensor(var_key)
keras_weights = self.model.get_layer(layer_name).get_weights()
# print(len(keras_weights))
# print(keras_weights[0].shape)
self.model.get_layer(layer_name).set_weights([gamma, beta, mean, var])
print("Weights for resnet18 backbone loaded!")