valste commited on
Commit
0d908d7
·
1 Parent(s): 71a6ddb

fixing preprocess imports

Browse files
Files changed (2) hide show
  1. defs.py +213 -0
  2. modelbuilder.py +165 -121
defs.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from enum import Enum
3
+ from pprint import pprint
4
+ from pathlib import Path
5
+
6
+
7
+
8
+ def initDataPaths(project_dir=os.path.dirname(os.path.dirname(__file__))):
9
+ # initializes datapaths
10
+
11
+ global PROJECT_DIR
12
+ global METADATA_DIR
13
+ global IMAGE_DIRECTORIES
14
+ global TRAINIG_DATA_DIR_254_IMG_ORIENTATION
15
+ global TRULY_ROTATED_IMG_224
16
+ global TRAINIG_DATA_DIR_256_MASKED_IMBALANCED
17
+ global TRAINIG_DATA_DIR_256_MASKED_BALANCED
18
+ global MLRUNS_URI
19
+ global MLRUNS_DIR
20
+ global MODELS_DIR
21
+ global EXTERNAL_RAW_DEMO_DATA_DIR
22
+ global EXTERNAL_MASKED_DEMO_DATA_DIR
23
+
24
+ PROJECT_DIR = project_dir
25
+ MODELS_DIR = os.path.join(PROJECT_DIR, "models")
26
+ METADATA_DIR = os.path.join(PROJECT_DIR, r"metadata")
27
+ TRAINIG_DATA_DIR_254_IMG_ORIENTATION = os.path.join(
28
+ PROJECT_DIR, r"data_224x224\train_val_224x224"
29
+ )
30
+ TRULY_ROTATED_IMG_224 = os.path.join(PROJECT_DIR, r"224x224_truly_rotated")
31
+ TRAINIG_DATA_DIR_256_MASKED_IMBALANCED = os.path.join(
32
+ PROJECT_DIR, r"256x256_masked_images_imbalanced"
33
+ )
34
+ TRAINIG_DATA_DIR_256_MASKED_BALANCED = os.path.join(
35
+ PROJECT_DIR, r"256x256_masked_images_balanced"
36
+ )
37
+
38
+ MLRUNS_URI = Path(os.path.abspath(os.path.join(PROJECT_DIR, "mlruns_vst"))).as_uri()
39
+ MLRUNS_DIR = os.path.abspath(os.path.join(PROJECT_DIR, "mlruns_vst"))
40
+
41
+ EXTERNAL_RAW_DEMO_DATA_DIR = os.path.join(PROJECT_DIR, "src", "streamlit", "data", "data_for_product_demo", "unlabeled", "external", "external_raw_299x299")
42
+ EXTERNAL_MASKED_DEMO_DATA_DIR = os.path.join(PROJECT_DIR, "src", "streamlit", "data", "data_for_product_demo", "unlabeled", "external", "external_masked_256x256")
43
+
44
+ IMAGE_DIRECTORIES = {
45
+ "COVID": {
46
+ "images": os.path.join(
47
+ PROJECT_DIR, "data", "COVID-19_Radiography_Dataset", "COVID", "images"
48
+ ),
49
+ "masks": os.path.join(
50
+ PROJECT_DIR, "data", "COVID-19_Radiography_Dataset", "COVID", "masks"
51
+ ),
52
+ },
53
+ "Lung_Opacity": {
54
+ "images": os.path.join(
55
+ PROJECT_DIR,
56
+ "data",
57
+ "COVID-19_Radiography_Dataset",
58
+ "Lung_Opacity",
59
+ "images",
60
+ ),
61
+ "masks": os.path.join(
62
+ PROJECT_DIR,
63
+ "data",
64
+ "COVID-19_Radiography_Dataset",
65
+ "Lung_Opacity",
66
+ "masks",
67
+ ),
68
+ },
69
+ "Normal": {
70
+ "images": os.path.join(
71
+ PROJECT_DIR, "data", "COVID-19_Radiography_Dataset", "Normal", "images"
72
+ ),
73
+ "masks": os.path.join(
74
+ PROJECT_DIR, "data", "COVID-19_Radiography_Dataset", "Normal", "masks"
75
+ ),
76
+ },
77
+ "Viral Pneumonia": {
78
+ "images": os.path.join(
79
+ PROJECT_DIR,
80
+ "data",
81
+ "COVID-19_Radiography_Dataset",
82
+ "Viral Pneumonia",
83
+ "images",
84
+ ),
85
+ "masks": os.path.join(
86
+ PROJECT_DIR,
87
+ "data",
88
+ "COVID-19_Radiography_Dataset",
89
+ "Viral Pneumonia",
90
+ "masks",
91
+ ),
92
+ },
93
+ }
94
+
95
+
96
+ def checkPaths():
97
+ print(
98
+ "\nPROJECT_DIR: ",
99
+ PROJECT_DIR,
100
+ "\nMETADATA_DIR: ",
101
+ METADATA_DIR,
102
+ "\nIMAGE_DIRECTORIES: ",
103
+ IMAGE_DIRECTORIES,
104
+ "\nTRAINIG_DATA_DIR_254_IMG_ORIENTATION: ",
105
+ TRAINIG_DATA_DIR_254_IMG_ORIENTATION,
106
+ "\nTRULY_ROTATED_IMG_224: ",
107
+ TRULY_ROTATED_IMG_224,
108
+ "\nTRAINIG_DATA_DIR_256_MASKED_IMBALANCED: ",
109
+ TRAINIG_DATA_DIR_256_MASKED_IMBALANCED,
110
+ "\nTRAINIG_DATA_DIR_256_MASKED_BALANCED: ",
111
+ TRAINIG_DATA_DIR_256_MASKED_BALANCED,
112
+ "\nMLRUNS_URI: ",
113
+ MLRUNS_URI,
114
+ "\nMODELS_DIR: ",
115
+ MODELS_DIR,
116
+ )
117
+
118
+
119
+ #----setting paths----
120
+ initDataPaths()
121
+ #---and checking them----
122
+ checkPaths()
123
+
124
+
125
+
126
+ class _Base(str, Enum):
127
+ def __str__(self):
128
+ return self.value
129
+
130
+
131
+ class ModelPath(_Base):
132
+ CAPSNET = os.path.join(
133
+ MODELS_DIR, "capsnet-4class-disease-classifier", "model.keras"
134
+ )
135
+ COVID19 = os.path.join(MODELS_DIR, "ds-crx-covid19", "model.keras")
136
+ GAN = os.path.join(MODELS_DIR, "lung-segmentation-gan", "model.keras")
137
+ UNET = os.path.join(MODELS_DIR, "lung-segmentation-unet", "model.keras")
138
+ MOBNET = os.path.join(
139
+ MODELS_DIR, "orientation-classifier-224x224-aug-head1-mobnet", "model.keras"
140
+ )
141
+ RESNET = os.path.join(
142
+ MODELS_DIR, "orientation-classifier-224x224-aug-head2-resnet50", "model.keras"
143
+ )
144
+
145
+
146
+ class DiseaseCategory(_Base):
147
+ # Enum for the different disease categories
148
+ # alligned to file names without extension .png
149
+ VIRAL_PNEUMONIA = "Viral Pneumonia"
150
+ COVID = "COVID"
151
+ LUNG_OPACITY = "Lung_Opacity"
152
+ NORMAL = "Normal"
153
+
154
+
155
+ class ImageType(_Base):
156
+ IMAGES = "images"
157
+ MASKS = "masks"
158
+ MASKED = "masked"
159
+
160
+
161
+ class ModelType(_Base):
162
+ # Enum for the different model types
163
+ RESNET50 = "resnet50"
164
+ MOBILENET = "mobnet"
165
+ GAN = "gan"
166
+ UNET = "unet"
167
+ CUST_COVID_CNN = "cust_covid_cnn"
168
+ CAPSNET = "capsnet"
169
+
170
+
171
+ class ExperimentName(_Base):
172
+ # mlflow experiment names
173
+ ORIENTATION_CLASSIFIER = "orientation_classifier"
174
+ DESEASE_CLASSIFIER = "desease_classifier"
175
+
176
+
177
+ # >>>>>IMPORTANT: the mapping must be the same as for the training dataset!!!!<<<<<
178
+ # check loaded dataset
179
+ class_to_orientation_map = {
180
+ "long": {0: "rotated_0", 1: "rotated_180", 2: "rotated_90", 3: "rotated_minus_90"},
181
+ "short": {
182
+ 0: "0°",
183
+ 1: "180°",
184
+ 2: "90°",
185
+ 3: "-90°",
186
+ },
187
+ }
188
+
189
+ orientation_labels = {
190
+ "short": [
191
+ "0°",
192
+ "180°",
193
+ "90°",
194
+ "-90°",
195
+ ],
196
+ "long": ["rotated_0", "rotated_180", "rotated_90", "rotated_minus_90"],
197
+ }
198
+
199
+ class_to_disease_map = {
200
+ 0: "COVID",
201
+ 1: "Lung_Opacity",
202
+ 2: "Normal",
203
+ 3: "Viral Pneumonia",
204
+ }
205
+
206
+ disease_labels = ["COVID", "Lung_Opacity", "Normal", "Viral Pneumonia"]
207
+
208
+
209
+ class DatasetType(_Base):
210
+ TRAIN = "train"
211
+ TEST = "test"
212
+ PREDICT = "predict"
213
+
modelbuilder.py CHANGED
@@ -6,21 +6,17 @@ Class to construct the different type of models
6
  import tensorflow as tf
