""" 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training Advanced Gradio Interface - 4 Disease Classes Features: - Real-time detection: Normal, TB, Pneumonia, COVID-19 - Grad-CAM visualization (explainable AI) - Improved specificity - distinguishes TB from pneumonia - Confidence scores with visual indicators - Clinical interpretation and recommendations - Mobile-responsive design """ import os from pathlib import Path import io import gradio as gr import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import numpy as np import cv2 import matplotlib.pyplot as plt # ============================================================================ # Device # ============================================================================ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"βœ… Using device: {device}") # ============================================================================ # Model Setup & Robust Loader # ============================================================================ NUM_CLASSES = 4 CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"] CLASS_COLORS = { "Normal": "#2ecc71", # Green "Tuberculosis": "#e74c3c", # Red "Pneumonia": "#f39c12", # Orange "COVID-19": "#9b59b6", # Purple } def build_base_model(num_classes: int = NUM_CLASSES) -> nn.Module: """ Build the base EfficientNet-B2 model with a 4-class classifier head. This matches the architecture used during training. """ # πŸ”’ Do NOT change this to efficientnet_b0 – your checkpoint is B2 (1408 features) model = models.efficientnet_b2(weights=None) in_features = model.classifier[1].in_features model.classifier[1] = nn.Linear(in_features, num_classes) return model def load_trained_model() -> nn.Module: """ Load EfficientNet-B2 4-class checkpoint from: - 'best.pt' OR - 'checkpoints/best.pt' Supports both: - Plain state_dict - Training checkpoint with 'model_state_dict' or 'state_dict' keys """ model = build_base_model().to(device) search_paths = [ Path("best.pt"), Path("checkpoints/best.pt"), ] ckpt_path = None for p in search_paths: if p.exists(): ckpt_path = p break if ckpt_path is None: raise RuntimeError( "❌ Could not find model checkpoint.\n" "Expected 'best.pt' in the project root OR 'checkpoints/best.pt'.\n" "Please upload your 4-class EfficientNet-B2 weights as 'best.pt' or 'checkpoints/best.pt'." ) print(f"πŸ” Loading weights from: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location=device) # Try to extract the actual state_dict if isinstance(ckpt, dict): if "model_state_dict" in ckpt: state_dict = ckpt["model_state_dict"] elif "state_dict" in ckpt: state_dict = ckpt["state_dict"] else: # Assume it's already a plain state_dict state_dict = ckpt else: # Definitely just a state_dict state_dict = ckpt # Now load strictly – if this fails, the checkpoint truly doesn't match the architecture try: missing, unexpected = model.load_state_dict(state_dict, strict=True) if missing or unexpected: # This branch rarely happens with strict=True, but keep for clarity print(f"⚠️ Missing keys in state_dict: {missing}") print(f"⚠️ Unexpected keys in state_dict: {unexpected}") except RuntimeError as e: # Most common cause: trying to load B2 checkpoint into B0/B1 or wrong architecture raise RuntimeError( f"❌ Failed to load weights from {ckpt_path}.\n" "Most likely cause: the checkpoint was trained with a different architecture.\n" "This app expects an EfficientNet-B2 checkpoint with 4 output classes.\n\n" f"PyTorch error:\n{e}" ) print("βœ… Model weights loaded successfully!") model.eval() return model model = load_trained_model() # ============================================================================ # Preprocessing # ============================================================================ transform = transforms.Compose( [ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize( [0.485, 0.456, 0.406], # ImageNet mean [0.229, 0.224, 0.225], # ImageNet std ), ] ) # ============================================================================ # Grad-CAM Implementation # ============================================================================ class GradCAM: def __init__(self, model: nn.Module, target_layer: nn.Module): self.model = model self.target_layer = target_layer self.gradients = None self.activations = None def save_activation(module, input, output): self.activations = output.detach() def save_gradient(module, grad_input, grad_output): # grad_output is a tuple; take the gradient wrt output activations self.gradients = grad_output[0].detach() target_layer.register_forward_hook(save_activation) target_layer.register_full_backward_hook(save_gradient) def generate(self, input_image: torch.Tensor, target_class=None): """ Generate CAM for a single image batch (1, C, H, W). """ output = self.model(input_image) if target_class is None: target_class = output.argmax(dim=1) self.model.zero_grad() one_hot = torch.zeros_like(output) one_hot[0][target_class] = 1 output.backward(gradient=one_hot, retain_graph=True) if self.gradients is None or self.activations is None: return None, output # Global average pooling over H, W weights = self.gradients.mean(dim=(2, 3), keepdim=True) cam = (weights * self.activations).sum(dim=1, keepdim=True) cam = torch.relu(cam) cam = cam.squeeze().cpu().numpy() cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) return cam, output # Target the last feature block for Grad-CAM target_layer = model.features[-1] grad_cam = GradCAM(model, target_layer) # ============================================================================ # Visualization Helpers # ============================================================================ def create_original_display(image, pred_label, confidence): """Create annotated original image""" fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(image) ax.axis("off") color = CLASS_COLORS[pred_label] title = f"Prediction: {pred_label}\nConfidence: {confidence:.1f}%" ax.set_title(title, fontsize=16, fontweight="bold", color=color, pad=20) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white") plt.close(fig) buf.seek(0) return Image.open(buf) def create_gradcam_visualization(image, cam, pred_label, confidence): """Create Grad-CAM heatmap""" img_array = np.array(image.resize((224, 224))) cam_resized = cv2.resize(cam, (224, 224)) heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(heatmap) ax.axis("off") ax.set_title( "Attention Heatmap\n(Areas the model focuses on)", fontsize=14, fontweight="bold", pad=20, ) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white") plt.close(fig) buf.seek(0) return Image.open(buf) def create_overlay_visualization(image, cam): """Create overlay of image and heatmap""" img_array = np.array(image.resize((224, 224))) / 255.0 cam_resized = cv2.resize(cam, (224, 224)) heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 overlay = img_array * 0.5 + heatmap * 0.5 overlay = np.clip(overlay, 0, 1) fig, ax = plt.subplots(figsize=(8, 8)) ax.imshow(overlay) ax.axis("off") ax.set_title( "Explainable AI Visualization\n(Original + Heatmap)", fontsize=14, fontweight="bold", pad=20, ) plt.tight_layout() buf = io.BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white") plt.close(fig) buf.seek(0) return Image.open(buf) def create_interpretation(pred_label, confidence, results): """Create interpretation text with medical-style narrative + disclaimers""" interpretation = f""" ## πŸ”¬ Analysis Results ### Prediction: **{pred_label}** - Confidence: **{confidence:.1f}%** ### Probability Breakdown: - 🟒 Normal: **{results['Normal']:.1f}%** - πŸ”΄ Tuberculosis: **{results['Tuberculosis']:.1f}%** - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%** - 🟣 COVID-19: **{results['COVID-19']:.1f}%** --- """ # Disease-specific sections if pred_label == "Tuberculosis": if confidence >= 85: interpretation += """ **⚠️ High Confidence TB Detection** The model has detected features highly consistent with pulmonary tuberculosis. **CRITICAL – Suggested next clinical steps (for clinicians):** 1. **Immediate clinical review** of the patient (history + physical exam) 2. **Confirmatory tests**: - Sputum smear microscopy and/or GeneXpert MTB/RIF - TB culture where available 3. **Correlate with symptoms**: - Cough > 2 weeks - Fever, night sweats - Weight loss, hemoptysis 4. **Consider isolation** and contact tracing if active TB is suspected 5. **Additional imaging** (e.g., CT chest) if diagnosis remains uncertain > This tool is **screening-only** and cannot replace microbiological confirmation. """ else: interpretation += """ **⚠️ Possible Tuberculosis** There are radiographic features that *may* be compatible with TB, but the model's confidence is moderate. **Recommended actions (for clinicians):** 1. Perform focused clinical assessment 2. Consider sputum testing (smear / GeneXpert) 3. Review prior imaging for evolution of disease 4. Use this result as a **second reader**, not definitive evidence > Moderate probability predictions always require clinical judgment. """ elif pred_label == "Pneumonia": if confidence >= 85: interpretation += """ **⚠️ High Confidence Pneumonia Detection** Findings are strongly suggestive of pneumonia (bacterial or viral). **Suggested steps:** 1. Clinical evaluation for pneumonia severity 2. Laboratory assessment: - CBC, CRP/ESR - Blood cultures if severely unwell 3. Consider empiric antibiotics (if bacterial suspected) per local guidelines 4. Repeat imaging if no improvement or worsening > Classic pneumonia patterns can overlap with other diseases – interpretation must remain clinical. """ else: interpretation += """ **⚠️ Possible Pneumonia** The X-ray may show early or subtle changes of pneumonia. **Suggested steps:** 1. Correlate with respiratory symptoms (cough, fever, dyspnea) 2. Consider repeat imaging in 24–72 hours if clinically indicated 3. Use this AI opinion as supportive, not definitive """ elif pred_label == "COVID-19": if confidence >= 85: interpretation += """ **⚠️ High Confidence COVID-19 Pattern** Pattern is compatible with COVID-19 pneumonia. **Suggested next steps:** 1. **Confirmatory testing** with RT-PCR or validated antigen test 2. **Infection control**: - Isolation according to institutional policy - Appropriate PPE for staff 3. **Clinical monitoring**: - Oxygen saturation (SpOβ‚‚) - Respiratory rate, hemodynamics 4. **Escalation** if: - SpOβ‚‚ < 94% - Increased work of breathing - Hemodynamic instability > Radiology alone cannot confirm COVID-19 – virological testing is mandatory. """ else: interpretation += """ **⚠️ Possible COVID-19** Some features overlap with COVID-19, but the model is not highly confident. **Suggested steps:** 1. Test with RT-PCR or validated antigen assay 2. Assess epidemiologic risk and exposure history 3. Follow local protocols for isolation and monitoring """ else: # Normal if confidence >= 85: interpretation += """ **βœ… High Confidence β€œNormal” Chest X-Ray (for the 4 modeled diseases)** Within the limits of this model: - No strong evidence of **TB**, **pneumonia**, or **COVID-19** is detected. - Lung fields appear within normal limits on this projection. **Important caveats:** - A β€œnormal” AI result does **not** exclude all lung disease. - Early or subtle TB/pneumonia/COVID-19 may still be radiographically occult. - Other conditions (PE, asthma, COPD, malignancy, etc.) are **outside the scope** of this model. Clinical review remains essential, especially if symptoms persist. """ else: interpretation += """ **⚠️ Likely Normal, but with Lower Confidence** The model leans towards a normal study, but with limited confidence. **Suggested steps:** 1. If the patient is symptomatic, clinical evaluation is still required. 2. Consider repeat imaging if symptoms evolve. 3. Use this output as an adjunct, not reassurance in isolation. """ # Global disclaimer and technical note interpretation += """ --- ## ⚠️ CRITICAL MEDICAL DISCLAIMER ### What this model *can* do: - βœ… Screen for 4 specific classes: **Normal**, **Tuberculosis**, **Pneumonia**, **COVID-19** - βœ… Provide **explainable heatmaps** (Grad-CAM) to highlight regions of interest - βœ… Offer **probabilistic support** to human readers - βœ… Leverage **Adaptive Sparse Training (AST)** for ~89% energy savings vs dense baselines ### What this model *cannot* do: - ❌ It is **not** FDA/EMA-approved – research / educational use only - ❌ It does **not** replace radiologists, pulmonologists, or infectious disease specialists - ❌ It does **not** detect many other thoracic pathologies (e.g., cancer, fibrosis, PE) - ❌ It does **not** provide a microbiological diagnosis ### Clinical usage guidance: 1. Use as a **second reader** or screening tool. 2. Always **correlate with clinical history, examination, and lab tests**. 3. Never start, stop, or change treatment **solely** based on this AI prediction. 4. Follow your local and international guidelines for TB, pneumonia, and COVID-19 management. ### Diagnostic gold standards: - **TB**: Sputum AFB, culture, GeneXpert MTB/RIF, TB-PCR - **Pneumonia**: Clinical + imaging + microbiology - **COVID-19**: RT-PCR / validated antigen testing > When in doubt, a qualified healthcare professional’s judgment takes absolute precedence. --- 🫁 **Powered by Adaptive Sparse Training (AST)** Energy-efficient AI for accessible lung disease screening. **Project links:** - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis """ return interpretation # ============================================================================ # Prediction Function # ============================================================================ def predict_chest_xray(image, show_gradcam=True): """ Main prediction function used by Gradio. Returns: - dict of class probabilities - Annotated original image - Grad-CAM heatmap - Overlay image - Markdown interpretation """ if image is None: return None, None, None, None, "Please upload a chest X-ray image." # Ensure PIL RGB if isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") else: image = image.convert("RGB") original_img = image.copy() input_tensor = transform(image).unsqueeze(0).to(device) with torch.set_grad_enabled(show_gradcam): if show_gradcam: cam, output = grad_cam.generate(input_tensor) else: output = model(input_tensor) cam = None probs = torch.softmax(output, dim=1)[0].detach().cpu().numpy() prob_sum = float(np.sum(probs)) if not (0.98 <= prob_sum <= 1.02): print(f"⚠️ Probability sum = {prob_sum:.4f} (should be ~1.0). Check model/weights.") pred_idx = int(output.argmax(dim=1).item()) pred_label = CLASSES[pred_idx] confidence = float(probs[pred_idx]) * 100.0 results = { CLASSES[i]: float(np.clip(probs[i] * 100.0, 0.0, 100.0)) for i in range(len(CLASSES)) } original_pil = create_original_display(original_img, pred_label, confidence) if cam is not None and show_gradcam: gradcam_viz = create_gradcam_visualization( original_img, cam, pred_label, confidence ) overlay_viz = create_overlay_visualization(original_img, cam) else: gradcam_viz = None overlay_viz = None interpretation = create_interpretation(pred_label, confidence, results) return results, original_pil, gradcam_viz, overlay_viz, interpretation # ============================================================================ # Gradio Interface # ============================================================================ custom_css = """ #main-container { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 20px; } #title { text-align: center; color: white; font-size: 2.5em; font-weight: bold; margin-bottom: 10px; text-shadow: 2px 2px 4px rgba(0,0,0,0.3); } #subtitle { text-align: center; color: #f0f0f0; font-size: 1.2em; margin-bottom: 20px; } #stats { text-align: center; color: #fff; font-size: 0.95em; margin-bottom: 30px; padding: 15px; background: rgba(255,255,255,0.1); border-radius: 10px; backdrop-filter: blur(10px); } .gradio-container { font-family: 'Inter', sans-serif; } #upload-box { border: 3px dashed #667eea; border-radius: 15px; padding: 20px; background: rgba(255,255,255,0.95); } #results-box { background: white; border-radius: 15px; padding: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); } .output-image { border-radius: 10px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } footer { text-align: center; margin-top: 30px; color: white; font-size: 0.9em; } """ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.HTML( """
🫁 Multi-Class Chest X-Ray Detection AI
Advanced chest X-ray analysis with Explainable AI
95–97% Accuracy across 4 disease classes | 89% Energy Efficient | Powered by Adaptive Sparse Training (AST)

