CGAllenger commited on
Commit
5876089
·
verified ·
1 Parent(s): ec4b97e

Modified the wrapper

Browse files
Files changed (1) hide show
  1. app.py +33 -37
app.py CHANGED
@@ -3,34 +3,33 @@ import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
 
6
- # --- 1. BUILD X-RAY MODEL (Matching the 3-layer .h5 file) ---
 
7
  def build_xray_model():
8
- # include_top=False + pooling=None = 1st layer
9
- base_model = tf.keras.applications.DenseNet121(
10
- input_shape=(320, 320, 3),
11
  include_top=False,
12
- weights=None,
13
- pooling=None
14
  )
15
 
16
- # We build a Sequential model with exactly 3 layers to match your weights file
17
  model = tf.keras.Sequential([
18
- base_model, # Layer 1
19
- tf.keras.layers.GlobalAveragePooling2D(), # Layer 2
20
- tf.keras.layers.Dense(14, activation='sigmoid') # Layer 3
21
  ])
22
 
23
  try:
24
  model.load_weights("xray.h5")
25
- print("X-Ray weights loaded successfully (3/3 layers matched)!")
26
  return model
27
  except Exception as e:
28
  print(f"Error loading X-Ray weights: {e}")
29
  return None
30
 
31
- # --- 2. LOAD MODELS ---
 
32
  try:
33
- # compile=False helps avoid versioning issues with the optimizer state
34
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
35
  print("MRI model loaded successfully!")
36
  except Exception as e:
@@ -39,7 +38,7 @@ except Exception as e:
39
 
40
  xray_model = build_xray_model()
41
 
42
- # --- 3. CONFIGURATION ---
43
  mri_labels = ['Glioma', 'Meningioma', 'Pituitary tumor', 'no tumor']