7
  from tensorflow import keras
8
  from tensorflow.keras import layers, Sequential
9
- from tensorflow.keras.layers import (
10
- Dense,
11
- Input,
12
- Rescaling
13
- )
14
  from tensorflow.keras.applications import MobileNet, ResNet50
15
 
16
  # --- CapsNet-specific
17
  from keras.saving import register_keras_serializable # For custom layer serialization
18
 
19
  # --- Project-specific
20
- from src.defs import ModelType as mt
21
 
22
 
23
- class ModelBuilder():
24
  # builds the models
25
 
26
  def __init__(self, model_type, **model_params):
@@ -32,40 +28,49 @@ class ModelBuilder():
32
 
33
  # config extractor and attributes adding by model type
34
  if self.model_type in (mt.MOBILENET, mt.RESNET50):
35
- self.base_model_params = self.model_params.pop("base_model")
36
- self.model_name = self.base_model_params["name"]
37
- self.input_shape = self.base_model_params["input_shape"]
38
- self.base_trainable = self.model_params.pop("base_trainable")
39
- self.base_model = None
40
-
41
  elif self.model_type == mt.CAPSNET:
42
- self.model_name = model_params.pop("name")
43
- self.input_shape = model_params.pop("input_shape")
44
- self.prim_caps_params = model_params.pop("prim_caps")
45
- self.digit_caps_params = model_params.pop("digit_caps")
46
- self.routing_algo = model_params.pop("routing_algo") # informative only
47
-
48
  # model_type vs input shape validation
