Spaces:
Sleeping
Sleeping
| """ | |
| 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training | |
| WOW UI/UX Edition – 4 Disease Classes | |
| - Normal, Tuberculosis, Pneumonia, COVID-19 | |
| - Grad-CAM (Explainable AI) | |
| - Energy-efficient Adaptive Sparse Training | |
| """ | |
| import io | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from torchvision import models, transforms | |
| # ============================================================================ | |
| # Model Setup | |
| # ============================================================================ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # EfficientNet backbone | |
| model = models.efficientnet_b0(weights=None) | |
| model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes | |
| # Try a few reasonable checkpoint locations | |
| checkpoint_candidates = [ | |
| "best.pt", | |
| "checkpoints/best.pt", # <-- your current file | |
| "checkpoints/lasttb.pt", # optional fallback | |
| ] | |
| MODEL_LOAD_INFO = "" | |
| loaded = False | |
| for ckpt_path in checkpoint_candidates: | |
| if Path(ckpt_path).is_file(): | |
| try: | |
| print(f"🔍 Trying to load weights from: {ckpt_path}") | |
| state_dict = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(state_dict) | |
| MODEL_LOAD_INFO = f"✅ Model loaded from **{ckpt_path}** on **{device.type.upper()}**." | |
| loaded = True | |
| break | |
| except Exception as e: | |
| print(f"⚠️ Found {ckpt_path} but failed to load: {e}") | |
| if not loaded: | |
| raise RuntimeError( | |
| "Model file not found or could not be loaded. " | |
| "Please upload 'checkpoints/best.pt' (or 'best.pt' in the repo root)." | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| TOTAL_PARAMS = sum(p.numel() for p in model.parameters()) | |
| TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6 | |
| # ============================================================================ | |
| # Classes & Preprocessing | |
| # ============================================================================ | |
| CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"] | |
| CLASS_COLORS = { | |
| "Normal": "#2ecc71", # Green | |
| "Tuberculosis": "#e74c3c", # Red | |
| "Pneumonia": "#f39c12", # Orange | |
| "COVID-19": "#9b59b6", # Purple | |
| } | |
| transform = transforms.Compose( | |
| [ | |
| transforms.Resize(256), | |
| transforms.CenterCrop(224), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| [0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225], | |
| ), | |
| ] | |
| ) | |
| # ============================================================================ | |
| # Grad-CAM Implementation | |
| # ============================================================================ | |
| class GradCAM: | |
| def __init__(self, model, target_layer): | |
| self.model = model | |
| self.target_layer = target_layer | |
| self.gradients = None | |
| self.activations = None | |
| def save_gradient(grad): | |
| self.gradients = grad | |
| def save_activation(module, input, output): | |
| self.activations = output.detach() | |
| target_layer.register_forward_hook(save_activation) | |
| target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0])) | |
| def generate(self, input_image, target_class=None): | |
| output = self.model(input_image) | |
| if target_class is None: | |
| target_class = output.argmax(dim=1) | |
| self.model.zero_grad() | |
| one_hot = torch.zeros_like(output) | |
| one_hot[0, target_class] = 1 | |
| output.backward(gradient=one_hot, retain_graph=True) | |
| if self.gradients is None: | |
| return None, output | |
| weights = self.gradients.mean(dim=(2, 3), keepdim=True) | |
| cam = (weights * self.activations).sum(dim=1, keepdim=True) | |
| cam = torch.relu(cam) | |
| cam = cam.squeeze().cpu().numpy() | |
| cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) | |
| return cam, output | |
| target_layer = model.features[-1] | |
| grad_cam = GradCAM(model, target_layer) | |
| # ============================================================================ | |
| # Visualization Helpers | |
| # ============================================================================ | |
| def _figure_to_pil(): | |
| buf = io.BytesIO() | |
| plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white") | |
| plt.close() | |
| buf.seek(0) | |
| return Image.open(buf) | |
| def create_original_display(image, pred_label, confidence): | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| ax.imshow(image) | |
| ax.axis("off") | |
| color = CLASS_COLORS[pred_label] | |
| title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%" | |
| ax.set_title( | |
| title, | |
| fontsize=16, | |
| fontweight="bold", | |
| color=color, | |
| pad=20, | |
| ) | |
| plt.tight_layout() | |
| return _figure_to_pil() | |
| def create_gradcam_visualization(image, cam): | |
| img_array = np.array(image.resize((224, 224))) | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| ax.imshow(heatmap) | |
| ax.axis("off") | |
| ax.set_title( | |
| "Attention Heatmap\n(Where the model is looking)", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=20, | |
| ) | |
| plt.tight_layout() | |
| return _figure_to_pil() | |
| def create_overlay_visualization(image, cam): | |
| img_array = np.array(image.resize((224, 224))) / 255.0 | |
| cam_resized = cv2.resize(cam, (224, 224)) | |
| heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET) | |
| heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0 | |
| overlay = img_array * 0.5 + heatmap * 0.5 | |
| overlay = np.clip(overlay, 0, 1) | |
| fig, ax = plt.subplots(figsize=(7, 7)) | |
| ax.imshow(overlay) | |
| ax.axis("off") | |
| ax.set_title( | |
| "Explainable AI Overlay\n(Anatomy + Attention)", | |
| fontsize=14, | |
| fontweight="bold", | |
| pad=20, | |
| ) | |
| plt.tight_layout() | |
| return _figure_to_pil() | |
| # ============================================================================ | |
| # Interpretation | |
| # ============================================================================ | |
| def create_interpretation(pred_label, confidence, results, audience="Clinician"): | |
| header_note = { | |
| "Clinician": "This view is tuned for **clinical decision support** (not a replacement for your judgement).", | |
| "Researcher": "This view is tuned for **model behavior understanding** and experimental workflows.", | |
| "Patient / Public": "This view is tuned for **patient-friendly language**. Always discuss results with a doctor.", | |
| }.get(audience, "Use this output as a **screening aid**, not a final diagnosis.") | |
| interpretation = f""" | |
| ## 🔬 Analysis Results ({audience} View) | |
| > {header_note} | |
| ### Primary Prediction: **{pred_label}** | |
| - Confidence: **{confidence:.1f}%** | |
| ### Probability Breakdown | |
| - 🟢 Normal: **{results['Normal']:.1f}%** | |
| - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%** | |
| - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%** | |
| - 🟣 COVID-19: **{results['COVID-19']:.1f}%** | |
| --- | |
| """ | |
| # Disease-specific sections (same logic, slightly formatted) | |
| if pred_label == "Tuberculosis": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### 🧫 TB Pattern – High Confidence | |
| The model has detected features strongly suggestive of **pulmonary tuberculosis**. | |
| **Recommended Clinical Pathway:** | |
| 1. ✅ Immediate medical review by a clinician or chest physician | |
| 2. ✅ **Sputum testing** (AFB smear, GeneXpert MTB/RIF, or TB-PCR) | |
| 3. ✅ Correlate with symptoms: | |
| - Persistent cough > 2 weeks | |
| - Weight loss, night sweats | |
| - Fever, fatigue | |
| - Hemoptysis (coughing blood) | |
| 4. ✅ Consider CT scan or additional imaging if uncertainty remains | |
| 5. ✅ Infection control and contact tracing if TB is confirmed | |
| > This tool helps *flag* suspicious cases. TB diagnosis still requires **laboratory confirmation**. | |
| """ | |
| else: | |
| interpretation += """ | |
| ### 🧫 TB Pattern – Possible | |
| The scan shows features that **could** be compatible with tuberculosis, but confidence is moderate. | |
| **Suggested Actions:** | |
| - Clinical review and detailed history | |
| - Consider sputum testing if symptoms or risk factors are present | |
| - Follow-up imaging where clinically indicated | |
| """ | |
| elif pred_label == "Pneumonia": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### 🌫 Pneumonia Pattern – High Confidence | |
| The model has detected an opacity pattern consistent with **pneumonia**. | |
| **Typical Clinical Correlates:** | |
| - Fever, productive cough | |
| - Shortness of breath | |
| - Pleuritic chest pain | |
| **Next Steps (for clinicians):** | |
| - Correlate with fever, auscultation, and lab results | |
| - Consider antibiotics for bacterial pneumonia as per local guidelines | |
| - Repeat imaging if clinical evolution is atypical | |
| """ | |
| else: | |
| interpretation += """ | |
| ### 🌫 Pneumonia Pattern – Possible | |
| Findings may be compatible with pneumonia, but alternative explanations exist. | |
| **Recommended:** | |
| - Clinical evaluation (vital signs, exam) | |
| - Consider labs (WBC, CRP, cultures) | |
| - Watchful follow-up or repeat imaging as appropriate | |
| """ | |
| elif pred_label == "COVID-19": | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### 🦠 COVID-19 Pattern – High Confidence | |
| Distribution and appearance of opacities are compatible with **COVID-19 pneumonia**. | |
| **Critical Points:** | |
| - Imaging is **not** diagnostic by itself | |
| - **RT-PCR / rapid antigen testing** is mandatory for confirmation | |
| **If clinically suspected:** | |
| - Isolate per local infection-control policies | |
| - Monitor SpO₂ and respiratory status | |
| - Escalate care if: | |
| - SpO₂ < 94% on room air | |
| - Increasing work of breathing | |
| - Hemodynamic instability | |
| """ | |
| else: | |
| interpretation += """ | |
| ### 🦠 COVID-19 Pattern – Possible | |
| Some features may overlap with COVID-19, but there is **significant uncertainty**. | |
| **Do not rely on imaging alone.** | |
| - Obtain RT-PCR / rapid antigen testing | |
| - Use clinical context and epidemiology to guide decisions | |
| """ | |
| else: # Normal | |
| if confidence >= 85: | |
| interpretation += """ | |
| ### ✅ No Major Abnormality Detected | |
| The model did **not** detect features suggestive of TB, pneumonia, or COVID-19. | |
| **Important Caveats:** | |
| - Early disease or small lesions may be missed | |
| - Non-infective conditions (e.g., cancer, ILD) are **not** specifically evaluated | |
| - If symptoms are present, further workup may still be required | |
| """ | |
| else: | |
| interpretation += """ | |
| ### ℹ️ Likely Normal, But Low Confidence | |
| The scan leans towards **normal**, but the model is not highly confident. | |
| **If symptoms persist:** | |
| - Consider follow-up imaging | |
| - Seek a clinician’s interpretation | |
| """ | |
| # Universal disclaimer | |
| interpretation += """ | |
| --- | |
| ## ⚠️ CRITICAL MEDICAL DISCLAIMER | |
| - This AI model is a **screening / decision-support tool only** | |
| - It is **not FDA-approved** and **must not** be used as a stand-alone diagnostic device | |
| - Always integrate: | |
| - Clinical history and examination | |
| - Laboratory tests (e.g., sputum, PCR, cultures) | |
| - Expert radiologist review | |
| **Gold Standards:** | |
| - TB: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR | |
| - Pneumonia: Clinical diagnosis + labs / microbiology | |
| - COVID-19: RT-PCR or validated antigen tests | |
| When in doubt, consult a qualified healthcare professional. | |
| """ | |
| interpretation += """ | |
| --- | |
| 🫁 **Powered by Adaptive Sparse Training (AST)** | |
| Energy-efficient deep learning – designed to make advanced chest X-ray screening more accessible. | |
| **Links:** | |
| - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis | |
| - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis | |
| """ | |
| return interpretation | |
| # ============================================================================ | |
| # Prediction Pipeline | |
| # ============================================================================ | |
| def predict_chest_xray(image, show_gradcam=True, audience="Clinician"): | |
| """ | |
| Main inference function used by Gradio. | |
| Returns: | |
| - dict of class probabilities | |
| - annotated original | |
| - grad-cam heatmap | |
| - overlay | |
| - full markdown report | |
| - short textual snapshot | |
| """ | |
| if image is None: | |
| msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report." | |
| return {}, None, None, None, msg, "Awaiting image upload…" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image).convert("RGB") | |
| else: | |
| image = image.convert("RGB") | |
| original_img = image.copy() | |
| input_tensor = transform(image).unsqueeze(0).to(device) | |
| # Inference with optional Grad-CAM | |
| with torch.set_grad_enabled(show_gradcam): | |
| if show_gradcam: | |
| cam, output = grad_cam.generate(input_tensor) | |
| else: | |
| output = model(input_tensor) | |
| cam = None | |
| probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy() | |
| prob_sum = float(np.sum(probs)) | |
| if not (0.99 <= prob_sum <= 1.01): | |
| print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.") | |
| pred_class = int(output.argmax(dim=1).item()) | |
| pred_label = CLASSES[pred_class] | |
| confidence = float(probs[pred_class]) * 100.0 | |
| results = { | |
| CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0))) | |
| for i in range(len(CLASSES)) | |
| } | |
| # Visuals | |
| original_pil = create_original_display(original_img, pred_label, confidence) | |
| gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None | |
| overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None | |
| interpretation = create_interpretation(pred_label, confidence, results, audience=audience) | |
| snapshot = f"**{pred_label}** · {confidence:.1f}% confidence • Sum of probabilities: {prob_sum:.3f}" | |
| return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot | |
| # ============================================================================ | |
| # WOW UI / UX – Gradio App | |
| # ============================================================================ | |
| custom_css = """ | |
| :root { | |
| --primary: #6366f1; | |
| --primary-soft: rgba(99, 102, 241, 0.12); | |
| --accent: #ec4899; | |
| } | |
| .gradio-container { | |
| font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif; | |
| background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%); | |
| color: #e5e7eb; | |
| } | |
| #hero { | |
| padding: 24px 24px 8px 24px; | |
| border-radius: 24px; | |
| background: linear-gradient(120deg, rgba(99,102,241,0.18), rgba(236,72,153,0.14)); | |
| border: 1px solid rgba(148, 163, 184, 0.4); | |
| box-shadow: 0 24px 60px rgba(15,23,42,0.85); | |
| backdrop-filter: blur(18px); | |
| } | |
| .hero-title { | |
| font-size: 2.4rem; | |
| font-weight: 800; | |
| letter-spacing: 0.04em; | |
| color: #f9fafb; | |
| margin-bottom: 6px; | |
| } | |
| .hero-subtitle { | |
| font-size: 0.98rem; | |
| color: #e5e7eb; | |
| } | |
| .hero-chip-row { | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| margin-top: 14px; | |
| } | |
| .hero-chip { | |
| padding: 4px 10px; | |
| border-radius: 999px; | |
| font-size: 0.78rem; | |
| background: rgba(15,23,42,0.8); | |
| border: 1px solid rgba(148,163,184,0.5); | |
| display: inline-flex; | |
| align-items: center; | |
| gap: 6px; | |
| color: #e5e7eb; | |
| } | |
| .pulse-dot { | |
| width: 8px; | |
| height: 8px; | |
| border-radius: 999px; | |
| background: #22c55e; | |
| box-shadow: 0 0 0 0 rgba(34,197,94,0.7); | |
| animation: pulse 1.4s infinite; | |
| } | |
| @keyframes pulse { | |
| 0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); } | |
| 70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); } | |
| 100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); } | |
| } | |
| .glass-card { | |
| background: rgba(15,23,42,0.82); | |
| border-radius: 18px; | |
| border: 1px solid rgba(148,163,184,0.4); | |
| box-shadow: 0 18px 40px rgba(15,23,42,0.85); | |
| padding: 18px; | |
| backdrop-filter: blur(16px); | |
| } | |
| .glass-card-light { | |
| background: rgba(15,23,42,0.65); | |
| border-radius: 18px; | |
| border: 1px solid rgba(148,163,184,0.3); | |
| box-shadow: 0 12px 24px rgba(15,23,42,0.85); | |
| padding: 16px; | |
| backdrop-filter: blur(12px); | |
| } | |
| .stat-pill { | |
| padding: 10px 12px; | |
| border-radius: 14px; | |
| background: rgba(15,23,42,0.9); | |
| border: 1px solid rgba(148,163,184,0.5); | |
| font-size: 0.78rem; | |
| display: flex; | |
| flex-direction: column; | |
| gap: 2px; | |
| } | |
| .stat-pill-label { | |
| color: #9ca3af; | |
| text-transform: uppercase; | |
| font-size: 0.68rem; | |
| } | |
| .stat-pill-value { | |
| color: #e5e7eb; | |
| font-weight: 600; | |
| } | |
| .dropzone-image img { | |
| border-radius: 16px !important; | |
| } | |
| .output-image img { | |
| border-radius: 16px !important; | |
| } | |
| footer { | |
| text-align: center; | |
| margin-top: 24px; | |
| color: #9ca3af; | |
| font-size: 0.78rem; | |
| } | |
| """ | |
| theme = gr.themes.Soft( | |
| primary_hue="indigo", | |
| secondary_hue="pink", | |
| neutral_hue="slate", | |
| ).set( | |
| button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)", | |
| button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)", | |
| ) | |
| with gr.Blocks(css=custom_css, theme=theme) as demo: | |
| # HERO | |
| gr.HTML( | |
| f""" | |
| <div id="hero"> | |
| <div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;"> | |
| <div> | |
| <div class="hero-title">🫁 AST Chest X-Ray Lab</div> | |
| <div class="hero-subtitle"> | |
| Multi-class chest X-ray analysis with <b>Explainable AI</b> and | |
| <b>Adaptive Sparse Training</b>. | |
| Designed for TB, Pneumonia, COVID-19 and Normal scans. | |
| </div> | |
| <div class="hero-chip-row"> | |
| <div class="hero-chip"> | |
| <span class="pulse-dot"></span> | |
| Live Inference | |
| </div> | |
| <div class="hero-chip"> | |
| 4-class EfficientNet · ~{TOTAL_PARAMS_M:.1f}M params | |
| </div> | |
| <div class="hero-chip"> | |
| 95–97% validation accuracy · ~89% energy savings | |
| </div> | |
| <div class="hero-chip"> | |
| {MODEL_LOAD_INFO} | |
| </div> | |
| </div> | |
| </div> | |
| <div style="min-width:210px;display:flex;flex-direction:column;gap:8px;"> | |
| <div class="stat-pill"> | |
| <div class="stat-pill-label">Device</div> | |
| <div class="stat-pill-value">{device.type.upper()}</div> | |
| </div> | |
| <div class="stat-pill"> | |
| <div class="stat-pill-label">Model</div> | |
| <div class="stat-pill-value">EfficientNet-B0 · 4-way classifier</div> | |
| </div> | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown(" ") | |
| with gr.Row(equal_height=True): | |
| # ---------------------------------- | |
| # LEFT: INPUT PANEL | |
| # ---------------------------------- | |
| with gr.Column(scale=1, elem_classes="glass-card"): | |
| gr.Markdown("### 1️⃣ Upload & Configure") | |
| image_input = gr.Image( | |
| type="pil", | |
| label="Drop a chest X-ray here", | |
| elem_classes=["dropzone-image"], | |
| ) | |
| with gr.Row(): | |
| show_gradcam = gr.Checkbox( | |
| value=True, | |
| label="Explainable AI (Grad-CAM)", | |
| info="Highlight regions that drive the prediction", | |
| ) | |
| audience_select = gr.Radio( | |
| ["Clinician", "Researcher", "Patient / Public"], | |
| value="Clinician", | |
| label="Report Style", | |
| ) | |
| with gr.Row(): | |
| analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3) | |
| clear_btn = gr.Button("🧹 Reset", variant="secondary") | |
| gr.Markdown( | |
| """ | |
| **Tips** | |
| - Use frontal (PA/AP) chest X-rays in PNG / JPG format | |
| - This tool is best used as a **triage / screening assistant** | |
| - For noisy images or rotated scans, consider preprocessing before upload | |
| """ | |
| ) | |
| # ---------------------------------- | |
| # RIGHT: RESULTS PANEL | |
| # ---------------------------------- | |
| with gr.Column(scale=2, elem_classes="glass-card-light"): | |
| gr.Markdown("### 2️⃣ AI Dashboard") | |
| with gr.Tabs(): | |
| with gr.Tab("Snapshot"): | |
| snapshot_output = gr.Markdown( | |
| "No scan analyzed yet. Upload an X-ray to get started." | |
| ) | |
| prob_output = gr.Label( | |
| label="Prediction Confidence (All Classes)", | |
| num_top_classes=4, | |
| ) | |
| with gr.Tab("Visual Explanations"): | |
| with gr.Row(): | |
| original_output = gr.Image( | |
| label="Annotated X-ray", | |
| elem_classes=["output-image"], | |
| ) | |
| gradcam_output = gr.Image( | |
| label="Attention Heatmap", | |
| elem_classes=["output-image"], | |
| ) | |
| overlay_output = gr.Image( | |
| label="Explainable Overlay", | |
| elem_classes=["output-image"], | |
| ) | |
| with gr.Tab("Full Report"): | |
| interpretation_output = gr.Markdown( | |
| "The full clinical / research report will appear here after inference." | |
| ) | |
| with gr.Tab("Model Card"): | |
| gr.Markdown( | |
| f""" | |
| ### 🧠 Model Card – AST Chest X-Ray | |
| - **Backbone**: EfficientNet-B0 | |
| - **Task**: 4-way classification (Normal, Tuberculosis, Pneumonia, COVID-19) | |
| - **Optimization**: Sample-based Adaptive Sparse Training (AST) | |
| - **Energy Profile**: ~89% training energy reduction vs dense baseline | |
| **Design Goals** | |
| 1. Provide **fast, explainable triage** support for TB & pneumonia | |
| 2. Maintain **high specificity**, especially differentiating TB from pneumonia | |
| 3. Be lightweight enough for **deployment in resource-constrained settings** | |
| > This model is a research prototype. Do **not** use it as a stand-alone clinical device. | |
| """ | |
| ) | |
| gr.Markdown("---") | |
| gr.HTML( | |
| """ | |
| <footer> | |
| <p> | |
| <b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/> | |
| Built for research, education, and early-stage screening support. | |
| </p> | |
| <p style="margin-top:6px;"> | |
| ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is not FDA-approved and cannot replace a clinician | |
| or radiologist. All decisions must be made by qualified healthcare professionals. | |
| </p> | |
| </footer> | |
| """ | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # Wiring | |
| # ---------------------------------------------------------------------- | |
| analyze_btn.click( | |
| fn=predict_chest_xray, | |
| inputs=[image_input, show_gradcam, audience_select], | |
| outputs=[ | |
| prob_output, | |
| original_output, | |
| gradcam_output, | |
| overlay_output, | |
| interpretation_output, | |
| snapshot_output, | |
| ], | |
| ) | |
| clear_btn.click( | |
| fn=lambda: ({}, None, None, None, "Awaiting image upload…", "Awaiting image upload…"), | |
| inputs=None, | |
| outputs=[ | |
| prob_output, | |
| original_output, | |
| gradcam_output, | |
| overlay_output, | |
| interpretation_output, | |
| snapshot_output, | |
| ], | |
| ) | |
| # Example X-rays section (optional – remove if you don't have these paths) | |
| gr.Markdown("### 🔍 Try Example X-rays") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/normal.png"], | |
| ["examples/tb.png"], | |
| ["examples/pneumonia.png"], | |
| ["examples/covid.png"], | |
| ], | |
| inputs=image_input, | |
| ) | |
| # ============================================================================ | |
| # Launch | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| demo.launch( | |
| share=False, | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| show_error=True, | |
| ) | |