mgbam commited on
Commit
c79e22e
·
verified ·
1 Parent(s): f577d0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +591 -419
app.py CHANGED
@@ -1,76 +1,97 @@
1
  """
2
  🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
3
- Advanced Gradio Interface - 4 Disease Classes
4
- Features:
5
- - Real-time detection: Normal, TB, Pneumonia, COVID-19
6
- - Grad-CAM visualization (explainable AI)
7
- - Improved specificity - distinguishes TB from pneumonia
8
- - Confidence scores with visual indicators
9
- - Clinical interpretation and recommendations
10
- - Mobile-responsive design
11
  """
12
 
 
 
 
 
13
  import gradio as gr
 
 
14
  import torch
15
  import torch.nn as nn
16
- from torchvision import models, transforms
17
  from PIL import Image
18
- import numpy as np
19
- import cv2
20
- import matplotlib.pyplot as plt
21
- from pathlib import Path
22
- import io
23
 
24
  # ============================================================================
25
  # Model Setup
26
  # ============================================================================
27
 
28
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
29
 
30
- # Load model
31
  model = models.efficientnet_b0(weights=None)
32
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
33
 
34
- try:
35
- # Try loading best.pt from root directory (HuggingFace Spaces location)
36
- model.load_state_dict(torch.load('best.pt', map_location=device))
37
- print(" Multi-class model loaded successfully from best.pt!")
38
- except Exception as e:
39
- print(f"⚠️ Error loading model from best.pt: {e}")
40
- try:
41
- # Fallback to checkpoints directory
42
- model.load_state_dict(torch.load('checkpoints/best_multiclass.pt', map_location=device))
43
- print("✅ Multi-class model loaded successfully from checkpoints/best_multiclass.pt!")
44
- except Exception as e2:
45
- print(f"❌ CRITICAL ERROR: Could not load model from any location!")
46
- print(f" - best.pt error: {e}")
47
- print(f" - checkpoints/best_multiclass.pt error: {e2}")
48
- raise RuntimeError("Model file not found! Please ensure best.pt is uploaded to the Space.")
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
  model = model.to(device)
51
  model.eval()
52
 
53
- # Classes
54
- CLASSES = ['Normal', 'Tuberculosis', 'Pneumonia', 'COVID-19']
 
 
 
 
 
 
55
  CLASS_COLORS = {
56
- 'Normal': '#2ecc71', # Green
57
- 'Tuberculosis': '#e74c3c', # Red
58
- 'Pneumonia': '#f39c12', # Orange
59
- 'COVID-19': '#9b59b6' # Purple
60
  }
61
 
62
- # Image preprocessing
63
- transform = transforms.Compose([
64
- transforms.Resize(256),
65
- transforms.CenterCrop(224),
66
- transforms.ToTensor(),
67
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
68
- ])
 
 
 
 
69
 
70
  # ============================================================================
71
  # Grad-CAM Implementation
72
  # ============================================================================
73
 
 
74
  class GradCAM:
75
  def __init__(self, model, target_layer):
76
  self.model = model
@@ -95,7 +116,7 @@ class GradCAM:
95
 
96
  self.model.zero_grad()
97
  one_hot = torch.zeros_like(output)
98
- one_hot[0][target_class] = 1
99
  output.backward(gradient=one_hot, retain_graph=True)
100
 
101
  if self.gradients is None:
@@ -109,471 +130,650 @@ class GradCAM:
109
 
110
  return cam, output
111
 
112
- # Setup Grad-CAM
113
  target_layer = model.features[-1]
114
  grad_cam = GradCAM(model, target_layer)
115
 
116
  # ============================================================================
117
- # Prediction Functions
118
  # ============================================================================
119
 
120
- def predict_chest_xray(image, show_gradcam=True):
121
- """
122
- Predict disease class from chest X-ray with Grad-CAM visualization
123
- """
124
- if image is None:
125
- return None, None, None, None
126
-
127
- # Convert to PIL if needed
128
- if isinstance(image, np.ndarray):
129
- image = Image.fromarray(image).convert('RGB')
130
- else:
131
- image = image.convert('RGB')
132
-
133
- # Store original for display
134
- original_img = image.copy()
135
 
136
- # Preprocess
137
- input_tensor = transform(image).unsqueeze(0).to(device)
138
-
139
- # Get prediction with Grad-CAM
140
- with torch.set_grad_enabled(show_gradcam):
141
- if show_gradcam:
142
- cam, output = grad_cam.generate(input_tensor)
143
- else:
144
- output = model(input_tensor)
145
- cam = None
146
-
147
- # Get probabilities
148
- probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
149
-
150
- # Safety check: ensure probabilities sum to ~1.0
151
- prob_sum = np.sum(probs)
152
- if not (0.99 <= prob_sum <= 1.01):
153
- print(f"⚠️ WARNING: Probability sum is {prob_sum}, not 1.0. Model may not be loaded correctly!")
154
-
155
- pred_class = int(output.argmax(dim=1).item())
156
- pred_label = CLASSES[pred_class]
157
- confidence = float(probs[pred_class]) * 100
158
-
159
- # Create results - ensure values are between 0-100
160
- results = {
161
- CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100))) for i in range(len(CLASSES))
162
- }
163
-
164
- # Generate visualizations
165
- original_pil = create_original_display(original_img, pred_label, confidence)
166
-
167
- if cam is not None and show_gradcam:
168
- gradcam_viz = create_gradcam_visualization(original_img, cam, pred_label, confidence)
169
- overlay_viz = create_overlay_visualization(original_img, cam)
170
- else:
171
- gradcam_viz = None
172
- overlay_viz = None
173
-
174
- # Create interpretation text
175
- interpretation = create_interpretation(pred_label, confidence, results)
176
 
177
- return results, original_pil, gradcam_viz, overlay_viz, interpretation
178
 
179
  def create_original_display(image, pred_label, confidence):
180
- """Create annotated original image"""
181
- fig, ax = plt.subplots(figsize=(8, 8))
182
  ax.imshow(image)
183
- ax.axis('off')
184
 
185
- # Add prediction box
186
  color = CLASS_COLORS[pred_label]
187
- title = f'Prediction: {pred_label}\nConfidence: {confidence:.1f}%'
188
- ax.set_title(title, fontsize=16, fontweight='bold', color=color, pad=20)
189
-
 
 
 
 
 
