pranesh-sk commited on
Commit
5b505ab
·
verified ·
1 Parent(s): d4c3450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -124
app.py CHANGED
@@ -7,53 +7,15 @@ from tensorflow.keras.models import Model
7
  import os
8
  from reportlab.lib.pagesizes import letter
9
  from reportlab.pdfgen import canvas
10
- import tempfile
11
-
12
- # Print TensorFlow version for debugging
13
- print(f"TensorFlow version: {tf.__version__}")
14
-
15
- # Model loading with error handling
16
- def load_models():
17
- try:
18
- # Check if model files exist
19
- if not os.path.exists("modelDense.h5"):
20
- print("Warning: modelDense.h5 not found")
21
- return None, None, None
22
-
23
- if not os.path.exists("modelVGG16.h5"):
24
- print("Warning: modelVGG16.h5 not found")
25
- return None, None, None
26
-
27
- if not os.path.exists("modelCovid.h5"):
28
- print("Warning: modelCovid.h5 not found")
29
- return None, None, None
30
-
31
- print("Loading models...")
32
- model_step1 = tf.keras.models.load_model("modelDense.h5")
33
- model_step2 = tf.keras.models.load_model("modelVGG16.h5")
34
- model_bin = tf.keras.models.load_model("modelCovid.h5")
35
- print("Models loaded successfully!")
36
- return model_step1, model_step2, model_bin
37
- except Exception as e:
38
- print(f"Error loading models: {e}")
39
- return None, None, None
40
-
41
- # Load models when app starts
42
- try:
43
- print("Attempting to load models...")
44
- model_step1, model_step2, model_bin = load_models()
45
- except Exception as e:
46
- print(f"Exception during model loading: {e}")
47
- model_step1, model_step2, model_bin = None, None, None
48
 
49
  # Function to preprocess and predict
50
  def predict(img):
51
- if img is None:
52
- return "Please upload an image."
53
-
54
- if model_step1 is None or model_step2 is None:
55
- return "Models could not be loaded. Please check the model files."
56
-
57
  img_array = analyze_image(img) # Pass the PIL image directly
58
  img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
59
  img_array_expanded /= 255.0 # Normalize
@@ -86,12 +48,6 @@ def analyze_image(img):
86
 
87
  # Function for Grad-CAM visualization with center focus
88
  def generate_gradcam_heatmap_center_focus(img):
89
- if img is None:
90
- return None
91
-
92
- if model_step1 is None or model_step2 is None:
93
- return None
94
-
95
  img_array = analyze_image(img)
96
  img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
97
  img_array_expanded /= 255.0 # Normalize image
@@ -99,7 +55,6 @@ def generate_gradcam_heatmap_center_focus(img):
99
  step1_prediction = model_step1.predict(img_array_expanded)
100
  class_idx2 = np.argmax(step1_prediction)
101
 
102
- # Select appropriate model for Grad-CAM
103
  if class_idx2 == 1 or class_idx2 == 2:
104
  last_conv_layer_name = 'block3_conv1'
105
  last_conv_layer = model_step2.get_layer(last_conv_layer_name)
@@ -139,26 +94,24 @@ def generate_gradcam_heatmap_center_focus(img):
139
  heatmap_resized /= np.max(heatmap_resized)
140
 
141
  heatmap_resized = np.uint8(255 * heatmap_resized)
 
142
  heatmap_colormap = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
143
 
144
  original_img = np.uint8(255 * img_array / np.max(img_array))
145
  superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap_colormap, 0.4, 0)
146
-
147
- return superimposed_img # Return the image directly, not a file path
 
 
 
148
 
149
  # Function to generate a PDF report
150
  def generate_pdf_report(prediction, gradcam_img, patient_name, patient_id, notes):
151
- if gradcam_img is None:
152
- return None
153
 
154
- # Create a temporary file for the PDF
155
- with tempfile.NamedTemporaryFile(suffix='.pdf', delete=False) as tmp:
156
- pdf_path = tmp.name
157
-
158
- # Create a temporary file for the Grad-CAM image
159
- with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as img_tmp:
160
- gradcam_path = img_tmp.name
161
- cv2.imwrite(gradcam_path, gradcam_img)
162
 
