mgbam commited on
Commit
2a34f6f
·
verified ·
1 Parent(s): 30bd152

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -875
app.py DELETED
@@ -1,875 +0,0 @@
1
- """
2
- 🫁 AST Chest X-Ray Lab
3
- Multi-Class Chest X-Ray Detection (Normal · TB · Pneumonia · COVID-19)
4
- with Adaptive Sparse Training & Explainable AI (Grad-CAM)
5
-
6
- This app is a research / screening tool – not a diagnostic device.
7
- """
8
-
9
- import io
10
- from pathlib import Path
11
-
12
- import cv2
13
- import gradio as gr
14
- import matplotlib
15
- matplotlib.use("Agg") # non-interactive backend for servers
16
- import matplotlib.pyplot as plt
17
- import numpy as np
18
- import torch
19
- import torch.nn as nn
20
- from PIL import Image
21
- from torchvision import models, transforms
22
-
23
- # ============================================================================
24
- # Model Setup
25
- # ============================================================================
26
-
27
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
-
29
- NUM_CLASSES = 4
30
-
31
- # Backbone: EfficientNet-B0 with 4-class head
32
- model = models.efficientnet_b0(weights=None)
33
- model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
34
-
35
- # Where we expect the (4-class) checkpoint to live
36
- checkpoint_candidates = [
37
- "checkpoints/best.pt", # main location (from your HF screenshot)
38
- "best.pt", # optional fallback in root
39
- ]
40
-
41
- MODEL_LOAD_INFO = ""
42
- loaded = False
43
-
44
-
45
- def extract_state_dict(ckpt):
46
- """
47
- Handle both:
48
- - plain state_dict (just param tensors)
49
- - training checkpoints: keys like 'model_state_dict', 'state_dict', 'model', etc.
50
- """
51
- if isinstance(ckpt, dict):
52
- for key in ["model_state_dict", "state_dict", "model"]:
53
- if key in ckpt and isinstance(ckpt[key], dict):
54
- return ckpt[key]
55
- return ckpt
56
-
57
-
58
- for ckpt_path in checkpoint_candidates:
59
- if Path(ckpt_path).is_file():
60
- try:
61
- print(f"🔍 Trying to load weights from: {ckpt_path}")
62
- raw_ckpt = torch.load(ckpt_path, map_location=device)
63
- state_dict = extract_state_dict(raw_ckpt)
64
-
65
- # Sanity check: classifier head must be 4-way
66
- if "classifier.1.weight" in state_dict:
67
- out_features = state_dict["classifier.1.weight"].shape[0]
68
- if out_features != NUM_CLASSES:
69
- raise ValueError(
70
- f"Checkpoint at {ckpt_path} has {out_features} output "
71
- f"classes, but this app expects {NUM_CLASSES}."
72
- )
73
-
74
- model.load_state_dict(state_dict, strict=True)
75
-
76
- MODEL_LOAD_INFO = (
77
- f"✅ Model loaded from <b>{ckpt_path}</b> on <b>{device.type.upper()}</b>."
78
- )
79
- loaded = True
80
- break
81
- except Exception as e:
82
- print(f"⚠️ Found {ckpt_path} but failed to load model_state_dict: {e}")
83
-
84
- if not loaded:
85
- raise RuntimeError(
86
- "Model file not found or could not be loaded.\n"
87
- "Expected a 4-class EfficientNet checkpoint at 'checkpoints/best.pt' or 'best.pt'.\n"
88
- "If you saved a training checkpoint, make sure it contains "
89
- "a 'model_state_dict' key with the 4-class EfficientNet weights."
90
- )
91
-
92
- model = model.to(device)
93
- model.eval()
94
-
95
- TOTAL_PARAMS = sum(p.numel() for p in model.parameters())
96
- TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6
97
-
98
- # ============================================================================
99
- # Classes & Preprocessing
100
- # ============================================================================
101
-
102
- CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
103
- CLASS_COLORS = {
104
- "Normal": "#22c55e", # Green
105
- "Tuberculosis": "#ef4444", # Red
106
- "Pneumonia": "#f97316", # Orange
107
- "COVID-19": "#a855f7", # Purple
108
- }
109
-
110
- transform = transforms.Compose(
111
- [
112
- transforms.Resize(256),
113
- transforms.CenterCrop(224),
114
- transforms.ToTensor(),
115
- transforms.Normalize(
116
- [0.485, 0.456, 0.406],
117
- [0.229, 0.224, 0.225],
118
- ),
119
- ]
120
- )
121
-
122
- # ============================================================================
123
- # Grad-CAM Implementation
124
- # ============================================================================
125
-
126
-
127
- class GradCAM:
128
- def __init__(self, model, target_layer):
129
- self.model = model
130
- self.target_layer = target_layer
131
- self.gradients = None
132
- self.activations = None
133
-
134
- def save_gradient(grad):
135
- self.gradients = grad
136
-
137
- def save_activation(module, input, output):
138
- self.activations = output.detach()
139
-
140
- target_layer.register_forward_hook(save_activation)
141
- target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0]))
142
-
143
- def generate(self, input_image, target_class=None):
144
- output = self.model(input_image)
145
-
146
- if target_class is None:
147
- target_class = output.argmax(dim=1)
148
-
149
- self.model.zero_grad()
150
- one_hot = torch.zeros_like(output)
151
- one_hot[0, target_class] = 1
152
- output.backward(gradient=one_hot, retain_graph=True)
153
-
154
- if self.gradients is None:
155
- return None, output
156
-
157
- weights = self.gradients.mean(dim=(2, 3), keepdim=True)
158
- cam = (weights * self.activations).sum(dim=1, keepdim=True)
159
- cam = torch.relu(cam)
160
- cam = cam.squeeze().cpu().numpy()
161
- cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
162
-
163
- return cam, output
164
-
165
-
166
- target_layer = model.features[-1]
167
- grad_cam = GradCAM(model, target_layer)
168
-
169
- # ============================================================================
170
- # Visualization Helpers
171
- # ============================================================================
172
-
173
-
174
- def _figure_to_pil():
175
- buf = io.BytesIO()
176
- plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
177
- plt.close()
178
- buf.seek(0)
179
- return Image.open(buf)
180
-
181
-
182
- def create_original_display(image, pred_label, confidence):
183
- fig, ax = plt.subplots(figsize=(7, 7))
184
- ax.imshow(image)
185
- ax.axis("off")
186
-
187
- color = CLASS_COLORS[pred_label]
188
- title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%"
189
- ax.set_title(
190
- title,
191
- fontsize=16,
192
- fontweight="bold",
193
- color=color,
194
- pad=20,
195
- )
196
- plt.tight_layout()
197
- return _figure_to_pil()
198
-
199
-
200
- def create_gradcam_visualization(image, cam):
201
- img_array = np.array(image.resize((224, 224)))
202
- cam_resized = cv2.resize(cam, (224, 224))
203
-
204
- heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
205
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
206
-
207
- fig, ax = plt.subplots(figsize=(7, 7))
208
- ax.imshow(heatmap)
209
- ax.axis("off")
210
- ax.set_title(
211
- "Attention Heatmap\n(Where the model is focusing)",
212
- fontsize=14,
213
- fontweight="bold",
214
- pad=20,
215
- )
216
- plt.tight_layout()
217
- return _figure_to_pil()
218
-
219
-
220
- def create_overlay_visualization(image, cam):
221
- img_array = np.array(image.resize((224, 224))) / 255.0
222
- cam_resized = cv2.resize(cam, (224, 224))
223
-
224
- heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
225
- heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
226
-
227
- overlay = img_array * 0.5 + heatmap * 0.5
228
- overlay = np.clip(overlay, 0, 1)
229
-
230
- fig, ax = plt.subplots(figsize=(7, 7))
231
- ax.imshow(overlay)
232
- ax.axis("off")
233
- ax.set_title(
234
- "Explainable AI Overlay\n(Anatomy + Model Attention)",
235
- fontsize=14,
236
- fontweight="bold",
237
- pad=20,
238
- )
239
- plt.tight_layout()
240
- return _figure_to_pil()
241
-
242
-
243
- def create_probability_bar(results):
244
- """Horizontal bar chart of 4-class probabilities."""
245
- classes = list(results.keys())
246
- values = [results[c] for c in classes]
247
- y_pos = np.arange(len(classes))
248
-
249
- fig, ax = plt.subplots(figsize=(6.4, 3.5))
250
- ax.barh(y_pos, values)
251
- ax.set_yticks(y_pos)
252
- ax.set_yticklabels(classes)
253
- ax.invert_yaxis()
254
- ax.set_xlim(0, 100)
255
- ax.set_xlabel("Probability (%)")
256
- ax.set_title("Probability Profile by Class", fontsize=12, fontweight="bold")
257
- for i, v in enumerate(values):
258
- ax.text(v + 1, i, f"{v:.1f}%", va="center", fontsize=9)
259
- plt.tight_layout()
260
- return _figure_to_pil()
261
-
262
- # ============================================================================
263
- # Interpretation
264
- # ============================================================================
265
-
266
-
267
- def triage_label(pred_label, confidence):
268
- """
269
- Simple triage categorisation for clinicians / dashboards.
270
- """
271
- high = confidence >= 85
272
- moderate = 65 <= confidence < 85
273
-
274
- if pred_label == "Normal":
275
- if high:
276
- return "🟢 Low risk – no major abnormality detected (model view)"
277
- elif moderate:
278
- return "🟡 Likely normal, but low confidence – consider clinical context"
279
- else:
280
- return "🟡 Indeterminate – imaging looks close to normal, but model is uncertain"
281
- else:
282
- if high:
283
- return "🔴 High risk finding – prioritise expert review"
284
- elif moderate:
285
- return "🟠 Possible pathology – correlate with symptoms and labs"
286
- else:
287
- return "🟡 Weak signal – treat as a soft flag, not a diagnosis"
288
-
289
-
290
- def create_interpretation(pred_label, confidence, results, audience="Clinician"):
291
- header_note = {
292
- "Clinician": "Optimised for **clinical decision support** – not a replacement for your judgement.",
293
- "Researcher": "Optimised for **model behaviour analysis** and experimental workflows.",
294
- "Patient / Public": "Optimised for **patient-friendly language**. Always discuss results with a doctor.",
295
- }.get(audience, "Use this output as a **screening aid**, not a final diagnosis.")
296
-
297
- interpretation = f"""
298
- ## 🔬 Analysis Results ({audience} View)
299
-
300
- > {header_note}
301
-
302
- ### Primary Prediction: **{pred_label}**
303
- - Confidence: **{confidence:.1f}%**
304
- - Triage comment: {triage_label(pred_label, confidence)}
305
-
306
- ### Probability Breakdown
307
- - 🟢 Normal: **{results['Normal']:.1f}%**
308
- - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
309
- - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
310
- - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
311
-
312
- ---
313
- """
314
-
315
- # Disease-specific narrative
316
- if pred_label == "Tuberculosis":
317
- if confidence >= 85:
318
- interpretation += """
319
- ### 🧫 TB Pattern – High Confidence
320
-
321
- The model has detected features strongly suggestive of **pulmonary tuberculosis**.
322
-
323
- **Suggested Clinical Pathway**
324
- 1. Prompt review by a clinician / chest physician
325
- 2. Sputum testing (AFB smear, GeneXpert MTB/RIF, or TB-PCR)
326
- 3. Correlate with symptoms:
327
- - Chronic cough (>2 weeks)
328
- - Weight loss, night sweats
329
- - Fever, fatigue
330
- - Haemoptysis
331
- 4. Consider CT or further imaging if discordant with clinical picture
332
- 5. Infection control and contact tracing as per TB guidelines
333
- """
334
- else:
335
- interpretation += """
336
- ### 🧫 TB Pattern – Possible
337
-
338
- There are features that **could** be compatible with TB, but the confidence is moderate.
339
-
340
- - Review history and risk factors
341
- - Consider sputum testing if suspicion is non-trivial
342
- - Follow-up imaging where indicated
343
- """
344
-
345
- elif pred_label == "Pneumonia":
346
- if confidence >= 85:
347
- interpretation += """
348
- ### 🌫 Pneumonia Pattern – High Confidence
349
-
350
- The model has detected an opacity pattern consistent with **pneumonia**.
351
-
352
- Typical clinical picture may include:
353
-
354
- - Fever, productive cough
355
- - Shortness of breath
356
- - Pleuritic chest pain
357
-
358
- Use in combination with examination, labs (WBC, CRP, cultures) and local treatment guidelines.
359
- """
360
- else:
361
- interpretation += """
362
- ### 🌫 Pneumonia Pattern – Possible
363
-
364
- Findings may be compatible with pneumonia, but alternative explanations exist.
365
-
366
- - Check vitals and respiratory exam
367
- - Labs and microbiology can support or refute the impression
368
- - Consider watchful follow-up or repeat imaging
369
- """
370
-
371
- elif pred_label == "COVID-19":
372
- if confidence >= 85:
373
- interpretation += """
374
- ### 🦠 COVID-19 Pattern – High Confidence
375
-
376
- The distribution and appearance of opacities are compatible with **COVID-19 pneumonia**.
377
-
378
- ⚠️ Imaging alone is **not diagnostic**. Key points:
379
-
380
- - Confirmation requires RT-PCR or validated antigen testing
381
- - Follow local isolation and infection-control policies
382
- - Monitor SpO₂ and work of breathing; escalate care if deterioration occurs
383
- """
384
- else:
385
- interpretation += """
386
- ### 🦠 COVID-19 Pattern – Possible
387
-
388
- There are features that could overlap with COVID-19, but uncertainty is substantial.
389
-
390
- - Testing (RT-PCR / antigen) is essential
391
- - Integrate exposure history, symptoms, and public health guidance
392
- """
393
-
394
- else: # Normal
395
- if confidence >= 85:
396
- interpretation += """
397
- ### ✅ No Major Abnormality Detected (Model View)
398
-
399
- The model did **not** detect strong features of TB, pneumonia, or COVID-19.
400
-
401
- **Important caveats**
402
-
403
- - Early disease or small lesions may be missed
404
- - Non-infective conditions (e.g. malignancy, ILD) are **not** specifically evaluated
405
- - Persistent or unexplained symptoms still require clinical review
406
- """
407
- else:
408
- interpretation += """
409
- ### ℹ️ Likely Normal, But Low Confidence
410
-
411
- The scan leans towards **normal**, but the model's confidence is limited.
412
-
413
- - Consider repeat imaging, further tests, or expert review if symptoms persist
414
- """
415
-
416
- interpretation += """
417
- ---
418
- ## ⚠️ CRITICAL MEDICAL DISCLAIMER
419
-
420
- - This AI system is a **screening / decision-support tool** only
421
- - It is **not FDA-approved**, CE-marked, or licensed as a medical device
422
- - It must **not** be used as a stand-alone diagnostic system
423
-
424
- Always integrate:
425
- - Clinical history and examination
426
- - Laboratory tests (e.g. sputum AFB / GeneXpert, PCR, cultures)
427
- - Radiologist / specialist interpretation
428
-
429
- **Gold Standards**
430
-
431
- - Tuberculosis: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR
432
- - Pneumonia: Clinical diagnosis + labs / microbiology
433
- - COVID-19: RT-PCR or validated antigen tests
434
-
435
- When in doubt, consult a qualified healthcare professional.
436
- ---
437
- 🫁 **Powered by Adaptive Sparse Training (AST)**
438
- Energy-efficient deep learning to support lung health in both high-resource and low-resource settings.
439
-
440
- **Project Links**
441
-
442
- - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
443
- - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
444
- """
445
- return interpretation
446
-
447
- # ============================================================================
448
- # Prediction Pipeline
449
- # ============================================================================
450
-
451
-
452
- def predict_chest_xray(image, show_gradcam=True, audience="Clinician"):
453
- """
454
- Main inference function used by Gradio.
455
- Returns:
456
- - dict of class probabilities
457
- - annotated original
458
- - grad-cam heatmap
459
- - overlay
460
- - full markdown report
461
- - short textual snapshot
462
- - probability bar-chart image
463
- """
464
- if image is None:
465
- msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report."
466
- return {}, None, None, None, msg, "Awaiting image upload…", None
467
-
468
- if isinstance(image, np.ndarray):
469
- image = Image.fromarray(image).convert("RGB")
470
- else:
471
- image = image.convert("RGB")
472
-
473
- original_img = image.copy()
474
- input_tensor = transform(image).unsqueeze(0).to(device)
475
-
476
- with torch.set_grad_enabled(show_gradcam):
477
- if show_gradcam:
478
- cam, output = grad_cam.generate(input_tensor)
479
- else:
480
- output = model(input_tensor)
481
- cam = None
482
-
483
- probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
484
- prob_sum = float(np.sum(probs))
485
-
486
- if not (0.99 <= prob_sum <= 1.01):
487
- print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.")
488
-
489
- pred_class = int(output.argmax(dim=1).item())
490
- pred_label = CLASSES[pred_class]
491
- confidence = float(probs[pred_class]) * 100.0
492
-
493
- results = {
494
- CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
495
- for i in range(len(CLASSES))
496
- }
497
-
498
- original_pil = create_original_display(original_img, pred_label, confidence)
499
- gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None
500
- overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None
501
- prob_chart = create_probability_bar(results)
502
-
503
- interpretation = create_interpretation(pred_label, confidence, results, audience=audience)
504
-
505
- snapshot = (
506
- f"### 📝 Triage Snapshot\n\n"
507
- f"- **Finding:** {pred_label}\n"
508
- f"- **Model confidence:** {confidence:.1f}%\n"
509
- f"- **Triage comment:** {triage_label(pred_label, confidence)}\n"
510
- f"- **Probability sum (sanity check):** {prob_sum:.3f}"
511
- )
512
-
513
- return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot, prob_chart
514
-
515
- # ============================================================================
516
- # WOW UI / UX – Gradio App
517
- # ============================================================================
518
-
519
- custom_css = """
520
- :root {
521
- --primary: #6366f1;
522
- --primary-soft: rgba(99, 102, 241, 0.12);
523
- --accent: #ec4899;
524
- }
525
-
526
- .gradio-container {
527
- font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
528
- background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%);
529
- color: #e5e7eb;
530
- }
531
-
532
- #hero {
533
- padding: 24px 24px 8px 24px;
534
- border-radius: 24px;
535
- background: linear-gradient(120deg, rgba(99,102,241,0.22), rgba(236,72,153,0.18));
536
- border: 1px solid rgba(148, 163, 184, 0.45);
537
- box-shadow: 0 24px 60px rgba(15,23,42,0.9);
538
- backdrop-filter: blur(18px);
539
- }
540
-
541
- .hero-title {
542
- font-size: 2.4rem;
543
- font-weight: 800;
544
- letter-spacing: 0.04em;
545
- color: #f9fafb;
546
- margin-bottom: 6px;
547
- }
548
-
549
- .hero-subtitle {
550
- font-size: 0.98rem;
551
- color: #e5e7eb;
552
- }
553
-
554
- .hero-chip-row {
555
- display: flex;
556
- flex-wrap: wrap;
557
- gap: 8px;
558
- margin-top: 14px;
559
- }
560
-
561
- .hero-chip {
562
- padding: 4px 10px;
563
- border-radius: 999px;
564
- font-size: 0.78rem;
565
- background: rgba(15,23,42,0.8);
566
- border: 1px solid rgba(148,163,184,0.5);
567
- display: inline-flex;
568
- align-items: center;
569
- gap: 6px;
570
- color: #e5e7eb;
571
- }
572
-
573
- .pulse-dot {
574
- width: 8px;
575
- height: 8px;
576
- border-radius: 999px;
577
- background: #22c55e;
578
- box-shadow: 0 0 0 0 rgba(34,197,94,0.7);
579
- animation: pulse 1.4s infinite;
580
- }
581
-
582
- @keyframes pulse {
583
- 0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); }
584
- 70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); }
585
- 100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); }
586
- }
587
-
588
- .glass-card {
589
- background: rgba(15,23,42,0.86);
590
- border-radius: 18px;
591
- border: 1px solid rgba(148,163,184,0.4);
592
- box-shadow: 0 18px 40px rgba(15,23,42,0.9);
593
- padding: 18px;
594
- backdrop-filter: blur(16px);
595
- }
596
-
597
- .glass-card-light {
598
- background: rgba(15,23,42,0.7);
599
- border-radius: 18px;
600
- border: 1px solid rgba(148,163,184,0.3);
601
- box-shadow: 0 12px 30px rgba(15,23,42,0.9);
602
- padding: 16px;
603
- backdrop-filter: blur(12px);
604
- }
605
-
606
- .stat-pill {
607
- padding: 10px 12px;
608
- border-radius: 14px;
609
- background: rgba(15,23,42,0.9);
610
- border: 1px solid rgba(148,163,184,0.5);
611
- font-size: 0.78rem;
612
- display: flex;
613
- flex-direction: column;
614
- gap: 2px;
615
- }
616
-
617
- .stat-pill-label {
618
- color: #9ca3af;
619
- text-transform: uppercase;
620
- font-size: 0.68rem;
621
- }
622
-
623
- .stat-pill-value {
624
- color: #e5e7eb;
625
- font-weight: 600;
626
- }
627
-
628
- .dropzone-image img,
629
- .output-image img {
630
- border-radius: 16px !important;
631
- }
632
-
633
- footer {
634
- text-align: center;
635
- margin-top: 24px;
636
- color: #9ca3af;
637
- font-size: 0.78rem;
638
- }
639
- """
640
-
641
- theme = gr.themes.Soft(
642
- primary_hue="indigo",
643
- secondary_hue="pink",
644
- neutral_hue="slate",
645
- ).set(
646
- button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)",
647
- button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)",
648
- )
649
-
650
- with gr.Blocks(css=custom_css, theme=theme) as demo:
651
- # HERO
652
- gr.HTML(
653
- f"""
654
- <div id="hero">
655
- <div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;">
656
- <div>
657
- <div class="hero-title">🫁 AST Chest X-Ray Lab</div>
658
- <div class="hero-subtitle">
659
- Multi-class chest X-ray analysis with <b>Explainable AI</b> and
660
- <b>Adaptive Sparse Training</b> – Normal, Tuberculosis, Pneumonia, COVID-19.
661
- Designed to support clinicians, researchers, and global health teams.
662
- </div>
663
- <div class="hero-chip-row">
664
- <div class="hero-chip">
665
- <span class="pulse-dot"></span>
666
- Live Inference (Research Prototype)
667
- </div>
668
- <div class="hero-chip">
669
- EfficientNet-B0 · ~{TOTAL_PARAMS_M:.1f}M parameters
670
- </div>
671
- <div class="hero-chip">
672
- 95–97% validation accuracy · ~89% training energy savings
673
- </div>
674
- <div class="hero-chip">
675
- {MODEL_LOAD_INFO}
676
- </div>
677
- </div>
678
- </div>
679
- <div style="min-width:220px;display:flex;flex-direction:column;gap:8px;">
680
- <div class="stat-pill">
681
- <div class="stat-pill-label">Compute</div>
682
- <div class="stat-pill-value">{device.type.upper()}</div>
683
- </div>
684
- <div class="stat-pill">
685
- <div class="stat-pill-label">Use Case</div>
686
- <div class="stat-pill-value">Triage & decision support (not diagnostic)</div>
687
- </div>
688
- </div>
689
- </div>
690
- </div>
691
- """
692
- )
693
-
694
- gr.Markdown(" ")
695
-
696
- with gr.Row(equal_height=True):
697
- # LEFT: INPUT PANEL
698
- with gr.Column(scale=1, elem_classes="glass-card"):
699
- gr.Markdown("### 1️⃣ Upload & Configure")
700
-
701
- image_input = gr.Image(
702
- type="pil",
703
- label="Drop a chest X-ray here",
704
- elem_classes=["dropzone-image"],
705
- )
706
-
707
- with gr.Row():
708
- show_gradcam = gr.Checkbox(
709
- value=True,
710
- label="Explainable AI (Grad-CAM)",
711
- info="Highlight regions that drive the prediction",
712
- )
713
- audience_select = gr.Radio(
714
- ["Clinician", "Researcher", "Patient / Public"],
715
- value="Clinician",
716
- label="Report Style",
717
- )
718
-
719
- with gr.Row():
720
- analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3)
721
- clear_btn = gr.Button("🧹 Reset", variant="secondary")
722
-
723
- gr.Markdown(
724
- """
725
- **Usage Notes**
726
-
727
- - Best for frontal (PA/AP) chest X-rays in PNG / JPG format
728
- - Intended for **triage, education, and research**, not final diagnosis
729
- - For off-axis, noisy, or portable images, interpret outputs with extra caution
730
- """
731
- )
732
-
733
- # RIGHT: RESULTS PANEL
734
- with gr.Column(scale=2, elem_classes="glass-card-light"):
735
- gr.Markdown("### 2️⃣ AI Dashboard")
736
-
737
- with gr.Tabs():
738
- with gr.Tab("Triage Snapshot"):
739
- snapshot_output = gr.Markdown(
740
- "No scan analysed yet. Upload an X-ray to get started."
741
- )
742
- with gr.Row():
743
- prob_output = gr.Label(
744
- label="Prediction Confidence (All Classes)",
745
- num_top_classes=4,
746
- )
747
- prob_chart_output = gr.Image(
748
- label="Probability Profile",
749
- elem_classes=["output-image"],
750
- )
751
-
752
- with gr.Tab("Visual Explanations"):
753
- with gr.Row():
754
- original_output = gr.Image(
755
- label="Annotated X-ray",
756
- elem_classes=["output-image"],
757
- )
758
- gradcam_output = gr.Image(
759
- label="Attention Heatmap",
760
- elem_classes=["output-image"],
761
- )
762
- overlay_output = gr.Image(
763
- label="Explainable Overlay",
764
- elem_classes=["output-image"],
765
- )
766
-
767
- with gr.Tab("Full Report"):
768
- interpretation_output = gr.Markdown(
769
- "The full clinical / research report will appear here after inference."
770
- )
771
-
772
- with gr.Tab("Model Card"):
773
- gr.Markdown(
774
- f"""
775
- ### 🧠 Model Card – AST Chest X-Ray
776
-
777
- - **Backbone**: EfficientNet-B0
778
- - **Task**: 4-way classification (Normal, Tuberculosis, Pneumonia, COVID-19)
779
- - **Optimisation**: Sample-based Adaptive Sparse Training (AST)
780
- - **Motivation**: Energy-efficient AI for global lung health
781
-
782
- **Intended Use**
783
-
784
- - Research and prototyping
785
- - Triage decision-support in pilot settings
786
- - Education (medical students, residents, data scientists)
787
-
788
- **Non-Intended Use**
789
-
790
- - Stand-alone diagnosis
791
- - Automated treatment decisions
792
- - Regulatory-grade clinical deployment
793
-
794
- > Always pair the model with local guidelines, expert radiology, and laboratory testing.
795
- """
796
- )
797
-
798
- gr.Markdown("---")
799
-
800
- gr.HTML(
801
- """
802
- <footer>
803
- <p>
804
- <b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/>
805
- Built to explore how energy-efficient AI can support clinicians and patients worldwide.
806
- </p>
807
- <p style="margin-top:6px;">
808
- ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is a research prototype and is not FDA-approved,
809
- CE-marked, or licensed as a medical device. All clinical decisions must be made by
810
- qualified healthcare professionals.
811
- </p>
812
- </footer>
813
- """
814
- )
815
-
816
- # Wiring – connect logic to UI
817
- analyze_btn.click(
818
- fn=predict_chest_xray,
819
- inputs=[image_input, show_gradcam, audience_select],
820
- outputs=[
821
- prob_output,
822
- original_output,
823
- gradcam_output,
824
- overlay_output,
825
- interpretation_output,
826
- snapshot_output,
827
- prob_chart_output,
828
- ],
829
- )
830
-
831
- clear_btn.click(
832
- fn=lambda: (
833
- {},
834
- None,
835
- None,
836
- None,
837
- "Awaiting image upload…",
838
- "Awaiting image upload…",
839
- None,
840
- ),
841
- inputs=None,
842
- outputs=[
843
- prob_output,
844
- original_output,
845
- gradcam_output,
846
- overlay_output,
847
- interpretation_output,
848
- snapshot_output,
849
- prob_chart_output,
850
- ],
851
- )
852
-
853
- # Example X-rays (optional – comment out if these paths don't exist)
854
- gr.Markdown("### 🔍 Try Example X-rays")
855
- gr.Examples(
856
- examples=[
857
- ["examples/normal.png"],
858
- ["examples/tb.png"],
859
- ["examples/pneumonia.png"],
860
- ["examples/covid.png"],
861
- ],
862
- inputs=image_input,
863
- )
864
-
865
- # ============================================================================
866
- # Launch
867
- # ============================================================================
868
-
869
- if __name__ == "__main__":
870
- demo.launch(
871
- share=False,
872
- server_name="0.0.0.0",
873
- server_port=7860,
874
- show_error=True,
875
- )