190
  plt.tight_layout()
 
191
 
192
- # Convert to PIL
193
- buf = io.BytesIO()
194
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
195
- plt.close()
196
- buf.seek(0)
197
-
198
- return Image.open(buf)
199
 
200
- def create_gradcam_visualization(image, cam, pred_label, confidence):
201
- """Create Grad-CAM heatmap"""
202
- # Resize CAM to image size
203
  img_array = np.array(image.resize((224, 224)))
204
  cam_resized = cv2.resize(cam, (224, 224))
205
 
206
- # Create heatmap
207
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
208
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
209
 
210
- fig, ax = plt.subplots(figsize=(8, 8))
211
  ax.imshow(heatmap)
212
- ax.axis('off')
213
- ax.set_title('Attention Heatmap\n(Areas the model focuses on)',
214
- fontsize=14, fontweight='bold', pad=20)
215
-
 
 
 
216
  plt.tight_layout()
 
217
 
218
- buf = io.BytesIO()
219
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
220
- plt.close()
221
- buf.seek(0)
222
-
223
- return Image.open(buf)
224
 
225
  def create_overlay_visualization(image, cam):
226
- """Create overlay of image and heatmap"""
227
  img_array = np.array(image.resize((224, 224))) / 255.0
228
  cam_resized = cv2.resize(cam, (224, 224))
229
 
230
- # Create heatmap
231
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
232
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
233
 
234
- # Overlay
235
  overlay = img_array * 0.5 + heatmap * 0.5
236
  overlay = np.clip(overlay, 0, 1)
237
 
238
- fig, ax = plt.subplots(figsize=(8, 8))
239
  ax.imshow(overlay)
240
- ax.axis('off')
241
- ax.set_title('Explainable AI Visualization\n(Original + Heatmap)',
242
- fontsize=14, fontweight='bold', pad=20)
243
-
 
 
 
244
  plt.tight_layout()
 
245
 
246
- buf = io.BytesIO()
247
- plt.savefig(buf, format='png', dpi=150, bbox_inches='tight', facecolor='white')
248
- plt.close()
249
- buf.seek(0)
250
 
251
- return Image.open(buf)
 
 
 
252
 
253
- def create_interpretation(pred_label, confidence, results):
254
- """Create interpretation text with improved medical disclaimers"""
 
 
 
 
255
 
256
  interpretation = f"""
257
- ## 🔬 Analysis Results
258
- ### Prediction: **{pred_label}**
 
 
 
259
  - Confidence: **{confidence:.1f}%**
260
- ### Probability Breakdown:
 
261
  - 🟢 Normal: **{results['Normal']:.1f}%**
262
  - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
263
  - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
264
  - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
 
265
  ---
266
  """
267
 
268
- # Disease-specific interpretations
269
- if pred_label == 'Tuberculosis':
270
  if confidence >= 85:
271
  interpretation += """
272
- **⚠️ High Confidence TB Detection**
273
- The model has detected features highly consistent with tuberculosis infection.
274
- **CRITICAL - Immediate Actions Required:**
275
- 1. ✅ **Immediate consultation** with a healthcare provider
276
- 2. ✅ **Confirmatory sputum test** (AFB smear or GeneXpert MTB/RIF)
277
- 3. ✅ **Clinical correlation** with symptoms:
278
- - Persistent cough (>2 weeks)
279
- - Fever, especially night sweats
280
- - Unexplained weight loss
281
- - Hemoptysis (coughing blood)
282
- 4. **Isolation** and contact tracing if confirmed
283
- 5. **Chest CT scan** if needed for further evaluation
284
- **⚠️ IMPORTANT**: This is a SCREENING tool, not a diagnostic tool.
285
- Clinical diagnosis of TB requires laboratory confirmation (sputum test).
 
 
286
  """
287
  else:
288
  interpretation += """
289
- **⚠️ Possible TB Detection**
290
- The model has detected features suggestive of tuberculosis, but confidence is moderate.
291
- **Recommended Actions:**
292
- 1. Consult healthcare provider for clinical evaluation
293
- 2. Consider confirmatory sputum testing
294
- 3. Evaluate clinical symptoms
295
- 4. Follow-up imaging may be recommended
296
- **Note**: Moderate confidence requires professional medical evaluation.
297
  """
298
 
299
- elif pred_label == 'Pneumonia':
300
  if confidence >= 85:
301
  interpretation += """
302
- **⚠️ High Confidence Pneumonia Detection**
303
- The model has detected features consistent with pneumonia (bacterial or viral).
304
- **Recommended Actions:**
305
- 1. ✅ **Medical evaluation** for pneumonia diagnosis
306
- 2. ✅ **Possible confirmatory tests**:
307
- - Sputum culture
308
- - Blood tests (WBC count, CRP)
309
- - Additional chest imaging if needed
310
- 3. ✅ **Clinical correlation** with symptoms:
311
- - Cough with sputum production
312
- - Fever and chills
313
- - Shortness of breath
314
- - Chest pain with breathing
315
- 4. ✅ **Treatment**: Antibiotics (bacterial) or supportive care (viral)
316
- **Note**: Pneumonia can present similarly to other lung diseases.
317
- Professional diagnosis is essential for appropriate treatment.
318
  """
319
  else:
320
  interpretation += """
321
- **⚠️ Possible Pneumonia**
322
- Features suggest possible pneumonia, but further evaluation is needed.
323
- **Recommended Actions:**
324
- 1. Seek medical evaluation
325
- 2. Clinical symptom assessment
326
- 3. Consider additional diagnostic tests
327
- **Note**: Requires professional medical evaluation for confirmation.
 
328
  """
329
 
330
- elif pred_label == 'COVID-19':
331
  if confidence >= 85:
332
  interpretation += """
333
- **⚠️ High Confidence COVID-19 Detection**
334
- The model has detected features consistent with COVID-19 pneumonia.
335
- **URGENT - Immediate Actions:**
336
- 1. ✅ **COVID-19 RT-PCR test** for confirmation
337
- 2. ✅ **Isolation** to prevent transmission
338
- 3. **Monitor oxygen saturation** (SpO2 levels)
339
- 4. **Seek immediate medical care** if:
340
- - Difficulty breathing
341
- - SpO2 < 94%
342
- - Persistent chest pain
343
- - Confusion or inability to stay awake
344
- 5. **Contact tracing** if positive
345
- **Clinical Symptoms to Monitor:**
346
- - Fever, cough, shortness of breath
347
- - Loss of taste/smell
348
- - Fatigue, body aches
349
- - Gastrointestinal symptoms
350
- **⚠️ IMPORTANT**: Imaging findings alone cannot confirm COVID-19.
351
- RT-PCR or antigen testing is required for diagnosis.
352
  """