163
  # Create the PDF
164
  c = canvas.Canvas(pdf_path, pagesize=letter)
@@ -181,69 +134,38 @@ def generate_pdf_report(prediction, gradcam_img, patient_name, patient_id, notes
181
  c.setFont("Helvetica-Bold", 12)
182
  c.drawString(100, 640, "Grad-CAM Visualization:")
183
 
184
- c.drawImage(gradcam_path, 100, 350, width=224, height=224)
 
 
185
  c.save()
186
 
187
- # Clean up the temporary image file
188
- os.unlink(gradcam_path)
189
-
190
  return pdf_path
191
 
 
 
 
192
 
193
- # Define what happens when models aren't loaded
194
- if model_step1 is None or model_step2 is None or model_bin is None:
195
- gr.Markdown("# ⚠️ Error: Models could not be loaded")
196
- gr.Markdown("""
197
- This application requires the following model files to be present:
198
- - modelDense.h5
199
- - modelVGG16.h5
200
- - modelCovid.h5
201
-
202
- Please make sure these files are uploaded to the Hugging Face Space.
203
- """)
204
- else:
205
- # Create Gradio Blocks for the app
206
- with gr.Blocks(theme=gr.themes.Default()) as demo:
207
-
208
- gr.Markdown("# Pneumonia Detection Model")
209
- gr.Markdown("Upload a Chest X-ray to detect Pneumonia and classify bacterial/viral vs normal/covid.")
210
-
211
- with gr.Row():
212
- with gr.Column(scale=1):
213
- img_input = gr.Image(type="pil", label="Upload X-ray Image")
214
- predict_btn = gr.Button("Predict", variant="primary")
215
- gradcam_btn = gr.Button("Generate Grad-CAM Heatmap")
216
-
217
- with gr.Column(scale=1):
218
- name_input = gr.Textbox(label="Patient Name")
219
- id_input = gr.Textbox(label="Patient ID")
220
- notes_input = gr.Textbox(label="Additional Notes", lines=3)
221
- report_btn = gr.Button("Generate Report")
222
-
223
- with gr.Row():
224
- with gr.Column(scale=1):
225
- output_text = gr.Textbox(label="Prediction Result")
226
- gradcam_output = gr.Image(label="Grad-CAM Heatmap")
227
-
228
- with gr.Column(scale=1):
229
- report_output = gr.File(label="Download Report")
230
-
231
- # Set up event handlers
232
- predict_btn.click(predict, inputs=img_input, outputs=output_text)
233
- gradcam_btn.click(generate_gradcam_heatmap_center_focus, inputs=img_input, outputs=gradcam_output)
234
- report_btn.click(
235
- generate_pdf_report,
236
- inputs=[output_text, gradcam_output, name_input, id_input, notes_input],
237
- outputs=report_output
238
- )
239
-
240
- # Add example images for demo purposes
241
- # gr.Examples(
242
- # examples=["example1.jpg", "example2.jpg"],
243
- # inputs=img_input
244
- # )
245
-
246
- # Launch the app - for Hugging Face, we don't need to specify share=True
247
- # Use queue=False to avoid the validation error
248
- if model_step1 is not None and model_step2 is not None and model_bin is not None:
249
- demo.launch(debug=True, show_error=True, queue=False)
 
7
  import os
8
  from reportlab.lib.pagesizes import letter
9
  from reportlab.pdfgen import canvas
10
+ from reportlab.lib import utils
11
+
12
+ # Load the trained models
13
+ model_step1 = tf.keras.models.load_model("modelDense.h5") # Model for 3-class classification (bacterial-viral, covid, normal)
14
+ model_step2 = tf.keras.models.load_model("modelVGG16.h5") # Model for 2-class classification (bacterial, viral)
15
+ model_bin = tf.keras.models.load_model("modelCovid.h5")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Function to preprocess and predict
18
  def predict(img):
 
 
 
 
 
 
19
  img_array = analyze_image(img) # Pass the PIL image directly
20
  img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
21
  img_array_expanded /= 255.0 # Normalize
 
48
 
49
  # Function for Grad-CAM visualization with center focus
50
  def generate_gradcam_heatmap_center_focus(img):
 
 
 
 
 
 
51
  img_array = analyze_image(img)
52
  img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
53
  img_array_expanded /= 255.0 # Normalize image
 
55
  step1_prediction = model_step1.predict(img_array_expanded)
56
  class_idx2 = np.argmax(step1_prediction)
57
 
 
58
  if class_idx2 == 1 or class_idx2 == 2:
59
  last_conv_layer_name = 'block3_conv1'
60
  last_conv_layer = model_step2.get_layer(last_conv_layer_name)
 
94
  heatmap_resized /= np.max(heatmap_resized)
95
 
96
  heatmap_resized = np.uint8(255 * heatmap_resized)
97
+ #heatmap_resized = 255 - heatmap_resized
98
  heatmap_colormap = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
99
 
100
  original_img = np.uint8(255 * img_array / np.max(img_array))
101
  superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap_colormap, 0.4, 0)
