Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.preprocessing import image | |
| import cv2 | |
| from tensorflow.keras.models import Model | |
| import os | |
| from reportlab.lib.pagesizes import letter | |
| from reportlab.pdfgen import canvas | |
| from reportlab.lib import utils | |
| # Load the trained models | |
| model_step1 = tf.keras.models.load_model("modelDense.h5") # Model for 3-class classification (bacterial-viral, covid, normal) | |
| model_step2 = tf.keras.models.load_model("modelVGG16.h5") # Model for 2-class classification (bacterial, viral) | |
| model_bin = tf.keras.models.load_model("modelCovid.h5") | |
| # Function to preprocess and predict | |
| def predict(img): | |
| img_array = analyze_image(img) # Pass the PIL image directly | |
| img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| img_array_expanded /= 255.0 # Normalize | |
| # First step: Classify the image into one of the 3 categories | |
| step1_prediction = model_step1.predict(img_array_expanded) | |
| class_idx = np.argmax(step1_prediction) # Get the class index (0: bacterial-viral, 1: covid, 2: normal) | |
| # Define the labels for the 3-class classification | |
| step1_labels = ["Bacterial-Viral", "COVID", "Normal"] | |
| step1_label = step1_labels[class_idx] | |
| if step1_label == "Bacterial-Viral": | |
| # Second step: Classify the bacterial-viral class into either bacterial or viral | |
| step2_prediction = model_step2.predict(img_array_expanded) | |
| # Sigmoid output (between 0 and 1), 0 for bacterial, 1 for viral | |
| step2_label = "Viral" if step2_prediction[0][0] > 0.5 else "Bacterial" | |
| confidence = step2_prediction[0][0] if step2_label == "Viral" else 1 - step2_prediction[0][0] | |
| return f"{step2_label} ({confidence * 100:.2f}% confidence)" | |
| else: | |
| return f"{step1_label} ({step1_prediction[0][class_idx] * 100:.2f}% confidence)" | |
| # Function to preprocess the image | |
| def analyze_image(img): | |
| img = img.resize((224, 224)) # Resize to match model input size | |
| img_array = image.img_to_array(img) # Convert to array | |
| return img_array | |
| # Function for Grad-CAM visualization with center focus | |
| def generate_gradcam_heatmap_center_focus(img): | |
| img_array = analyze_image(img) | |
| img_array_expanded = np.expand_dims(img_array, axis=0) # Add batch dimension | |
| img_array_expanded /= 255.0 # Normalize image | |
| step1_prediction = model_step1.predict(img_array_expanded) | |
| class_idx2 = np.argmax(step1_prediction) | |
| if class_idx2 == 1 or class_idx2 == 2: | |
| last_conv_layer_name = 'block3_conv1' | |
| last_conv_layer = model_step2.get_layer(last_conv_layer_name) | |
| heatmap_model = Model([model_step2.inputs], [last_conv_layer.output, model_step2.output]) | |
| elif class_idx2 == 0: | |
| last_conv_layer_name = 'block3_conv1' | |
| last_conv_layer = model_step2.get_layer(last_conv_layer_name) | |
| heatmap_model = Model([model_step2.inputs], [last_conv_layer.output, model_step2.output]) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = heatmap_model(img_array_expanded) | |
| loss = predictions[:, 0] | |
| grads = tape.gradient(loss, conv_outputs) | |
| pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2)) | |
| conv_outputs = conv_outputs[0] | |
| conv_outputs_numpy = conv_outputs.numpy() | |
| for i in range(pooled_grads.shape[0]): | |
| conv_outputs_numpy[:, :, i] *= pooled_grads[i] | |
| heatmap = np.mean(conv_outputs_numpy, axis=-1) | |
| heatmap = np.maximum(heatmap, 0) | |
| heatmap /= (np.max(heatmap) + 1e-10) | |
| heatmap_resized = cv2.resize(heatmap, (img_array.shape[1], img_array.shape[0])) | |
| h, w = heatmap_resized.shape | |
| y, x = np.ogrid[:h, :w] | |
| center_y, center_x = h / 2, w / 2 | |
| distance_from_center = np.sqrt((y - center_y)**2 + (x - center_x)**2) | |
| max_distance = np.max(distance_from_center) | |
| p = 0.3 | |
| center_weight = 1 - (distance_from_center / max_distance) ** p | |
| heatmap_resized *= center_weight | |
| heatmap_resized /= np.max(heatmap_resized) | |
| heatmap_resized = np.uint8(255 * heatmap_resized) | |
| #heatmap_resized = 255 - heatmap_resized | |
| heatmap_colormap = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET) | |
| original_img = np.uint8(255 * img_array / np.max(img_array)) | |
| superimposed_img = cv2.addWeighted(original_img, 0.6, heatmap_colormap, 0.4, 0) | |
| output_path = "gradcam_output.png" | |
| cv2.imwrite(output_path, superimposed_img) | |
| return output_path # Return the path to the image file | |
| # Function to generate a PDF report | |
| def generate_pdf_report(prediction, gradcam_img, patient_name, patient_id, notes): | |
| pdf_path = f"patient_pneumonia_report_{patient_id}.pdf" | |
| gradcam_image_path = "gradcam_output.png" | |
| # Save the Grad-CAM image | |
| cv2.imwrite(gradcam_image_path, gradcam_img) # Saving the image file | |
| # Create the PDF | |
| c = canvas.Canvas(pdf_path, pagesize=letter) | |
| c.setFont("Helvetica", 12) | |
| # Add Title | |
| c.setFont("Helvetica-Bold", 16) | |
| c.drawString(100, 770, "Pneumonia Diagnosis with DL Tool") | |
| # Add patient details | |
| c.setFont("Helvetica", 12) | |
| c.drawString(100, 735, f"Patient Name: {patient_name}") | |
| c.drawString(100, 720, f"Patient ID: {patient_id}") | |
| c.drawString(100, 705, f"Notes: {notes}") | |
| # Add prediction | |
| c.drawString(100, 680, f"Prediction: {prediction}") | |
| # Add Grad-CAM image | |
| c.setFont("Helvetica-Bold", 12) | |
| c.drawString(100, 640, "Grad-CAM Visualization:") | |
| c.drawImage(gradcam_image_path, 100, 350, width=224, height=224) # Adjust image size slightly smaller | |
| c.save() | |
| return pdf_path | |
| # Create Gradio Blocks for the prediction function and Grad-CAM functionality | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Pneumonia Detection Model (2-step classification with Grad-CAM)") | |
| gr.Markdown("Upload a Chest X-ray to detect Pneumonia and classify bacterial/viral vs normal/covid.") | |
| with gr.Row(): | |
| img_input = gr.Image(type="pil", label="Upload X-ray Image") | |
| name_input = gr.Textbox(label="Patient Name") | |
| id_input = gr.Textbox(label="Patient ID") | |
| notes_input = gr.Textbox(label="Additional Notes", lines=5) | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict") | |
| gradcam_btn = gr.Button("Generate Grad-CAM Heatmap") | |
| report_btn = gr.Button("Generate Report") | |
| output_text = gr.Textbox(label="Prediction Result") | |
| gradcam_output = gr.Image(label="Grad-CAM Heatmap") | |
| report_output = gr.File(label="Download Report") | |
| predict_btn.click(predict, inputs=img_input, outputs=output_text) | |
| gradcam_btn.click(generate_gradcam_heatmap_center_focus, inputs=img_input, outputs=gradcam_output) | |
| report_btn.click(generate_pdf_report, inputs=[output_text, gradcam_output, name_input, id_input, notes_input], outputs=report_output) | |
| # Launch the app | |
| demo.launch(share=True) # Makes it accessible online | |