353
  else:
354
  interpretation += """
355
- **⚠️ Possible COVID-19**
356
- Features suggest possible COVID-19, but confirmation testing is essential.
357
- **Recommended Actions:**
358
- 1. Get RT-PCR or rapid antigen test
359
- 2. Self-isolate pending test results
360
- 3. Monitor symptoms
361
- 4. Seek medical care if symptoms worsen
362
- **Note**: COVID-19 diagnosis requires laboratory confirmation.
363
  """
364
 
365
  else: # Normal
366
  if confidence >= 85:
367
  interpretation += """
368
- **✅ High Confidence Normal Result**
369
- The model has not detected significant abnormalities consistent with TB, pneumonia, or COVID-19.
370
- **Interpretation:**
371
- - Chest X-ray appears within normal limits
372
- - No features of active tuberculosis detected
373
- - No signs of pneumonia or COVID-19
374
- **Important Notes:**
375
- - This does NOT rule out all lung diseases
376
- - Early-stage diseases may not show on X-ray
377
- - If you have symptoms, seek medical evaluation
378
- - Regular health screenings are recommended
379
- **When to still see a doctor:**
380
- - Persistent cough, fever, or respiratory symptoms
381
- - Unexplained weight loss or night sweats
382
- - Shortness of breath or chest pain
383
- - Known exposure to TB or COVID-19
384
  """
385
  else:
386
  interpretation += """
387
- **⚠️ Likely Normal, Low Confidence**
388
- The model suggests a normal chest X-ray, but confidence is not high.
389
- **Recommended Actions:**
390
- 1. If symptomatic, seek medical evaluation
391
- 2. Consider repeat imaging if concerns persist
392
- 3. Clinical correlation is important
393
- **Note**: Low confidence results should be reviewed by healthcare professionals.
394
  """
395
 
396
- # Add universal disclaimer
397
  interpretation += """
398
  ---
399
  ## ⚠️ CRITICAL MEDICAL DISCLAIMER
400
- ### Model Capabilities:
401
- - Trained on 4 disease classes: Normal, TB, Pneumonia, COVID-19
402
- - Can distinguish between different lung diseases
403
- - ~95-97% accuracy in validation testing
404
- - Powered by Adaptive Sparse Training (89% energy efficient)
405
- ### Important Limitations:
406
- - ⚠️ This is a **SCREENING tool**, not a diagnostic device
407
- - ⚠️ **NOT FDA-approved** for clinical diagnosis
408
- - ⚠️ Cannot detect: lung cancer, pulmonary fibrosis, bronchiectasis, other rare diseases
409
- - ⚠️ Cannot replace: professional radiologist review
410
- - ⚠️ Cannot confirm: laboratory diagnosis (sputum tests, PCR, cultures)
411
- ### Clinical Use Guidelines:
412
- 1. ✅ Use as a **preliminary screening** tool only
413
- 2. ALL positive results require **confirmatory laboratory testing**
414
- 3. ✅ ALL cases require **clinical correlation** with symptoms and history
415
- 4. ✅ Expert radiologist review is recommended for clinical decisions
416
- 5. ✅ Do NOT initiate treatment based solely on AI predictions
417
- ### Diagnostic Gold Standards:
418
- - **TB**: Sputum AFB smear/culture, GeneXpert MTB/RIF, TB-PCR
419
- - **Pneumonia**: Clinical diagnosis + sputum culture + blood tests
420
- - **COVID-19**: RT-PCR, rapid antigen test
421
- **When in doubt, always consult a qualified healthcare professional.**
422
- ---
423
- 🫁 **Powered by Adaptive Sparse Training**
424
- Energy-efficient AI for accessible healthcare
425
- **Learn more:**
426
- - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
427
- - Research: Sample-based Adaptive Sparse Training for deep learning
428
  """
429
 
 
 
 
 
 
 
 
 
 
430
  return interpretation
431
 
 
432
  # ============================================================================
433
- # Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
  # ============================================================================
435
 
436
- # Custom CSS
437
  custom_css = """
438
- #main-container {
439
- background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
440
- padding: 20px;
 
441
  }
442
- #title {
443
- text-align: center;
444
- color: white;
445
- font-size: 2.5em;
446
- font-weight: bold;
447
- margin-bottom: 10px;
448
- text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
449
  }
450
- #subtitle {
451
- text-align: center;
452
- color: #f0f0f0;
453
- font-size: 1.2em;
454
- margin-bottom: 20px;
 
 
 
455
  }
456
- #stats {
457
- text-align: center;
458
- color: #fff;
459
- font-size: 0.95em;
460
- margin-bottom: 30px;
461
- padding: 15px;
462
- background: rgba(255,255,255,0.1);
463
- border-radius: 10px;
464
- backdrop-filter: blur(10px);
465
  }
466
- .gradio-container {
467
- font-family: 'Inter', sans-serif;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
468
  }
469
- #upload-box {
470
- border: 3px dashed #667eea;
471
- border-radius: 15px;
472
- padding: 20px;
473
- background: rgba(255,255,255,0.95);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
474
  }
475
- #results-box {
476
- background: white;
477
- border-radius: 15px;
478
- padding: 20px;
479
- box-shadow: 0 4px 6px rgba(0,0,0,0.1);
480
  }
481
- .output-image {
482
- border-radius: 10px;
483
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
 
 
 
 
 
484
  }
 
 
 
 
 
485
  footer {
486
  text-align: center;
487
- margin-top: 30px;
488
- color: white;
489
- font-size: 0.9em;
490
  }
491
  """
492
 
493
- # Create interface
494
- with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
495
- gr.HTML("""
496
- <div id="main-container">
497
- <div id="title">🫁 Multi-Class Chest X-Ray Detection AI</div>
498
- <div id="subtitle">Advanced chest X-ray analysis with Explainable AI</div>
499
- <div id="stats">
500
- <b>95-97% Accuracy</b> across 4 disease classes |
501
- <b>89% Energy Efficient</b> |
502
- Powered by Adaptive Sparse Training
503
- <br><br>
504
- <b>Detects:</b> Normal • Tuberculosis • Pneumonia • COVID-19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
505
  </div>
506
  </div>
507
- """)
 
 
 
 
 
 
 
 
 
 
508
 