49
- if self.model_type in (mt.MOBILENET, mt.RESNET50,):
50
- if self.input_shape != (224,224,3):
51
- raise Exception(f"input shape for {self.model_name} model must be (224,224,3)")
 
 
 
 
 
52
  elif self.model_type == mt.CAPSNET:
53
- if self.input_shape != (256,256,3):
54
- raise Exception(f"input shape for {self.model_name} model must be (256,256,3)")
 
 
55
  else:
56
- raise Exception(f"Model not supported: {self.model_name}. The model name must contain one substring from {mt.MOBILENET, mt.RESNET50, mt.CAPSNET}")
57
-
58
-
59
-
60
  def get_augmentation_pipe(self):
61
  # Random-* layers are stochastic only when training=True
62
  # disabled during inference/evaluation
63
- return Sequential([
64
- layers.RandomRotation(0.1),
65
- layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
66
- layers.RandomZoom(0.1),
67
- ], name="augmentation")
68
-
 
 
69
 
70
  def get_compiled_model(self):
71
  # Extract config
@@ -74,9 +79,11 @@ class ModelBuilder():
74
  # Define input layer
75
  inputs = Input(shape=self.input_shape, name="inputs")
76
  # Random-* layers are stochastic only when training=True
77
-
78
- x_aug = self.get_augmentation_pipe()(inputs) # stochastic only when training=True
79
- x = Rescaling(1./255)(x_aug) # disabled during inference/evaluation
 
 
80
 
81
  # Model selector
82
  match self.model_type:
@@ -85,21 +92,25 @@ class ModelBuilder():
85
  self.base_model.trainable = self.base_trainable
86
 
87
  case mt.MOBILENET:
88
- self.base_model = MobileNet(input_tensor=x_aug, **self.base_model_params)
 
 
89
  self.base_model.trainable = self.base_trainable
90
 
91
  case mt.CAPSNET:
92
  self.base_model = None
93
- x = Rescaling(1./255)(x)
94
- outputs = self.build_capsnet(inputs = x_aug, **self.model_params)
95
 
96
  case _:
97
- raise Exception(f"Model type {self.model_type} not supported: {self.model_name}")
 
 
98
 
99
  # Classification head
100
  if self.model_type in (mt.RESNET50, mt.MOBILENET):
101
  x = self.base_model.output
