CGAllenger commited on
Commit
0f309cb
·
verified ·
1 Parent(s): 7079f3c

rearranging the order of channels

Browse files
Files changed (1) hide show
  1. app.py +15 -20
app.py CHANGED
@@ -3,23 +3,24 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # --- 1. BUILD X-RAY MODEL (From redwankarimsony Kaggle Notebook) ---
7
  def build_xray_model():
8
- # The Kaggle notebook uses DenseNet121, not EfficientNet
9
  base_model = tf.keras.applications.DenseNet121(
10
  input_shape=(320, 320, 3),
11
  include_top=False,
12
- weights=None,
13
- pooling='avg' # Kaggle notebook uses global average pooling natively here
14
  )
15
 
16
- # The Kaggle notebook adds a single Dense layer for the 14 classes
17
- output = tf.keras.layers.Dense(14, activation='sigmoid')(base_model.output)
18
- model = tf.keras.Model(inputs=base_model.input, outputs=output)
 
 
 
19
 
20
  try:
21
  model.load_weights("xray.h5")
22
- print("X-Ray weights loaded successfully into DenseNet121!")
23
  return model
24
  except Exception as e:
25
  print(f"Error loading X-Ray weights: {e}")
@@ -27,7 +28,6 @@ def build_xray_model():
27
 
28
  # --- 2. LOAD MODELS ---
29
  try:
30
- # .keras files contain the whole model, so we just load it directly
31
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
32
  print("MRI model loaded successfully!")
33
  except Exception as e:
@@ -51,27 +51,22 @@ def predict(img, model_type):
51
  if model_type == "MRI":
52
  if mri_model is None: return {"MRI Model Error": 0.0}
53
 
54
- # MRI PREPROCESSING (Based on the Kaggle notebook style)
55
- img = img.resize((256, 256))
56
  img_array = np.array(img).astype('float32')
57
-
58
- # Ensure RGB (3 channels) for MRI, as most standard CNNs expect it
59
- if len(img_array.shape) == 2:
60
- img_array = np.stack((img_array,)*3, axis=-1)
61
-
62
- img_array = np.expand_dims(img_array, axis=0)
63
  model, labels = mri_model, mri_labels
64
 
65
  else:
66
  if xray_model is None: return {"X-Ray Model Error": 0.0}
67
 
68
- # X-RAY PREPROCESSING (DenseNet121 requires 320x320 and RGB)
69
  img = img.convert("RGB").resize((320, 320))
70
  img_array = np.array(img).astype('float32')
71
- img_array = np.expand_dims(img_array, axis=0)
72
  model, labels = xray_model, xray_labels
73
 
74
- # Normalize (standard across both notebooks)
75
  img_array /= 255.0
76
 
77
  # Predict
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # --- 1. BUILD X-RAY MODEL (DenseNet121 in 3 Sequential Layers) ---
7
  def build_xray_model():
 
8
  base_model = tf.keras.applications.DenseNet121(
9
  input_shape=(320, 320, 3),
10
  include_top=False,
11
+ weights=None
 
12
  )
13
 
14
+ # Wrap in Sequential to match the "3 saved layers" format exactly
15
+ model = tf.keras.Sequential([
16
+ base_model,
17
+ tf.keras.layers.GlobalAveragePooling2D(),
18
+ tf.keras.layers.Dense(14, activation='sigmoid')
19
+ ])
20
 
21
  try:
22
  model.load_weights("xray.h5")
23
+ print("X-Ray weights loaded successfully into Sequential DenseNet121!")
24
  return model
25
  except Exception as e:
26
  print(f"Error loading X-Ray weights: {e}")
 
28
 
29
  # --- 2. LOAD MODELS ---
30
  try:
 
31
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
32
  print("MRI model loaded successfully!")
33
  except Exception as e:
 
51
  if model_type == "MRI":
52
  if mri_model is None: return {"MRI Model Error": 0.0}
53
 
54
+ # MRI PREPROCESSING: Grayscale (1 channel) and 256x256
55
+ img = img.convert("L").resize((256, 256))
56
  img_array = np.array(img).astype('float32')
57
+ img_array = np.expand_dims(img_array, axis=(0, -1)) # Shape becomes (1, 256, 256, 1)
 
 
 
 
 
58
  model, labels = mri_model, mri_labels
59
 
60
  else:
61
  if xray_model is None: return {"X-Ray Model Error": 0.0}
62
 
63
+ # X-RAY PREPROCESSING: RGB (3 channels) and 320x320
64
  img = img.convert("RGB").resize((320, 320))
65
  img_array = np.array(img).astype('float32')
66
+ img_array = np.expand_dims(img_array, axis=0) # Shape is (1, 320, 320, 3)
67
  model, labels = xray_model, xray_labels
68
 
69
+ # Normalize pixel values
70
  img_array /= 255.0
71
 
72
  # Predict