509
- with gr.Row():
510
- with gr.Column(scale=1, elem_id="upload-box"):
511
- gr.Markdown("## 📤 Upload Chest X-Ray")
512
  image_input = gr.Image(
513
  type="pil",
514
- label="Upload X-Ray Image",
515
- elem_classes="output-image"
516
  )
517
 
518
- show_gradcam = gr.Checkbox(
519
- value=True,
520
- label="Enable Grad-CAM Visualization (Explainable AI)",
521
- info="Shows which areas the model focuses on"
522
- )
 
 
 
 
 
 
523
 
524
- analyze_btn = gr.Button(
525
- "🔬 Analyze X-Ray",
526
- variant="primary",
527
- size="lg"
 
 
 
 
 
 
 
 
528
  )
529
 
530
- gr.Markdown("""
531
- ### 📋 Supported Images:
532
- - Chest X-rays (PA or AP view)
533
- - PNG, JPG, JPEG formats
534
- - Grayscale or RGB
535
- ### ⚡ What's New:
536
- - ✅ **Improved Specificity**: Can distinguish TB from Pneumonia
537
- - ✅ **4 Disease Classes**: Normal, TB, Pneumonia, COVID-19
538
- - ✅ **Fewer False Positives**: <5% on pneumonia cases
539
- - ✅ **Same Energy Efficiency**: 89% savings with AST
540
- """)
541
-
542
- with gr.Column(scale=2, elem_id="results-box"):
543
- gr.Markdown("## 📊 Analysis Results")
544
-
545
- # Results display
546
- with gr.Row():
547
- prob_output = gr.Label(
548
- label="Prediction Confidence",
549
- num_top_classes=4
550
- )
551
 
552
  with gr.Tabs():
553
- with gr.Tab("Original"):
554
- original_output = gr.Image(
555
- label="Annotated X-Ray",
556
- elem_classes="output-image"
557
  )
558
-
559
- with gr.Tab("Grad-CAM Heatmap"):
560
- gradcam_output = gr.Image(
561
- label="Attention Heatmap",
562
- elem_classes="output-image"
563
  )
564
 
565
- with gr.Tab("Overlay"):
 
 
 
 
 
 
 
 
 
566
  overlay_output = gr.Image(
567
- label="Explainable AI Visualization",
568
- elem_classes="output-image"
569
  )
570
 
