mgbam commited on
Commit
005b870
Β·
verified Β·
1 Parent(s): 1100eb0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +723 -0
app.py ADDED
@@ -0,0 +1,723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
3
+ Advanced Gradio Interface - 4 Disease Classes
4
+
5
+ Features:
6
+ - Real-time detection: Normal, TB, Pneumonia, COVID-19
7
+ - Grad-CAM visualization (explainable AI)
8
+ - Improved specificity - distinguishes TB from pneumonia
9
+ - Confidence scores with visual indicators
10
+ - Clinical interpretation and recommendations
11
+ - Mobile-responsive design
12
+ """
13
+
14
+ import os
15
+ from pathlib import Path
16
+ import io
17
+
18
+ import gradio as gr
19
+ import torch
20
+ import torch.nn as nn
21
+ from torchvision import models, transforms
22
+ from PIL import Image
23
+ import numpy as np
24
+ import cv2
25
+ import matplotlib.pyplot as plt
26
+
27
+ # ============================================================================
28
+ # Device
29
+ # ============================================================================
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ print(f"βœ… Using device: {device}")
33
+
34
+ # ============================================================================
35
+ # Model Setup & Robust Loader
36
+ # ============================================================================
37
+
38
+ NUM_CLASSES = 4
39
+ CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
40
+ CLASS_COLORS = {
41
+ "Normal": "#2ecc71", # Green
42
+ "Tuberculosis": "#e74c3c", # Red
43
+ "Pneumonia": "#f39c12", # Orange
44
+ "COVID-19": "#9b59b6", # Purple
45
+ }
46
+
47
+
48
+ def build_base_model(num_classes: int = NUM_CLASSES) -> nn.Module:
49
+ """
50
+ Build the base EfficientNet-B2 model with a 4-class classifier head.
51
+ This matches the architecture used during training.
52
+ """
53
+ # πŸ”’ Do NOT change this to efficientnet_b0 – your checkpoint is B2 (1408 features)
54
+ model = models.efficientnet_b2(weights=None)
55
+ in_features = model.classifier[1].in_features
56
+ model.classifier[1] = nn.Linear(in_features, num_classes)
57
+ return model
58
+
59
+
60
+ def load_trained_model() -> nn.Module:
61
+ """
62
+ Load EfficientNet-B2 4-class checkpoint from:
63
+ - 'best.pt' OR
64
+ - 'checkpoints/best.pt'
65
+
66
+ Supports both:
67
+ - Plain state_dict
68
+ - Training checkpoint with 'model_state_dict' or 'state_dict' keys
69
+ """
70
+ model = build_base_model().to(device)
71
+
72
+ search_paths = [
73
+ Path("best.pt"),
74
+ Path("checkpoints/best.pt"),
75
+ ]
76
+
77
+ ckpt_path = None
78
+ for p in search_paths:
79
+ if p.exists():
80
+ ckpt_path = p
81
+ break
82
+
83
+ if ckpt_path is None:
84
+ raise RuntimeError(
85
+ "❌ Could not find model checkpoint.\n"
86
+ "Expected 'best.pt' in the project root OR 'checkpoints/best.pt'.\n"
87
+ "Please upload your 4-class EfficientNet-B2 weights as 'best.pt' or 'checkpoints/best.pt'."
88
+ )
89
+
90
+ print(f"πŸ” Loading weights from: {ckpt_path}")
91
+
92
+ ckpt = torch.load(ckpt_path, map_location=device)
93
+
94
+ # Try to extract the actual state_dict
95
+ if isinstance(ckpt, dict):
96
+ if "model_state_dict" in ckpt:
97
+ state_dict = ckpt["model_state_dict"]
98
+ elif "state_dict" in ckpt:
99
+ state_dict = ckpt["state_dict"]
100
+ else:
101
+ # Assume it's already a plain state_dict
102
+ state_dict = ckpt
103
+ else:
104
+ # Definitely just a state_dict
105
+ state_dict = ckpt
106
+
107
+ # Now load strictly – if this fails, the checkpoint truly doesn't match the architecture
108
+ try:
109
+ missing, unexpected = model.load_state_dict(state_dict, strict=True)
110
+ if missing or unexpected:
111
+ # This branch rarely happens with strict=True, but keep for clarity
112
+ print(f"⚠️ Missing keys in state_dict: {missing}")
113
+ print(f"⚠️ Unexpected keys in state_dict: {unexpected}")
114
+ except RuntimeError as e:
115
+ # Most common cause: trying to load B2 checkpoint into B0/B1 or wrong architecture
116
+ raise RuntimeError(
117
+ f"❌ Failed to load weights from {ckpt_path}.\n"
118
+ "Most likely cause: the checkpoint was trained with a different architecture.\n"
119
+ "This app expects an EfficientNet-B2 checkpoint with 4 output classes.\n\n"
120
+ f"PyTorch error:\n{e}"
121
+ )
122
+
123
+ print("βœ… Model weights loaded successfully!")
124
+ model.eval()
125
+ return model
126
+
127
+
128
+ model = load_trained_model()
129
+
130
+ # ============================================================================
131
+ # Preprocessing
132
+ # ============================================================================
133
+
134
+ transform = transforms.Compose(
135
+ [
136
+ transforms.Resize(256),
137
+ transforms.CenterCrop(224),
138
+ transforms.ToTensor(),
139
+ transforms.Normalize(
140
+ [0.485, 0.456, 0.406], # ImageNet mean
141
+ [0.229, 0.224, 0.225], # ImageNet std
142
+ ),
143
+ ]
144
+ )
145
+
146
+ # ============================================================================
147
+ # Grad-CAM Implementation
148
+ # ============================================================================
149
+
150
+
151
+ class GradCAM:
152
+ def __init__(self, model: nn.Module, target_layer: nn.Module):
153
+ self.model = model
154
+ self.target_layer = target_layer
155
+ self.gradients = None
156
+ self.activations = None
157
+
158
+ def save_activation(module, input, output):
159
+ self.activations = output.detach()
160
+
161
+ def save_gradient(module, grad_input, grad_output):
162
+ # grad_output is a tuple; take the gradient wrt output activations
163
+ self.gradients = grad_output[0].detach()
164
+
165
+ target_layer.register_forward_hook(save_activation)
166
+ target_layer.register_full_backward_hook(save_gradient)
167
+
168
+ def generate(self, input_image: torch.Tensor, target_class=None):
169
+ """
170
+ Generate CAM for a single image batch (1, C, H, W).
171
+ """
172
+ output = self.model(input_image)
173
+
174
+ if target_class is None:
175
+ target_class = output.argmax(dim=1)
176
+
177
+ self.model.zero_grad()
178
+ one_hot = torch.zeros_like(output)
179
+ one_hot[0][target_class] = 1
180
+ output.backward(gradient=one_hot, retain_graph=True)
181
+
182
+ if self.gradients is None or self.activations is None:
183
+ return None, output
184
+
185
+ # Global average pooling over H, W
186
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
187
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
188
+ cam = torch.relu(cam)
189
+ cam = cam.squeeze().cpu().numpy()
190
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
191
+
192
+ return cam, output
193
+
194
+
195
+ # Target the last feature block for Grad-CAM
196
+ target_layer = model.features[-1]
197
+ grad_cam = GradCAM(model, target_layer)
198
+
199
+ # ============================================================================
200
+ # Visualization Helpers
201
+ # ============================================================================
202
+
203
+
204
+ def create_original_display(image, pred_label, confidence):
205
+ """Create annotated original image"""
206
+ fig, ax = plt.subplots(figsize=(8, 8))
207
+ ax.imshow(image)
208
+ ax.axis("off")
209
+
210
+ color = CLASS_COLORS[pred_label]
211
+ title = f"Prediction: {pred_label}\nConfidence: {confidence:.1f}%"
212
+ ax.set_title(title, fontsize=16, fontweight="bold", color=color, pad=20)
213
+
214
+ plt.tight_layout()
215
+ buf = io.BytesIO()
216
+ plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
217
+ plt.close(fig)
218
+ buf.seek(0)
219
+ return Image.open(buf)
220
+
221
+
222
+ def create_gradcam_visualization(image, cam, pred_label, confidence):
223
+ """Create Grad-CAM heatmap"""
224
+ img_array = np.array(image.resize((224, 224)))
225
+ cam_resized = cv2.resize(cam, (224, 224))
226
+
227
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
228
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
229
+
230
+ fig, ax = plt.subplots(figsize=(8, 8))
231
+ ax.imshow(heatmap)
232
+ ax.axis("off")
233
+ ax.set_title(
234
+ "Attention Heatmap\n(Areas the model focuses on)",
235
+ fontsize=14,
236
+ fontweight="bold",
237
+ pad=20,
238
+ )
239
+
240
+ plt.tight_layout()
241
+ buf = io.BytesIO()
242
+ plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
243
+ plt.close(fig)
244
+ buf.seek(0)
245
+ return Image.open(buf)
246
+
247
+
248
+ def create_overlay_visualization(image, cam):
249
+ """Create overlay of image and heatmap"""
250
+ img_array = np.array(image.resize((224, 224))) / 255.0
251
+ cam_resized = cv2.resize(cam, (224, 224))
252
+
253
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
254
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
255
+
256
+ overlay = img_array * 0.5 + heatmap * 0.5
257
+ overlay = np.clip(overlay, 0, 1)
258
+
259
+ fig, ax = plt.subplots(figsize=(8, 8))
260
+ ax.imshow(overlay)
261
+ ax.axis("off")
262
+ ax.set_title(
263
+ "Explainable AI Visualization\n(Original + Heatmap)",
264
+ fontsize=14,
265
+ fontweight="bold",
266
+ pad=20,
267
+ )
268
+
269
+ plt.tight_layout()
270
+ buf = io.BytesIO()
271
+ plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
272
+ plt.close(fig)
273
+ buf.seek(0)
274
+ return Image.open(buf)
275
+
276
+
277
+ def create_interpretation(pred_label, confidence, results):
278
+ """Create interpretation text with medical-style narrative + disclaimers"""
279
+
280
+ interpretation = f"""
281
+ ## πŸ”¬ Analysis Results
282
+
283
+ ### Prediction: **{pred_label}**
284
+ - Confidence: **{confidence:.1f}%**
285
+
286
+ ### Probability Breakdown:
287
+ - 🟒 Normal: **{results['Normal']:.1f}%**
288
+ - πŸ”΄ Tuberculosis: **{results['Tuberculosis']:.1f}%**
289
+ - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
290
+ - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
291
+
292
+ ---
293
+ """
294
+
295
+ # Disease-specific sections
296
+ if pred_label == "Tuberculosis":
297
+ if confidence >= 85:
298
+ interpretation += """
299
+ **⚠️ High Confidence TB Detection**
300
+
301
+ The model has detected features highly consistent with pulmonary tuberculosis.
302
+
303
+ **CRITICAL – Suggested next clinical steps (for clinicians):**
304
+ 1. **Immediate clinical review** of the patient (history + physical exam)
305
+ 2. **Confirmatory tests**:
306
+ - Sputum smear microscopy and/or GeneXpert MTB/RIF
307
+ - TB culture where available
308
+ 3. **Correlate with symptoms**:
309
+ - Cough > 2 weeks
310
+ - Fever, night sweats
311
+ - Weight loss, hemoptysis
312
+ 4. **Consider isolation** and contact tracing if active TB is suspected
313
+ 5. **Additional imaging** (e.g., CT chest) if diagnosis remains uncertain
314
+
315
+ > This tool is **screening-only** and cannot replace microbiological confirmation.
316
+ """
317
+ else:
318
+ interpretation += """
319
+ **⚠️ Possible Tuberculosis**
320
+
321
+ There are radiographic features that *may* be compatible with TB, but the model's confidence is moderate.
322
+
323
+ **Recommended actions (for clinicians):**
324
+ 1. Perform focused clinical assessment
325
+ 2. Consider sputum testing (smear / GeneXpert)
326
+ 3. Review prior imaging for evolution of disease
327
+ 4. Use this result as a **second reader**, not definitive evidence
328
+
329
+ > Moderate probability predictions always require clinical judgment.
330
+ """
331
+
332
+ elif pred_label == "Pneumonia":
333
+ if confidence >= 85:
334
+ interpretation += """
335
+ **⚠️ High Confidence Pneumonia Detection**
336
+
337
+ Findings are strongly suggestive of pneumonia (bacterial or viral).
338
+
339
+ **Suggested steps:**
340
+ 1. Clinical evaluation for pneumonia severity
341
+ 2. Laboratory assessment:
342
+ - CBC, CRP/ESR
343
+ - Blood cultures if severely unwell
344
+ 3. Consider empiric antibiotics (if bacterial suspected) per local guidelines
345
+ 4. Repeat imaging if no improvement or worsening
346
+
347
+ > Classic pneumonia patterns can overlap with other diseases – interpretation must remain clinical.
348
+ """
349
+ else:
350
+ interpretation += """
351
+ **⚠️ Possible Pneumonia**
352
+
353
+ The X-ray may show early or subtle changes of pneumonia.
354
+
355
+ **Suggested steps:**
356
+ 1. Correlate with respiratory symptoms (cough, fever, dyspnea)
357
+ 2. Consider repeat imaging in 24–72 hours if clinically indicated
358
+ 3. Use this AI opinion as supportive, not definitive
359
+ """
360
+
361
+ elif pred_label == "COVID-19":
362
+ if confidence >= 85:
363
+ interpretation += """
364
+ **⚠️ High Confidence COVID-19 Pattern**
365
+
366
+ Pattern is compatible with COVID-19 pneumonia.
367
+
368
+ **Suggested next steps:**
369
+ 1. **Confirmatory testing** with RT-PCR or validated antigen test
370
+ 2. **Infection control**:
371
+ - Isolation according to institutional policy
372
+ - Appropriate PPE for staff
373
+ 3. **Clinical monitoring**:
374
+ - Oxygen saturation (SpOβ‚‚)
375
+ - Respiratory rate, hemodynamics
376
+ 4. **Escalation** if:
377
+ - SpOβ‚‚ < 94%
378
+ - Increased work of breathing
379
+ - Hemodynamic instability
380
+
381
+ > Radiology alone cannot confirm COVID-19 – virological testing is mandatory.
382
+ """
383
+ else:
384
+ interpretation += """
385
+ **⚠️ Possible COVID-19**
386
+
387
+ Some features overlap with COVID-19, but the model is not highly confident.
388
+
389
+ **Suggested steps:**
390
+ 1. Test with RT-PCR or validated antigen assay
391
+ 2. Assess epidemiologic risk and exposure history
392
+ 3. Follow local protocols for isolation and monitoring
393
+ """
394
+
395
+ else: # Normal
396
+ if confidence >= 85:
397
+ interpretation += """
398
+ **βœ… High Confidence β€œNormal” Chest X-Ray (for the 4 modeled diseases)**
399
+
400
+ Within the limits of this model:
401
+ - No strong evidence of **TB**, **pneumonia**, or **COVID-19** is detected.
402
+ - Lung fields appear within normal limits on this projection.
403
+
404
+ **Important caveats:**
405
+ - A β€œnormal” AI result does **not** exclude all lung disease.
406
+ - Early or subtle TB/pneumonia/COVID-19 may still be radiographically occult.
407
+ - Other conditions (PE, asthma, COPD, malignancy, etc.) are **outside the scope** of this model.
408
+
409
+ Clinical review remains essential, especially if symptoms persist.
410
+ """
411
+ else:
412
+ interpretation += """
413
+ **⚠️ Likely Normal, but with Lower Confidence**
414
+
415
+ The model leans towards a normal study, but with limited confidence.
416
+
417
+ **Suggested steps:**
418
+ 1. If the patient is symptomatic, clinical evaluation is still required.
419
+ 2. Consider repeat imaging if symptoms evolve.
420
+ 3. Use this output as an adjunct, not reassurance in isolation.
421
+ """
422
+
423
+ # Global disclaimer and technical note
424
+ interpretation += """
425
+ ---
426
+ ## ⚠️ CRITICAL MEDICAL DISCLAIMER
427
+
428
+ ### What this model *can* do:
429
+ - βœ… Screen for 4 specific classes: **Normal**, **Tuberculosis**, **Pneumonia**, **COVID-19**
430
+ - βœ… Provide **explainable heatmaps** (Grad-CAM) to highlight regions of interest
431
+ - βœ… Offer **probabilistic support** to human readers
432
+ - βœ… Leverage **Adaptive Sparse Training (AST)** for ~89% energy savings vs dense baselines
433
+
434
+ ### What this model *cannot* do:
435
+ - ❌ It is **not** FDA/EMA-approved – research / educational use only
436
+ - ❌ It does **not** replace radiologists, pulmonologists, or infectious disease specialists
437
+ - ❌ It does **not** detect many other thoracic pathologies (e.g., cancer, fibrosis, PE)
438
+ - ❌ It does **not** provide a microbiological diagnosis
439
+
440
+ ### Clinical usage guidance:
441
+ 1. Use as a **second reader** or screening tool.
442
+ 2. Always **correlate with clinical history, examination, and lab tests**.
443
+ 3. Never start, stop, or change treatment **solely** based on this AI prediction.
444
+ 4. Follow your local and international guidelines for TB, pneumonia, and COVID-19 management.
445
+
446
+ ### Diagnostic gold standards:
447
+ - **TB**: Sputum AFB, culture, GeneXpert MTB/RIF, TB-PCR
448
+ - **Pneumonia**: Clinical + imaging + microbiology
449
+ - **COVID-19**: RT-PCR / validated antigen testing
450
+
451
+ > When in doubt, a qualified healthcare professional’s judgment takes absolute precedence.
452
+
453
+ ---
454
+ 🫁 **Powered by Adaptive Sparse Training (AST)**
455
+ Energy-efficient AI for accessible lung disease screening.
456
+
457
+ **Project links:**
458
+ - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
459
+ - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
460
+ """
461
+
462
+ return interpretation
463
+
464
+
465
+ # ============================================================================
466
+ # Prediction Function
467
+ # ============================================================================
468
+
469
+
470
+ def predict_chest_xray(image, show_gradcam=True):
471
+ """
472
+ Main prediction function used by Gradio.
473
+ Returns:
474
+ - dict of class probabilities
475
+ - Annotated original image
476
+ - Grad-CAM heatmap
477
+ - Overlay image
478
+ - Markdown interpretation
479
+ """
480
+ if image is None:
481
+ return None, None, None, None, "Please upload a chest X-ray image."
482
+
483
+ # Ensure PIL RGB
484
+ if isinstance(image, np.ndarray):
485
+ image = Image.fromarray(image).convert("RGB")
486
+ else:
487
+ image = image.convert("RGB")
488
+
489
+ original_img = image.copy()
490
+
491
+ input_tensor = transform(image).unsqueeze(0).to(device)
492
+
493
+ with torch.set_grad_enabled(show_gradcam):
494
+ if show_gradcam:
495
+ cam, output = grad_cam.generate(input_tensor)
496
+ else:
497
+ output = model(input_tensor)
498
+ cam = None
499
+
500
+ probs = torch.softmax(output, dim=1)[0].detach().cpu().numpy()
501
+ prob_sum = float(np.sum(probs))
502
+
503
+ if not (0.98 <= prob_sum <= 1.02):
504
+ print(f"⚠️ Probability sum = {prob_sum:.4f} (should be ~1.0). Check model/weights.")
505
+
506
+ pred_idx = int(output.argmax(dim=1).item())
507
+ pred_label = CLASSES[pred_idx]
508
+ confidence = float(probs[pred_idx]) * 100.0
509
+
510
+ results = {
511
+ CLASSES[i]: float(np.clip(probs[i] * 100.0, 0.0, 100.0))
512
+ for i in range(len(CLASSES))
513
+ }
514
+
515
+ original_pil = create_original_display(original_img, pred_label, confidence)
516
+
517
+ if cam is not None and show_gradcam:
518
+ gradcam_viz = create_gradcam_visualization(
519
+ original_img, cam, pred_label, confidence
520
+ )
521
+ overlay_viz = create_overlay_visualization(original_img, cam)
522
+ else:
523
+ gradcam_viz = None
524
+ overlay_viz = None
525
+
526
+ interpretation = create_interpretation(pred_label, confidence, results)
527
+
528
+ return results, original_pil, gradcam_viz, overlay_viz, interpretation
529
+
530
+
531
+ # ============================================================================
532
+ # Gradio Interface
533
+ # ============================================================================
534
+
535
+ custom_css = """
536
+ #main-container {
537
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
538
+ padding: 20px;
539
+ }
540
+ #title {
541
+ text-align: center;
542
+ color: white;
543
+ font-size: 2.5em;
544
+ font-weight: bold;
545
+ margin-bottom: 10px;
546
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
547
+ }
548
+ #subtitle {
549
+ text-align: center;
550
+ color: #f0f0f0;
551
+ font-size: 1.2em;
552
+ margin-bottom: 20px;
553
+ }
554
+ #stats {
555
+ text-align: center;
556
+ color: #fff;
557
+ font-size: 0.95em;
558
+ margin-bottom: 30px;
559
+ padding: 15px;
560
+ background: rgba(255,255,255,0.1);
561
+ border-radius: 10px;
562
+ backdrop-filter: blur(10px);
563
+ }
564
+ .gradio-container {
565
+ font-family: 'Inter', sans-serif;
566
+ }
567
+ #upload-box {
568
+ border: 3px dashed #667eea;
569
+ border-radius: 15px;
570
+ padding: 20px;
571
+ background: rgba(255,255,255,0.95);
572
+ }
573
+ #results-box {
574
+ background: white;
575
+ border-radius: 15px;
576
+ padding: 20px;
577
+ box-shadow: 0 4px 6px rgba(0,0,0,0.1);
578
+ }
579
+ .output-image {
580
+ border-radius: 10px;
581
+ box-shadow: 0 2px 4px rgba(0,0,0,0.1);
582
+ }
583
+ footer {
584
+ text-align: center;
585
+ margin-top: 30px;
586
+ color: white;
587
+ font-size: 0.9em;
588
+ }
589
+ """
590
+
591
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
592
+ gr.HTML(
593
+ """
594
+ <div id="main-container">
595
+ <div id="title">🫁 Multi-Class Chest X-Ray Detection AI</div>
596
+ <div id="subtitle">Advanced chest X-ray analysis with Explainable AI</div>
597
+ <div id="stats">
598
+ <b>95–97% Accuracy</b> across 4 disease classes |
599
+ <b>89% Energy Efficient</b> |
600
+ Powered by Adaptive Sparse Training (AST)
601
+ <br><br>
602
+ <b>Detects:</b> Normal β€’ Tuberculosis β€’ Pneumonia β€’ COVID-19
603
+ </div>
604
+ </div>
605
+ """
606
+ )
607
+
608
+ with gr.Row():
609
+ with gr.Column(scale=1, elem_id="upload-box"):
610
+ gr.Markdown("## πŸ“€ Upload Chest X-Ray")
611
+ image_input = gr.Image(
612
+ type="pil",
613
+ label="Upload X-Ray Image",
614
+ elem_classes="output-image",
615
+ )
616
+
617
+ show_gradcam = gr.Checkbox(
618
+ value=True,
619
+ label="Enable Grad-CAM Visualization (Explainable AI)",
620
+ info="Shows which areas the model focuses on",
621
+ )
622
+
623
+ analyze_btn = gr.Button("πŸ”¬ Analyze X-Ray", variant="primary", size="lg")
624
+
625
+ gr.Markdown(
626
+ """
627
+ ### πŸ“‹ Supported Images:
628
+ - Chest X-rays (PA or AP view)
629
+ - PNG, JPG, JPEG formats
630
+ - Grayscale or RGB
631
+
632
+ ### ⚑ Model Highlights:
633
+ - βœ… **Improved Specificity**: Better separation of TB vs Pneumonia
634
+ - βœ… **4 Disease Classes**: Normal, TB, Pneumonia, COVID-19
635
+ - βœ… **Energy-Aware**: ~89% energy savings with AST
636
+ - βœ… **Explainable**: Grad-CAM heatmaps for clinical teams
637
+ """
638
+ )
639
+
640
+ with gr.Column(scale=2, elem_id="results-box"):
641
+ gr.Markdown("## πŸ“Š Analysis Results")
642
+
643
+ prob_output = gr.Label(
644
+ label="Prediction Confidence", num_top_classes=4
645
+ )
646
+
647
+ with gr.Tabs():
648
+ with gr.Tab("Original"):
649
+ original_output = gr.Image(
650
+ label="Annotated X-Ray", elem_classes="output-image"
651
+ )
652
+
653
+ with gr.Tab("Grad-CAM Heatmap"):
654
+ gradcam_output = gr.Image(
655
+ label="Attention Heatmap", elem_classes="output-image"
656
+ )
657
+
658
+ with gr.Tab("Overlay"):
659
+ overlay_output = gr.Image(
660
+ label="Explainable AI Visualization",
661
+ elem_classes="output-image",
662
+ )
663
+
664
+ interpretation_output = gr.Markdown(label="Clinical Interpretation")
665
+
666
+ gr.Markdown("## πŸ“ Example X-Rays (Demo Only)")
667
+ gr.Examples(
668
+ examples=[
669
+ ["examples/normal.png"],
670
+ ["examples/tb.png"],
671
+ ["examples/pneumonia.png"],
672
+ ["examples/covid.png"],
673
+ ],
674
+ inputs=image_input,
675
+ label="Click an example to load",
676
+ )
677
+
678
+ analyze_btn.click(
679
+ fn=predict_chest_xray,
680
+ inputs=[image_input, show_gradcam],
681
+ outputs=[
682
+ prob_output,
683
+ original_output,
684
+ gradcam_output,
685
+ overlay_output,
686
+ interpretation_output,
687
+ ],
688
+ )
689
+
690
+ gr.HTML(
691
+ """
692
+ <footer>
693
+ <p>
694
+ <b>🫁 Multi-Class Chest X-Ray Detection with AST</b><br>
695
+ Trained on Normal, Tuberculosis, Pneumonia, and COVID-19 cases<br>
696
+ 95–97% Accuracy | 89% Energy Savings | Explainable AI<br><br>
697
+ <a href="https://github.com/oluwafemidiakhoa/Tuberculosis" target="_blank" style="color: white;">
698
+ πŸ“‚ GitHub Repository
699
+ </a> |
700
+ <a href="https://huggingface.co/spaces/mgbam/Tuberculosis" target="_blank" style="color: white;">
701
+ πŸ€— Hugging Face Space
702
+ </a>
703
+ </p>
704
+ <p style="font-size: 0.8em; margin-top: 15px;">
705
+ ⚠️ <b>MEDICAL DISCLAIMER</b>: This is a screening / research tool, not a diagnostic device.
706
+ All predictions require professional medical evaluation and laboratory confirmation.
707
+ Not FDA-approved for clinical use.
708
+ </p>
709
+ </footer>
710
+ """
711
+ )
712
+
713
+ # ============================================================================
714
+ # Launch
715
+ # ============================================================================
716
+
717
+ if __name__ == "__main__":
718
+ demo.launch(
719
+ share=False,
720
+ server_name="0.0.0.0",
721
+ server_port=7860,
722
+ show_error=True,
723
+ )