102
- outputs = Dense(4, activation='softmax')(x)
103
  elif self.model_type == mt.CAPSNET:
104
  pass
105
  else:
@@ -110,12 +121,9 @@ class ModelBuilder():
110
  self.model.compile(**compile_params)
111
 
112
  print(f"The {self.model_name} model has been compiled successfully")
113
-
114
- return self.base_model, self.model
115
 
 
116
 
117
-
118
-
119
  def build_capsnet(self, inputs, **params):
120
  """
121
  Build a Capsule Network model for four class lung iseases classification: COVID, Normal, Pneumonia and Opacity.
@@ -130,77 +138,100 @@ class ModelBuilder():
130
  Returns:
131
  model: to be compiled
132
  """
133
-
134
- first_Conv2DKernel_size = params.pop("first_Conv2DKernel_size")
135
-
136
  # --- Preprocessing Layers ---
137
  x = inputs
138
 
139
  # --- Feature Extraction ---
140
  # learns 64 different 3x3 filters
141
- x = layers.Conv2D(filters = 64, kernel_size=first_Conv2DKernel_size, strides=2, padding='valid', activation='relu')(x) # downsampling strides=2, no padding because only exposed lung area matters/contains features
 
 
 
 
 
 
 
 
142
  x = layers.BatchNormalization()(x)
143
 
144
- x = layers.Conv2D(128, 5, strides=2, padding='same', activation='relu')(x) # padding="same" because of transformed output of the 1rst conv2D-layer (None, 125, 125, 64) to not lose the spatial info
 
 
145
  x = layers.BatchNormalization()(x)
146
  x = layers.Dropout(0.25)(x) # Dropout after second block (early regularization)
147
 
148
- x = layers.Conv2D(128, 3, strides=1, padding='same', activation='relu')(x)
149
  x = layers.BatchNormalization()(x)
150
 
151
- x = layers.Conv2D(256, 3, strides=1, padding='same', activation='relu')(x)
152
  x = layers.BatchNormalization()(x)
153
  x = layers.Dropout(0.3)(x) # Deeper regularization after more feature maps
154
 
155
- x = layers.Conv2D(512, 3, strides=1, padding='same', activation='relu')(x) # out : (None, 64, 64, 512)
156
- x = layers.BatchNormalization()(x) # out: (None, 64, 64, 512)
 
 
157
 
158
- x = layers.Dropout(0.3)(x) # Final dropout before capsules, out : (None, 64, 64, 512)
 
 
159
 
160
  # --- Capsule Layers for classification---
161
- primary_caps = PrimaryCaps(**self.prim_caps_params)(x) #dim_capsule=8, # Each capsule is an 8D vector (i.e. each capsule outputs a vector of length 8)
162
- #n_channels=32, # There are 32 capsule "types" per spatial location (like 32 different filters)
163
- #kernel_size=9,
164
- #strides=2, # Moves the 3×3 kernel with stride x → if x > 1 it reduces spatial size by x (downsampling)
165
- # # stride=1 This means the kernel moves 1 pixel at a time, covering every possible position in the input.
166
- #padding='same') # same: No paddingoutput size shrinks (no border pixels used)
167
-
168
- digit_caps = DigitCaps( **self.digit_caps_params)(primary_caps) #num_capsule=n_class, # 1 capsule per class (e.g. 4 diseases = 4 capsules)
169
- #dim_capsule=16, # Each output capsule is a 16D vector → captures pose info
170
- #routing_iters=routing_iters # Use 3 iterations of dynamic routing (or EM routing) to refine capsule agreement
171
- #) # out: (None, 4, 1, 16)
 
 
 
 
172
 
173
  outputs = Length()(digit_caps)
174
-
175
- return outputs
176
-
177
 
 
178
 
179
 
180
  # Squash function: This function shrinks small vectors to zero and large vectors to unit vectors.
181
  def squash(vectors, axis=-1):
182
  s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
183
  # tf.keras.backend.epsilon() on google coalb with A100 GPU = 1e-07
184
- scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
 
 
 
 
185
  return scale * vectors
186
 
187
 
188
-
189
  # PrimaryCaps Layer/ Lower-level capsules (e.g. detecting edges or textures)
190
- @register_keras_serializable() #make it serializable to .keras format
191
  class PrimaryCaps(layers.Layer):
192
 
193
- def __init__(self, dim_capsule, n_channels, kernel_size, strides, padding, **kwargs):
 
 
194
  super(PrimaryCaps, self).__init__(**kwargs)