571
- interpretation_output = gr.Markdown(
572
- label="Clinical Interpretation"
573
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
- # Example images
576
- gr.Markdown("## 📁 Example X-Rays")
577
  gr.Examples(
578
  examples=[
579
  ["examples/normal.png"],
@@ -582,44 +782,16 @@ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
582
  ["examples/covid.png"],
583
  ],
584
  inputs=image_input,
585
- label="Click to load example"
586
  )
587
 
588
- # Connect components
589
- analyze_btn.click(
590
- fn=predict_chest_xray,
591
- inputs=[image_input, show_gradcam],
592
- outputs=[prob_output, original_output, gradcam_output, overlay_output, interpretation_output]
593
- )
594
-
595
- # Footer
596
- gr.HTML("""
597
- <footer>
598
- <p>
599
- <b>🫁 Multi-Class Chest X-Ray Detection with AST</b><br>
600
- Trained on Normal, Tuberculosis, Pneumonia, and COVID-19 cases<br>
601
- 95-97% Accuracy | 89% Energy Savings | Explainable AI<br><br>
602
- <a href="https://github.com/oluwafemidiakhoa/Tuberculosis" target="_blank" style="color: white;">
603
- 📂 GitHub Repository
604
- </a> |
605
- <a href="https://huggingface.co/spaces/mgbam/Tuberculosis" target="_blank" style="color: white;">
606
- 🤗 Hugging Face Space
607
- </a>
608
- </p>
609
- <p style="font-size: 0.8em; margin-top: 15px;">
610
- ⚠️ <b>MEDICAL DISCLAIMER</b>: This is a screening tool, not a diagnostic device.
611
- All predictions require professional medical evaluation and laboratory confirmation.
612
- Not FDA-approved for clinical use.
613
- </p>
614
- </footer>
615
- """)
616
-
617
  # Launch
 
 
618
  if __name__ == "__main__":
619
  demo.launch(
620
  share=False,
621
  server_name="0.0.0.0",
622
  server_port=7860,
623
- show_error=True
624
  )
625
-
 
1
  """
2
  🫁 Multi-Class Chest X-Ray Detection with Adaptive Sparse Training
3
+ WOW UI/UX Edition 4 Disease Classes
4
+
5
+ - Normal, Tuberculosis, Pneumonia, COVID-19
6
+ - Grad-CAM (Explainable AI)
7
+ - Energy-efficient Adaptive Sparse Training
 
 
 
8
  """
9
 
10
+ import io
11
+ from pathlib import Path
12
+
13
+ import cv2
14
  import gradio as gr
15
+ import matplotlib.pyplot as plt
16
+ import numpy as np
17
  import torch
18
  import torch.nn as nn
 
19
  from PIL import Image
20
+ from torchvision import models, transforms
 
 
 
 
21
 
22
  # ============================================================================
23
  # Model Setup
24
  # ============================================================================
25
 
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
 
28
+ # EfficientNet backbone
29
  model = models.efficientnet_b0(weights=None)
30
  model.classifier[1] = nn.Linear(model.classifier[1].in_features, 4) # 4 classes
31
 
32
+ # Try a few reasonable checkpoint locations
33
+ checkpoint_candidates = [
34
+ "best.pt",
35
+ "checkpoints/best.pt", # <-- your current file
36
+ "checkpoints/lasttb.pt", # optional fallback
37
+ ]
38
+
39
+ MODEL_LOAD_INFO = ""
40
+ loaded = False
41
+
42
+ for ckpt_path in checkpoint_candidates:
43
+ if Path(ckpt_path).is_file():
44
+ try:
45
+ print(f"🔍 Trying to load weights from: {ckpt_path}")
46
+ state_dict = torch.load(ckpt_path, map_location=device)
47
+ model.load_state_dict(state_dict)
48
+ MODEL_LOAD_INFO = f"✅ Model loaded from **{ckpt_path}** on **{device.type.upper()}**."
49
+ loaded = True
50
+ break
51
+ except Exception as e:
52
+ print(f"⚠️ Found {ckpt_path} but failed to load: {e}")
53
+
54
+ if not loaded:
55
+ raise RuntimeError(
56
+ "Model file not found or could not be loaded. "
57
+ "Please upload 'checkpoints/best.pt' (or 'best.pt' in the repo root)."
58
+ )
59
 
60
  model = model.to(device)
61
  model.eval()
62
 
63
+ TOTAL_PARAMS = sum(p.numel() for p in model.parameters())
64
+ TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6
65
+
66
+ # ============================================================================
67
+ # Classes & Preprocessing
68
+ # ============================================================================
69
+
70
+ CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
71
  CLASS_COLORS = {
72
+ "Normal": "#2ecc71", # Green
73
+ "Tuberculosis": "#e74c3c", # Red
74
+ "Pneumonia": "#f39c12", # Orange
75
+ "COVID-19": "#9b59b6", # Purple
76
  }
77
 
78
+ transform = transforms.Compose(
79
+ [
80
+ transforms.Resize(256),
81
+ transforms.CenterCrop(224),
82
+ transforms.ToTensor(),
83
+ transforms.Normalize(
84
+ [0.485, 0.456, 0.406],
85
+ [0.229, 0.224, 0.225],
86
+ ),
87
+ ]
88
+ )
89
 
90
  # ============================================================================
91
  # Grad-CAM Implementation
92
  # ============================================================================
93
 
94
+
95
  class GradCAM:
96
  def __init__(self, model, target_layer):
97
  self.model = model
 
116
 
117
  self.model.zero_grad()
118
  one_hot = torch.zeros_like(output)
119
+ one_hot[0, target_class] = 1
120
  output.backward(gradient=one_hot, retain_graph=True)
121
 
122
  if self.gradients is None:
 
130
 
131
  return cam, output
132
 
133
+
134
  target_layer = model.features[-1]
135
  grad_cam = GradCAM(model, target_layer)
136
 
137
  # ============================================================================
138
+ # Visualization Helpers
139
  # ============================================================================
140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
+ def _figure_to_pil():
143
+ buf = io.BytesIO()
144
+ plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
145
+ plt.close()
146
+ buf.seek(0)
147
+ return Image.open(buf)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
149
 
150
  def create_original_display(image, pred_label, confidence):
151
+ fig, ax = plt.subplots(figsize=(7, 7))
 
152
  ax.imshow(image)
153
+ ax.axis("off")
154
 
 
155
  color = CLASS_COLORS[pred_label]
156
+ title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%"
157
+ ax.set_title(
158
+ title,
159
+ fontsize=16,
160
+ fontweight="bold",
161
+ color=color,
162
+ pad=20,
163
+ )
164
  plt.tight_layout()
165
+ return _figure_to_pil()
166
 
 
 
 
 
 
 
 
167
 
168
+ def create_gradcam_visualization(image, cam):
 
 
169
  img_array = np.array(image.resize((224, 224)))
170
  cam_resized = cv2.resize(cam, (224, 224))
171
 
 
172
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
173
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
174
 
175
+ fig, ax = plt.subplots(figsize=(7, 7))
176
  ax.imshow(heatmap)
177
+ ax.axis("off")
178
+ ax.set_title(
179
+ "Attention Heatmap\n(Where the model is looking)",
180
+ fontsize=14,
181
+ fontweight="bold",
182
+ pad=20,
183
+ )
184
  plt.tight_layout()
185
+ return _figure_to_pil()
186
 
 
 
 
 
 
 
187
 
188
  def create_overlay_visualization(image, cam):
 
189
  img_array = np.array(image.resize((224, 224))) / 255.0
190
  cam_resized = cv2.resize(cam, (224, 224))
191
 
 
192
  heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
193
  heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
194
 
 
195
  overlay = img_array * 0.5 + heatmap * 0.5
196
  overlay = np.clip(overlay, 0, 1)
197
 
198
+ fig, ax = plt.subplots(figsize=(7, 7))
199
  ax.imshow(overlay)
200
+ ax.axis("off")
201
+ ax.set_title(
202
+ "Explainable AI Overlay\n(Anatomy + Attention)",
203
+ fontsize=14,
204
+ fontweight="bold",
205
+ pad=20,
206
+ )
207
  plt.tight_layout()
208
+ return _figure_to_pil()
209
 
 
 
 
 
210
 
211
+ # ============================================================================
212
+ # Interpretation
213
+ # ============================================================================
214
+
215
 
216
+ def create_interpretation(pred_label, confidence, results, audience="Clinician"):
217
+ header_note = {
218
+ "Clinician": "This view is tuned for **clinical decision support** (not a replacement for your judgement).",
219
+ "Researcher": "This view is tuned for **model behavior understanding** and experimental workflows.",
220
+ "Patient / Public": "This view is tuned for **patient-friendly language**. Always discuss results with a doctor.",
221
+ }.get(audience, "Use this output as a **screening aid**, not a final diagnosis.")
222
 
223
  interpretation = f"""
224
+ ## 🔬 Analysis Results ({audience} View)
225
+
226
+ > {header_note}
227
+
228
+ ### Primary Prediction: **{pred_label}**
229
  - Confidence: **{confidence:.1f}%**
230
+
231
+ ### Probability Breakdown
232
  - 🟢 Normal: **{results['Normal']:.1f}%**
233
  - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
234
  - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
235
  - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
236
+
237
  ---
238
  """
239
 
240
+ # Disease-specific sections (same logic, slightly formatted)
241
+ if pred_label == "Tuberculosis":
242
  if confidence >= 85:
243
  interpretation += """
244
+ ### 🧫 TB Pattern – High Confidence
245
+
246
+ The model has detected features strongly suggestive of **pulmonary tuberculosis**.
247
+
248
+ **Recommended Clinical Pathway:**
249
+ 1. ✅ Immediate medical review by a clinician or chest physician
250
+ 2. **Sputum testing** (AFB smear, GeneXpert MTB/RIF, or TB-PCR)
251
+ 3. Correlate with symptoms:
252
+ - Persistent cough > 2 weeks
253
+ - Weight loss, night sweats
254
+ - Fever, fatigue
255
+ - Hemoptysis (coughing blood)
256
+ 4. Consider CT scan or additional imaging if uncertainty remains
257
+ 5. Infection control and contact tracing if TB is confirmed
258
+
259
+ > This tool helps *flag* suspicious cases. TB diagnosis still requires **laboratory confirmation**.
260
  """
261
  else:
262
  interpretation += """
263
+ ### 🧫 TB Pattern – Possible
264
+
265
+ The scan shows features that **could** be compatible with tuberculosis, but confidence is moderate.
266
+
267
+ **Suggested Actions:**
268
+ - Clinical review and detailed history
269
+ - Consider sputum testing if symptoms or risk factors are present
270
+ - Follow-up imaging where clinically indicated
271
  """
272
 
273
+ elif pred_label == "Pneumonia":
274
  if confidence >= 85:
275
  interpretation += """
276
+ ### 🌫 Pneumonia Pattern – High Confidence
277
+
278
+ The model has detected an opacity pattern consistent with **pneumonia**.
279
+
280
+ **Typical Clinical Correlates:**
281
+ - Fever, productive cough
282
+ - Shortness of breath
283
+ - Pleuritic chest pain
284
+
285
+ **Next Steps (for clinicians):**
286
+ - Correlate with fever, auscultation, and lab results
287
+ - Consider antibiotics for bacterial pneumonia as per local guidelines
288
+ - Repeat imaging if clinical evolution is atypical
 
 
 
289
  """
290
  else:
291
  interpretation += """
292
+ ### 🌫 Pneumonia Pattern – Possible
293
+
294
+ Findings may be compatible with pneumonia, but alternative explanations exist.
295
+
296
+ **Recommended:**
297
+ - Clinical evaluation (vital signs, exam)
298
+ - Consider labs (WBC, CRP, cultures)
299
+ - Watchful follow-up or repeat imaging as appropriate
300
  """
301
 
302
+ elif pred_label == "COVID-19":
303
  if confidence >= 85:
304
  interpretation += """
305
+ ### 🦠 COVID-19 Pattern – High Confidence
306
+
307
+ Distribution and appearance of opacities are compatible with **COVID-19 pneumonia**.
308
+
309
+ **Critical Points:**
310
+ - Imaging is **not** diagnostic by itself
311
+ - **RT-PCR / rapid antigen testing** is mandatory for confirmation
312
+
313
+ **If clinically suspected:**
314
+ - Isolate per local infection-control policies
315
+ - Monitor SpO₂ and respiratory status
316
+ - Escalate care if:
317
+ - SpO₂ < 94% on room air
318
+ - Increasing work of breathing
319
+ - Hemodynamic instability
 
 
 
 
320
  """
321
  else:
322
  interpretation += """
323
+ ### 🦠 COVID-19 Pattern – Possible
324
+
325
+ Some features may overlap with COVID-19, but there is **significant uncertainty**.
326
+
327
+ **Do not rely on imaging alone.**
328
+ - Obtain RT-PCR / rapid antigen testing
329
+ - Use clinical context and epidemiology to guide decisions
 
330
  """
331
 
332
  else: # Normal
333
  if confidence >= 85:
334
  interpretation += """
335
+ ### No Major Abnormality Detected
336
+
337
+ The model did **not** detect features suggestive of TB, pneumonia, or COVID-19.
338
+
339
+ **Important Caveats:**
340
+ - Early disease or small lesions may be missed
341
+ - Non-infective conditions (e.g., cancer, ILD) are **not** specifically evaluated
342
+ - If symptoms are present, further workup may still be required
 
 
 
 
 
 
 
 
343
  """
344
  else:
345
  interpretation += """
346
+ ### ℹ️ Likely Normal, But Low Confidence
347
+
348
+ The scan leans towards **normal**, but the model is not highly confident.
349
+
350
+ **If symptoms persist:**
351
+ - Consider follow-up imaging
352
+ - Seek a clinician’s interpretation
353
  """
354
 
355
+ # Universal disclaimer
356
  interpretation += """
357
  ---
358
  ## ⚠️ CRITICAL MEDICAL DISCLAIMER
359
+
360
+ - This AI model is a **screening / decision-support tool only**
361
+ - It is **not FDA-approved** and **must not** be used as a stand-alone diagnostic device
362
+ - Always integrate:
363
+ - Clinical history and examination
364
+ - Laboratory tests (e.g., sputum, PCR, cultures)
365
+ - Expert radiologist review
366
+
367
+ **Gold Standards:**
368
+ - TB: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR
369
+ - Pneumonia: Clinical diagnosis + labs / microbiology
370
+ - COVID-19: RT-PCR or validated antigen tests
371
+
372
+ When in doubt, consult a qualified healthcare professional.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  """
374
 
375
+ interpretation += """
376
+ ---
377
+ 🫁 **Powered by Adaptive Sparse Training (AST)**
378
+ Energy-efficient deep learning – designed to make advanced chest X-ray screening more accessible.
379
+
380
+ **Links:**
381
+ - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
382
+ - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
383
+ """
384
  return interpretation
385
 
386
+
387
  # ============================================================================
388
+ # Prediction Pipeline
389
+ # ============================================================================
390
+
391
+
392
+ def predict_chest_xray(image, show_gradcam=True, audience="Clinician"):
393
+ """
394
+ Main inference function used by Gradio.
395
+ Returns:
396
+ - dict of class probabilities
397
+ - annotated original
398
+ - grad-cam heatmap
399
+ - overlay
400
+ - full markdown report
401
+ - short textual snapshot
402
+ """
403
+ if image is None:
404
+ msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report."
405
+ return {}, None, None, None, msg, "Awaiting image upload…"
406
+
407
+ if isinstance(image, np.ndarray):
408
+ image = Image.fromarray(image).convert("RGB")
409
+ else:
410
+ image = image.convert("RGB")
411
+
412
+ original_img = image.copy()
413
+ input_tensor = transform(image).unsqueeze(0).to(device)
414
+
415
+ # Inference with optional Grad-CAM
416
+ with torch.set_grad_enabled(show_gradcam):
417
+ if show_gradcam:
418
+ cam, output = grad_cam.generate(input_tensor)
419
+ else:
420
+ output = model(input_tensor)
421
+ cam = None
422
+
423
+ probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
424
+ prob_sum = float(np.sum(probs))
425
+
426
+ if not (0.99 <= prob_sum <= 1.01):
427
+ print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.")
428
+
429
+ pred_class = int(output.argmax(dim=1).item())
430
+ pred_label = CLASSES[pred_class]
431
+ confidence = float(probs[pred_class]) * 100.0
432
+
433
+ results = {
434
+ CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
435
+ for i in range(len(CLASSES))
436
+ }
437
+
438
+ # Visuals
439
+ original_pil = create_original_display(original_img, pred_label, confidence)
440
+ gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None
441
+ overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None
442
+
443
+ interpretation = create_interpretation(pred_label, confidence, results, audience=audience)
444
+
445
+ snapshot = f"**{pred_label}** · {confidence:.1f}% confidence • Sum of probabilities: {prob_sum:.3f}"
446
+
447
+ return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot
448
+
449
+
450
+ # ============================================================================
451
+ # WOW UI / UX – Gradio App
452
  # ============================================================================
453
 
 
454
  custom_css = """
455
+ :root {
456
+ --primary: #6366f1;
457
+ --primary-soft: rgba(99, 102, 241, 0.12);
458
+ --accent: #ec4899;
459
  }
460
+
461
+ .gradio-container {
462
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
463
+ background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%);
464
+ color: #e5e7eb;
 
 
465
  }
466
+
467
+ #hero {
468
+ padding: 24px 24px 8px 24px;
469
+ border-radius: 24px;
470
+ background: linear-gradient(120deg, rgba(99,102,241,0.18), rgba(236,72,153,0.14));
471
+ border: 1px solid rgba(148, 163, 184, 0.4);
472
+ box-shadow: 0 24px 60px rgba(15,23,42,0.85);
473
+ backdrop-filter: blur(18px);
474
  }
475
+
476
+ .hero-title {
477
+ font-size: 2.4rem;
478
+ font-weight: 800;
479
+ letter-spacing: 0.04em;
480
+ color: #f9fafb;
481
+ margin-bottom: 6px;
 
 
482
  }
483
+
484
+ .hero-subtitle {
485
+ font-size: 0.98rem;
486
+ color: #e5e7eb;
487
+ }
488
+
489
+ .hero-chip-row {
490
+ display: flex;
491
+ flex-wrap: wrap;
492
+ gap: 8px;
493
+ margin-top: 14px;
494
+ }
495
+
496
+ .hero-chip {
497
+ padding: 4px 10px;
498
+ border-radius: 999px;
499
+ font-size: 0.78rem;
500
+ background: rgba(15,23,42,0.8);
501
+ border: 1px solid rgba(148,163,184,0.5);
502
+ display: inline-flex;
503
+ align-items: center;
504
+ gap: 6px;
505
+ color: #e5e7eb;
506
+ }
507
+
508
+ .pulse-dot {
509
+ width: 8px;
510
+ height: 8px;
511
+ border-radius: 999px;
512
+ background: #22c55e;
513
+ box-shadow: 0 0 0 0 rgba(34,197,94,0.7);
514
+ animation: pulse 1.4s infinite;
515
+ }
516
+
517
+ @keyframes pulse {
518
+ 0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); }
519
+ 70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); }
520
+ 100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); }
521
+ }
522
+
523
+ .glass-card {
524
+ background: rgba(15,23,42,0.82);
525
+ border-radius: 18px;
526
+ border: 1px solid rgba(148,163,184,0.4);
527
+ box-shadow: 0 18px 40px rgba(15,23,42,0.85);
528
+ padding: 18px;
529
+ backdrop-filter: blur(16px);
530
  }
531
+
532
+ .glass-card-light {
533
+ background: rgba(15,23,42,0.65);
534
+ border-radius: 18px;
535
+ border: 1px solid rgba(148,163,184,0.3);
536
+ box-shadow: 0 12px 24px rgba(15,23,42,0.85);
537
+ padding: 16px;
538
+ backdrop-filter: blur(12px);
539
+ }
540
+
541
+ .stat-pill {
542
+ padding: 10px 12px;
543
+ border-radius: 14px;
544
+ background: rgba(15,23,42,0.9);
545
+ border: 1px solid rgba(148,163,184,0.5);
546
+ font-size: 0.78rem;
547
+ display: flex;
548
+ flex-direction: column;
549
+ gap: 2px;
550
  }
551
+
552
+ .stat-pill-label {
553
+ color: #9ca3af;
554
+ text-transform: uppercase;
555
+ font-size: 0.68rem;
556
  }
557
+
558
+ .stat-pill-value {
559
+ color: #e5e7eb;
560
+ font-weight: 600;
561
+ }
562
+
563
+ .dropzone-image img {
564
+ border-radius: 16px !important;
565
  }
566
+
567
+ .output-image img {
568
+ border-radius: 16px !important;
569
+ }
570
+
571
  footer {
572
  text-align: center;
573
+ margin-top: 24px;
574
+ color: #9ca3af;
575
+ font-size: 0.78rem;
576
  }
577
  """
578
 
579
+ theme = gr.themes.Soft(
580
+ primary_hue="indigo",
581
+ secondary_hue="pink",
582
+ neutral_hue="slate",
583
+ ).set(
584
+ button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)",
585
+ button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)",
586
+ )
587
+
588
+ with gr.Blocks(css=custom_css, theme=theme) as demo:
589
+ # HERO
590
+ gr.HTML(
591
+ f"""
592
+ <div id="hero">
593
+ <div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;">
594
+ <div>
595
+ <div class="hero-title">🫁 AST Chest X-Ray Lab</div>
596
+ <div class="hero-subtitle">
597
+ Multi-class chest X-ray analysis with <b>Explainable AI</b> and
598
+ <b>Adaptive Sparse Training</b>.
599
+ Designed for TB, Pneumonia, COVID-19 and Normal scans.
600
+ </div>
601
+ <div class="hero-chip-row">
602
+ <div class="hero-chip">
603
+ <span class="pulse-dot"></span>
604
+ Live Inference
605
+ </div>
606
+ <div class="hero-chip">
607
+ 4-class EfficientNet · ~{TOTAL_PARAMS_M:.1f}M params
608
+ </div>
609
+ <div class="hero-chip">
610
+ 95–97% validation accuracy · ~89% energy savings
611
+ </div>
612
+ <div class="hero-chip">
613
+ {MODEL_LOAD_INFO}
614
+ </div>
615
+ </div>
616
+ </div>
617
+ <div style="min-width:210px;display:flex;flex-direction:column;gap:8px;">
618
+ <div class="stat-pill">
619
+ <div class="stat-pill-label">Device</div>
620
+ <div class="stat-pill-value">{device.type.upper()}</div>
621
+ </div>
622
+ <div class="stat-pill">
623
+ <div class="stat-pill-label">Model</div>
624
+ <div class="stat-pill-value">EfficientNet-B0 · 4-way classifier</div>
625
+ </div>
626
+ </div>
627
  </div>
628
  </div>
629
+ """
630
+ )
631
+
632
+ gr.Markdown(" ")
633
+
634
+ with gr.Row(equal_height=True):
635
+ # ----------------------------------
636
+ # LEFT: INPUT PANEL
637
+ # ----------------------------------
638
+ with gr.Column(scale=1, elem_classes="glass-card"):
639
+ gr.Markdown("### 1️⃣ Upload & Configure")
640
 
 
 
 
641
  image_input = gr.Image(
642
  type="pil",
643
+ label="Drop a chest X-ray here",
644
+ elem_classes=["dropzone-image"],
645
  )
