valste commited on
Commit
e9116b6
·
1 Parent(s): 99db4b5

addeded the missing custom_objects

Browse files
Files changed (3) hide show
  1. app.py +3 -2
  2. modelbuilder.py +313 -0
  3. requirements.txt +3 -1
app.py CHANGED
@@ -5,6 +5,7 @@ from datasets import load_dataset
5
  from PIL import Image
6
  import numpy as np
7
  from tensorflow.keras.preprocessing.image import img_to_array
 
8
 
9
  # ------------------------------------------------------------
10
  # 1️⃣ Load the models from Hugging Face Hub
@@ -20,7 +21,7 @@ capsnet_model_path = hf_hub_download(
20
  repo_id="valste/capsnet-4class-lung-disease-classifier",
21
  filename="model.keras"
22
  )
23
- capsnet_model = tf.keras.models.load_model(capsnet_model_path, compile=False)
24
  # ------------------------------------------------------------
25
  # 2️⃣ Load sample X-ray images from your dataset
26
  # ------------------------------------------------------------
@@ -38,7 +39,7 @@ for example in dataset:
38
  # ------------------------------------------------------------
39
  # 3️⃣ Define preprocessing and inference function
40
  # ------------------------------------------------------------
41
- class_labels = ["COVID-19", "Lung Opacity", "Normal", "Viral Pneumonia"]
42
 
43
 
44
  def preprocess_image(img: Image.Image):
 
5
  from PIL import Image
6
  import numpy as np
7
  from tensorflow.keras.preprocessing.image import img_to_array
8
+ from modelbuilder import capsnet_custom_objects
9
 
10
  # ------------------------------------------------------------
11
  # 1️⃣ Load the models from Hugging Face Hub
 
21
  repo_id="valste/capsnet-4class-lung-disease-classifier",
22
  filename="model.keras"
23
  )
24
+ capsnet_model = tf.keras.models.load_model(capsnet_model_path, custom_objects=capsnet_custom_objects, compile=False)
25
  # ------------------------------------------------------------
26
  # 2️⃣ Load sample X-ray images from your dataset
27
  # ------------------------------------------------------------
 
39
  # ------------------------------------------------------------
40
  # 3️⃣ Define preprocessing and inference function
41
  # ------------------------------------------------------------
42
+ class_labels = ['COVID', 'Lung_Opacity', 'Normal', 'Viral Pneumonia']
43
 
44
 
45
  def preprocess_image(img: Image.Image):
