hp733 commited on
Commit
2f1c285
·
verified ·
1 Parent(s): 1c5d913

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -47
app.py CHANGED
@@ -137,68 +137,81 @@
137
 
138
 
139
 
 
140
  import numpy as np
141
- from tensorflow.keras.models import load_model
142
  from PIL import Image
143
- import gradio as gr
144
- from models import create_vgg19_model
 
145
  from gradcam_utils import generate_heatmap_tf_explain
 
146
 
147
- # Load models
148
  ensemble_model = load_model("ensemble_model_best(92.3).h5")
149
- vgg_model = create_vgg19_model() # Used only for Grad-CAM
150
 
151
  def get_class_name(class_id):
152
  return "Normal" if class_id == 0 else "Pneumonia"
153
 
154
  def predict_and_heatmap(image):
155
- # Resize and normalize image for prediction
156
  img = image.resize((224, 224))
157
  img_array = np.array(img) / 255.0
158
  img_array = np.expand_dims(img_array, axis=0)
159
 
160
- # Predict with ensemble model
161
  prediction = ensemble_model.predict(img_array)
162
  class_id = int(np.argmax(prediction[0]))
163
- result = get_class_name(class_id)
164
-
165
- # Generate heatmap with tf-explain using VGG19
166
- heatmap_img = generate_heatmap_tf_explain(img, vgg_model, class_index=class_id)
167
-
168
- return result, heatmap_img
169
-
170
- # 🎨 Custom CSS styling
171
- custom_css = """
172
- body {
173
- background-color: #1c1c1e;
174
- font-family: 'Segoe UI', sans-serif;
175
- }
176
- h1, h2, .output_class {
177
- color: #ffffff;
178
- text-align: center;
179
- }
180
- .gr-button {
181
- background-color: #007aff !important;
182
- color: white !important;
183
- }
184
- .gr-image-preview {
185
- box-shadow: 0 0 20px rgba(0,0,0,0.5);
186
- border-radius: 8px;
187
- }
188
- """
189
-
190
- # Launch Gradio Interface
191
- interface = gr.Interface(
192
- fn=predict_and_heatmap,
193
- inputs=gr.Image(type="pil", label="Upload Chest X-ray"),
194
- outputs=[
195
- gr.Label(label="Prediction"),
196
- gr.Image(label="Grad-CAM Heatmap")
197
- ],
198
- title="🩺 Pneumonia Detection Using Deep Learning",
199
- description="Upload a chest X-ray to detect Pneumonia and see the heatmap visualization (powered by tf-explain and VGG19).",
200
- css=custom_css
201
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
202
 
203
  if __name__ == "__main__":
204
- interface.launch()
 
137
 
138
 
139
 
140
+ import gradio as gr
141
  import numpy as np
 
142
  from PIL import Image
143
+ import cv2
144
+ import tensorflow as tf
145
+ from tensorflow.keras.models import load_model
146
  from gradcam_utils import generate_heatmap_tf_explain
147
+ from models import create_vgg19_model
148
 
149
+ # Load your trained model
150
  ensemble_model = load_model("ensemble_model_best(92.3).h5")
151
+ vgg_model = create_vgg19_model()
152
 
153
  def get_class_name(class_id):
154
  return "Normal" if class_id == 0 else "Pneumonia"
155
 
156
  def predict_and_heatmap(image):
 
157
  img = image.resize((224, 224))
158
  img_array = np.array(img) / 255.0
159
  img_array = np.expand_dims(img_array, axis=0)
160
 
 
161
  prediction = ensemble_model.predict(img_array)
162
  class_id = int(np.argmax(prediction[0]))
163
+ label = get_class_name(class_id)
164
+
165
+ # Styled HTML result
166
+ result_html = f"""
167
+ <div style='
168
+ text-align: center;
169
+ font-size: 1.5rem;
170
+ font-weight: bold;
171
+ color: {"green" if class_id == 0 else "red"};
172
+ background-color: #f0f8ff;
173
+ border: 2px solid {"green" if class_id == 0 else "red"};
174
+ padding: 10px;
175
+ border-radius: 10px;
176
+ width: fit-content;
177
+ margin: 0 auto;
178
+ '>
179
+ Result: {label}
180
+ </div>
181
+ """
182
+
183
+ # Generate heatmap from VGG19 using tf-explain
184
+ heatmap_img = generate_heatmap_tf_explain(image, vgg_model, class_index=class_id)
185
+ return result_html, heatmap_img
186
+
187
+ # Styled Gradio Interface
188
+ with gr.Blocks(theme="soft") as demo:
189
+ gr.Markdown("""
190
+ <div style="text-align: center; font-size: 2.5rem; font-weight: bold; color: #0b5394; margin-bottom: 1rem;">
191
+ 🩺 Pneumonia Detection from Chest X-rays
192
+ </div>
193
+ <div style="text-align: center; font-size: 1.1rem; margin-bottom: 2rem;">
194
+ Upload a chest X-ray image to predict if the lungs are Normal or show signs of Pneumonia.
195
+ </div>
196
+ """)
197
+
198
+ with gr.Row():
199
+ with gr.Column(scale=1, min_width=600):
200
+ image_input = gr.Image(type="pil", label="Upload Chest X-Ray", interactive=True, width=600, height=600)
201
+ prediction_output = gr.HTML(label="Prediction")
202
+ heatmap_output = gr.Image(label="Grad-CAM Heatmap")
203
+ submit_button = gr.Button("Predict")
204
+ clear_button = gr.Button("Clear")
205
+
206
+ submit_button.click(fn=predict_and_heatmap, inputs=image_input, outputs=[prediction_output, heatmap_output])
207
+ clear_button.click(fn=lambda: (None, "", None), inputs=[], outputs=[image_input, prediction_output, heatmap_output])
208
+
209
+ gr.Markdown("""
210
+ <div style="text-align: center; font-size: 0.95rem; color: #888; margin-top: 30px;">
211
+ Made with ❤️ by <a href="https://github.com/hruthik733" target="_blank" style="color: #0b5394; text-decoration: none; font-weight: bold;">
212
+ Hruthik Pavarala</a>
213
+ </div>
214
+ """)
215
 
216
  if __name__ == "__main__":
217
+ demo.launch()