Tuberculosis / app.py
mgbam's picture
Update app.py
c79e22e verified
raw
history blame
25 kB
"""
🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
WOW UI/UX Edition – 4 Disease Classes
- Normal, Tuberculosis, Pneumonia, COVID-19
- Grad-CAM (Explainable AI)
- Energy-efficient Adaptive Sparse Training
"""
import io
from pathlib import Path
import cv2
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from torchvision import models, transforms
# ============================================================================
# Model Setup
# ============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# EfficientNet backbone
model = models.efficientnet_b0(weights=None)
model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
# Try a few reasonable checkpoint locations
checkpoint_candidates = [
"best.pt",
"checkpoints/best.pt", # <-- your current file
"checkpoints/lasttb.pt", # optional fallback
]
MODEL_LOAD_INFO = ""
loaded = False
for ckpt_path in checkpoint_candidates:
if Path(ckpt_path).is_file():
try:
print(f"🔍 Trying to load weights from: {ckpt_path}")
state_dict = torch.load(ckpt_path, map_location=device)
model.load_state_dict(state_dict)
MODEL_LOAD_INFO = f"✅ Model loaded from **{ckpt_path}** on **{device.type.upper()}**."
loaded = True
break
except Exception as e:
print(f"⚠️ Found {ckpt_path} but failed to load: {e}")
if not loaded:
raise RuntimeError(
"Model file not found or could not be loaded. "
"Please upload 'checkpoints/best.pt' (or 'best.pt' in the repo root)."
)
model = model.to(device)
model.eval()
TOTAL_PARAMS = sum(p.numel() for p in model.parameters())
TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6
# ============================================================================
# Classes & Preprocessing
# ============================================================================
CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
CLASS_COLORS = {
"Normal": "#2ecc71", # Green
"Tuberculosis": "#e74c3c", # Red
"Pneumonia": "#f39c12", # Orange
"COVID-19": "#9b59b6", # Purple
}
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
),
]
)
# ============================================================================
# Grad-CAM Implementation
# ============================================================================
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
def save_gradient(grad):
self.gradients = grad
def save_activation(module, input, output):
self.activations = output.detach()
target_layer.register_forward_hook(save_activation)
target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0]))
def generate(self, input_image, target_class=None):
output = self.model(input_image)
if target_class is None:
target_class = output.argmax(dim=1)
self.model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, target_class] = 1
output.backward(gradient=one_hot, retain_graph=True)
if self.gradients is None:
return None, output
weights = self.gradients.mean(dim=(2, 3), keepdim=True)
cam = (weights * self.activations).sum(dim=1, keepdim=True)
cam = torch.relu(cam)
cam = cam.squeeze().cpu().numpy()
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
return cam, output
target_layer = model.features[-1]
grad_cam = GradCAM(model, target_layer)
# ============================================================================
# Visualization Helpers
# ============================================================================
def _figure_to_pil():
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
plt.close()
buf.seek(0)
return Image.open(buf)
def create_original_display(image, pred_label, confidence):
fig, ax = plt.subplots(figsize=(7, 7))
ax.imshow(image)
ax.axis("off")
color = CLASS_COLORS[pred_label]
title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%"
ax.set_title(
title,
fontsize=16,
fontweight="bold",
color=color,
pad=20,
)
plt.tight_layout()
return _figure_to_pil()
def create_gradcam_visualization(image, cam):
img_array = np.array(image.resize((224, 224)))
cam_resized = cv2.resize(cam, (224, 224))
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
fig, ax = plt.subplots(figsize=(7, 7))
ax.imshow(heatmap)
ax.axis("off")
ax.set_title(
"Attention Heatmap\n(Where the model is looking)",
fontsize=14,
fontweight="bold",
pad=20,
)
plt.tight_layout()
return _figure_to_pil()
def create_overlay_visualization(image, cam):
img_array = np.array(image.resize((224, 224))) / 255.0
cam_resized = cv2.resize(cam, (224, 224))
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
overlay = img_array * 0.5 + heatmap * 0.5
overlay = np.clip(overlay, 0, 1)
fig, ax = plt.subplots(figsize=(7, 7))
ax.imshow(overlay)
ax.axis("off")
ax.set_title(
"Explainable AI Overlay\n(Anatomy + Attention)",
fontsize=14,
fontweight="bold",
pad=20,
)
plt.tight_layout()
return _figure_to_pil()
# ============================================================================
# Interpretation
# ============================================================================
def create_interpretation(pred_label, confidence, results, audience="Clinician"):
header_note = {
"Clinician": "This view is tuned for **clinical decision support** (not a replacement for your judgement).",
"Researcher": "This view is tuned for **model behavior understanding** and experimental workflows.",
"Patient / Public": "This view is tuned for **patient-friendly language**. Always discuss results with a doctor.",
}.get(audience, "Use this output as a **screening aid**, not a final diagnosis.")
interpretation = f"""
## 🔬 Analysis Results ({audience} View)
> {header_note}
### Primary Prediction: **{pred_label}**
- Confidence: **{confidence:.1f}%**
### Probability Breakdown
- 🟢 Normal: **{results['Normal']:.1f}%**
- 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
- 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
- 🟣 COVID-19: **{results['COVID-19']:.1f}%**
---
"""
# Disease-specific sections (same logic, slightly formatted)
if pred_label == "Tuberculosis":
if confidence >= 85:
interpretation += """
### 🧫 TB Pattern – High Confidence
The model has detected features strongly suggestive of **pulmonary tuberculosis**.
**Recommended Clinical Pathway:**
1. ✅ Immediate medical review by a clinician or chest physician
2. ✅ **Sputum testing** (AFB smear, GeneXpert MTB/RIF, or TB-PCR)
3. ✅ Correlate with symptoms:
- Persistent cough > 2 weeks
- Weight loss, night sweats
- Fever, fatigue
- Hemoptysis (coughing blood)
4. ✅ Consider CT scan or additional imaging if uncertainty remains
5. ✅ Infection control and contact tracing if TB is confirmed
> This tool helps *flag* suspicious cases. TB diagnosis still requires **laboratory confirmation**.
"""
else:
interpretation += """
### 🧫 TB Pattern – Possible
The scan shows features that **could** be compatible with tuberculosis, but confidence is moderate.
**Suggested Actions:**
- Clinical review and detailed history
- Consider sputum testing if symptoms or risk factors are present
- Follow-up imaging where clinically indicated
"""
elif pred_label == "Pneumonia":
if confidence >= 85:
interpretation += """
### 🌫 Pneumonia Pattern – High Confidence
The model has detected an opacity pattern consistent with **pneumonia**.
**Typical Clinical Correlates:**
- Fever, productive cough
- Shortness of breath
- Pleuritic chest pain
**Next Steps (for clinicians):**
- Correlate with fever, auscultation, and lab results
- Consider antibiotics for bacterial pneumonia as per local guidelines
- Repeat imaging if clinical evolution is atypical
"""
else:
interpretation += """
### 🌫 Pneumonia Pattern – Possible
Findings may be compatible with pneumonia, but alternative explanations exist.
**Recommended:**
- Clinical evaluation (vital signs, exam)
- Consider labs (WBC, CRP, cultures)
- Watchful follow-up or repeat imaging as appropriate
"""
elif pred_label == "COVID-19":
if confidence >= 85:
interpretation += """
### 🦠 COVID-19 Pattern – High Confidence
Distribution and appearance of opacities are compatible with **COVID-19 pneumonia**.
**Critical Points:**
- Imaging is **not** diagnostic by itself
- **RT-PCR / rapid antigen testing** is mandatory for confirmation
**If clinically suspected:**
- Isolate per local infection-control policies
- Monitor SpO₂ and respiratory status
- Escalate care if:
- SpO₂ < 94% on room air
- Increasing work of breathing
- Hemodynamic instability
"""
else:
interpretation += """
### 🦠 COVID-19 Pattern – Possible
Some features may overlap with COVID-19, but there is **significant uncertainty**.
**Do not rely on imaging alone.**
- Obtain RT-PCR / rapid antigen testing
- Use clinical context and epidemiology to guide decisions
"""
else: # Normal
if confidence >= 85:
interpretation += """
### ✅ No Major Abnormality Detected
The model did **not** detect features suggestive of TB, pneumonia, or COVID-19.
**Important Caveats:**
- Early disease or small lesions may be missed
- Non-infective conditions (e.g., cancer, ILD) are **not** specifically evaluated
- If symptoms are present, further workup may still be required
"""
else:
interpretation += """
### ℹ️ Likely Normal, But Low Confidence
The scan leans towards **normal**, but the model is not highly confident.
**If symptoms persist:**
- Consider follow-up imaging
- Seek a clinician’s interpretation
"""
# Universal disclaimer
interpretation += """
---
## ⚠️ CRITICAL MEDICAL DISCLAIMER
- This AI model is a **screening / decision-support tool only**
- It is **not FDA-approved** and **must not** be used as a stand-alone diagnostic device
- Always integrate:
- Clinical history and examination
- Laboratory tests (e.g., sputum, PCR, cultures)
- Expert radiologist review
**Gold Standards:**
- TB: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR
- Pneumonia: Clinical diagnosis + labs / microbiology
- COVID-19: RT-PCR or validated antigen tests
When in doubt, consult a qualified healthcare professional.
"""
interpretation += """
---
🫁 **Powered by Adaptive Sparse Training (AST)**
Energy-efficient deep learning – designed to make advanced chest X-ray screening more accessible.
**Links:**
- GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
- Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
"""
return interpretation
# ============================================================================
# Prediction Pipeline
# ============================================================================
def predict_chest_xray(image, show_gradcam=True, audience="Clinician"):
"""
Main inference function used by Gradio.
Returns:
- dict of class probabilities
- annotated original
- grad-cam heatmap
- overlay
- full markdown report
- short textual snapshot
"""
if image is None:
msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report."
return {}, None, None, None, msg, "Awaiting image upload…"
if isinstance(image, np.ndarray):
image = Image.fromarray(image).convert("RGB")
else:
image = image.convert("RGB")
original_img = image.copy()
input_tensor = transform(image).unsqueeze(0).to(device)
# Inference with optional Grad-CAM
with torch.set_grad_enabled(show_gradcam):
if show_gradcam:
cam, output = grad_cam.generate(input_tensor)
else:
output = model(input_tensor)
cam = None
probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
prob_sum = float(np.sum(probs))
if not (0.99 <= prob_sum <= 1.01):
print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.")
pred_class = int(output.argmax(dim=1).item())
pred_label = CLASSES[pred_class]
confidence = float(probs[pred_class]) * 100.0
results = {
CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
for i in range(len(CLASSES))
}
# Visuals
original_pil = create_original_display(original_img, pred_label, confidence)
gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None
overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None
interpretation = create_interpretation(pred_label, confidence, results, audience=audience)
snapshot = f"**{pred_label}** · {confidence:.1f}% confidence • Sum of probabilities: {prob_sum:.3f}"
return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot
# ============================================================================
# WOW UI / UX – Gradio App
# ============================================================================
custom_css = """
:root {
--primary: #6366f1;
--primary-soft: rgba(99, 102, 241, 0.12);
--accent: #ec4899;
}
.gradio-container {
font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%);
color: #e5e7eb;
}
#hero {
padding: 24px 24px 8px 24px;
border-radius: 24px;
background: linear-gradient(120deg, rgba(99,102,241,0.18), rgba(236,72,153,0.14));
border: 1px solid rgba(148, 163, 184, 0.4);
box-shadow: 0 24px 60px rgba(15,23,42,0.85);
backdrop-filter: blur(18px);
}
.hero-title {
font-size: 2.4rem;
font-weight: 800;
letter-spacing: 0.04em;
color: #f9fafb;
margin-bottom: 6px;
}
.hero-subtitle {
font-size: 0.98rem;
color: #e5e7eb;
}
.hero-chip-row {
display: flex;
flex-wrap: wrap;
gap: 8px;
margin-top: 14px;
}
.hero-chip {
padding: 4px 10px;
border-radius: 999px;
font-size: 0.78rem;
background: rgba(15,23,42,0.8);
border: 1px solid rgba(148,163,184,0.5);
display: inline-flex;
align-items: center;
gap: 6px;
color: #e5e7eb;
}
.pulse-dot {
width: 8px;
height: 8px;
border-radius: 999px;
background: #22c55e;
box-shadow: 0 0 0 0 rgba(34,197,94,0.7);
animation: pulse 1.4s infinite;
}
@keyframes pulse {
0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); }
70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); }
100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); }
}
.glass-card {
background: rgba(15,23,42,0.82);
border-radius: 18px;
border: 1px solid rgba(148,163,184,0.4);
box-shadow: 0 18px 40px rgba(15,23,42,0.85);
padding: 18px;
backdrop-filter: blur(16px);
}
.glass-card-light {
background: rgba(15,23,42,0.65);
border-radius: 18px;
border: 1px solid rgba(148,163,184,0.3);
box-shadow: 0 12px 24px rgba(15,23,42,0.85);
padding: 16px;
backdrop-filter: blur(12px);
}
.stat-pill {
padding: 10px 12px;
border-radius: 14px;
background: rgba(15,23,42,0.9);
border: 1px solid rgba(148,163,184,0.5);
font-size: 0.78rem;
display: flex;
flex-direction: column;
gap: 2px;
}
.stat-pill-label {
color: #9ca3af;
text-transform: uppercase;
font-size: 0.68rem;
}
.stat-pill-value {
color: #e5e7eb;
font-weight: 600;
}
.dropzone-image img {
border-radius: 16px !important;
}
.output-image img {
border-radius: 16px !important;
}
footer {
text-align: center;
margin-top: 24px;
color: #9ca3af;
font-size: 0.78rem;
}
"""
theme = gr.themes.Soft(
primary_hue="indigo",
secondary_hue="pink",
neutral_hue="slate",
).set(
button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)",
button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)",
)
with gr.Blocks(css=custom_css, theme=theme) as demo:
# HERO
gr.HTML(
f"""
<div id="hero">
<div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;">
<div>
<div class="hero-title">🫁 AST Chest X-Ray Lab</div>
<div class="hero-subtitle">
Multi-class chest X-ray analysis with <b>Explainable AI</b> and
<b>Adaptive Sparse Training</b>.
Designed for TB, Pneumonia, COVID-19 and Normal scans.
</div>
<div class="hero-chip-row">
<div class="hero-chip">
<span class="pulse-dot"></span>
Live Inference
</div>
<div class="hero-chip">
4-class EfficientNet · ~{TOTAL_PARAMS_M:.1f}M params
</div>
<div class="hero-chip">
95–97% validation accuracy · ~89% energy savings
</div>
<div class="hero-chip">
{MODEL_LOAD_INFO}
</div>
</div>
</div>
<div style="min-width:210px;display:flex;flex-direction:column;gap:8px;">
<div class="stat-pill">
<div class="stat-pill-label">Device</div>
<div class="stat-pill-value">{device.type.upper()}</div>
</div>
<div class="stat-pill">
<div class="stat-pill-label">Model</div>
<div class="stat-pill-value">EfficientNet-B0 · 4-way classifier</div>
</div>
</div>
</div>
</div>
"""
)
gr.Markdown(" ")
with gr.Row(equal_height=True):
# ----------------------------------
# LEFT: INPUT PANEL
# ----------------------------------
with gr.Column(scale=1, elem_classes="glass-card"):
gr.Markdown("### 1️⃣ Upload & Configure")
image_input = gr.Image(
type="pil",
label="Drop a chest X-ray here",
elem_classes=["dropzone-image"],
)
with gr.Row():
show_gradcam = gr.Checkbox(
value=True,
label="Explainable AI (Grad-CAM)",
info="Highlight regions that drive the prediction",
)
audience_select = gr.Radio(
["Clinician", "Researcher", "Patient / Public"],
value="Clinician",
label="Report Style",
)
with gr.Row():
analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3)
clear_btn = gr.Button("🧹 Reset", variant="secondary")
gr.Markdown(
"""
**Tips**
- Use frontal (PA/AP) chest X-rays in PNG / JPG format
- This tool is best used as a **triage / screening assistant**
- For noisy images or rotated scans, consider preprocessing before upload
"""
)
# ----------------------------------
# RIGHT: RESULTS PANEL
# ----------------------------------
with gr.Column(scale=2, elem_classes="glass-card-light"):
gr.Markdown("### 2️⃣ AI Dashboard")
with gr.Tabs():
with gr.Tab("Snapshot"):
snapshot_output = gr.Markdown(
"No scan analyzed yet. Upload an X-ray to get started."
)
prob_output = gr.Label(
label="Prediction Confidence (All Classes)",
num_top_classes=4,
)
with gr.Tab("Visual Explanations"):
with gr.Row():
original_output = gr.Image(
label="Annotated X-ray",
elem_classes=["output-image"],
)
gradcam_output = gr.Image(
label="Attention Heatmap",
elem_classes=["output-image"],
)
overlay_output = gr.Image(
label="Explainable Overlay",
elem_classes=["output-image"],
)
with gr.Tab("Full Report"):
interpretation_output = gr.Markdown(
"The full clinical / research report will appear here after inference."
)
with gr.Tab("Model Card"):
gr.Markdown(
f"""
### 🧠 Model Card – AST Chest X-Ray
- **Backbone**: EfficientNet-B0
- **Task**: 4-way classification (Normal, Tuberculosis, Pneumonia, COVID-19)
- **Optimization**: Sample-based Adaptive Sparse Training (AST)
- **Energy Profile**: ~89% training energy reduction vs dense baseline
**Design Goals**
1. Provide **fast, explainable triage** support for TB & pneumonia
2. Maintain **high specificity**, especially differentiating TB from pneumonia
3. Be lightweight enough for **deployment in resource-constrained settings**
> This model is a research prototype. Do **not** use it as a stand-alone clinical device.
"""
)
gr.Markdown("---")
gr.HTML(
"""
<footer>
<p>
<b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/>
Built for research, education, and early-stage screening support.
</p>
<p style="margin-top:6px;">
⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is not FDA-approved and cannot replace a clinician
or radiologist. All decisions must be made by qualified healthcare professionals.
</p>
</footer>
"""
)
# ----------------------------------------------------------------------
# Wiring
# ----------------------------------------------------------------------
analyze_btn.click(
fn=predict_chest_xray,
inputs=[image_input, show_gradcam, audience_select],
outputs=[
prob_output,
original_output,
gradcam_output,
overlay_output,
interpretation_output,
snapshot_output,
],
)
clear_btn.click(
fn=lambda: ({}, None, None, None, "Awaiting image upload…", "Awaiting image upload…"),
inputs=None,
outputs=[
prob_output,
original_output,
gradcam_output,
overlay_output,
interpretation_output,
snapshot_output,
],
)
# Example X-rays section (optional – remove if you don't have these paths)
gr.Markdown("### 🔍 Try Example X-rays")
gr.Examples(
examples=[
["examples/normal.png"],
["examples/tb.png"],
["examples/pneumonia.png"],
["examples/covid.png"],
],
inputs=image_input,
)
# ============================================================================
# Launch
# ============================================================================
if __name__ == "__main__":
demo.launch(
share=False,
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)