195
- self.conv = layers.Conv2D(filters=dim_capsule * n_channels,
196
- kernel_size=kernel_size,
197
- strides=strides,
198
- padding=padding,
199
- activation='relu')
 
 
200
  self.dim_capsule = dim_capsule
201
  self.n_channels = n_channels
202
- self.kernel_size = kernel_size #
203
- self.strides = strides #
204
  self.padding = padding
205
 
206
  def build(self, input_shape):
@@ -208,27 +239,33 @@ class PrimaryCaps(layers.Layer):
208
  self.conv.build(input_shape)
209
  super().build(input_shape) # Let Keras know the layer is built
210
 
211
-
212
  def call(self, inputs):
213
  outputs = self.conv(inputs)
214
- outputs = tf.reshape(outputs, (-1, outputs.shape[1] * outputs.shape[2] * self.n_channels, self.dim_capsule))
 
 
 
 
 
 
 
215
  return squash(outputs)
216
 
217
-
218
  def get_config(self):
219
  # hook in to keras Layer to modify layer's config on reload
220
  config = super().get_config()
221
- config.update({
222
- "dim_capsule": self.dim_capsule,
223
- "n_channels": self.n_channels,
224
- "kernel_size": self.kernel_size,
225
- "strides": self.strides,
226
- "padding": self.padding
227
- })
 
 
228
  return config
229
 
230
 
231
-
232
  @register_keras_serializable()
233
  class DigitCaps(layers.Layer):
234
  # DigitCaps Layer / Higher-level capsules (e.g. detecting objects like digits or lungs)
@@ -242,10 +279,16 @@ class DigitCaps(layers.Layer):
242
  def build(self, input_shape):
243
  self.input_num_capsule = input_shape[1]
244
  self.input_dim_capsule = input_shape[2]
245
- self.W = self.add_weight(shape=[self.input_num_capsule, self.num_capsule,
246
- self.input_dim_capsule, self.dim_capsule],
247
- initializer='glorot_uniform',
248
- trainable=True)
 
 
 
 
 
 
249
 
250
  def call(self, inputs):
251
  inputs_expand = tf.expand_dims(inputs, 2)
@@ -253,31 +296,39 @@ class DigitCaps(layers.Layer):
253
  inputs_tiled = tf.tile(inputs_tiled, [1, 1, self.num_capsule, 1, 1])
254
  inputs_hat = tf.matmul(inputs_tiled, self.W)
255
 
256
- b = tf.zeros(shape=[tf.shape(inputs)[0], self.input_num_capsule, self.num_capsule, 1, 1])
 
 
257
 
258
  # Dynamic Routing by Agreement algo
259
  for i in range(self.routing_iters):
260
- c = tf.nn.softmax(b, axis=2) # coupling coefficient, beacause of softmax(...) all c's connected to a single higher capsule sum to 1.
261
- s = tf.reduce_sum(c * inputs_hat, axis=1, keepdims=True) # weighted sum along axis=1
262
- v = squash(s, axis=-2) # shrinks small vectors to zero and large vectors to unit vectors
 
 
 
 
 
 
263
  if i < self.routing_iters - 1:
264
  b += tf.reduce_sum(inputs_hat * v, axis=-1, keepdims=True)
265
 
266
  return tf.squeeze(v, axis=1)
267
 
268
-
269
  def get_config(self):
270
  # hook in to keras Layer to modify layer's config on reload
271
  config = super().get_config()
272
- config.update({
273
- "num_capsule": self.num_capsule,
274
- "dim_capsule": self.dim_capsule,
275
- "routing_iters": self.routing_iters
276
- })
 
 
277
  return config
278
 
279
 
280
-
281
  # Length Layer
282
  @register_keras_serializable()
283
  class Length(layers.Layer):
@@ -285,7 +336,6 @@ class Length(layers.Layer):
285
  return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
286
 
287
 
288
-
289
  # Margin Loss for Capsule Networks
290
  def margin_loss(y_true, y_pred):
291
  # y_true is a one-hot vector
@@ -293,21 +343,15 @@ def margin_loss(y_true, y_pred):
293
  m_plus = 0.9
294
  m_minus = 0.1
295
  lambda_val = 0.5
296
- L = y_true * tf.square(tf.maximum(0., m_plus - y_pred)) + \
297
- lambda_val * (1 - y_true) * tf.square(tf.maximum(0., y_pred - m_minus))
 
298
  return tf.reduce_mean(tf.reduce_sum(L, axis=1))
299
 
300
 