102
+
103
+ output_path = "gradcam_output.png"
104
+ cv2.imwrite(output_path, superimposed_img)
105
+
106
+ return output_path # Return the path to the image file
107
 
108
  # Function to generate a PDF report
109
  def generate_pdf_report(prediction, gradcam_img, patient_name, patient_id, notes):
110
+ pdf_path = f"patient_pneumonia_report_{patient_id}.pdf"
111
+ gradcam_image_path = "gradcam_output.png"
112
 
113
+ # Save the Grad-CAM image
114
+ cv2.imwrite(gradcam_image_path, gradcam_img) # Saving the image file
 
 
 
 
 
 
115
 
116
  # Create the PDF
117
  c = canvas.Canvas(pdf_path, pagesize=letter)
 
134
  c.setFont("Helvetica-Bold", 12)
135
  c.drawString(100, 640, "Grad-CAM Visualization:")
136
 
137
+ c.drawImage(gradcam_image_path, 100, 350, width=224, height=224) # Adjust image size slightly smaller
138
+
139
+
140
  c.save()
141
 
 
 
 
142
  return pdf_path
143
 
144
+
145
+ # Create Gradio Blocks for the prediction function and Grad-CAM functionality
146
+ with gr.Blocks() as demo:
147
 
148
+ gr.Markdown("# Pneumonia Detection Model (2-step classification with Grad-CAM)")
149
+ gr.Markdown("Upload a Chest X-ray to detect Pneumonia and classify bacterial/viral vs normal/covid.")
150
+
151
+ with gr.Row():
152
+ img_input = gr.Image(type="pil", label="Upload X-ray Image")
153
+ name_input = gr.Textbox(label="Patient Name")
154
+ id_input = gr.Textbox(label="Patient ID")
155
+ notes_input = gr.Textbox(label="Additional Notes", lines=5)
156
+
157
+ with gr.Row():
158
+ predict_btn = gr.Button("Predict")
159
+ gradcam_btn = gr.Button("Generate Grad-CAM Heatmap")
160
+ report_btn = gr.Button("Generate Report")
161
+
162
+ output_text = gr.Textbox(label="Prediction Result")
163
+ gradcam_output = gr.Image(label="Grad-CAM Heatmap")
164
+ report_output = gr.File(label="Download Report")
165
+
166
+ predict_btn.click(predict, inputs=img_input, outputs=output_text)
167
+ gradcam_btn.click(generate_gradcam_heatmap_center_focus, inputs=img_input, outputs=gradcam_output)
168
+ report_btn.click(generate_pdf_report, inputs=[output_text, gradcam_output, name_input, id_input, notes_input], outputs=report_output)
169
+
170
+ # Launch the app
171
+ demo.launch(share=True) # Makes it accessible online