mgbam commited on
Commit
30bd152
·
verified ·
1 Parent(s): c53fdfe

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +875 -0
app.py ADDED
@@ -0,0 +1,875 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 🫁 AST Chest X-Ray Lab
3
+ Multi-Class Chest X-Ray Detection (Normal · TB · Pneumonia · COVID-19)
4
+ with Adaptive Sparse Training & Explainable AI (Grad-CAM)
5
+
6
+ This app is a research / screening tool – not a diagnostic device.
7
+ """
8
+
9
+ import io
10
+ from pathlib import Path
11
+
12
+ import cv2
13
+ import gradio as gr
14
+ import matplotlib
15
+ matplotlib.use("Agg") # non-interactive backend for servers
16
+ import matplotlib.pyplot as plt
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ from PIL import Image
21
+ from torchvision import models, transforms
22
+
23
+ # ============================================================================
24
+ # Model Setup
25
+ # ============================================================================
26
+
27
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
28
+
29
+ NUM_CLASSES = 4
30
+
31
+ # Backbone: EfficientNet-B0 with 4-class head
32
+ model = models.efficientnet_b0(weights=None)
33
+ model.classifier[1] = nn.Linear(model.classifier[1].in_features, NUM_CLASSES)
34
+
35
+ # Where we expect the (4-class) checkpoint to live
36
+ checkpoint_candidates = [
37
+ "checkpoints/best.pt", # main location (from your HF screenshot)
38
+ "best.pt", # optional fallback in root
39
+ ]
40
+
41
+ MODEL_LOAD_INFO = ""
42
+ loaded = False
43
+
44
+
45
+ def extract_state_dict(ckpt):
46
+ """
47
+ Handle both:
48
+ - plain state_dict (just param tensors)
49
+ - training checkpoints: keys like 'model_state_dict', 'state_dict', 'model', etc.
50
+ """
51
+ if isinstance(ckpt, dict):
52
+ for key in ["model_state_dict", "state_dict", "model"]:
53
+ if key in ckpt and isinstance(ckpt[key], dict):
54
+ return ckpt[key]
55
+ return ckpt
56
+
57
+
58
+ for ckpt_path in checkpoint_candidates:
59
+ if Path(ckpt_path).is_file():
60
+ try:
61
+ print(f"🔍 Trying to load weights from: {ckpt_path}")
62
+ raw_ckpt = torch.load(ckpt_path, map_location=device)
63
+ state_dict = extract_state_dict(raw_ckpt)
64
+
65
+ # Sanity check: classifier head must be 4-way
66
+ if "classifier.1.weight" in state_dict:
67
+ out_features = state_dict["classifier.1.weight"].shape[0]
68
+ if out_features != NUM_CLASSES:
69
+ raise ValueError(
70
+ f"Checkpoint at {ckpt_path} has {out_features} output "
71
+ f"classes, but this app expects {NUM_CLASSES}."
72
+ )
73
+
74
+ model.load_state_dict(state_dict, strict=True)
75
+
76
+ MODEL_LOAD_INFO = (
77
+ f"✅ Model loaded from <b>{ckpt_path}</b> on <b>{device.type.upper()}</b>."
78
+ )
79
+ loaded = True
80
+ break
81
+ except Exception as e:
82
+ print(f"⚠️ Found {ckpt_path} but failed to load model_state_dict: {e}")
83
+
84
+ if not loaded:
85
+ raise RuntimeError(
86
+ "Model file not found or could not be loaded.\n"
87
+ "Expected a 4-class EfficientNet checkpoint at 'checkpoints/best.pt' or 'best.pt'.\n"
88
+ "If you saved a training checkpoint, make sure it contains "
89
+ "a 'model_state_dict' key with the 4-class EfficientNet weights."
90
+ )
91
+
92
+ model = model.to(device)
93
+ model.eval()
94
+
95
+ TOTAL_PARAMS = sum(p.numel() for p in model.parameters())
96
+ TOTAL_PARAMS_M = TOTAL_PARAMS / 1e6
97
+
98
+ # ============================================================================
99
+ # Classes & Preprocessing
100
+ # ============================================================================
101
+
102
+ CLASSES = ["Normal", "Tuberculosis", "Pneumonia", "COVID-19"]
103
+ CLASS_COLORS = {
104
+ "Normal": "#22c55e", # Green
105
+ "Tuberculosis": "#ef4444", # Red
106
+ "Pneumonia": "#f97316", # Orange
107
+ "COVID-19": "#a855f7", # Purple
108
+ }
109
+
110
+ transform = transforms.Compose(
111
+ [
112
+ transforms.Resize(256),
113
+ transforms.CenterCrop(224),
114
+ transforms.ToTensor(),
115
+ transforms.Normalize(
116
+ [0.485, 0.456, 0.406],
117
+ [0.229, 0.224, 0.225],
118
+ ),
119
+ ]
120
+ )
121
+
122
+ # ============================================================================
123
+ # Grad-CAM Implementation
124
+ # ============================================================================
125
+
126
+
127
+ class GradCAM:
128
+ def __init__(self, model, target_layer):
129
+ self.model = model
130
+ self.target_layer = target_layer
131
+ self.gradients = None
132
+ self.activations = None
133
+
134
+ def save_gradient(grad):
135
+ self.gradients = grad
136
+
137
+ def save_activation(module, input, output):
138
+ self.activations = output.detach()
139
+
140
+ target_layer.register_forward_hook(save_activation)
141
+ target_layer.register_full_backward_hook(lambda m, gi, go: save_gradient(go[0]))
142
+
143
+ def generate(self, input_image, target_class=None):
144
+ output = self.model(input_image)
145
+
146
+ if target_class is None:
147
+ target_class = output.argmax(dim=1)
148
+
149
+ self.model.zero_grad()
150
+ one_hot = torch.zeros_like(output)
151
+ one_hot[0, target_class] = 1
152
+ output.backward(gradient=one_hot, retain_graph=True)
153
+
154
+ if self.gradients is None:
155
+ return None, output
156
+
157
+ weights = self.gradients.mean(dim=(2, 3), keepdim=True)
158
+ cam = (weights * self.activations).sum(dim=1, keepdim=True)
159
+ cam = torch.relu(cam)
160
+ cam = cam.squeeze().cpu().numpy()
161
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)
162
+
163
+ return cam, output
164
+
165
+
166
+ target_layer = model.features[-1]
167
+ grad_cam = GradCAM(model, target_layer)
168
+
169
+ # ============================================================================
170
+ # Visualization Helpers
171
+ # ============================================================================
172
+
173
+
174
+ def _figure_to_pil():
175
+ buf = io.BytesIO()
176
+ plt.savefig(buf, format="png", dpi=150, bbox_inches="tight", facecolor="white")
177
+ plt.close()
178
+ buf.seek(0)
179
+ return Image.open(buf)
180
+
181
+
182
+ def create_original_display(image, pred_label, confidence):
183
+ fig, ax = plt.subplots(figsize=(7, 7))
184
+ ax.imshow(image)
185
+ ax.axis("off")
186
+
187
+ color = CLASS_COLORS[pred_label]
188
+ title = f"Prediction: {pred_label} • Confidence: {confidence:.1f}%"
189
+ ax.set_title(
190
+ title,
191
+ fontsize=16,
192
+ fontweight="bold",
193
+ color=color,
194
+ pad=20,
195
+ )
196
+ plt.tight_layout()
197
+ return _figure_to_pil()
198
+
199
+
200
+ def create_gradcam_visualization(image, cam):
201
+ img_array = np.array(image.resize((224, 224)))
202
+ cam_resized = cv2.resize(cam, (224, 224))
203
+
204
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
205
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
206
+
207
+ fig, ax = plt.subplots(figsize=(7, 7))
208
+ ax.imshow(heatmap)
209
+ ax.axis("off")
210
+ ax.set_title(
211
+ "Attention Heatmap\n(Where the model is focusing)",
212
+ fontsize=14,
213
+ fontweight="bold",
214
+ pad=20,
215
+ )
216
+ plt.tight_layout()
217
+ return _figure_to_pil()
218
+
219
+
220
+ def create_overlay_visualization(image, cam):
221
+ img_array = np.array(image.resize((224, 224))) / 255.0
222
+ cam_resized = cv2.resize(cam, (224, 224))
223
+
224
+ heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
225
+ heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.0
226
+
227
+ overlay = img_array * 0.5 + heatmap * 0.5
228
+ overlay = np.clip(overlay, 0, 1)
229
+
230
+ fig, ax = plt.subplots(figsize=(7, 7))
231
+ ax.imshow(overlay)
232
+ ax.axis("off")
233
+ ax.set_title(
234
+ "Explainable AI Overlay\n(Anatomy + Model Attention)",
235
+ fontsize=14,
236
+ fontweight="bold",
237
+ pad=20,
238
+ )
239
+ plt.tight_layout()
240
+ return _figure_to_pil()
241
+
242
+
243
+ def create_probability_bar(results):
244
+ """Horizontal bar chart of 4-class probabilities."""
245
+ classes = list(results.keys())
246
+ values = [results[c] for c in classes]
247
+ y_pos = np.arange(len(classes))
248
+
249
+ fig, ax = plt.subplots(figsize=(6.4, 3.5))
250
+ ax.barh(y_pos, values)
251
+ ax.set_yticks(y_pos)
252
+ ax.set_yticklabels(classes)
253
+ ax.invert_yaxis()
254
+ ax.set_xlim(0, 100)
255
+ ax.set_xlabel("Probability (%)")
256
+ ax.set_title("Probability Profile by Class", fontsize=12, fontweight="bold")
257
+ for i, v in enumerate(values):
258
+ ax.text(v + 1, i, f"{v:.1f}%", va="center", fontsize=9)
259
+ plt.tight_layout()
260
+ return _figure_to_pil()
261
+
262
+ # ============================================================================
263
+ # Interpretation
264
+ # ============================================================================
265
+
266
+
267
+ def triage_label(pred_label, confidence):
268
+ """
269
+ Simple triage categorisation for clinicians / dashboards.
270
+ """
271
+ high = confidence >= 85
272
+ moderate = 65 <= confidence < 85
273
+
274
+ if pred_label == "Normal":
275
+ if high:
276
+ return "🟢 Low risk – no major abnormality detected (model view)"
277
+ elif moderate:
278
+ return "🟡 Likely normal, but low confidence – consider clinical context"
279
+ else:
280
+ return "🟡 Indeterminate – imaging looks close to normal, but model is uncertain"
281
+ else:
282
+ if high:
283
+ return "🔴 High risk finding – prioritise expert review"
284
+ elif moderate:
285
+ return "🟠 Possible pathology – correlate with symptoms and labs"
286
+ else:
287
+ return "🟡 Weak signal – treat as a soft flag, not a diagnosis"
288
+
289
+
290
+ def create_interpretation(pred_label, confidence, results, audience="Clinician"):
291
+ header_note = {
292
+ "Clinician": "Optimised for **clinical decision support** – not a replacement for your judgement.",
293
+ "Researcher": "Optimised for **model behaviour analysis** and experimental workflows.",
294
+ "Patient / Public": "Optimised for **patient-friendly language**. Always discuss results with a doctor.",
295
+ }.get(audience, "Use this output as a **screening aid**, not a final diagnosis.")
296
+
297
+ interpretation = f"""
298
+ ## 🔬 Analysis Results ({audience} View)
299
+
300
+ > {header_note}
301
+
302
+ ### Primary Prediction: **{pred_label}**
303
+ - Confidence: **{confidence:.1f}%**
304
+ - Triage comment: {triage_label(pred_label, confidence)}
305
+
306
+ ### Probability Breakdown
307
+ - 🟢 Normal: **{results['Normal']:.1f}%**
308
+ - 🔴 Tuberculosis: **{results['Tuberculosis']:.1f}%**
309
+ - 🟠 Pneumonia: **{results['Pneumonia']:.1f}%**
310
+ - 🟣 COVID-19: **{results['COVID-19']:.1f}%**
311
+
312
+ ---
313
+ """
314
+
315
+ # Disease-specific narrative
316
+ if pred_label == "Tuberculosis":
317
+ if confidence >= 85:
318
+ interpretation += """
319
+ ### 🧫 TB Pattern – High Confidence
320
+
321
+ The model has detected features strongly suggestive of **pulmonary tuberculosis**.
322
+
323
+ **Suggested Clinical Pathway**
324
+ 1. Prompt review by a clinician / chest physician
325
+ 2. Sputum testing (AFB smear, GeneXpert MTB/RIF, or TB-PCR)
326
+ 3. Correlate with symptoms:
327
+ - Chronic cough (>2 weeks)
328
+ - Weight loss, night sweats
329
+ - Fever, fatigue
330
+ - Haemoptysis
331
+ 4. Consider CT or further imaging if discordant with clinical picture
332
+ 5. Infection control and contact tracing as per TB guidelines
333
+ """
334
+ else:
335
+ interpretation += """
336
+ ### 🧫 TB Pattern – Possible
337
+
338
+ There are features that **could** be compatible with TB, but the confidence is moderate.
339
+
340
+ - Review history and risk factors
341
+ - Consider sputum testing if suspicion is non-trivial
342
+ - Follow-up imaging where indicated
343
+ """
344
+
345
+ elif pred_label == "Pneumonia":
346
+ if confidence >= 85:
347
+ interpretation += """
348
+ ### 🌫 Pneumonia Pattern – High Confidence
349
+
350
+ The model has detected an opacity pattern consistent with **pneumonia**.
351
+
352
+ Typical clinical picture may include:
353
+
354
+ - Fever, productive cough
355
+ - Shortness of breath
356
+ - Pleuritic chest pain
357
+
358
+ Use in combination with examination, labs (WBC, CRP, cultures) and local treatment guidelines.
359
+ """
360
+ else:
361
+ interpretation += """
362
+ ### 🌫 Pneumonia Pattern – Possible
363
+
364
+ Findings may be compatible with pneumonia, but alternative explanations exist.
365
+
366
+ - Check vitals and respiratory exam
367
+ - Labs and microbiology can support or refute the impression
368
+ - Consider watchful follow-up or repeat imaging
369
+ """
370
+
371
+ elif pred_label == "COVID-19":
372
+ if confidence >= 85:
373
+ interpretation += """
374
+ ### 🦠 COVID-19 Pattern – High Confidence
375
+
376
+ The distribution and appearance of opacities are compatible with **COVID-19 pneumonia**.
377
+
378
+ ⚠️ Imaging alone is **not diagnostic**. Key points:
379
+
380
+ - Confirmation requires RT-PCR or validated antigen testing
381
+ - Follow local isolation and infection-control policies
382
+ - Monitor SpO₂ and work of breathing; escalate care if deterioration occurs
383
+ """
384
+ else:
385
+ interpretation += """
386
+ ### 🦠 COVID-19 Pattern – Possible
387
+
388
+ There are features that could overlap with COVID-19, but uncertainty is substantial.
389
+
390
+ - Testing (RT-PCR / antigen) is essential
391
+ - Integrate exposure history, symptoms, and public health guidance
392
+ """
393
+
394
+ else: # Normal
395
+ if confidence >= 85:
396
+ interpretation += """
397
+ ### ✅ No Major Abnormality Detected (Model View)
398
+
399
+ The model did **not** detect strong features of TB, pneumonia, or COVID-19.
400
+
401
+ **Important caveats**
402
+
403
+ - Early disease or small lesions may be missed
404
+ - Non-infective conditions (e.g. malignancy, ILD) are **not** specifically evaluated
405
+ - Persistent or unexplained symptoms still require clinical review
406
+ """
407
+ else:
408
+ interpretation += """
409
+ ### ℹ️ Likely Normal, But Low Confidence
410
+
411
+ The scan leans towards **normal**, but the model's confidence is limited.
412
+
413
+ - Consider repeat imaging, further tests, or expert review if symptoms persist
414
+ """
415
+
416
+ interpretation += """
417
+ ---
418
+ ## ⚠️ CRITICAL MEDICAL DISCLAIMER
419
+
420
+ - This AI system is a **screening / decision-support tool** only
421
+ - It is **not FDA-approved**, CE-marked, or licensed as a medical device
422
+ - It must **not** be used as a stand-alone diagnostic system
423
+
424
+ Always integrate:
425
+ - Clinical history and examination
426
+ - Laboratory tests (e.g. sputum AFB / GeneXpert, PCR, cultures)
427
+ - Radiologist / specialist interpretation
428
+
429
+ **Gold Standards**
430
+
431
+ - Tuberculosis: Sputum AFB / culture, GeneXpert MTB/RIF, TB-PCR
432
+ - Pneumonia: Clinical diagnosis + labs / microbiology
433
+ - COVID-19: RT-PCR or validated antigen tests
434
+
435
+ When in doubt, consult a qualified healthcare professional.
436
+ ---
437
+ 🫁 **Powered by Adaptive Sparse Training (AST)**
438
+ Energy-efficient deep learning to support lung health in both high-resource and low-resource settings.
439
+
440
+ **Project Links**
441
+
442
+ - GitHub: https://github.com/oluwafemidiakhoa/Tuberculosis
443
+ - Hugging Face Space: https://huggingface.co/spaces/mgbam/Tuberculosis
444
+ """
445
+ return interpretation
446
+
447
+ # ============================================================================
448
+ # Prediction Pipeline
449
+ # ============================================================================
450
+
451
+
452
+ def predict_chest_xray(image, show_gradcam=True, audience="Clinician"):
453
+ """
454
+ Main inference function used by Gradio.
455
+ Returns:
456
+ - dict of class probabilities
457
+ - annotated original
458
+ - grad-cam heatmap
459
+ - overlay
460
+ - full markdown report
461
+ - short textual snapshot
462
+ - probability bar-chart image
463
+ """
464
+ if image is None:
465
+ msg = "👋 Upload a chest X-ray (PNG/JPG) and click **Analyze** to generate a full AI report."
466
+ return {}, None, None, None, msg, "Awaiting image upload…", None
467
+
468
+ if isinstance(image, np.ndarray):
469
+ image = Image.fromarray(image).convert("RGB")
470
+ else:
471
+ image = image.convert("RGB")
472
+
473
+ original_img = image.copy()
474
+ input_tensor = transform(image).unsqueeze(0).to(device)
475
+
476
+ with torch.set_grad_enabled(show_gradcam):
477
+ if show_gradcam:
478
+ cam, output = grad_cam.generate(input_tensor)
479
+ else:
480
+ output = model(input_tensor)
481
+ cam = None
482
+
483
+ probs = torch.softmax(output, dim=1)[0].cpu().detach().numpy()
484
+ prob_sum = float(np.sum(probs))
485
+
486
+ if not (0.99 <= prob_sum <= 1.01):
487
+ print(f"⚠️ WARNING: Probability sum is {prob_sum}, not ≈1.0 – check model weights.")
488
+
489
+ pred_class = int(output.argmax(dim=1).item())
490
+ pred_label = CLASSES[pred_class]
491
+ confidence = float(probs[pred_class]) * 100.0
492
+
493
+ results = {
494
+ CLASSES[i]: float(min(100.0, max(0.0, probs[i] * 100.0)))
495
+ for i in range(len(CLASSES))
496
+ }
497
+
498
+ original_pil = create_original_display(original_img, pred_label, confidence)
499
+ gradcam_viz = create_gradcam_visualization(original_img, cam) if cam is not None else None
500
+ overlay_viz = create_overlay_visualization(original_img, cam) if cam is not None else None
501
+ prob_chart = create_probability_bar(results)
502
+
503
+ interpretation = create_interpretation(pred_label, confidence, results, audience=audience)
504
+
505
+ snapshot = (
506
+ f"### 📝 Triage Snapshot\n\n"
507
+ f"- **Finding:** {pred_label}\n"
508
+ f"- **Model confidence:** {confidence:.1f}%\n"
509
+ f"- **Triage comment:** {triage_label(pred_label, confidence)}\n"
510
+ f"- **Probability sum (sanity check):** {prob_sum:.3f}"
511
+ )
512
+
513
+ return results, original_pil, gradcam_viz, overlay_viz, interpretation, snapshot, prob_chart
514
+
515
+ # ============================================================================
516
+ # WOW UI / UX – Gradio App
517
+ # ============================================================================
518
+
519
+ custom_css = """
520
+ :root {
521
+ --primary: #6366f1;
522
+ --primary-soft: rgba(99, 102, 241, 0.12);
523
+ --accent: #ec4899;
524
+ }
525
+
526
+ .gradio-container {
527
+ font-family: system-ui, -apple-system, BlinkMacSystemFont, "Inter", sans-serif;
528
+ background: radial-gradient(circle at top left, #111827 0, #020617 50%, #020617 100%);
529
+ color: #e5e7eb;
530
+ }
531
+
532
+ #hero {
533
+ padding: 24px 24px 8px 24px;
534
+ border-radius: 24px;
535
+ background: linear-gradient(120deg, rgba(99,102,241,0.22), rgba(236,72,153,0.18));
536
+ border: 1px solid rgba(148, 163, 184, 0.45);
537
+ box-shadow: 0 24px 60px rgba(15,23,42,0.9);
538
+ backdrop-filter: blur(18px);
539
+ }
540
+
541
+ .hero-title {
542
+ font-size: 2.4rem;
543
+ font-weight: 800;
544
+ letter-spacing: 0.04em;
545
+ color: #f9fafb;
546
+ margin-bottom: 6px;
547
+ }
548
+
549
+ .hero-subtitle {
550
+ font-size: 0.98rem;
551
+ color: #e5e7eb;
552
+ }
553
+
554
+ .hero-chip-row {
555
+ display: flex;
556
+ flex-wrap: wrap;
557
+ gap: 8px;
558
+ margin-top: 14px;
559
+ }
560
+
561
+ .hero-chip {
562
+ padding: 4px 10px;
563
+ border-radius: 999px;
564
+ font-size: 0.78rem;
565
+ background: rgba(15,23,42,0.8);
566
+ border: 1px solid rgba(148,163,184,0.5);
567
+ display: inline-flex;
568
+ align-items: center;
569
+ gap: 6px;
570
+ color: #e5e7eb;
571
+ }
572
+
573
+ .pulse-dot {
574
+ width: 8px;
575
+ height: 8px;
576
+ border-radius: 999px;
577
+ background: #22c55e;
578
+ box-shadow: 0 0 0 0 rgba(34,197,94,0.7);
579
+ animation: pulse 1.4s infinite;
580
+ }
581
+
582
+ @keyframes pulse {
583
+ 0% { box-shadow: 0 0 0 0 rgba(34,197,94,0.7); }
584
+ 70% { box-shadow: 0 0 0 10px rgba(34,197,94,0); }
585
+ 100% { box-shadow: 0 0 0 0 rgba(34,197,94,0); }
586
+ }
587
+
588
+ .glass-card {
589
+ background: rgba(15,23,42,0.86);
590
+ border-radius: 18px;
591
+ border: 1px solid rgba(148,163,184,0.4);
592
+ box-shadow: 0 18px 40px rgba(15,23,42,0.9);
593
+ padding: 18px;
594
+ backdrop-filter: blur(16px);
595
+ }
596
+
597
+ .glass-card-light {
598
+ background: rgba(15,23,42,0.7);
599
+ border-radius: 18px;
600
+ border: 1px solid rgba(148,163,184,0.3);
601
+ box-shadow: 0 12px 30px rgba(15,23,42,0.9);
602
+ padding: 16px;
603
+ backdrop-filter: blur(12px);
604
+ }
605
+
606
+ .stat-pill {
607
+ padding: 10px 12px;
608
+ border-radius: 14px;
609
+ background: rgba(15,23,42,0.9);
610
+ border: 1px solid rgba(148,163,184,0.5);
611
+ font-size: 0.78rem;
612
+ display: flex;
613
+ flex-direction: column;
614
+ gap: 2px;
615
+ }
616
+
617
+ .stat-pill-label {
618
+ color: #9ca3af;
619
+ text-transform: uppercase;
620
+ font-size: 0.68rem;
621
+ }
622
+
623
+ .stat-pill-value {
624
+ color: #e5e7eb;
625
+ font-weight: 600;
626
+ }
627
+
628
+ .dropzone-image img,
629
+ .output-image img {
630
+ border-radius: 16px !important;
631
+ }
632
+
633
+ footer {
634
+ text-align: center;
635
+ margin-top: 24px;
636
+ color: #9ca3af;
637
+ font-size: 0.78rem;
638
+ }
639
+ """
640
+
641
+ theme = gr.themes.Soft(
642
+ primary_hue="indigo",
643
+ secondary_hue="pink",
644
+ neutral_hue="slate",
645
+ ).set(
646
+ button_primary_background_fill="linear-gradient(135deg,#4f46e5,#ec4899)",
647
+ button_primary_background_fill_hover="linear-gradient(135deg,#6366f1,#f97316)",
648
+ )
649
+
650
+ with gr.Blocks(css=custom_css, theme=theme) as demo:
651
+ # HERO
652
+ gr.HTML(
653
+ f"""
654
+ <div id="hero">
655
+ <div style="display:flex;justify-content:space-between;gap:16px;align-items:flex-start;">
656
+ <div>
657
+ <div class="hero-title">🫁 AST Chest X-Ray Lab</div>
658
+ <div class="hero-subtitle">
659
+ Multi-class chest X-ray analysis with <b>Explainable AI</b> and
660
+ <b>Adaptive Sparse Training</b> – Normal, Tuberculosis, Pneumonia, COVID-19.
661
+ Designed to support clinicians, researchers, and global health teams.
662
+ </div>
663
+ <div class="hero-chip-row">
664
+ <div class="hero-chip">
665
+ <span class="pulse-dot"></span>
666
+ Live Inference (Research Prototype)
667
+ </div>
668
+ <div class="hero-chip">
669
+ EfficientNet-B0 · ~{TOTAL_PARAMS_M:.1f}M parameters
670
+ </div>
671
+ <div class="hero-chip">
672
+ 95–97% validation accuracy · ~89% training energy savings
673
+ </div>
674
+ <div class="hero-chip">
675
+ {MODEL_LOAD_INFO}
676
+ </div>
677
+ </div>
678
+ </div>
679
+ <div style="min-width:220px;display:flex;flex-direction:column;gap:8px;">
680
+ <div class="stat-pill">
681
+ <div class="stat-pill-label">Compute</div>
682
+ <div class="stat-pill-value">{device.type.upper()}</div>
683
+ </div>
684
+ <div class="stat-pill">
685
+ <div class="stat-pill-label">Use Case</div>
686
+ <div class="stat-pill-value">Triage & decision support (not diagnostic)</div>
687
+ </div>
688
+ </div>
689
+ </div>
690
+ </div>
691
+ """
692
+ )
693
+
694
+ gr.Markdown(" ")
695
+
696
+ with gr.Row(equal_height=True):
697
+ # LEFT: INPUT PANEL
698
+ with gr.Column(scale=1, elem_classes="glass-card"):
699
+ gr.Markdown("### 1️⃣ Upload & Configure")
700
+
701
+ image_input = gr.Image(
702
+ type="pil",
703
+ label="Drop a chest X-ray here",
704
+ elem_classes=["dropzone-image"],
705
+ )
706
+
707
+ with gr.Row():
708
+ show_gradcam = gr.Checkbox(
709
+ value=True,
710
+ label="Explainable AI (Grad-CAM)",
711
+ info="Highlight regions that drive the prediction",
712
+ )
713
+ audience_select = gr.Radio(
714
+ ["Clinician", "Researcher", "Patient / Public"],
715
+ value="Clinician",
716
+ label="Report Style",
717
+ )
718
+
719
+ with gr.Row():
720
+ analyze_btn = gr.Button("🔬 Analyze X-Ray", variant="primary", scale=3)
721
+ clear_btn = gr.Button("🧹 Reset", variant="secondary")
722
+
723
+ gr.Markdown(
724
+ """
725
+ **Usage Notes**
726
+
727
+ - Best for frontal (PA/AP) chest X-rays in PNG / JPG format
728
+ - Intended for **triage, education, and research**, not final diagnosis
729
+ - For off-axis, noisy, or portable images, interpret outputs with extra caution
730
+ """
731
+ )
732
+
733
+ # RIGHT: RESULTS PANEL
734
+ with gr.Column(scale=2, elem_classes="glass-card-light"):
735
+ gr.Markdown("### 2️⃣ AI Dashboard")
736
+
737
+ with gr.Tabs():
738
+ with gr.Tab("Triage Snapshot"):
739
+ snapshot_output = gr.Markdown(
740
+ "No scan analysed yet. Upload an X-ray to get started."
741
+ )
742
+ with gr.Row():
743
+ prob_output = gr.Label(
744
+ label="Prediction Confidence (All Classes)",
745
+ num_top_classes=4,
746
+ )
747
+ prob_chart_output = gr.Image(
748
+ label="Probability Profile",
749
+ elem_classes=["output-image"],
750
+ )
751
+
752
+ with gr.Tab("Visual Explanations"):
753
+ with gr.Row():
754
+ original_output = gr.Image(
755
+ label="Annotated X-ray",
756
+ elem_classes=["output-image"],
757
+ )
758
+ gradcam_output = gr.Image(
759
+ label="Attention Heatmap",
760
+ elem_classes=["output-image"],
761
+ )
762
+ overlay_output = gr.Image(
763
+ label="Explainable Overlay",
764
+ elem_classes=["output-image"],
765
+ )
766
+
767
+ with gr.Tab("Full Report"):
768
+ interpretation_output = gr.Markdown(
769
+ "The full clinical / research report will appear here after inference."
770
+ )
771
+
772
+ with gr.Tab("Model Card"):
773
+ gr.Markdown(
774
+ f"""
775
+ ### 🧠 Model Card – AST Chest X-Ray
776
+
777
+ - **Backbone**: EfficientNet-B0
778
+ - **Task**: 4-way classification (Normal, Tuberculosis, Pneumonia, COVID-19)
779
+ - **Optimisation**: Sample-based Adaptive Sparse Training (AST)
780
+ - **Motivation**: Energy-efficient AI for global lung health
781
+
782
+ **Intended Use**
783
+
784
+ - Research and prototyping
785
+ - Triage decision-support in pilot settings
786
+ - Education (medical students, residents, data scientists)
787
+
788
+ **Non-Intended Use**
789
+
790
+ - Stand-alone diagnosis
791
+ - Automated treatment decisions
792
+ - Regulatory-grade clinical deployment
793
+
794
+ > Always pair the model with local guidelines, expert radiology, and laboratory testing.
795
+ """
796
+ )
797
+
798
+ gr.Markdown("---")
799
+
800
+ gr.HTML(
801
+ """
802
+ <footer>
803
+ <p>
804
+ <b>AST Chest X-Ray Lab</b> · Normal · TB · Pneumonia · COVID-19 · Explainable AI<br/>
805
+ Built to explore how energy-efficient AI can support clinicians and patients worldwide.
806
+ </p>
807
+ <p style="margin-top:6px;">
808
+ ⚠️ <b>MEDICAL DISCLAIMER:</b> This tool is a research prototype and is not FDA-approved,
809
+ CE-marked, or licensed as a medical device. All clinical decisions must be made by
810
+ qualified healthcare professionals.
811
+ </p>
812
+ </footer>
813
+ """
814
+ )
815
+
816
+ # Wiring – connect logic to UI
817
+ analyze_btn.click(
818
+ fn=predict_chest_xray,
819
+ inputs=[image_input, show_gradcam, audience_select],
820
+ outputs=[
821
+ prob_output,
822
+ original_output,
823
+ gradcam_output,
824
+ overlay_output,
825
+ interpretation_output,
826
+ snapshot_output,
827
+ prob_chart_output,
828
+ ],
829
+ )
830
+
831
+ clear_btn.click(
832
+ fn=lambda: (
833
+ {},
834
+ None,
835
+ None,
836
+ None,
837
+ "Awaiting image upload…",
838
+ "Awaiting image upload…",
839
+ None,
840
+ ),
841
+ inputs=None,
842
+ outputs=[
843
+ prob_output,
844
+ original_output,
845
+ gradcam_output,
846
+ overlay_output,
847
+ interpretation_output,
848
+ snapshot_output,
849
+ prob_chart_output,
850
+ ],
851
+ )
852
+
853
+ # Example X-rays (optional – comment out if these paths don't exist)
854
+ gr.Markdown("### 🔍 Try Example X-rays")
855
+ gr.Examples(
856
+ examples=[
857
+ ["examples/normal.png"],
858
+ ["examples/tb.png"],
859
+ ["examples/pneumonia.png"],
860
+ ["examples/covid.png"],
861
+ ],
862
+ inputs=image_input,
863
+ )
864
+
865
+ # ============================================================================
866
+ # Launch
867
+ # ============================================================================
868
+
869
+ if __name__ == "__main__":
870
+ demo.launch(
871
+ share=False,
872
+ server_name="0.0.0.0",
873
+ server_port=7860,
874
+ show_error=True,
875
+ )