Detects: Normal β€’ Tuberculosis β€’ Pneumonia β€’ COVID-19
""" ) with gr.Row(): with gr.Column(scale=1, elem_id="upload-box"): gr.Markdown("## πŸ“€ Upload Chest X-Ray") image_input = gr.Image( type="pil", label="Upload X-Ray Image", elem_classes="output-image", ) show_gradcam = gr.Checkbox( value=True, label="Enable Grad-CAM Visualization (Explainable AI)", info="Shows which areas the model focuses on", ) analyze_btn = gr.Button("πŸ”¬ Analyze X-Ray", variant="primary", size="lg") gr.Markdown( """ ### πŸ“‹ Supported Images: - Chest X-rays (PA or AP view) - PNG, JPG, JPEG formats - Grayscale or RGB ### ⚑ Model Highlights: - βœ… **Improved Specificity**: Better separation of TB vs Pneumonia - βœ… **4 Disease Classes**: Normal, TB, Pneumonia, COVID-19 - βœ… **Energy-Aware**: ~89% energy savings with AST - βœ… **Explainable**: Grad-CAM heatmaps for clinical teams """ ) with gr.Column(scale=2, elem_id="results-box"): gr.Markdown("## πŸ“Š Analysis Results") prob_output = gr.Label( label="Prediction Confidence", num_top_classes=4 ) with gr.Tabs(): with gr.Tab("Original"): original_output = gr.Image( label="Annotated X-Ray", elem_classes="output-image" ) with gr.Tab("Grad-CAM Heatmap"): gradcam_output = gr.Image( label="Attention Heatmap", elem_classes="output-image" ) with gr.Tab("Overlay"): overlay_output = gr.Image( label="Explainable AI Visualization", elem_classes="output-image", ) interpretation_output = gr.Markdown(label="Clinical Interpretation") gr.Markdown("## πŸ“ Example X-Rays (Demo Only)") gr.Examples( examples=[ ["examples/normal.png"], ["examples/tb.png"], ["examples/pneumonia.png"], ["examples/covid.png"], ], inputs=image_input, label="Click an example to load", ) analyze_btn.click( fn=predict_chest_xray, inputs=[image_input, show_gradcam], outputs=[ prob_output, original_output, gradcam_output, overlay_output, interpretation_output, ], ) gr.HTML( """ """ ) # ============================================================================ # Launch # ============================================================================ if __name__ == "__main__": demo.launch( share=False, server_name="0.0.0.0", server_port=7860, show_error=True, )