modelbuilder.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Class to construct the different type of models
3
+ """
4
+
5
+ # --- Core TensorFlow/Keras
6
+ import tensorflow as tf
7
+ from tensorflow import keras
8
+ from tensorflow.keras import layers, Sequential
9
+ from tensorflow.keras.layers import (
10
+ Dense,
11
+ Input,
12
+ Rescaling
13
+ )
14
+ from tensorflow.keras.applications import MobileNet, ResNet50
15
+
16
+ # --- CapsNet-specific
17
+ from keras.saving import register_keras_serializable # For custom layer serialization
18
+
19
+ # --- Project-specific
20
+ from src.defs import ModelType as mt
21
+
22
+
23
+ class ModelBuilder():
24
+ # builds the models
25
+
26
+ def __init__(self, model_type, **model_params):
27
+
28
+ self.model_type = model_type
29
+ self.model_params = model_params
30
+ self.model = None
31
+ self.model_name = None
32
+
33
+ # config extractor and attributes adding by model type
34
+ if self.model_type in (mt.MOBILENET, mt.RESNET50):
35
+ self.base_model_params = self.model_params.pop("base_model")
36
+ self.model_name = self.base_model_params["name"]
37
+ self.input_shape = self.base_model_params["input_shape"]
38
+ self.base_trainable = self.model_params.pop("base_trainable")
39
+ self.base_model = None
40
+
41
+ elif self.model_type == mt.CAPSNET:
42
+ self.model_name = model_params.pop("name")
43
+ self.input_shape = model_params.pop("input_shape")
44
+ self.prim_caps_params = model_params.pop("prim_caps")
45
+ self.digit_caps_params = model_params.pop("digit_caps")
46
+ self.routing_algo = model_params.pop("routing_algo") # informative only
47
+
48
+ # model_type vs input shape validation
49
+ if self.model_type in (mt.MOBILENET, mt.RESNET50,):
50
+ if self.input_shape != (224,224,3):
51
+ raise Exception(f"input shape for {self.model_name} model must be (224,224,3)")
52
+ elif self.model_type == mt.CAPSNET:
53
+ if self.input_shape != (256,256,3):
54
+ raise Exception(f"input shape for {self.model_name} model must be (256,256,3)")
55
+ else:
56
+ raise Exception(f"Model not supported: {self.model_name}. The model name must contain one substring from {mt.MOBILENET, mt.RESNET50, mt.CAPSNET}")
57
+
58
+
59
+
60
+ def get_augmentation_pipe(self):
61
+ # Random-* layers are stochastic only when training=True
62
+ # disabled during inference/evaluation
63
+ return Sequential([
64
+ layers.RandomRotation(0.1),
65
+ layers.RandomTranslation(height_factor=0.1, width_factor=0.1),
66
+ layers.RandomZoom(0.1),
67
+ ], name="augmentation")
68
+
69
+
70
+ def get_compiled_model(self):
71
+ # Extract config
72
+ compile_params = self.model_params.pop("compile_params")
73
+
74
+ # Define input layer
75
+ inputs = Input(shape=self.input_shape, name="inputs")
76
+ # Random-* layers are stochastic only when training=True
77
+
78
+ x_aug = self.get_augmentation_pipe()(inputs) # stochastic only when training=True
79
+ x = Rescaling(1./255)(x_aug) # disabled during inference/evaluation
80
+
81
+ # Model selector
82
+ match self.model_type:
83
+ case mt.RESNET50:
84
+ self.base_model = ResNet50(input_tensor=x_aug, **self.base_model_params)
85
+ self.base_model.trainable = self.base_trainable
86
+
87
+ case mt.MOBILENET:
88
+ self.base_model = MobileNet(input_tensor=x_aug, **self.base_model_params)
89
+ self.base_model.trainable = self.base_trainable
90
+
91
+ case mt.CAPSNET:
92
+ self.base_model = None
93
+ x = Rescaling(1./255)(x)
94
+ outputs = self.build_capsnet(inputs = x_aug, **self.model_params)
95
+
96
+ case _:
97
+ raise Exception(f"Model type {self.model_type} not supported: {self.model_name}")
98
+
99
+ # Classification head
100
+ if self.model_type in (mt.RESNET50, mt.MOBILENET):
101
+ x = self.base_model.output
102
+ outputs = Dense(4, activation='softmax')(x)
103
+ elif self.model_type == mt.CAPSNET:
104
+ pass
105
+ else:
106
+ raise Exception(f"No classifier head defined for {self.model_type}")
107
+
108
+ # Final model
109
+ self.model = keras.Model(name=self.model_name, inputs=inputs, outputs=outputs)
110
+ self.model.compile(**compile_params)
111
+
112
+ print(f"The {self.model_name} model has been compiled successfully")
113
+
114
+ return self.base_model, self.model
115
+
116
+
117
+
118
+
119
+ def build_capsnet(self, inputs, **params):
120
+ """
121
+ Build a Capsule Network model for four class lung iseases classification: COVID, Normal, Pneumonia and Opacity.
122
+ Args:
123
+ name (_type_): _description_
124
+ first_Conv2DKernel_size (int, optional): _description_. Defaults to 10.
125
+ input_shape (tuple, optional): _description_. Defaults to (256, 256, 3).
126
+ n_class (int, optional): _description_. Defaults to 4.
127
+ routing_iters (int, optional): _description_. Defaults to 3.
128
+ routing_algo (str, optional): _description_. Defaults to "by_agreement".
129
+
130
+ Returns:
131
+ model: to be compiled
132
+ """
133
+
134
+ first_Conv2DKernel_size = params.pop("first_Conv2DKernel_size")
135
+
136
+ # --- Preprocessing Layers ---
137
+ x = inputs
138
+
139
+ # --- Feature Extraction ---
140
+ # learns 64 different 3x3 filters
141
+ x = layers.Conv2D(filters = 64, kernel_size=first_Conv2DKernel_size, strides=2, padding='valid', activation='relu')(x) # downsampling strides=2, no padding because only exposed lung area matters/contains features
142
+ x = layers.BatchNormalization()(x)
143
+
144
+ x = layers.Conv2D(128, 5, strides=2, padding='same', activation='relu')(x) # padding="same" because of transformed output of the 1rst conv2D-layer (None, 125, 125, 64) to not lose the spatial info
145
+ x = layers.BatchNormalization()(x)
146
+ x = layers.Dropout(0.25)(x) # Dropout after second block (early regularization)
147
+
148
+ x = layers.Conv2D(128, 3, strides=1, padding='same', activation='relu')(x)
149
+ x = layers.BatchNormalization()(x)
150
+
151
+ x = layers.Conv2D(256, 3, strides=1, padding='same', activation='relu')(x)
152
+ x = layers.BatchNormalization()(x)
153
+ x = layers.Dropout(0.3)(x) # Deeper regularization after more feature maps
154
+
155
+ x = layers.Conv2D(512, 3, strides=1, padding='same', activation='relu')(x) # out : (None, 64, 64, 512)
156
+ x = layers.BatchNormalization()(x) # out: (None, 64, 64, 512)
157
+
158
+ x = layers.Dropout(0.3)(x) # Final dropout before capsules, out : (None, 64, 64, 512)
159
+
160
+ # --- Capsule Layers for classification---
161
+ primary_caps = PrimaryCaps(**self.prim_caps_params)(x) #dim_capsule=8, # Each capsule is an 8D vector (i.e. each capsule outputs a vector of length 8)
162
+ #n_channels=32, # There are 32 capsule "types" per spatial location (like 32 different filters)
163
+ #kernel_size=9,
164
+ #strides=2, # Moves the 3×3 kernel with stride x → if x > 1 it reduces spatial size by x (downsampling)
165
+ # # stride=1 This means the kernel moves 1 pixel at a time, covering every possible position in the input.
166
+ #padding='same') # same: No padding → output size shrinks (no border pixels used)
167
+
168
+ digit_caps = DigitCaps( **self.digit_caps_params)(primary_caps) #num_capsule=n_class, # 1 capsule per class (e.g. 4 diseases = 4 capsules)
169
+ #dim_capsule=16, # Each output capsule is a 16D vector → captures pose info
170
+ #routing_iters=routing_iters # Use 3 iterations of dynamic routing (or EM routing) to refine capsule agreement
171
+ #) # out: (None, 4, 1, 16)
172
+
173
+ outputs = Length()(digit_caps)
174
+
175
+ return outputs
176
+
177
+
178
+
179
+
180
+ # Squash function: This function shrinks small vectors to zero and large vectors to unit vectors.
181
+ def squash(vectors, axis=-1):
182
+ s_squared_norm = tf.reduce_sum(tf.square(vectors), axis, keepdims=True)
183
+ # tf.keras.backend.epsilon() on google coalb with A100 GPU = 1e-07
184
+ scale = s_squared_norm / (1 + s_squared_norm) / tf.sqrt(s_squared_norm + tf.keras.backend.epsilon())
185
+ return scale * vectors
186
+
187
+
188
+
189
+ # PrimaryCaps Layer/ Lower-level capsules (e.g. detecting edges or textures)
190
+ @register_keras_serializable() #make it serializable to .keras format
191
+ class PrimaryCaps(layers.Layer):
192
+
193
+ def __init__(self, dim_capsule, n_channels, kernel_size, strides, padding, **kwargs):
194
+ super(PrimaryCaps, self).__init__(**kwargs)
195
+ self.conv = layers.Conv2D(filters=dim_capsule * n_channels,
196
+ kernel_size=kernel_size,
197
+ strides=strides,
198
+ padding=padding,
199
+ activation='relu')
200
+ self.dim_capsule = dim_capsule
201
+ self.n_channels = n_channels
202
+ self.kernel_size = kernel_size #
203
+ self.strides = strides #
204
+ self.padding = padding
205
+
206
+ def build(self, input_shape):
207
+ # Important: build the internal Conv2D layer using input shape
208
+ self.conv.build(input_shape)
209
+ super().build(input_shape) # Let Keras know the layer is built
210
+
211
+
212
+ def call(self, inputs):
213
+ outputs = self.conv(inputs)
214
+ outputs = tf.reshape(outputs, (-1, outputs.shape[1] * outputs.shape[2] * self.n_channels, self.dim_capsule))
215
+ return squash(outputs)
216
+
217
+
218
+ def get_config(self):
219
+ # hook in to keras Layer to modify layer's config on reload
220
+ config = super().get_config()
221
+ config.update({
222
+ "dim_capsule": self.dim_capsule,
223
+ "n_channels": self.n_channels,
224
+ "kernel_size": self.kernel_size,
225
+ "strides": self.strides,
226
+ "padding": self.padding
227
+ })
228
+ return config
229
+
230
+
231
+
232
+ @register_keras_serializable()
233
+ class DigitCaps(layers.Layer):
234
+ # DigitCaps Layer / Higher-level capsules (e.g. detecting objects like digits or lungs)
235
+
236
+ def __init__(self, num_capsule, dim_capsule, routing_iters=3, **kwargs):
237
+ super(DigitCaps, self).__init__(**kwargs)
238
+ self.num_capsule = num_capsule
239
+ self.dim_capsule = dim_capsule
240
+ self.routing_iters = routing_iters
241
+
242
+ def build(self, input_shape):
243
+ self.input_num_capsule = input_shape[1]
244
+ self.input_dim_capsule = input_shape[2]
245
+ self.W = self.add_weight(shape=[self.input_num_capsule, self.num_capsule,
246
+ self.input_dim_capsule, self.dim_capsule],
247
+ initializer='glorot_uniform',
248
+ trainable=True)
249
+
250
+ def call(self, inputs):
251
+ inputs_expand = tf.expand_dims(inputs, 2)
252
+ inputs_tiled = tf.expand_dims(inputs_expand, 3)
253
+ inputs_tiled = tf.tile(inputs_tiled, [1, 1, self.num_capsule, 1, 1])
254
+ inputs_hat = tf.matmul(inputs_tiled, self.W)
255
+
256
+ b = tf.zeros(shape=[tf.shape(inputs)[0], self.input_num_capsule, self.num_capsule, 1, 1])
257
+
258
+ # Dynamic Routing by Agreement algo
259
+ for i in range(self.routing_iters):
260
+ c = tf.nn.softmax(b, axis=2) # coupling coefficient, beacause of softmax(...) all c's connected to a single higher capsule sum to 1.
261
+ s = tf.reduce_sum(c * inputs_hat, axis=1, keepdims=True) # weighted sum along axis=1
262
+ v = squash(s, axis=-2) # shrinks small vectors to zero and large vectors to unit vectors
263
+ if i < self.routing_iters - 1:
264
+ b += tf.reduce_sum(inputs_hat * v, axis=-1, keepdims=True)
265
+
266
+ return tf.squeeze(v, axis=1)
267
+
268
+
269
+ def get_config(self):
270
+ # hook in to keras Layer to modify layer's config on reload
271
+ config = super().get_config()
272
+ config.update({
273
+ "num_capsule": self.num_capsule,
274
+ "dim_capsule": self.dim_capsule,
275
+ "routing_iters": self.routing_iters
276
+ })
277
+ return config
278
+
279
+
280
+
281
+ # Length Layer
282
+ @register_keras_serializable()
283
+ class Length(layers.Layer):
284
+ def call(self, inputs, **kwargs):
285
+ return tf.sqrt(tf.reduce_sum(tf.square(inputs), -1))
286
+
287
+
288
+
289
+ # Margin Loss for Capsule Networks
290
+ def margin_loss(y_true, y_pred):
291
+ # y_true is a one-hot vector
292
+ # y_pred is the Length() output: vector of shape [batch_size, num_classes] (each value ≈ class presence probability)
293
+ m_plus = 0.9
294
+ m_minus = 0.1
295
+ lambda_val = 0.5
296
+ L = y_true * tf.square(tf.maximum(0., m_plus - y_pred)) + \
297
+ lambda_val * (1 - y_true) * tf.square(tf.maximum(0., y_pred - m_minus))
298
+ return tf.reduce_mean(tf.reduce_sum(L, axis=1))
299
+
300
+
301
+ capsnet_custom_objects = {
302
+ 'PrimaryCaps': PrimaryCaps,
303
+ 'DigitCaps': DigitCaps,
304
+ 'Length': Length,
305
+ 'margin_loss': margin_loss
306
+ }
307
+
308
+
309
+
310
+
311
+
312
+
313
+
requirements.txt CHANGED
@@ -1,5 +1,7 @@
1
  gradio==5.49.1
2
- tensorflow
 
 
3
  huggingface_hub
4
  datasets
5
  Pillow
 
1
  gradio==5.49.1
2
+ # Pin TensorFlow/Keras to avoid Keras 3 deserialization issues
3
+ tensorflow==2.13.1
4
+ keras<3
5
  huggingface_hub
6
  datasets
7
  Pillow