44
  xray_labels = [
45
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
@@ -49,55 +48,52 @@ xray_labels = [
49
 
50
  # --- 4. PREDICTION LOGIC ---
51
  def predict(img, model_type):
52
- if img is None: return {"No image": 0.0}
53
 
54
  try:
55
  if model_type == "MRI":
56
- if mri_model is None: return {"MRI Model Error": 0.0}
57
 
58
- # STrictly convert to Grayscale (L) and resize
59
  img = img.convert("L").resize((256, 256))
60
  img_array = np.array(img).astype('float32')
61
-
62
- # Reshape to (1, 256, 256, 1) to match "expected 1 channel"
63
- img_array = img_array.reshape((1, 256, 256, 1))
64
  model, labels = mri_model, mri_labels
65
 
66
  else:
67
- if xray_model is None: return {"X-Ray Model Error": 0.0}
68
 
69
- # X-Ray expects RGB (3 channels) and 320x320
70
- img = img.convert("RGB").resize((320, 320))
71
  img_array = np.array(img).astype('float32')
72
- img_array = np.expand_dims(img_array, axis=0)
73
  model, labels = xray_model, xray_labels
74
 
75
- # Normalize 0-255 to 0-1
76
  img_array /= 255.0
77
 
78
- # Predict
79
  preds = model.predict(img_array)[0]
80
  return {labels[i]: float(preds[i]) for i in range(len(labels))}
81
 
82
  except Exception as e:
83
- return {f"Error during prediction: {str(e)}": 0.0}
84
 
85
- # --- 5. GRADIO UI ---
86
  with gr.Blocks() as demo:
87
  gr.Markdown("# 🏥 BTech Medical Diagnostic API")
 
88
 
89
  with gr.Tabs():
90
- with gr.TabItem("Brain MRI (256x256)"):
91
- mri_in = gr.Image(type="pil", label="Upload MRI")
92
- mri_out = gr.Label(num_top_classes=1, label="Detection Result")
93
- mri_btn = gr.Button("Analyze MRI")
94
  mri_btn.click(fn=lambda i: predict(i, "MRI"), inputs=mri_in, outputs=mri_out, api_name="predict_mri")
95
 
96
- with gr.TabItem("Chest X-Ray (320x320)"):
97
- xray_in = gr.Image(type="pil", label="Upload X-Ray")
98
- xray_out = gr.Label(num_top_classes=1, label="Detection Result")
99
- xray_btn = gr.Button("Analyze X-Ray")
100
  xray_btn.click(fn=lambda i: predict(i, "X-Ray"), inputs=xray_in, outputs=xray_out, api_name="predict_xray")
101
 
102
- # Launch with theme
103
  demo.launch(theme=gr.themes.Soft())
 
3
  import numpy as np
4
  from PIL import Image
5
 
6
+ # --- 1. X-RAY MODEL RECONSTRUCTION ---
7
+ # Rebuilding exactly based on EfficientNetB1 and 128x128
8
  def build_xray_model():
9
+ base_model = tf.keras.applications.EfficientNetB1(
10
+ input_shape=(128, 128, 3),
 
11
  include_top=False,
12
+ weights=None
 
13
  )
14
 
15
+ # Built as a Sequential model to perfectly match the 3 saved layers in your .h5 file
16
  model = tf.keras.Sequential([
17
+ base_model, # Layer 1: Backbone
18
+ tf.keras.layers.GlobalAveragePooling2D(), # Layer 2: Pooling
19
+ tf.keras.layers.Dense(14, activation='sigmoid') # Layer 3: Output head
20
  ])
21
 
22
  try:
23
  model.load_weights("xray.h5")
24
+ print("X-Ray weights (EfficientNetB1) loaded successfully!")
25
  return model
26
  except Exception as e:
27
  print(f"Error loading X-Ray weights: {e}")
28
  return None
29
 
30
+ # --- 2. LOAD MRI MODEL ---
31
+ # The zaahaa notebook outputs a standard .keras file, so we load it whole
32
  try:
 
33
  mri_model = tf.keras.models.load_model("mri.keras", compile=False)
34
  print("MRI model loaded successfully!")
35
  except Exception as e:
 
38
 
39
  xray_model = build_xray_model()
40
 
41
+ # --- 3. LABELS ---
42
  mri_labels = ['Glioma', 'Meningioma', 'Pituitary tumor', 'no tumor']
43
  xray_labels = [
44
  'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration',
 
48
 
49
  # --- 4. PREDICTION LOGIC ---
50
  def predict(img, model_type):
51
+ if img is None: return {"No image provided": 0.0}
52
 
53
  try:
54
  if model_type == "MRI":
55
+ if mri_model is None: return {"MRI Model Error - Check Logs": 0.0}
56
 
57
+ # The zaahaa notebook uses Grayscale images. We force Grayscale (L) and 256x256.
58
  img = img.convert("L").resize((256, 256))
59
  img_array = np.array(img).astype('float32')
60
+ img_array = img_array.reshape((1, 256, 256, 1)) # Explicitly format to 1 channel
 
 
61
  model, labels = mri_model, mri_labels
62
 
63
  else:
64
+ if xray_model is None: return {"X-Ray Model Error - Check Logs": 0.0}
65
 
66
+ # Your EfficientNetB1 code used RGB images at 128x128
67
+ img = img.convert("RGB").resize((128, 128))
68
  img_array = np.array(img).astype('float32')
69
+ img_array = np.expand_dims(img_array, axis=0) # Format to 3 channels
70
  model, labels = xray_model, xray_labels
71
 
72
+ # Standard normalization used in both Kaggle notebooks
73
  img_array /= 255.0
74
 
 
75
  preds = model.predict(img_array)[0]
76
  return {labels[i]: float(preds[i]) for i in range(len(labels))}
77
 
78
  except Exception as e:
79
+ return {f"Prediction Error: {str(e)}": 0.0}
80
 
81
+ # --- 5. UI APP ---
82
  with gr.Blocks() as demo:
83
  gr.Markdown("# 🏥 BTech Medical Diagnostic API")
84
+ gr.Markdown("Upload an image to get a diagnostic prediction.")
85
 
86
  with gr.Tabs():
87
+ with gr.TabItem("Brain MRI Classifier"):
88
+ mri_in = gr.Image(type="pil", label="Upload Brain MRI")
89
+ mri_out = gr.Label(num_top_classes=1, label="Result")
90
+ mri_btn = gr.Button("Analyze MRI", variant="primary")
91
  mri_btn.click(fn=lambda i: predict(i, "MRI"), inputs=mri_in, outputs=mri_out, api_name="predict_mri")
92
 
93
+ with gr.TabItem("Chest X-Ray Classifier"):
94
+ xray_in = gr.Image(type="pil", label="Upload Chest X-Ray")
95
+ xray_out = gr.Label(num_top_classes=1, label="Result")
96
+ xray_btn = gr.Button("Analyze X-Ray", variant="primary")
97
  xray_btn.click(fn=lambda i: predict(i, "X-Ray"), inputs=xray_in, outputs=xray_out, api_name="predict_xray")
98
 
 
99
  demo.launch(theme=gr.themes.Soft())