301
  capsnet_custom_objects = {
302
- 'PrimaryCaps': PrimaryCaps,
303
- 'DigitCaps': DigitCaps,
304
- 'Length': Length,
305
- 'margin_loss': margin_loss
306
  }
307
-
308
-
309
-
310
-
311
-
312
-
313
-
 
6
  import tensorflow as tf
7
  from tensorflow import keras
8
  from tensorflow.keras import layers, Sequential
9
+ from tensorflow.keras.layers import Dense, Input, Rescaling
 
 
 
 
10
  from tensorflow.keras.applications import MobileNet, ResNet50
11
 
12
  # --- CapsNet-specific
13
  from keras.saving import register_keras_serializable # For custom layer serialization
14
 
15
  # --- Project-specific
16
+ from defs import ModelType as mt
17
 
18
 
19
+ class ModelBuilder:
20
  # builds the models
21
 
22
  def __init__(self, model_type, **model_params):
 
28
 
29
  # config extractor and attributes adding by model type
30
  if self.model_type in (mt.MOBILENET, mt.RESNET50):
31
+ self.base_model_params = self.model_params.pop("base_model")
32
+ self.model_name = self.base_model_params["name"]
33
+ self.input_shape = self.base_model_params["input_shape"]
34
+ self.base_trainable = self.model_params.pop("base_trainable")
35
+ self.base_model = None
36
+
37
  elif self.model_type == mt.CAPSNET:
38
+ self.model_name = model_params.pop("name")
39
+ self.input_shape = model_params.pop("input_shape")
40
+ self.prim_caps_params = model_params.pop("prim_caps")
41
+ self.digit_caps_params = model_params.pop("digit_caps")
42
+ self.routing_algo = model_params.pop("routing_algo") # informative only
43
+
44
  # model_type vs input shape validation
45
+ if self.model_type in (
46
+ mt.MOBILENET,
47
+ mt.RESNET50,
48
+ ):
49
+ if self.input_shape != (224, 224, 3):
50
+ raise Exception(
51
+ f"input shape for {self.model_name} model must be (224,224,3)"
52
+ )
53
  elif self.model_type == mt.CAPSNET:
54
+ if self.input_shape != (256, 256, 3):
55
+ raise Exception(
56
+ f"input shape for {self.model_name} model must be (256,256,3)"
57
+ )
58
  else:
59
+ raise Exception(
60
+ f"Model not supported: {self.model_name}. The model name must contain one substring from {mt.MOBILENET, mt.RESNET50, mt.CAPSNET}"
61
+ )
62
+
63
  def get_augmentation_pipe(self):
64
  # Random-* layers are stochastic only when training=True
65
  # disabled during inference/evaluation
66
+ return Sequential(
67
+ [
68
+ layers.RandomRotation(0.1),
69
+ layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
70
+ layers.RandomZoom(0.1),
71
+ ],
72
+ name="augmentation",
73
+ )
74
 
75
  def get_compiled_model(self):
76
  # Extract config
 
79
  # Define input layer
80
  inputs = Input(shape=self.input_shape, name="inputs")
81
  # Random-* layers are stochastic only when training=True
82
+
83
+ x_aug = self.get_augmentation_pipe()(
84
+ inputs
85
+ ) # stochastic only when training=True
86
+ x = Rescaling(1.0 / 255)(x_aug) # disabled during inference/evaluation
87
 
88
  # Model selector
89
  match self.model_type:
 
92
  self.base_model.trainable = self.base_trainable
93
 
94
  case mt.MOBILENET:
95
+ self.base_model = MobileNet(
96
+ input_tensor=x_aug, **self.base_model_params
97
+ )
98
  self.base_model.trainable = self.base_trainable
99
 
100
  case mt.CAPSNET:
101
  self.base_model = None
102
+ x = Rescaling(1.0 / 255)(x)
103
+ outputs = self.build_capsnet(inputs=x_aug, **self.model_params)
104
 
105
  case _:
106
+ raise Exception(
107
+ f"Model type {self.model_type} not supported: {self.model_name}"
108
+ )
109
 
110
  # Classification head
111
  if self.model_type in (mt.RESNET50, mt.MOBILENET):
112
  x = self.base_model.output
113
+ outputs = Dense(4, activation="softmax")(x)
114
  elif self.model_type == mt.CAPSNET:
115
  pass
116
  else:
 
121
  self.model.compile(**compile_params)
122
 
123
  print(f"The {self.model_name} model has been compiled successfully")
 
 
124
 
125
+ return self.base_model, self.model
126
 
 
 