646
 
647
+ with gr.Row():
648
+ show_gradcam = gr.Checkbox(
649
+ value=True,
650
+ label="Explainable AI (Grad-CAM)",
651
+ info="Highlight regions that drive the prediction",
652
+ )
653
+ audience_select = gr.Radio(
654
+ ["Clinician", "Researcher", "Patient / Public"],
655
+ value="Clinician",
656
+ label="Report Style",
657
+ )
658
 
659
+ with gr.Row():
660
+ analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3)
661
+ clear_btn = gr.Button("🧹 Reset", variant="secondary")
662
+
663
+ gr.Markdown(
664
+ """
665
+ **Tips**
666
+
667
+ - Use frontal (PA/AP) chest X-rays in PNG / JPG format
668
+ - This tool is best used as a **triage / screening assistant**
669
+ - For noisy images or rotated scans, consider preprocessing before upload
670
+ """
671
  )
672
 
673
+ # ----------------------------------
674
+ # RIGHT: RESULTS PANEL
675
+ # ----------------------------------
676
+ with gr.Column(scale=2, elem_classes="glass-card-light"):
677
+ gr.Markdown("### 2️⃣ AI Dashboard")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
  with gr.Tabs():
680
+ with gr.Tab("Snapshot"):
681
+ snapshot_output = gr.Markdown(
682
+ "No scan analyzed yet. Upload an X-ray to get started."
 
683
  )
