hp733 commited on
Commit
46c9dec
·
verified ·
1 Parent(s): 695fb8b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -10
app.py CHANGED
@@ -139,20 +139,23 @@
139
 
140
  import gradio as gr
141
  import numpy as np
142
- from PIL import Image
143
- import cv2
144
  import tensorflow as tf
145
  from tensorflow.keras.models import load_model
146
- from gradcam_utils import generate_heatmap_tf_explain
 
 
147
  from models import create_vgg19_model
 
148
 
149
  # Load your trained model
150
  ensemble_model = load_model("ensemble_model_best(92.3).h5")
151
- vgg_model = create_vgg19_model()
152
 
 
153
  def get_class_name(class_id):
154
  return "Normal" if class_id == 0 else "Pneumonia"
155
 
 
156
  def predict_and_heatmap(image):
157
  img = image.resize((224, 224))
158
  img_array = np.array(img) / 255.0
@@ -162,7 +165,6 @@ def predict_and_heatmap(image):
162
  class_id = int(np.argmax(prediction[0]))
163
  label = get_class_name(class_id)
164
 
165
- # Styled HTML result
166
  result_html = f"""
167
  <div style='
168
  text-align: center;
@@ -180,11 +182,15 @@ def predict_and_heatmap(image):
180
  </div>
181
  """
182
 
183
- # Generate heatmap from VGG19 using tf-explain
184
  heatmap_img = generate_heatmap_tf_explain(image, vgg_model, class_index=class_id)
185
  return result_html, heatmap_img
186
 
187
- # Styled Gradio Interface
 
 
 
 
188
  with gr.Blocks(theme="soft") as demo:
189
  gr.Markdown("""
190
  <div style="text-align: center; font-size: 2.5rem; font-weight: bold; color: #0b5394; margin-bottom: 1rem;">
@@ -199,12 +205,16 @@ with gr.Blocks(theme="soft") as demo:
199
  with gr.Column(scale=1, min_width=600):
200
  image_input = gr.Image(type="pil", label="Upload Chest X-Ray", interactive=True, width=600, height=600)
201
  prediction_output = gr.HTML(label="Prediction")
202
- heatmap_output = gr.Image(label="Grad-CAM Heatmap")
203
- submit_button = gr.Button("Predict")
204
- clear_button = gr.Button("Clear")
 
 
 
205
 
206
  submit_button.click(fn=predict_and_heatmap, inputs=image_input, outputs=[prediction_output, heatmap_output])
207
  clear_button.click(fn=lambda: (None, "", None), inputs=[], outputs=[image_input, prediction_output, heatmap_output])
 
208
 
209
  gr.Markdown("""
210
  <div style="text-align: center; font-size: 0.95rem; color: #888; margin-top: 30px;">
 
139
 
140
  import gradio as gr
141
  import numpy as np
 
 
142
  import tensorflow as tf
143
  from tensorflow.keras.models import load_model
144
+ from PIL import Image
145
+ import os
146
+
147
  from models import create_vgg19_model
148
+ from gradcam_utils import generate_heatmap_tf_explain
149
 
150
  # Load your trained model
151
  ensemble_model = load_model("ensemble_model_best(92.3).h5")
152
+ vgg_model = create_vgg19_model() # Only used for Grad-CAM (tf-explain)
153
 
154
+ # Label names
155
  def get_class_name(class_id):
156
  return "Normal" if class_id == 0 else "Pneumonia"
157
 
158
+ # Prediction + Heatmap generation
159
  def predict_and_heatmap(image):
160
  img = image.resize((224, 224))
161
  img_array = np.array(img) / 255.0
 
165
  class_id = int(np.argmax(prediction[0]))
166
  label = get_class_name(class_id)
167
 
 
168
  result_html = f"""
169
  <div style='
170
  text-align: center;
 
182
  </div>
183
  """
184
 
185
+ # Generate Grad-CAM heatmap using tf-explain (on VGG19)
186
  heatmap_img = generate_heatmap_tf_explain(image, vgg_model, class_index=class_id)
187
  return result_html, heatmap_img
188
 
189
+ # Function to load sample image
190
+ def load_sample():
191
+ return Image.open("sample_pneumonia.jpeg")
192
+
193
+ # Gradio interface
194
  with gr.Blocks(theme="soft") as demo:
195
  gr.Markdown("""
196
  <div style="text-align: center; font-size: 2.5rem; font-weight: bold; color: #0b5394; margin-bottom: 1rem;">
 
205
  with gr.Column(scale=1, min_width=600):
206
  image_input = gr.Image(type="pil", label="Upload Chest X-Ray", interactive=True, width=600, height=600)
207
  prediction_output = gr.HTML(label="Prediction")
208
+ heatmap_output = gr.Image(label="Grad-CAM Heatmap", width=600, height=600)
209
+
210
+ with gr.Row():
211
+ submit_button = gr.Button("Predict")
212
+ clear_button = gr.Button("Clear")
213
+ sample_button = gr.Button("Load Sample X-ray")
214
 
215
  submit_button.click(fn=predict_and_heatmap, inputs=image_input, outputs=[prediction_output, heatmap_output])
216
  clear_button.click(fn=lambda: (None, "", None), inputs=[], outputs=[image_input, prediction_output, heatmap_output])
217
+ sample_button.click(fn=load_sample, inputs=[], outputs=[image_input])
218
 
219
  gr.Markdown("""
220
  <div style="text-align: center; font-size: 0.95rem; color: #888; margin-top: 30px;">