| | """ |
| | Class to construct the different type of models |
| | """ |
| |
|
| | |
| | import tensorflow as tf |
| | from tensorflow import keras |
| | from tensorflow.keras import layers, Sequential |
| | from tensorflow.keras.layers import Dense, Input, Rescaling |
| | from tensorflow.keras.applications import MobileNet, ResNet50 |
| |
|
| | |
| | from keras.saving import register_keras_serializable |
| |
|
| | |
| | from defs import ModelType as mt |
| |
|
| |
|
| | class ModelBuilder: |
| | |
| |
|
| | def __init__(self, model_type, **model_params): |
| |
|
| | self.model_type = model_type |
| | self.model_params = model_params |
| | self.model = None |
| | self.model_name = None |
| |
|
| | |
| | if self.model_type in (mt.MOBILENET, mt.RESNET50): |
| | self.base_model_params = self.model_params.pop("base_model") |
| | self.model_name = self.base_model_params["name"] |
| | self.input_shape = self.base_model_params["input_shape"] |
| | self.base_trainable = self.model_params.pop("base_trainable") |
| | self.base_model = None |
| |
|
| | elif self.model_type == mt.CAPSNET: |
| | self.model_name = model_params.pop("name") |
| | self.input_shape = model_params.pop("input_shape") |
| | self.prim_caps_params = model_params.pop("prim_caps") |
| | self.digit_caps_params = model_params.pop("digit_caps") |
| | self.routing_algo = model_params.pop("routing_algo") |
| |
|
| | |
| | if self.model_type in ( |
| | mt.MOBILENET, |
| | mt.RESNET50, |
| | ): |
| | if self.input_shape != (224, 224, 3): |
| | raise Exception( |
| | f"input shape for {self.model_name} model must be (224,224,3)" |
| | ) |
| | elif self.model_type == mt.CAPSNET: |
| | if self.input_shape != (256, 256, 3): |
| | raise Exception( |
| | f"input shape for {self.model_name} model must be (256,256,3)" |
| | ) |
| | else: |
| | raise Exception( |
| | f"Model not supported: {self.model_name}. The model name must contain one substring from {mt.MOBILENET, mt.RESNET50, mt.CAPSNET}" |
| | ) |
| |
|
| | def get_augmentation_pipe(self): |
| | |
| | |
| | return Sequential( |
| | [ |
| | layers.RandomRotation(0.1), |
| | layers.RandomTranslation(height_factor=0.1, width_factor=0.1), |
| | layers.RandomZoom(0.1), |
| | ], |
| | name="augmentation", |
| | ) |
| |
|
| | def get_compiled_model(self): |
| | |
| | compile_params = self.model_params.pop("compile_params") |
| |
|
| | |
| | inputs = Input(shape=self.input_shape, name="inputs") |
| |
|
| | |
| | x_aug = self.get_augmentation_pipe()(inputs) |
| | |
| |
|
| | |
| | x = Rescaling(1.0 / 255)(x_aug) |
| |
|
| | |
| | match self.model_type: |
| | case mt.RESNET50: |
| | self.base_model = ResNet50(input_tensor=x_aug, **self.base_model_params) |
| | self.base_model.trainable = self.base_trainable |
| |
|
| | case mt.MOBILENET: |
| | self.base_model = MobileNet( |
| | input_tensor=x_aug, **self.base_model_params |
| | ) |
| | self.base_model.trainable = self.base_trainable |
| |
|
| | case mt.CAPSNET: |
| | self.base_model = None |
| | x = Rescaling(1.0 / 255)(x) |
| | outputs = self.build_capsnet(inputs=x_aug, **self.model_params) |
| |
|
| | case _: |
| | raise Exception( |
| | f"Model type {self.model_type} not supported: {self.model_name}" |
| | ) |
| |
|
| | |
| | if self.model_type in (mt.RESNET50, mt.MOBILENET): |
| | x = self.base_model.output |
| | outputs = Dense(4, activation="softmax")(x) |
| | elif self.model_type == mt.CAPSNET: |
| | pass |
| | else: |
| | raise Exception(f"No classifier head defined for {self.model_type}") |
| |
|
| | |
| | self.model = keras.Model(name=self.model_name, inputs=inputs, outputs=outputs) |
| | self.model.compile(**compile_params) |
| |
|
| | print(f"The {self.model_name} model has been compiled successfully") |
| |
|
| | return self.base_model, self.model |
| |
|
| | def build_capsnet(self, inputs, **params): |
| | """ |
| | Build a Capsule Network model for four class lung iseases classification: COVID, Normal, Pneumonia and Opacity. |
| | The batch dimension is always None internally → full input shape is (None, 256, 256, 1). |
| | The output shape is (None, 4, 1) |
| | Args: |
| | name (_type_): _description_ |
| | first_Conv2DKernel_size (int, optional): _description_. Defaults to 10. |
| | input_shape (tuple, optional): _description_. Defaults to (256, 256, 3). |
| | n_class (int, optional): _description_. Defaults to 4. |
| | routing_iters (int, optional): _description_. Defaults to 3. |
| | routing_algo (str, optional): _description_. Defaults to "by_agreement". |
| | |
| | Returns: |
| | model: to be compiled |
| | """ |
| |
|
| | first_Conv2DKernel_size = params.pop("first_Conv2DKernel_size") |
| |
|
| | |
| | x = inputs |
| |
|
| | |
| | |
| | x = layers.Conv2D( |
| | filters=64, |
| | kernel_size=first_Conv2DKernel_size, |
| | strides=2, |
| | padding="valid", |
| | activation="relu", |
| | )( |
| | x |
| | ) |
| | x = layers.BatchNormalization()(x) |
| |
|
| | x = layers.Conv2D(128, 5, strides=2, padding="same", activation="relu")( |
| | x |
| | ) |
| | x = layers.BatchNormalization()(x) |
| | x = layers.Dropout(0.25)(x) |
| |
|
| | x = layers.Conv2D(128, 3, strides=1, padding="same", activation="relu")(x) |
| | x = layers.BatchNormalization()(x) |
| |
|
| | x = layers.Conv2D(256, 3, strides=1, padding="same", activation="relu")(x) |
| | x = layers.BatchNormalization()(x) |
| | x = layers.Dropout(0.3)(x) |
| |
|
| | x = layers.Conv2D(512, 3, strides=1, padding="same", activation="relu")( |
| | x |
| | ) |
| | x = layers.BatchNormalization()(x) |
| |
|
| | x = layers.Dropout(0.3)( |
| | x |
| | ) |
| |
|
| | |
| | primary_caps = PrimaryCaps(**self.prim_caps_params)( |
| | x |
| | ) |
| | |
| | |
| | |
| | |
| | |
| |
|
| | digit_caps = DigitCaps(**self.digit_caps_params)( |
| | primary_caps |
| | ) |
| | |
| | |
| | |
| |
|
| | outputs = Length()(digit_caps) |
| |
|
| | return outputs |
| |
|
| |
|
| | |
| | def squash(vectors, axis=-1): |
| | s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True) |
| | |
| | scale = ( |
| | s_squared_norm |
| | / (1 + s_squared_norm) |
| | / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon()) |
| | ) |
| | return scale * vectors |
| |
|
| |
|
| | |
| | @register_keras_serializable() |
| | class PrimaryCaps(layers.Layer): |
| |
|
| | def __init__( |
| | self, dim_capsule, n_channels, kernel_size, strides, padding, **kwargs |
| | ): |
| | super(PrimaryCaps, self).__init__(**kwargs) |
| | self.conv = layers.Conv2D( |
| | filters=dim_capsule * n_channels, |
| | kernel_size=kernel_size, |
| | strides=strides, |
| | padding=padding, |
| | activation="relu", |
| | ) |
| | self.dim_capsule = dim_capsule |
| | self.n_channels = n_channels |
| | self.kernel_size = kernel_size |
| | self.strides = strides |
| | self.padding = padding |
| |
|
| | def build(self, input_shape): |
| | |
| | self.conv.build(input_shape) |
| | super().build(input_shape) |
| |
|
| | def call(self, inputs): |
| | outputs = self.conv(inputs) |
| | outputs = tf.reshape( |
| | outputs, |
| | ( |
| | -1, |
| | outputs.shape[1] * outputs.shape[2] * self.n_channels, |
| | self.dim_capsule, |
| | ), |
| | ) |
| | return squash(outputs) |
| |
|
| | def get_config(self): |
| | |
| | config = super().get_config() |
| | config.update( |
| | { |
| | "dim_capsule": self.dim_capsule, |
| | "n_channels": self.n_channels, |
| | "kernel_size": self.kernel_size, |
| | "strides": self.strides, |
| | "padding": self.padding, |
| | } |
| | ) |
| | return config |
| |
|
| |
|
| | @register_keras_serializable() |
| | class DigitCaps(layers.Layer): |
| | |
| |
|
| | def __init__(self, num_capsule, dim_capsule, routing_iters=3, **kwargs): |
| | super(DigitCaps, self).__init__(**kwargs) |
| | self.num_capsule = num_capsule |
| | self.dim_capsule = dim_capsule |
| | self.routing_iters = routing_iters |
| |
|
| | def build(self, input_shape): |
| | self.input_num_capsule = input_shape[1] |
| | self.input_dim_capsule = input_shape[2] |
| | self.W = self.add_weight( |
| | shape=[ |
| | self.input_num_capsule, |
| | self.num_capsule, |
| | self.input_dim_capsule, |
| | self.dim_capsule, |
| | ], |
| | initializer="glorot_uniform", |
| | trainable=True, |
| | ) |
| |
|
| | def call(self, inputs): |
| | inputs_expand = tf.expand_dims(inputs, 2) |
| | inputs_tiled = tf.expand_dims(inputs_expand, 3) |
| | inputs_tiled = tf.tile(inputs_tiled, [1, 1, self.num_capsule, 1, 1]) |
| | inputs_hat = tf.matmul(inputs_tiled, self.W) |
| |
|
| | b = tf.zeros( |
| | shape=[tf.shape(inputs)[0], self.input_num_capsule, self.num_capsule, 1, 1] |
| | ) |
| |
|
| | |
| | for i in range(self.routing_iters): |
| | c = tf.nn.softmax( |
| | b, axis=2 |
| | ) |
| | s = tf.reduce_sum( |
| | c * inputs_hat, axis=1, keepdims=True |
| | ) |
| | v = squash( |
| | s, axis=-2 |
| | ) |
| | if i < self.routing_iters - 1: |
| | b += tf.reduce_sum(inputs_hat * v, axis=-1, keepdims=True) |
| |
|
| | return tf.squeeze(v, axis=1) |
| |
|
| | def get_config(self): |
| | |
| | config = super().get_config() |
| | config.update( |
| | { |
| | "num_capsule": self.num_capsule, |
| | "dim_capsule": self.dim_capsule, |
| | "routing_iters": self.routing_iters, |
| | } |
| | ) |
| | return config |
| |
|
| |
|
| | |
| | @register_keras_serializable() |
| | class Length(layers.Layer): |
| | def call(self, inputs, **kwargs): |
| | return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1)) |
| |
|
| |
|
| | |
| | def margin_loss(y_true, y_pred): |
| | |
| | |
| | m_plus = 0.9 |
| | m_minus = 0.1 |
| | lambda_val = 0.5 |
| | L = y_true * tf.square(tf.maximum(0.0, m_plus - y_pred)) + lambda_val * ( |
| | 1 - y_true |
| | ) * tf.square(tf.maximum(0.0, y_pred - m_minus)) |
| | return tf.reduce_mean(tf.reduce_sum(L, axis=1)) |
| |
|
| |
|
| | capsnet_custom_objects = { |
| | "PrimaryCaps": PrimaryCaps, |
| | "DigitCaps": DigitCaps, |
| | "Length": Length, |
| | "margin_loss": margin_loss, |
| | } |
| |
|