drvikasgaur commited on
Commit
6251f25
·
verified ·
1 Parent(s): 09e7319

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +131 -174
app.py CHANGED
@@ -11,19 +11,9 @@
11
  # 3) Final consensus (comparison + next step)
12
  # - Adds collapsible detailed report per image
13
  # - Keeps gallery, adds legend, better labels
 
14
  #
15
  # HF Spaces: use relative weight paths (edit below if needed)
16
- #
17
- # Requirements (requirements.txt):
18
- # gradio
19
- # torch
20
- # torchvision
21
- # timm
22
- # opencv-python
23
- # pillow
24
- # transformers
25
- # einops
26
- # open_clip_torch
27
 
28
  import os
29
  import cv2
@@ -70,10 +60,10 @@ RADIO_GATE_DEFAULT = 0.21
70
 
71
  # ---- Consensus logic thresholds ----
72
  TBNET_SCREEN_THR = 0.30
73
- TBNET_MARGIN = 0.03 # (kept for compatibility / future use)
74
 
75
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
76
- RADIO_MARGIN = 0.02 # (kept for compatibility / future use)
77
 
78
  # ---- Mask fail-safes ----
79
  FAIL_COV = 0.10
@@ -99,7 +89,6 @@ If clinical suspicion exists (fever, weight loss, immunosuppression, known expos
99
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
100
  """
101
 
102
- # Friendly labels still map to GREEN/YELLOW/RED logic
103
  REPORT_LABELS = {
104
  "GREEN": {
105
  "title": "LOW TB LIKELIHOOD / Pulmonary T.B not detected by A.I",
@@ -135,6 +124,51 @@ CLINICAL_GUIDANCE = (
135
  )
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # ============================================================
139
  # UX HELPERS
140
  # ============================================================
@@ -149,12 +183,10 @@ def pretty_state(s: str) -> str:
149
 
150
 
151
  def html_escape(s: str) -> str:
152
- # minimal escaping for safety in HTML blocks
153
  return (s or "").replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
154
 
155
 
156
  def badge_color_for_state(state: str) -> str:
157
- # soft visual cue; works in both dark/light
158
  if state == "TB+":
159
  return "rgba(239,68,68,0.18)" # red
160
  if state == "SCREEN+":
@@ -181,7 +213,8 @@ class DoubleConv(nn.Module):
181
  nn.ReLU(inplace=True),
182
  )
183
 
184
- def forward(self, x): return self.net(x)
 
185
 
186
 
187
  class LungUNet(nn.Module):
@@ -225,7 +258,8 @@ class TBNet(nn.Module):
225
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
226
  self.fc = nn.Linear(self.backbone.num_features, 1)
227
 
228
- def forward(self, x): return self.fc(self.backbone(x)).view(-1)
 
229
 
230
 
231
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
@@ -241,8 +275,11 @@ class GradCAM:
241
  target_layer.register_forward_hook(self._fwd)
242
  target_layer.register_full_backward_hook(self._bwd)
243
 
244
- def _fwd(self, _, __, out): self.activ = out
245
- def _bwd(self, _, grad_in, grad_out): self.grad = grad_out[0]
 
 
 
246
 
247
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
248
  with torch.enable_grad():
@@ -305,7 +342,9 @@ def border_fraction(gray_u8: np.ndarray) -> float:
305
  left = gray_u8[:, :b]
306
  right = gray_u8[:, -b:]
307
 
308
- def frac_border(x): return float(((x < 15) | (x > 240)).mean())
 
 
309
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
310
 
311
 
@@ -325,22 +364,29 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
325
 
326
  if likely_phone:
327
  if sharp < 40:
328
- score -= 25; warnings.append("Blurry / motion blur detected (likely phone capture).")
 
329
  elif sharp < 80:
330
- score -= 12; warnings.append("Slight blur detected.")
 
331
  else:
332
  if sharp < 30:
333
- score -= 8; warnings.append("Low fine detail (possible downsampling).")
 
334
 
335
  if hi_clip > 0.05:
336
- score -= 15; warnings.append("Overexposed highlights (washed-out areas).")
 
337
  if lo_clip > 0.10:
338
- score -= 12; warnings.append("Underexposed shadows (very dark areas).")
 
339
 
340
  if border > 0.55:
341
- score -= 18; warnings.append("Large border/margins detected (possible screenshot/phone framing).")
 
342
  elif border > 0.35:
343
- score -= 10; warnings.append("Some border/margins detected.")
 
344
 
345
  return float(np.clip(score, 0, 100)), warnings
346
 
@@ -349,23 +395,28 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
349
  g = gray_u8.copy()
350
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
351
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
352
- if th.mean() > 127: th = 255 - th
 
353
 
354
  k = max(3, int(0.01 * min(g.shape)))
355
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
356
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
357
 
358
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
359
- if not contours: return gray_u8
 
360
 
361
  c = max(contours, key=cv2.contourArea)
362
  x, y, w, h = cv2.boundingRect(c)
363
  H, W = gray_u8.shape
364
- if w * h < 0.20 * (H * W): return gray_u8
 
365
 
366
  pad = int(0.03 * min(H, W))
367
- x1 = max(0, x - pad); y1 = max(0, y - pad)
368
- x2 = min(W, x + w + pad); y2 = min(H, y + h + pad)
 
 
369
  return gray_u8[y1:y2, x1:x2]
370
 
371
 
@@ -493,8 +544,6 @@ def recommendation_for_band(band: Optional[str]) -> str:
493
  # CONSENSUS LOGIC (TBNet vs RADIO)
494
  # ============================================================
495
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
496
- # IMPORTANT UX FIX:
497
- # If TBNet band is YELLOW, treat it as INDET (not LOW).
498
  if tb_band == "RED":
499
  return "TB+"
500
  if tb_band == "YELLOW":
@@ -523,8 +572,6 @@ def build_consensus(
523
  Returns:
524
  consensus_label, consensus_detail, tb_state, radio_state
525
  """
526
-
527
- # TBNet unavailable => consensus is N/A (fail-safe)
528
  if tb_prob is None or tb_band is None:
529
  return (
530
  "N/A",
@@ -533,7 +580,6 @@ def build_consensus(
533
  "N/A",
534
  )
535
 
536
- # Choose RADIO primary score
537
  if radio_masked is not None:
538
  radio_primary = radio_masked
539
  radio_used = "MASKED"
@@ -544,7 +590,6 @@ def build_consensus(
544
  tb_state = tbnet_state(tb_prob, tb_band)
545
 
546
  if radio_primary is None:
547
- # RADIO disabled or errored; TBNet-only
548
  return (
549
  "TBNet only",
550
  f"{MODEL_NAME_RADIO} unavailable → TBNet state={tb_state}, p={tb_prob:.4f} (band={tb_band}).",
@@ -555,7 +600,6 @@ def build_consensus(
555
  radio_state = radio_state_from_prob(radio_primary)
556
  rb = f" (RADIO band={radio_band})" if radio_band else ""
557
 
558
- # If either is INDET, handle explicitly
559
  if tb_state == "INDET" and radio_state == "INDET":
560
  return (
561
  "AGREE: INDET",
@@ -580,7 +624,6 @@ def build_consensus(
580
  radio_state,
581
  )
582
 
583
- # Perfect agreement
584
  if tb_state == radio_state:
585
  return (
586
  f"AGREE: {tb_state}",
@@ -589,7 +632,6 @@ def build_consensus(
589
  radio_state,
590
  )
591
 
592
- # Strong disagreement when one is LOW and other is SCREEN+/TB+
593
  if (tb_state in ("SCREEN+", "TB+") and radio_state == "LOW") or (radio_state in ("SCREEN+", "TB+") and tb_state == "LOW"):
594
  return (
595
  "DISAGREE",
@@ -598,7 +640,6 @@ def build_consensus(
598
  radio_state,
599
  )
600
 
601
- # Otherwise mixed (e.g., SCREEN+ vs TB+)
602
  return (
603
  "MIXED",
604
  f"Mixed: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.",
@@ -624,8 +665,10 @@ class ModelBundle:
624
  transforms.ToPILImage(),
625
  transforms.Resize((224, 224)),
626
  transforms.ToTensor(),
627
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
628
- std=[0.229, 0.224, 0.225]),
 
 
629
  ])
630
 
631
  def load(self, tb_weights: str, lung_weights: str, backbone: str = "efficientnet_b0"):
@@ -740,11 +783,13 @@ def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: floa
740
 
741
 
742
  @torch.inference_mode()
743
- def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
744
- lung_mask_u8: np.ndarray,
745
- coverage: float,
746
- device: torch.device,
747
- gate_threshold: float) -> Dict[str, Any]:
 
 
748
  RADIO_BUNDLE.load(device=device)
749
  dtype = torch.float16 if device.type == "cuda" else torch.float32
750
 
@@ -835,7 +880,6 @@ def analyze_one_image(
835
  fail_cov: float = FAIL_COV,
836
  warn_cov: float = WARN_COV,
837
  ) -> Dict[str, Any]:
838
-
839
  BUNDLE.load(tb_weights, lung_weights, backbone)
840
  device = BUNDLE.device
841
 
@@ -852,8 +896,8 @@ def analyze_one_image(
852
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
853
 
854
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
855
-
856
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
 
857
  k = max(3, int(0.02 * 256))
858
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
859
  mask256_bin = cv2.morphologyEx(mask256_bin, cv2.MORPH_CLOSE, kernel, iterations=1)
@@ -951,10 +995,14 @@ def analyze_one_image(
951
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
952
 
953
  warnings = []
954
- if phone_mode: warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
955
- if q_score < 55: warnings.append("Suboptimal image quality limits AI reliability.")
956
- if coverage < warn_cov: warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
957
- if diffuse: warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
 
 
 
 
958
  warnings.extend(q_warn)
959
 
960
  return {
@@ -1005,7 +1053,7 @@ def run_analysis(
1005
  gallery_items = []
1006
  details_md: List[str] = []
1007
 
1008
- # Top banner (legend / how to read)
1009
  summary_md.append(f"""
1010
  <div style="border:1px solid rgba(255,255,255,0.10); border-radius:14px; padding:12px; margin:10px 0;">
1011
  <div style="font-size:18px; font-weight:900;">Results</div>
@@ -1041,9 +1089,7 @@ def run_analysis(
1041
  img_size=224,
1042
  )
1043
 
1044
- # ------------------------
1045
  # RADIO (optional)
1046
- # ------------------------
1047
  radio_text_long = f"{MODEL_NAME_RADIO} disabled."
1048
  radio_raw_overlay = None
1049
  radio_masked_overlay = None
@@ -1091,9 +1137,7 @@ def run_analysis(
1091
  radio_band = None
1092
  radio_masked_ran = False
1093
 
1094
- # ------------------------
1095
- # Consensus (returns also states)
1096
- # ------------------------
1097
  consensus_label, consensus_detail, tb_state, radio_state = build_consensus(
1098
  tb_prob=out["prob"],
1099
  tb_band=out["band"],
@@ -1102,27 +1146,22 @@ def run_analysis(
1102
  radio_band=radio_band,
1103
  )
1104
 
1105
- # ------------------------
1106
- # Build descriptive cards (TBNet, RADIO, Consensus)
1107
- # ------------------------
1108
  tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1109
- tb_band = out.get("band", "YELLOW")
1110
  tb_label = out.get("pred", "INDETERMINATE")
1111
  q = float(out.get("quality_score", 0.0))
1112
  cov = float(out.get("lung_coverage", 0.0))
 
1113
 
1114
  warns = out.get("warnings", [])
1115
  top_warns = warns[:3] if warns else []
1116
  top_warn_line = " • ".join([html_escape(w) for w in top_warns]) if top_warns else "None"
1117
 
1118
- attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
1119
-
1120
  radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
1121
  radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
1122
  radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1123
 
1124
  if consensus_label == "DISAGREE":
1125
- next_step = "✅ Next step: Treat as **indeterminate** → radiologist review + microbiology if clinically suspected."
1126
  elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
1127
  next_step = "✅ Next step: Prompt clinician/radiologist review; consider microbiological confirmation if clinically suspected."
1128
  elif consensus_label in ("AGREE: LOW",):
@@ -1172,9 +1211,6 @@ def run_analysis(
1172
  <b>Scores:</b> PRIMARY <b>{radio_primary_line}</b> &nbsp; | &nbsp; RAW <b>{radio_raw_line}</b> &nbsp; | &nbsp; MASKED <b>{radio_masked_line}</b>
1173
  </div>
1174
  <div style="margin-top:6px; opacity:0.85;">{html_escape(gate_info)}</div>
1175
- <div style="margin-top:10px; padding:10px 12px; border-left:6px solid rgba(52,211,153,0.9); background: rgba(52,211,153,0.10); border-radius:12px;">
1176
- ✅ Recommendation: Use RADIO as decision-support only; correlate clinically and consider radiologist review.
1177
- </div>
1178
  </div>
1179
  """
1180
 
@@ -1196,9 +1232,7 @@ def run_analysis(
1196
  summary_md.append(radio_card)
1197
  summary_md.append(consensus_card)
1198
 
1199
- # ------------------------
1200
- # Visual outputs (gallery)
1201
- # ------------------------
1202
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1203
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1204
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
@@ -1217,49 +1251,36 @@ def run_analysis(
1217
  if radio_masked_overlay is not None:
1218
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1219
 
1220
- # ------------------------
1221
- # Detailed collapsible report (per image)
1222
- # ------------------------
1223
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1224
-
1225
  details_md.append(
1226
  f"""
1227
- <details style="margin: 10px 0;">
1228
- <summary style="cursor:pointer; font-weight:900; font-size:16px;">
1229
- {html_escape(name)} — Details (expand)
1230
- </summary>
1231
-
1232
- <div style="margin-top:10px;">
1233
-
1234
- **{MODEL_NAME_TBNET}**
1235
- - State: **{pretty_state(tb_state)}**
1236
- - Result: **{out['pred']}**
1237
- - Probability: **{tb_prob_line}**
1238
- - Interpretation:
1239
- {out['band_text']}
1240
-
1241
- **Reliability**
1242
- - Image quality: **{out['quality_score']:.0f}/100**
1243
- - Lung mask coverage: **{out.get('lung_coverage', 0.0) * 100:.1f}%**
1244
- - Attention pattern: **{"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}**
1245
-
1246
- **Notes that may affect reliability**
1247
- {warn_txt}
1248
 
1249
- **{MODEL_NAME_RADIO}**
1250
- {radio_text_long}
 
 
 
 
 
1251
 
1252
  **Consensus**
1253
- - TBNet: **{pretty_state(tb_state)}**
1254
- - RADIO: **{pretty_state(radio_state)}**
1255
- - Label: **{consensus_label}**
1256
- - Rationale: {consensus_detail}
1257
 
1258
- **Clinical guidance**
1259
- {CLINICAL_GUIDANCE}
 
 
 
1260
 
1261
- </div>
1262
  </details>
 
 
1263
  """
1264
  )
1265
 
@@ -1285,71 +1306,7 @@ def build_ui():
1285
  # ---------------------------
1286
  with gr.Column(visible=True) as welcome_screen:
1287
  gr.Markdown('<div class="title">Welcome — TB X-ray Assistant (HF Spaces)</div>')
1288
-
1289
- gr.Markdown(
1290
- f"""
1291
- <style>
1292
- .wrap {{ max-width: 980px; margin: 0 auto; }}
1293
- .hero {{
1294
- padding: 14px 16px;
1295
- border-radius: 14px;
1296
- background: rgba(255,255,255,0.04);
1297
- border: 1px solid rgba(255,255,255,0.10);
1298
- }}
1299
- .hero h2 {{ margin: 0 0 6px 0; font-size: 18px; font-weight: 900; }}
1300
- .hero p {{ margin: 0; opacity: 0.9; font-size: 13px; line-height: 1.35; }}
1301
- .chips {{ margin-top: 10px; display: flex; flex-wrap: wrap; gap: 8px; }}
1302
- .chip {{
1303
- font-size: 12px; padding: 6px 10px; border-radius: 999px;
1304
- background: rgba(255,255,255,0.06);
1305
- border: 1px solid rgba(255,255,255,0.10);
1306
- opacity: 0.95;
1307
- }}
1308
- .warn {{
1309
- margin-top: 12px;
1310
- padding: 10px 12px;
1311
- border-left: 5px solid #f59e0b;
1312
- border-radius: 12px;
1313
- background: rgba(245,158,11,0.10);
1314
- font-size: 12.5px;
1315
- line-height: 1.35;
1316
- }}
1317
- .tip {{
1318
- margin-top: 10px;
1319
- opacity: 0.85;
1320
- font-size: 12.5px;
1321
- }}
1322
- .footer {{ margin-top: 10px; opacity: 0.65; font-size: 12px; }}
1323
- </style>
1324
-
1325
- <div class="wrap">
1326
- <div class="hero">
1327
- <h2>TB X-ray screening assistant (research / decision support)</h2>
1328
- <p>Upload chest X-rays to get an AI screening score, heatmaps, and a simple consensus output.</p>
1329
-
1330
- <div class="chips">
1331
- <span class="chip"><b>{MODEL_NAME_TBNET}</b> + Grad-CAM</span>
1332
- <span class="chip">Auto Lung Mask + fail-safe</span>
1333
- <span class="chip"><b>{MODEL_NAME_RADIO}</b> RAW/MASKED</span>
1334
- <span class="chip">Consensus: ✅ LOW · ⚠️ INDET · ⚠️ SCREEN+ · 🚩 TB+</span>
1335
- <span class="chip">Phone/WhatsApp Mode (safe preprocessing)</span>
1336
- </div>
1337
-
1338
- <div class="tip"><b>Tip:</b> Enable Phone/WhatsApp Mode for phone photos, WhatsApp-forwarded images, or screenshots with borders.</div>
1339
- </div>
1340
-
1341
- <div class="warn">
1342
- <b>Clinical disclaimer:</b> Not diagnostic. TB can be subtle (incl. miliary TB).
1343
- If TB is clinically suspected, pursue CBNAAT/GeneXpert/sputum and/or CT chest regardless of AI output.
1344
- </div>
1345
-
1346
- <div class="footer">
1347
- Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})
1348
- </div>
1349
- </div>
1350
- """
1351
- )
1352
-
1353
  continue_btn = gr.Button("Continue →", variant="primary")
1354
 
1355
  # ---------------------------
 
11
  # 3) Final consensus (comparison + next step)
12
  # - Adds collapsible detailed report per image
13
  # - Keeps gallery, adds legend, better labels
14
+ # - Fixes welcome screen rendering (uses gr.HTML + inline styles; no raw HTML shown)
15
  #
16
  # HF Spaces: use relative weight paths (edit below if needed)
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  import os
19
  import cv2
 
60
 
61
  # ---- Consensus logic thresholds ----
62
  TBNET_SCREEN_THR = 0.30
63
+ TBNET_MARGIN = 0.03 # kept for compatibility / future use
64
 
65
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
66
+ RADIO_MARGIN = 0.02 # kept for compatibility / future use
67
 
68
  # ---- Mask fail-safes ----
69
  FAIL_COV = 0.10
 
89
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
90
  """
91
 
 
92
  REPORT_LABELS = {
93
  "GREEN": {
94
  "title": "LOW TB LIKELIHOOD / Pulmonary T.B not detected by A.I",
 
124
  )
125
 
126
 
127
+ # ============================================================
128
+ # WELCOME HTML (minimal + main features only)
129
+ # IMPORTANT: rendered with gr.HTML (not gr.Markdown)
130
+ # ============================================================
131
+ WELCOME_HTML = f"""
132
+ <div style="max-width:980px;margin:0 auto;">
133
+ <div style="padding:14px 16px;border-radius:14px;background:rgba(255,255,255,0.04);border:1px solid rgba(255,255,255,0.10);">
134
+ <div style="font-size:18px;font-weight:900;margin-bottom:6px;">
135
+ TB X-ray Assistant <span style="opacity:0.75;font-weight:700;font-size:13px;">(research / decision support)</span>
136
+ </div>
137
+
138
+ <div style="opacity:0.9;font-size:13px;line-height:1.35;">
139
+ Upload chest X-rays to get an AI screening score, heatmaps, and a simple consensus output.
140
+ </div>
141
+
142
+ <div style="margin-top:10px;display:flex;flex-wrap:wrap;gap:8px;">
143
+ <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);">
144
+ <b>{MODEL_NAME_TBNET}</b> + Grad-CAM
145
+ </span>
146
+ <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);">
147
+ Auto lung mask + fail-safe
148
+ </span>
149
+ <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);">
150
+ <b>{MODEL_NAME_RADIO}</b> (optional)
151
+ </span>
152
+ <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);">
153
+ Consensus: ✅ LOW · ⚠️ INDET · ⚠️ SCREEN+ · 🚩 TB+
154
+ </span>
155
+ <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);">
156
+ Phone/WhatsApp Mode
157
+ </span>
158
+ </div>
159
+
160
+ <div style="margin-top:10px;opacity:0.85;font-size:12.5px;">
161
+ <b>Tip:</b> Turn on Phone/WhatsApp Mode for phone photos, WhatsApp-forwards, or screenshots with borders.
162
+ </div>
163
+ </div>
164
+
165
+ <div style="margin-top:12px;padding:10px 12px;border-left:5px solid #f59e0b;border-radius:12px;background:rgba(245,158,11,0.10);font-size:12.5px;line-height:1.35;">
166
+ <b>Clinical disclaimer:</b> Not diagnostic. If TB is suspected clinically, pursue CBNAAT/GeneXpert/sputum and/or CT chest regardless of AI output.
167
+ </div>
168
+ </div>
169
+ """
170
+
171
+
172
  # ============================================================
173
  # UX HELPERS
174
  # ============================================================
 
183
 
184
 
185
  def html_escape(s: str) -> str:
 
186
  return (s or "").replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
187
 
188
 
189
  def badge_color_for_state(state: str) -> str:
 
190
  if state == "TB+":
191
  return "rgba(239,68,68,0.18)" # red
192
  if state == "SCREEN+":
 
213
  nn.ReLU(inplace=True),
214
  )
215
 
216
+ def forward(self, x):
217
+ return self.net(x)
218
 
219
 
220
  class LungUNet(nn.Module):
 
258
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
259
  self.fc = nn.Linear(self.backbone.num_features, 1)
260
 
261
+ def forward(self, x):
262
+ return self.fc(self.backbone(x)).view(-1)
263
 
264
 
265
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
 
275
  target_layer.register_forward_hook(self._fwd)
276
  target_layer.register_full_backward_hook(self._bwd)
277
 
278
+ def _fwd(self, _, __, out):
279
+ self.activ = out
280
+
281
+ def _bwd(self, _, grad_in, grad_out):
282
+ self.grad = grad_out[0]
283
 
284
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
285
  with torch.enable_grad():
 
342
  left = gray_u8[:, :b]
343
  right = gray_u8[:, -b:]
344
 
345
+ def frac_border(x):
346
+ return float(((x < 15) | (x > 240)).mean())
347
+
348
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
349
 
350
 
 
364
 
365
  if likely_phone:
366
  if sharp < 40:
367
+ score -= 25
368
+ warnings.append("Blurry / motion blur detected (likely phone capture).")
369
  elif sharp < 80:
370
+ score -= 12
371
+ warnings.append("Slight blur detected.")
372
  else:
373
  if sharp < 30:
374
+ score -= 8
375
+ warnings.append("Low fine detail (possible downsampling).")
376
 
377
  if hi_clip > 0.05:
378
+ score -= 15
379
+ warnings.append("Overexposed highlights (washed-out areas).")
380
  if lo_clip > 0.10:
381
+ score -= 12
382
+ warnings.append("Underexposed shadows (very dark areas).")
383
 
384
  if border > 0.55:
385
+ score -= 18
386
+ warnings.append("Large border/margins detected (possible screenshot/phone framing).")
387
  elif border > 0.35:
388
+ score -= 10
389
+ warnings.append("Some border/margins detected.")
390
 
391
  return float(np.clip(score, 0, 100)), warnings
392
 
 
395
  g = gray_u8.copy()
396
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
397
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
398
+ if th.mean() > 127:
399
+ th = 255 - th
400
 
401
  k = max(3, int(0.01 * min(g.shape)))
402
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
403
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
404
 
405
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
406
+ if not contours:
407
+ return gray_u8
408
 
409
  c = max(contours, key=cv2.contourArea)
410
  x, y, w, h = cv2.boundingRect(c)
411
  H, W = gray_u8.shape
412
+ if w * h < 0.20 * (H * W):
413
+ return gray_u8
414
 
415
  pad = int(0.03 * min(H, W))
416
+ x1 = max(0, x - pad)
417
+ y1 = max(0, y - pad)
418
+ x2 = min(W, x + w + pad)
419
+ y2 = min(H, y + h + pad)
420
  return gray_u8[y1:y2, x1:x2]
421
 
422
 
 
544
  # CONSENSUS LOGIC (TBNet vs RADIO)
545
  # ============================================================
546
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
 
 
547
  if tb_band == "RED":
548
  return "TB+"
549
  if tb_band == "YELLOW":
 
572
  Returns:
573
  consensus_label, consensus_detail, tb_state, radio_state
574
  """
 
 
575
  if tb_prob is None or tb_band is None:
576
  return (
577
  "N/A",
 
580
  "N/A",
581
  )
582
 
 
583
  if radio_masked is not None:
584
  radio_primary = radio_masked
585
  radio_used = "MASKED"
 
590
  tb_state = tbnet_state(tb_prob, tb_band)
591
 
592
  if radio_primary is None:
 
593
  return (
594
  "TBNet only",
595
  f"{MODEL_NAME_RADIO} unavailable → TBNet state={tb_state}, p={tb_prob:.4f} (band={tb_band}).",
 
600
  radio_state = radio_state_from_prob(radio_primary)
601
  rb = f" (RADIO band={radio_band})" if radio_band else ""
602
 
 
603
  if tb_state == "INDET" and radio_state == "INDET":
604
  return (
605
  "AGREE: INDET",
 
624
  radio_state,
625
  )
626
 
 
627
  if tb_state == radio_state:
628
  return (
629
  f"AGREE: {tb_state}",
 
632
  radio_state,
633
  )
634
 
 
635
  if (tb_state in ("SCREEN+", "TB+") and radio_state == "LOW") or (radio_state in ("SCREEN+", "TB+") and tb_state == "LOW"):
636
  return (
637
  "DISAGREE",
 
640
  radio_state,
641
  )
642
 
 
643
  return (
644
  "MIXED",
645
  f"Mixed: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.",
 
665
  transforms.ToPILImage(),
666
  transforms.Resize((224, 224)),
667
  transforms.ToTensor(),
668
+ transforms.Normalize(
669
+ mean=[0.485, 0.456, 0.406],
670
+ std=[0.229, 0.224, 0.225]
671
+ ),
672
  ])
673
 
674
  def load(self, tb_weights: str, lung_weights: str, backbone: str = "efficientnet_b0"):
 
783
 
784
 
785
  @torch.inference_mode()
786
+ def radio_predict_from_arrays(
787
+ gray_vis_u8: np.ndarray,
788
+ lung_mask_u8: np.ndarray,
789
+ coverage: float,
790
+ device: torch.device,
791
+ gate_threshold: float
792
+ ) -> Dict[str, Any]:
793
  RADIO_BUNDLE.load(device=device)
794
  dtype = torch.float16 if device.type == "cuda" else torch.float32
795
 
 
880
  fail_cov: float = FAIL_COV,
881
  warn_cov: float = WARN_COV,
882
  ) -> Dict[str, Any]:
 
883
  BUNDLE.load(tb_weights, lung_weights, backbone)
884
  device = BUNDLE.device
885
 
 
896
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
897
 
898
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
 
899
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
900
+
901
  k = max(3, int(0.02 * 256))
902
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
903
  mask256_bin = cv2.morphologyEx(mask256_bin, cv2.MORPH_CLOSE, kernel, iterations=1)
 
995
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
996
 
997
  warnings = []
998
+ if phone_mode:
999
+ warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
1000
+ if q_score < 55:
1001
+ warnings.append("Suboptimal image quality limits AI reliability.")
1002
+ if coverage < warn_cov:
1003
+ warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
1004
+ if diffuse:
1005
+ warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
1006
  warnings.extend(q_warn)
1007
 
1008
  return {
 
1053
  gallery_items = []
1054
  details_md: List[str] = []
1055
 
1056
+ # Top banner
1057
  summary_md.append(f"""
1058
  <div style="border:1px solid rgba(255,255,255,0.10); border-radius:14px; padding:12px; margin:10px 0;">
1059
  <div style="font-size:18px; font-weight:900;">Results</div>
 
1089
  img_size=224,
1090
  )
1091
 
 
1092
  # RADIO (optional)
 
1093
  radio_text_long = f"{MODEL_NAME_RADIO} disabled."
1094
  radio_raw_overlay = None
1095
  radio_masked_overlay = None
 
1137
  radio_band = None
1138
  radio_masked_ran = False
1139
 
1140
+ # Consensus
 
 
1141
  consensus_label, consensus_detail, tb_state, radio_state = build_consensus(
1142
  tb_prob=out["prob"],
1143
  tb_band=out["band"],
 
1146
  radio_band=radio_band,
1147
  )
1148
 
 
 
 
1149
  tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
 
1150
  tb_label = out.get("pred", "INDETERMINATE")
1151
  q = float(out.get("quality_score", 0.0))
1152
  cov = float(out.get("lung_coverage", 0.0))
1153
+ attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
1154
 
1155
  warns = out.get("warnings", [])
1156
  top_warns = warns[:3] if warns else []
1157
  top_warn_line = " • ".join([html_escape(w) for w in top_warns]) if top_warns else "None"
1158
 
 
 
1159
  radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
1160
  radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
1161
  radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1162
 
1163
  if consensus_label == "DISAGREE":
1164
+ next_step = "✅ Next step: Treat as <b>indeterminate</b> → radiologist review + microbiology if clinically suspected."
1165
  elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
1166
  next_step = "✅ Next step: Prompt clinician/radiologist review; consider microbiological confirmation if clinically suspected."
1167
  elif consensus_label in ("AGREE: LOW",):
 
1211
  <b>Scores:</b> PRIMARY <b>{radio_primary_line}</b> &nbsp; | &nbsp; RAW <b>{radio_raw_line}</b> &nbsp; | &nbsp; MASKED <b>{radio_masked_line}</b>
1212
  </div>
1213
  <div style="margin-top:6px; opacity:0.85;">{html_escape(gate_info)}</div>
 
 
 
1214
  </div>
1215
  """
1216
 
 
1232
  summary_md.append(radio_card)
1233
  summary_md.append(consensus_card)
1234
 
1235
+ # Gallery
 
 
1236
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1237
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1238
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
 
1251
  if radio_masked_overlay is not None:
1252
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1253
 
1254
+ # Details (collapsible per image) — FIXED (no welcome HTML here)
 
 
1255
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
 
1256
  details_md.append(
1257
  f"""
1258
+ <details>
1259
+ <summary><b>{html_escape(name)}</b> detailed report</summary>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1260
 
1261
+ **TBNet**
1262
+ - Result: **{html_escape(tb_label)}**
1263
+ - Probability: {tb_prob_line}
1264
+ - Band: {out.get("band", "YELLOW")}
1265
+ - Quality: {q:.0f}/100
1266
+ - Lung mask coverage: {cov*100:.1f}%
1267
+ - Attention: {attention}
1268
 
1269
  **Consensus**
1270
+ - TBNet state: {pretty_state(tb_state)}
1271
+ - RADIO state: {pretty_state(radio_state)}
1272
+ - Consensus label: **{html_escape(consensus_label)}**
1273
+ - Detail: {html_escape(consensus_detail)}
1274
 
1275
+ **Warnings**
1276
+ {warn_txt}
1277
+
1278
+ **RADIO (full)**
1279
+ {radio_text_long}
1280
 
 
1281
  </details>
1282
+
1283
+ ---
1284
  """
1285
  )
1286
 
 
1306
  # ---------------------------
1307
  with gr.Column(visible=True) as welcome_screen:
1308
  gr.Markdown('<div class="title">Welcome — TB X-ray Assistant (HF Spaces)</div>')
1309
+ gr.HTML(WELCOME_HTML)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1310
  continue_btn = gr.Button("Continue →", variant="primary")
1311
 
1312
  # ---------------------------