684
+ prob_output = gr.Label(
685
+ label="Prediction Confidence (All Classes)",
686
+ num_top_classes=4,
 
 
687
  )
688
 
689
+ with gr.Tab("Visual Explanations"):
690
+ with gr.Row():
691
+ original_output = gr.Image(
692
+ label="Annotated X-ray",
693
+ elem_classes=["output-image"],
694
+ )
695
+ gradcam_output = gr.Image(
696
+ label="Attention Heatmap",
697
+ elem_classes=["output-image"],
698
+ )
699
  overlay_output = gr.Image(
700
+ label="Explainable Overlay",
701
+ elem_classes=["output-image"],
702
  )
703
 
704
+ with gr.Tab("Full Report"):
705
+ interpretation_output = gr.Markdown(
706
+ "The full clinical / research report will appear here after inference."
707
+ )
708
+
709
+ with gr.Tab("Model Card"):
710
+ gr.Markdown(
711
+ f"""
712
+ ### 🧠 Model Card – AST Chest X-Ray
713
+
714
+ - **Backbone**: EfficientNet-B0
715
+ - **Task**: 4-way classification (Normal, Tuberculosis, Pneumonia, COVID-19)
716
+ - **Optimization**: Sample-based Adaptive Sparse Training (AST)
717
+ - **Energy Profile**: ~89% training energy reduction vs dense baseline
718
+
719
+ **Design Goals**
720
+
721
+ 1. Provide **fast, explainable triage** support for TB & pneumonia
722
+ 2. Maintain **high specificity**, especially differentiating TB from pneumonia
723
+ 3. Be lightweight enough for **deployment in resource-constrained settings**
724
+
725
+ > This model is a research prototype. Do **not** use it as a stand-alone clinical device.
726
+ """
727
+ )
728
+
729
+ gr.Markdown("---")
730
+
731
+ gr.HTML(
732
+ """
733
+ <footer>
734
+ <p>
735
+ <b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/>
736
+ Built for research, education, and early-stage screening support.
737
+ </p>
738
+ <p style="margin-top:6px;">
739
+ ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is not FDA-approved and cannot replace a clinician
740
+ or radiologist. All decisions must be made by qualified healthcare professionals.
741
+ </p>
742
+ </footer>
743
+ """
744
+ )
745
+
746
+ # ----------------------------------------------------------------------
747
+ # Wiring
748
+ # ----------------------------------------------------------------------
749
+ analyze_btn.click(
750
+ fn=predict_chest_xray,
751
+ inputs=[image_input, show_gradcam, audience_select],
752
+ outputs=[
753
+ prob_output,
754
+ original_output,
755
+ gradcam_output,
756
+ overlay_output,
757
+ interpretation_output,
758
+ snapshot_output,
759
+ ],
760
+ )
761
+
762
+ clear_btn.click(
763
+ fn=lambda: ({}, None, None, None, "Awaiting image upload…", "Awaiting image upload…"),
764
+ inputs=None,
765
+ outputs=[
766
+ prob_output,
767
+ original_output,
768
+ gradcam_output,
769
+ overlay_output,
770
+ interpretation_output,
771
+ snapshot_output,
772
+ ],
773
+ )
774
 
775
+ # Example X-rays section (optional – remove if you don't have these paths)
776
+ gr.Markdown("### 🔍 Try Example X-rays")
777
  gr.Examples(
778
  examples=[
779
  ["examples/normal.png"],
 
782
  ["examples/covid.png"],
783
  ],
784
  inputs=image_input,
 
785
  )
786
 
787
+ # ============================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
788
  # Launch
789
+ # ============================================================================
790
+
791
  if __name__ == "__main__":
792
  demo.launch(
793
  share=False,
794
  server_name="0.0.0.0",
795
  server_port=7860,
796
+ show_error=True,
797
  )