127
  def build_capsnet(self, inputs, **params):
128
  """
129
  Build a Capsule Network model for four class lung iseases classification: COVID, Normal, Pneumonia and Opacity.
 
138
  Returns:
139
  model: to be compiled
140
  """
141
+
142
+ first_Conv2DKernel_size = params.pop("first_Conv2DKernel_size")
143
+
144
  # --- Preprocessing Layers ---
145
  x = inputs
146
 
147
  # --- Feature Extraction ---
148
  # learns 64 different 3x3 filters
149
+ x = layers.Conv2D(
150
+ filters=64,
151
+ kernel_size=first_Conv2DKernel_size,
152
+ strides=2,
153
+ padding="valid",
154
+ activation="relu",
155
+ )(
156
+ x
157
+ ) # downsampling strides=2, no padding because only exposed lung area matters/contains features
158
  x = layers.BatchNormalization()(x)
159
 
160
+ x = layers.Conv2D(128, 5, strides=2, padding="same", activation="relu")(
161
+ x
162
+ ) # padding="same" because of transformed output of the 1rst conv2D-layer (None, 125, 125, 64) to not lose the spatial info
163
  x = layers.BatchNormalization()(x)
164
  x = layers.Dropout(0.25)(x) # Dropout after second block (early regularization)
165
 
166
+ x = layers.Conv2D(128, 3, strides=1, padding="same", activation="relu")(x)
167
  x = layers.BatchNormalization()(x)
168
 
169
+ x = layers.Conv2D(256, 3, strides=1, padding="same", activation="relu")(x)
170
  x = layers.BatchNormalization()(x)
171
  x = layers.Dropout(0.3)(x) # Deeper regularization after more feature maps
172
 
173
+ x = layers.Conv2D(512, 3, strides=1, padding="same", activation="relu")(
174
+ x
175
+ ) # out : (None, 64, 64, 512)
176
+ x = layers.BatchNormalization()(x) # out: (None, 64, 64, 512)
177
 
178
+ x = layers.Dropout(0.3)(
179
+ x
180
+ ) # Final dropout before capsules, out : (None, 64, 64, 512)
181
 
182
  # --- Capsule Layers for classification---
183
+ primary_caps = PrimaryCaps(**self.prim_caps_params)(
184
+ x
185
+ ) # dim_capsule=8, # Each capsule is an 8D vector (i.e. each capsule outputs a vector of length 8)
186
+ # n_channels=32, # There are 32 capsule "types" per spatial location (like 32 different filters)
187
+ # kernel_size=9,
188
+ # strides=2, # Moves the 3×3 kernel with stride x if x > 1 it reduces spatial size by x (downsampling)
189
+ # # stride=1 This means the kernel moves 1 pixel at a time, covering every possible position in the input.
190
+ # padding='same') # same: No padding output size shrinks (no border pixels used)
191
+
192
+ digit_caps = DigitCaps(**self.digit_caps_params)(
193
+ primary_caps
194
+ ) # num_capsule=n_class, # 1 capsule per class (e.g. 4 diseases = 4 capsules)
195
+ # dim_capsule=16, # Each output capsule is a 16D vector → captures pose info
196
+ # routing_iters=routing_iters # Use 3 iterations of dynamic routing (or EM routing) to refine capsule agreement
197
+ # ) # out: (None, 4, 1, 16)
198
 
199
  outputs = Length()(digit_caps)
 
 
 
200
 
201
+ return outputs
202
 
203
 
204
  # Squash function: This function shrinks small vectors to zero and large vectors to unit vectors.
205
  def squash(vectors, axis=-1):
206
  s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
207
  # tf.keras.backend.epsilon() on google coalb with A100 GPU = 1e-07
208
+ scale = (
209
+ s_squared_norm
210
+ / (1 + s_squared_norm)
211
+ / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
212
+ )
213
  return scale * vectors
214
 
215
 
 
216
  # PrimaryCaps Layer/ Lower-level capsules (e.g. detecting edges or textures)
217
+ @register_keras_serializable() # make it serializable to .keras format
218
  class PrimaryCaps(layers.Layer):
219
 
220
+ def __init__(
221
+ self, dim_capsule, n_channels, kernel_size, strides, padding, **kwargs
222
+ ):
223
  super(PrimaryCaps, self).__init__(**kwargs)
224
+ self.conv = layers.Conv2D(
225
+ filters=dim_capsule * n_channels,
226
+ kernel_size=kernel_size,
227
+ strides=strides,
228
+ padding=padding,
229
+ activation="relu",
230
+ )
231
  self.dim_capsule = dim_capsule
