CGAllenger commited on
Commit
ec4b97e
·
verified ·
1 Parent(s): 0f309cb

fixed the layers

Browse files
Files changed (1) hide show
  1. app.py +43 -33
app.py CHANGED
@@ -3,24 +3,26 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # --- 1. BUILD X-RAY MODEL (DenseNet121 in 3 Sequential Layers) ---
7
  def build_xray_model():
 
8
  base_model = tf.keras.applications.DenseNet121(
9
  input_shape=(320, 320, 3),
10
  include_top=False,
11
- weights=None
 
12
  )
13
 
14
- # Wrap in Sequential to match the "3 saved layers" format exactly
15
  model = tf.keras.Sequential([
16
- base_model,
17
- tf.keras.layers.GlobalAveragePooling2D(),
18
- tf.keras.layers.Dense(14, activation='sigmoid')
19
  ])
20
 
21
  try:
22
  model.load_weights("xray.h5")
23
- print("X-Ray weights loaded successfully into Sequential DenseNet121!")
24
  return model
25
  except Exception as e:
26
  print(f"Error loading X-Ray weights: {e}")
@@ -28,6 +30,7 @@ def build_xray_model():
28
 
29
  # --- 2. LOAD MODELS ---
30
  try:
 
31
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
32
  print("MRI model loaded successfully!")
33
  except Exception as e:
@@ -48,30 +51,36 @@ xray_labels = [
48
  def predict(img, model_type):
49
  if img is None: return {"No image": 0.0}
50
 
51
- if model_type == "MRI":
52
- if mri_model is None: return {"MRI Model Error": 0.0}
53
-
54
- # MRI PREPROCESSING: Grayscale (1 channel) and 256x256
55
- img = img.convert("L").resize((256, 256))
56
- img_array = np.array(img).astype('float32')
57
- img_array = np.expand_dims(img_array, axis=(0, -1)) # Shape becomes (1, 256, 256, 1)
58
- model, labels = mri_model, mri_labels
59
-
60
- else:
61
- if xray_model is None: return {"X-Ray Model Error": 0.0}
62
-
63
- # X-RAY PREPROCESSING: RGB (3 channels) and 320x320
64
- img = img.convert("RGB").resize((320, 320))
65
- img_array = np.array(img).astype('float32')
66
- img_array = np.expand_dims(img_array, axis=0) # Shape is (1, 320, 320, 3)
67
- model, labels = xray_model, xray_labels
 
 
 
68
 
69
- # Normalize pixel values
70
- img_array /= 255.0
 
 
 
 
71
 
72
- # Predict
73
- preds = model.predict(img_array)[0]
74
- return {labels[i]: float(preds[i]) for i in range(len(labels))}
75
 
76
  # --- 5. GRADIO UI ---
77
  with gr.Blocks() as demo:
@@ -79,15 +88,16 @@ with gr.Blocks() as demo:
79
 
80
  with gr.Tabs():
81
  with gr.TabItem("Brain MRI (256x256)"):
82
- mri_in = gr.Image(type="pil")
83
- mri_out = gr.Label(num_top_classes=1)
84
  mri_btn = gr.Button("Analyze MRI")
85
  mri_btn.click(fn=lambda i: predict(i, "MRI"), inputs=mri_in, outputs=mri_out, api_name="predict_mri")
86
 
87
  with gr.TabItem("Chest X-Ray (320x320)"):
88
- xray_in = gr.Image(type="pil")
89
- xray_out = gr.Label(num_top_classes=1)
90
  xray_btn = gr.Button("Analyze X-Ray")
91
  xray_btn.click(fn=lambda i: predict(i, "X-Ray"), inputs=xray_in, outputs=xray_out, api_name="predict_xray")
92
 
 
93
  demo.launch(theme=gr.themes.Soft())
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # --- 1. BUILD X-RAY MODEL (Matching the 3-layer .h5 file) ---
7
  def build_xray_model():
8
+ # include_top=False + pooling=None = 1st layer
9
  base_model = tf.keras.applications.DenseNet121(
10
  input_shape=(320, 320, 3),
11
  include_top=False,
12
+ weights=None,
13
+ pooling=None
14
  )
15
 
16
+ # We build a Sequential model with exactly 3 layers to match your weights file
17
  model = tf.keras.Sequential([
18
+ base_model, # Layer 1
19
+ tf.keras.layers.GlobalAveragePooling2D(), # Layer 2
20
+ tf.keras.layers.Dense(14, activation='sigmoid') # Layer 3
21
  ])
22
 
23
  try:
24
  model.load_weights("xray.h5")
25
+ print("X-Ray weights loaded successfully (3/3 layers matched)!")
26
  return model
27
  except Exception as e:
28
  print(f"Error loading X-Ray weights: {e}")
 
30
 
31
  # --- 2. LOAD MODELS ---
32
  try:
33
+ # compile=False helps avoid versioning issues with the optimizer state
34
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
35
  print("MRI model loaded successfully!")
36
  except Exception as e:
 
51
  def predict(img, model_type):
52
  if img is None: return {"No image": 0.0}
53
 
54
+ try:
55
+ if model_type == "MRI":
56
+ if mri_model is None: return {"MRI Model Error": 0.0}
57
+
58
+ # STrictly convert to Grayscale (L) and resize
59
+ img = img.convert("L").resize((256, 256))
60
+ img_array = np.array(img).astype('float32')
61
+
62
+ # Reshape to (1, 256, 256, 1) to match "expected 1 channel"
63
+ img_array = img_array.reshape((1, 256, 256, 1))
64
+ model, labels = mri_model, mri_labels
65
+
66
+ else:
67
+ if xray_model is None: return {"X-Ray Model Error": 0.0}
68
+
69
+ # X-Ray expects RGB (3 channels) and 320x320
70
+ img = img.convert("RGB").resize((320, 320))
71
+ img_array = np.array(img).astype('float32')
72
+ img_array = np.expand_dims(img_array, axis=0)
73
+ model, labels = xray_model, xray_labels
74
 
75
+ # Normalize 0-255 to 0-1
76
+ img_array /= 255.0
77
+
78
+ # Predict
79
+ preds = model.predict(img_array)[0]
80
+ return {labels[i]: float(preds[i]) for i in range(len(labels))}
81
 
82
+ except Exception as e:
83
+ return {f"Error during prediction: {str(e)}": 0.0}
 
84
 
85
  # --- 5. GRADIO UI ---
86
  with gr.Blocks() as demo:
 
88
 
89
  with gr.Tabs():
90
  with gr.TabItem("Brain MRI (256x256)"):
91
+ mri_in = gr.Image(type="pil", label="Upload MRI")
92
+ mri_out = gr.Label(num_top_classes=1, label="Detection Result")
93
  mri_btn = gr.Button("Analyze MRI")
94
  mri_btn.click(fn=lambda i: predict(i, "MRI"), inputs=mri_in, outputs=mri_out, api_name="predict_mri")
95
 
96
  with gr.TabItem("Chest X-Ray (320x320)"):
97
+ xray_in = gr.Image(type="pil", label="Upload X-Ray")
98
+ xray_out = gr.Label(num_top_classes=1, label="Detection Result")
99
  xray_btn = gr.Button("Analyze X-Ray")
100
  xray_btn.click(fn=lambda i: predict(i, "X-Ray"), inputs=xray_in, outputs=xray_out, api_name="predict_xray")
101
 
102
+ # Launch with theme
103
  demo.launch(theme=gr.themes.Soft())