pranesh-sk's picture
Update app.py
5b505ab verified
import gradio as gr
import numpy as np
import tensorflow as tf
from tensorflow.keras.preprocessing import image
import cv2
from tensorflow.keras.models import Model
import os
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib import utils
# Load the trained models
model_step1 = tf.keras.models.load_model("modelDense.h5") # Model for 3-class classification (bacterial-viral, covid, normal)
model_step2 = tf.keras.models.load_model("modelVGG16.h5") # Model for 2-class classification (bacterial, viral)
model_bin = tf.keras.models.load_model("modelCovid.h5")
# Function to preprocess and predict
def predict(img):
img_array = analyze_image(img) # Pass the PIL image directly
img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
img_array_expanded /= 255.0 # Normalize
# First step: Classify the image into one of the 3 categories
step1_prediction = model_step1.predict(img_array_expanded)
class_idx = np.argmax(step1_prediction) # Get the class index (0: bacterial-viral, 1: covid, 2: normal)
# Define the labels for the 3-class classification
step1_labels = ["Bacterial-Viral", "COVID", "Normal"]
step1_label = step1_labels[class_idx]
if step1_label == "Bacterial-Viral":
# Second step: Classify the bacterial-viral class into either bacterial or viral
step2_prediction = model_step2.predict(img_array_expanded)
# Sigmoid output (between 0 and 1), 0 for bacterial, 1 for viral
step2_label = "Viral" if step2_prediction[0][0] > 0.5 else "Bacterial"
confidence = step2_prediction[0][0] if step2_label == "Viral" else 1 - step2_prediction[0][0]
return f"{step2_label} ({confidence * 100:.2f}% confidence)"
else:
return f"{step1_label} ({step1_prediction[0][class_idx] * 100:.2f}% confidence)"
# Function to preprocess the image
def analyze_image(img):
img = img.resize((224, 224)) # Resize to match model input size
img_array = image.img_to_array(img) # Convert to array
return img_array
# Function for Grad-CAM visualization with center focus
def generate_gradcam_heatmap_center_focus(img):
img_array = analyze_image(img)
img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension
img_array_expanded /= 255.0 # Normalize image
step1_prediction = model_step1.predict(img_array_expanded)
class_idx2 = np.argmax(step1_prediction)
if class_idx2 == 1 or class_idx2 == 2:
last_conv_layer_name = 'block3_conv1'
last_conv_layer = model_step2.get_layer(last_conv_layer_name)
heatmap_model = Model([model_step2.inputs], [last_conv_layer.output, model_step2.output])
elif class_idx2 == 0:
last_conv_layer_name = 'block3_conv1'
last_conv_layer = model_step2.get_layer(last_conv_layer_name)
heatmap_model = Model([model_step2.inputs], [last_conv_layer.output, model_step2.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = heatmap_model(img_array_expanded)
loss = predictions[:, 0]
grads = tape.gradient(loss, conv_outputs)
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
conv_outputs = conv_outputs[0]
conv_outputs_numpy = conv_outputs.numpy()
for i in range(pooled_grads.shape[0]):
conv_outputs_numpy[:, :, i] *= pooled_grads[i]
heatmap = np.mean(conv_outputs_numpy, axis=-1)
heatmap = np.maximum(heatmap, 0)
heatmap /= (np.max(heatmap) + 1e-10)
heatmap_resized = cv2.resize(heatmap, (img_array.shape[1], img_array.shape[0]))
h, w = heatmap_resized.shape
y, x = np.ogrid[:h, :w]
center_y, center_x = h / 2, w / 2
distance_from_center = np.sqrt((y - center_y)**2 + (x - center_x)**2)
max_distance = np.max(distance_from_center)
p = 0.3
center_weight = 1 - (distance_from_center / max_distance) ** p
heatmap_resized *= center_weight
heatmap_resized /= np.max(heatmap_resized)
heatmap_resized = np.uint8(255 * heatmap_resized)
#heatmap_resized = 255 - heatmap_resized
heatmap_colormap = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
original_img = np.uint8(255 * img_array / np.max(img_array))
superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap_colormap, 0.4, 0)
output_path = "gradcam_output.png"
cv2.imwrite(output_path, superimposed_img)
return output_path # Return the path to the image file
# Function to generate a PDF report
def generate_pdf_report(prediction, gradcam_img, patient_name, patient_id, notes):
pdf_path = f"patient_pneumonia_report_{patient_id}.pdf"
gradcam_image_path = "gradcam_output.png"
# Save the Grad-CAM image
cv2.imwrite(gradcam_image_path, gradcam_img) # Saving the image file
# Create the PDF
c = canvas.Canvas(pdf_path, pagesize=letter)
c.setFont("Helvetica", 12)
# Add Title
c.setFont("Helvetica-Bold", 16)
c.drawString(100, 770, "Pneumonia Diagnosis with DL Tool")
# Add patient details
c.setFont("Helvetica", 12)
c.drawString(100, 735, f"Patient Name: {patient_name}")
c.drawString(100, 720, f"Patient ID: {patient_id}")
c.drawString(100, 705, f"Notes: {notes}")
# Add prediction
c.drawString(100, 680, f"Prediction: {prediction}")
# Add Grad-CAM image
c.setFont("Helvetica-Bold", 12)
c.drawString(100, 640, "Grad-CAM Visualization:")
c.drawImage(gradcam_image_path, 100, 350, width=224, height=224) # Adjust image size slightly smaller
c.save()
return pdf_path
# Create Gradio Blocks for the prediction function and Grad-CAM functionality
with gr.Blocks() as demo:
gr.Markdown("# Pneumonia Detection Model (2-step classification with Grad-CAM)")
gr.Markdown("Upload a Chest X-ray to detect Pneumonia and classify bacterial/viral vs normal/covid.")
with gr.Row():
img_input = gr.Image(type="pil", label="Upload X-ray Image")
name_input = gr.Textbox(label="Patient Name")
id_input = gr.Textbox(label="Patient ID")
notes_input = gr.Textbox(label="Additional Notes", lines=5)
with gr.Row():
predict_btn = gr.Button("Predict")
gradcam_btn = gr.Button("Generate Grad-CAM Heatmap")
report_btn = gr.Button("Generate Report")
output_text = gr.Textbox(label="Prediction Result")
gradcam_output = gr.Image(label="Grad-CAM Heatmap")
report_output = gr.File(label="Download Report")
predict_btn.click(predict, inputs=img_input, outputs=output_text)
gradcam_btn.click(generate_gradcam_heatmap_center_focus, inputs=img_input, outputs=gradcam_output)
report_btn.click(generate_pdf_report, inputs=[output_text, gradcam_output, name_input, id_input, notes_input], outputs=report_output)
# Launch the app
demo.launch(share=True) # Makes it accessible online