232
  self.n_channels = n_channels
233
+ self.kernel_size = kernel_size #
234
+ self.strides = strides #
235
  self.padding = padding
236
 
237
  def build(self, input_shape):
 
239
  self.conv.build(input_shape)
240
  super().build(input_shape) # Let Keras know the layer is built
241
 
 
242
  def call(self, inputs):
243
  outputs = self.conv(inputs)
244
+ outputs = tf.reshape(
245
+ outputs,
246
+ (
247
+ -1,
248
+ outputs.shape[1] * outputs.shape[2] * self.n_channels,
249
+ self.dim_capsule,
250
+ ),
251
+ )
252
  return squash(outputs)
253
 
 
254
  def get_config(self):
255
  # hook in to keras Layer to modify layer's config on reload
256
  config = super().get_config()
257
+ config.update(
258
+ {
259
+ "dim_capsule": self.dim_capsule,
260
+ "n_channels": self.n_channels,
261
+ "kernel_size": self.kernel_size,
262
+ "strides": self.strides,
263
+ "padding": self.padding,
264
+ }
265
+ )
266
  return config
267
 
268
 
 
269
  @register_keras_serializable()
270
  class DigitCaps(layers.Layer):
271
  # DigitCaps Layer / Higher-level capsules (e.g. detecting objects like digits or lungs)
 
279
  def build(self, input_shape):
280
  self.input_num_capsule = input_shape[1]
281
  self.input_dim_capsule = input_shape[2]
282
+ self.W = self.add_weight(
283
+ shape=[
284
+ self.input_num_capsule,
285
+ self.num_capsule,
286
+ self.input_dim_capsule,
287
+ self.dim_capsule,
288
+ ],
289
+ initializer="glorot_uniform",
290
+ trainable=True,
291
+ )
292
 
293
  def call(self, inputs):
294
  inputs_expand = tf.expand_dims(inputs, 2)
 
296
  inputs_tiled = tf.tile(inputs_tiled, [1, 1, self.num_capsule, 1, 1])
297
  inputs_hat = tf.matmul(inputs_tiled, self.W)
298
 
299
+ b = tf.zeros(
300
+ shape=[tf.shape(inputs)[0], self.input_num_capsule, self.num_capsule, 1, 1]
301
+ )
302
 
303
  # Dynamic Routing by Agreement algo
304
  for i in range(self.routing_iters):
305
+ c = tf.nn.softmax(
306
+ b, axis=2
307
+ ) # coupling coefficient, beacause of softmax(...) all c's connected to a single higher capsule sum to 1.
308
+ s = tf.reduce_sum(
309
+ c * inputs_hat, axis=1, keepdims=True
310
+ ) # weighted sum along axis=1
311
+ v = squash(
312
+ s, axis=-2
313
+ ) # shrinks small vectors to zero and large vectors to unit vectors
314
  if i < self.routing_iters - 1:
315
  b += tf.reduce_sum(inputs_hat * v, axis=-1, keepdims=True)
316
 
317
  return tf.squeeze(v, axis=1)
318
 
 
319
  def get_config(self):
320
  # hook in to keras Layer to modify layer's config on reload
321
  config = super().get_config()
322
+ config.update(
323
+ {
324
+ "num_capsule": self.num_capsule,
325
+ "dim_capsule": self.dim_capsule,
326
+ "routing_iters": self.routing_iters,
327
+ }
328
+ )
329
  return config
330
 
331
 
 
332
  # Length Layer
333
  @register_keras_serializable()
334
  class Length(layers.Layer):
 
336
  return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
337
 
338
 
 
339
  # Margin Loss for Capsule Networks
340
  def margin_loss(y_true, y_pred):
341
  # y_true is a one-hot vector
 
343
  m_plus = 0.9
344
  m_minus = 0.1
345
  lambda_val = 0.5
346
+ L = y_true * tf.square(tf.maximum(0.0, m_plus - y_pred)) + lambda_val * (
347
+ 1 - y_true
348
+ ) * tf.square(tf.maximum(0.0, y_pred - m_minus))
349
  return tf.reduce_mean(tf.reduce_sum(L, axis=1))
350
 
351
 
352
  capsnet_custom_objects = {
353
+ "PrimaryCaps": PrimaryCaps,
354
+ "DigitCaps": DigitCaps,
355
+ "Length": Length,
356
+ "margin_loss": margin_loss,
357
  }