drvikasgaur commited on
Commit
c88beff
·
verified ·
1 Parent(s): 01f4f67

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +215 -85
app.py CHANGED
@@ -136,7 +136,7 @@ CLINICAL_GUIDANCE = (
136
 
137
 
138
  # ============================================================
139
- # UX HELPERS (NEW)
140
  # ============================================================
141
  def pretty_state(s: str) -> str:
142
  return {
@@ -147,10 +147,12 @@ def pretty_state(s: str) -> str:
147
  "N/A": "⚠️ N/A",
148
  }.get(s, f"⚠️ {s}")
149
 
 
150
  def html_escape(s: str) -> str:
151
  # minimal escaping for safety in HTML blocks
152
  return (s or "").replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
153
 
 
154
  def badge_color_for_state(state: str) -> str:
155
  # soft visual cue; works in both dark/light
156
  if state == "TB+":
@@ -181,6 +183,7 @@ class DoubleConv(nn.Module):
181
 
182
  def forward(self, x): return self.net(x)
183
 
 
184
  class LungUNet(nn.Module):
185
  def __init__(self):
186
  super().__init__()
@@ -224,10 +227,12 @@ class TBNet(nn.Module):
224
 
225
  def forward(self, x): return self.fc(self.backbone(x)).view(-1)
226
 
 
227
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
228
  sd = torch.load(ckpt_path, map_location=device)
229
  model.load_state_dict(sd, strict=True)
230
 
 
231
  class GradCAM:
232
  def __init__(self, model: nn.Module, target_layer: nn.Module):
233
  self.model = model
@@ -271,6 +276,7 @@ def preprocess_for_lung_unet(gray_u8: np.ndarray) -> torch.Tensor:
271
  g = (g - lo) / (hi - lo + 1e-8)
272
  return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float()
273
 
 
274
  def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
275
  gray = gray_u8.astype(np.float32)
276
  lo, hi = np.percentile(gray, (1, 99))
@@ -278,16 +284,19 @@ def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
278
  gray = (gray - lo) / (hi - lo + 1e-8)
279
  return gray
280
 
 
281
  def laplacian_sharpness(gray_u8: np.ndarray) -> float:
282
  g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA)
283
  g = cv2.GaussianBlur(g, (3, 3), 0)
284
  return float(cv2.Laplacian(g, cv2.CV_64F).var())
285
 
 
286
  def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]:
287
  lo = float((gray_u8 < 10).mean())
288
  hi = float((gray_u8 > 245).mean())
289
  return lo, hi
290
 
 
291
  def border_fraction(gray_u8: np.ndarray) -> float:
292
  h, w = gray_u8.shape
293
  b = max(5, int(0.06 * min(h, w)))
@@ -299,6 +308,7 @@ def border_fraction(gray_u8: np.ndarray) -> float:
299
  def frac_border(x): return float(((x < 15) | (x > 240)).mean())
300
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
301
 
 
302
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
303
  warnings: List[str] = []
304
  h, w = gray_u8.shape
@@ -334,6 +344,7 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
334
 
335
  return float(np.clip(score, 0, 100)), warnings
336
 
 
337
  def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
338
  g = gray_u8.copy()
339
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
@@ -357,10 +368,12 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
357
  x2 = min(W, x + w + pad); y2 = min(H, y + h + pad)
358
  return gray_u8[y1:y2, x1:x2]
359
 
 
360
  def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
361
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
362
  return clahe.apply(gray_u8)
363
 
 
364
  def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
365
  sharp = laplacian_sharpness(gray_u8)
366
  lo_clip, _hi_clip = exposure_scores(gray_u8)
@@ -377,11 +390,13 @@ def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
377
 
378
  return g
379
 
 
380
  def cam_entropy(cam: np.ndarray) -> float:
381
  cam = cam.astype(np.float32)
382
  cam = cam / (cam.sum() + 1e-8)
383
  return float(-np.sum(cam * np.log(cam + 1e-8)))
384
 
 
385
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
386
  if quality_score < 55:
387
  return False
@@ -390,6 +405,7 @@ def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float
390
  ent = cam_entropy(cam_up)
391
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
392
 
 
393
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
394
  if prob_tb < 0.01 and quality_score >= 45:
395
  return ("GREEN", "✅ Very low TB signal detected.")
@@ -401,21 +417,24 @@ def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
401
  return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
402
  return ("GREEN", "✅ No strong TB signal detected.")
403
 
 
404
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
405
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
406
  mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET)
407
  return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0)
408
 
 
409
  def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
410
  m = (binary_u8 * 255).astype(np.uint8)
411
  h, w = m.shape
412
  flood = m.copy()
413
- mask = np.zeros((h+2, w+2), np.uint8)
414
  cv2.floodFill(flood, mask, (0, 0), 255)
415
  holes = cv2.bitwise_not(flood)
416
  filled = cv2.bitwise_or(m, holes)
417
  return (filled > 0).astype(np.uint8)
418
 
 
419
  def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
420
  m = (binary_u8 > 0).astype(np.uint8)
421
  n, labels = cv2.connectedComponents(m)
@@ -431,6 +450,7 @@ def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
431
  out[labels == i] = 1
432
  return out
433
 
 
434
  def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
435
  m = (mask_full_u8 > 0).astype(np.uint8)
436
  n, labels = cv2.connectedComponents(m)
@@ -460,6 +480,7 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
460
 
461
  return warns
462
 
 
463
  def recommendation_for_band(band: Optional[str]) -> str:
464
  if band in (None, "YELLOW"):
465
  return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)."
@@ -482,6 +503,7 @@ def tbnet_state(tb_prob: float, tb_band: str) -> str:
482
  return "SCREEN+"
483
  return "LOW"
484
 
 
485
  def radio_state_from_prob(radio_prob: float) -> str:
486
  if radio_prob >= RADIO_THR_RED:
487
  return "TB+"
@@ -489,6 +511,7 @@ def radio_state_from_prob(radio_prob: float) -> str:
489
  return "SCREEN+"
490
  return "LOW"
491
 
 
492
  def build_consensus(
493
  tb_prob: Optional[float],
494
  tb_band: Optional[str],
@@ -622,6 +645,7 @@ class ModelBundle:
622
  self.lung = lung
623
  self.lung_path = lung_weights
624
 
 
625
  BUNDLE = ModelBundle()
626
 
627
 
@@ -643,6 +667,7 @@ class RadioMLPHead(nn.Module):
643
  def forward(self, x: torch.Tensor) -> torch.Tensor:
644
  return self.net(x).squeeze(1)
645
 
 
646
  class RadioBundle:
647
  def __init__(self):
648
  self.loaded = False
@@ -691,8 +716,10 @@ class RadioBundle:
691
  self.device_str = dev_str
692
  self.loaded = True
693
 
 
694
  RADIO_BUNDLE = RadioBundle()
695
 
 
696
  def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray:
697
  ht = in_h // patch_size
698
  wt = in_w // patch_size
@@ -703,6 +730,7 @@ def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: in
703
  hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR)
704
  return np.array(hm_img, dtype=np.float32) / 255.0
705
 
 
706
  def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray:
707
  img = rgb_u8.astype(np.float32) / 255.0
708
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
@@ -710,6 +738,7 @@ def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: floa
710
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
711
  return (out * 255).astype(np.uint8)
712
 
 
713
  @torch.inference_mode()
714
  def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
715
  lung_mask_u8: np.ndarray,
@@ -1076,7 +1105,6 @@ def run_analysis(
1076
  # ------------------------
1077
  # Build descriptive cards (TBNet, RADIO, Consensus)
1078
  # ------------------------
1079
- # TBNet block values
1080
  tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1081
  tb_band = out.get("band", "YELLOW")
1082
  tb_label = out.get("pred", "INDETERMINATE")
@@ -1089,12 +1117,10 @@ def run_analysis(
1089
 
1090
  attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
1091
 
1092
- # RADIO block values
1093
  radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
1094
  radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
1095
  radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1096
 
1097
- # Recommended next step (consensus-aware)
1098
  if consensus_label == "DISAGREE":
1099
  next_step = "✅ Next step: Treat as **indeterminate** → radiologist review + microbiology if clinically suspected."
1100
  elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
@@ -1104,7 +1130,6 @@ def run_analysis(
1104
  else:
1105
  next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present."
1106
 
1107
- # Card containers (HTML)
1108
  state_badge_tb = f"""
1109
  <span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(tb_state)}; font-weight:800;">
1110
  {pretty_state(tb_state)}
@@ -1196,7 +1221,6 @@ def run_analysis(
1196
  # Detailed collapsible report (per image)
1197
  # ------------------------
1198
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1199
- rec_line = recommendation_for_band(out.get("band"))
1200
 
1201
  details_md.append(
1202
  f"""
@@ -1243,7 +1267,7 @@ def run_analysis(
1243
 
1244
 
1245
  # ============================================================
1246
- # UI
1247
  # ============================================================
1248
  def build_ui():
1249
  css = """
@@ -1251,101 +1275,205 @@ def build_ui():
1251
  .subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;}
1252
  .warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;}
1253
  .legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;}
 
1254
  """
1255
 
1256
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1257
- gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1258
- gr.Markdown(
1259
- f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1260
- f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Clear per-model results + consensus</div>"
1261
- )
1262
-
1263
- # Disclaimer shown early (so users see it before running)
1264
- gr.Markdown(
1265
- "<div class='warnbox'><b>Clinical disclaimer:</b> Decision support only (not diagnostic). "
1266
- "If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.</div>"
1267
- )
1268
 
1269
- with gr.Row():
1270
- with gr.Column(scale=1):
1271
- gr.Markdown("#### Model settings")
1272
-
1273
- tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS)
1274
- lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1275
 
1276
- backbone = gr.Dropdown(choices=["efficientnet_b0"], value="efficientnet_b0", label="TBNet backbone")
 
 
 
1277
 
1278
- threshold = gr.Slider(
1279
- 0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1280
- label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
1281
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1282
 
1283
- phone_mode = gr.Checkbox(
1284
- value=False,
1285
- label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
1286
- )
 
 
 
 
 
1287
 
1288
- use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
1289
- radio_gate = gr.Slider(
1290
- 0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
1291
- label="RADIO masked gate (run masked head if lung coverage ≥ gate)"
1292
- )
 
 
 
 
 
1293
 
1294
- gr.Markdown(
1295
- "<div class='warnbox'><b>Fail-safe:</b> If lung segmentation is too small or looks unreliable, "
1296
- f"{MODEL_NAME_TBNET} scoring is disabled to avoid unsafe outputs.</div>"
1297
- )
 
 
 
1298
 
1299
- gr.Markdown(
1300
- f"<div class='subtitle'>Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})</div>"
1301
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
1302
 
1303
- with gr.Column(scale=2):
1304
- gr.Markdown("#### Upload images")
1305
- files = gr.Files(
1306
- label="Upload one or multiple X-ray images",
1307
- file_types=[".png", ".jpg", ".jpeg", ".bmp"]
1308
- )
1309
- run_btn = gr.Button("Run Analysis", variant="primary")
1310
- status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1311
 
1312
- gr.Markdown("""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1313
  <div class='legend'><b>Gallery legend:</b><br/>
1314
  1) ORIGINAL &nbsp;•&nbsp; 2) INPUT / PHONE-PROC &nbsp;•&nbsp; 3) Lung mask overlay &nbsp;•&nbsp;
1315
  4) Masked model input &nbsp;•&nbsp; 5) TBNet Grad-CAM &nbsp;•&nbsp; 6) RADIO heatmaps</div>
1316
  """)
1317
 
1318
- # Summary cards (new)
1319
- gr.Markdown("#### Summary (per image)")
1320
- summary = gr.Markdown("Upload images and click <b>Run Analysis</b>.")
1321
-
1322
- with gr.Row():
1323
- gallery = gr.Gallery(
1324
- label="Visual outputs",
1325
- columns=3,
1326
- height=560
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1327
  )
1328
 
1329
- with gr.Row():
1330
- with gr.Column(scale=1):
1331
- disclaimer_box = gr.Markdown(CLINICAL_DISCLAIMER)
1332
- with gr.Column(scale=2):
1333
- gr.Markdown("#### Detailed report (expand per image)")
1334
- details = gr.Markdown("")
1335
-
1336
- run_btn.click(
1337
- fn=run_analysis,
1338
- inputs=[
1339
- files,
1340
- tb_weights,
1341
- lung_weights,
1342
- backbone,
1343
- threshold,
1344
- phone_mode,
1345
- use_radio,
1346
- radio_gate,
1347
- ],
1348
- outputs=[summary, gallery, details, disclaimer_box, status]
1349
  )
1350
 
1351
  return demo
@@ -1353,4 +1481,6 @@ def build_ui():
1353
 
1354
  if __name__ == "__main__":
1355
  demo = build_ui()
1356
- demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
 
 
136
 
137
 
138
  # ============================================================
139
+ # UX HELPERS
140
  # ============================================================
141
  def pretty_state(s: str) -> str:
142
  return {
 
147
  "N/A": "⚠️ N/A",
148
  }.get(s, f"⚠️ {s}")
149
 
150
+
151
  def html_escape(s: str) -> str:
152
  # minimal escaping for safety in HTML blocks
153
  return (s or "").replace("&", "&amp;").replace("<", "&lt;").replace(">", "&gt;")
154
 
155
+
156
  def badge_color_for_state(state: str) -> str:
157
  # soft visual cue; works in both dark/light
158
  if state == "TB+":
 
183
 
184
  def forward(self, x): return self.net(x)
185
 
186
+
187
  class LungUNet(nn.Module):
188
  def __init__(self):
189
  super().__init__()
 
227
 
228
  def forward(self, x): return self.fc(self.backbone(x)).view(-1)
229
 
230
+
231
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
232
  sd = torch.load(ckpt_path, map_location=device)
233
  model.load_state_dict(sd, strict=True)
234
 
235
+
236
  class GradCAM:
237
  def __init__(self, model: nn.Module, target_layer: nn.Module):
238
  self.model = model
 
276
  g = (g - lo) / (hi - lo + 1e-8)
277
  return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float()
278
 
279
+
280
  def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
281
  gray = gray_u8.astype(np.float32)
282
  lo, hi = np.percentile(gray, (1, 99))
 
284
  gray = (gray - lo) / (hi - lo + 1e-8)
285
  return gray
286
 
287
+
288
  def laplacian_sharpness(gray_u8: np.ndarray) -> float:
289
  g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA)
290
  g = cv2.GaussianBlur(g, (3, 3), 0)
291
  return float(cv2.Laplacian(g, cv2.CV_64F).var())
292
 
293
+
294
  def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]:
295
  lo = float((gray_u8 < 10).mean())
296
  hi = float((gray_u8 > 245).mean())
297
  return lo, hi
298
 
299
+
300
  def border_fraction(gray_u8: np.ndarray) -> float:
301
  h, w = gray_u8.shape
302
  b = max(5, int(0.06 * min(h, w)))
 
308
  def frac_border(x): return float(((x < 15) | (x > 240)).mean())
309
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
310
 
311
+
312
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
313
  warnings: List[str] = []
314
  h, w = gray_u8.shape
 
344
 
345
  return float(np.clip(score, 0, 100)), warnings
346
 
347
+
348
  def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
349
  g = gray_u8.copy()
350
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
 
368
  x2 = min(W, x + w + pad); y2 = min(H, y + h + pad)
369
  return gray_u8[y1:y2, x1:x2]
370
 
371
+
372
  def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
373
  clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
374
  return clahe.apply(gray_u8)
375
 
376
+
377
  def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
378
  sharp = laplacian_sharpness(gray_u8)
379
  lo_clip, _hi_clip = exposure_scores(gray_u8)
 
390
 
391
  return g
392
 
393
+
394
  def cam_entropy(cam: np.ndarray) -> float:
395
  cam = cam.astype(np.float32)
396
  cam = cam / (cam.sum() + 1e-8)
397
  return float(-np.sum(cam * np.log(cam + 1e-8)))
398
 
399
+
400
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
401
  if quality_score < 55:
402
  return False
 
405
  ent = cam_entropy(cam_up)
406
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
407
 
408
+
409
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
410
  if prob_tb < 0.01 and quality_score >= 45:
411
  return ("GREEN", "✅ Very low TB signal detected.")
 
417
  return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
418
  return ("GREEN", "✅ No strong TB signal detected.")
419
 
420
+
421
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
422
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
423
  mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET)
424
  return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0)
425
 
426
+
427
  def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
428
  m = (binary_u8 * 255).astype(np.uint8)
429
  h, w = m.shape
430
  flood = m.copy()
431
+ mask = np.zeros((h + 2, w + 2), np.uint8)
432
  cv2.floodFill(flood, mask, (0, 0), 255)
433
  holes = cv2.bitwise_not(flood)
434
  filled = cv2.bitwise_or(m, holes)
435
  return (filled > 0).astype(np.uint8)
436
 
437
+
438
  def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
439
  m = (binary_u8 > 0).astype(np.uint8)
440
  n, labels = cv2.connectedComponents(m)
 
450
  out[labels == i] = 1
451
  return out
452
 
453
+
454
  def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
455
  m = (mask_full_u8 > 0).astype(np.uint8)
456
  n, labels = cv2.connectedComponents(m)
 
480
 
481
  return warns
482
 
483
+
484
  def recommendation_for_band(band: Optional[str]) -> str:
485
  if band in (None, "YELLOW"):
486
  return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)."
 
503
  return "SCREEN+"
504
  return "LOW"
505
 
506
+
507
  def radio_state_from_prob(radio_prob: float) -> str:
508
  if radio_prob >= RADIO_THR_RED:
509
  return "TB+"
 
511
  return "SCREEN+"
512
  return "LOW"
513
 
514
+
515
  def build_consensus(
516
  tb_prob: Optional[float],
517
  tb_band: Optional[str],
 
645
  self.lung = lung
646
  self.lung_path = lung_weights
647
 
648
+
649
  BUNDLE = ModelBundle()
650
 
651
 
 
667
  def forward(self, x: torch.Tensor) -> torch.Tensor:
668
  return self.net(x).squeeze(1)
669
 
670
+
671
  class RadioBundle:
672
  def __init__(self):
673
  self.loaded = False
 
716
  self.device_str = dev_str
717
  self.loaded = True
718
 
719
+
720
  RADIO_BUNDLE = RadioBundle()
721
 
722
+
723
  def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray:
724
  ht = in_h // patch_size
725
  wt = in_w // patch_size
 
730
  hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR)
731
  return np.array(hm_img, dtype=np.float32) / 255.0
732
 
733
+
734
  def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray:
735
  img = rgb_u8.astype(np.float32) / 255.0
736
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
 
738
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
739
  return (out * 255).astype(np.uint8)
740
 
741
+
742
  @torch.inference_mode()
743
  def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
744
  lung_mask_u8: np.ndarray,
 
1105
  # ------------------------
1106
  # Build descriptive cards (TBNet, RADIO, Consensus)
1107
  # ------------------------
 
1108
  tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1109
  tb_band = out.get("band", "YELLOW")
1110
  tb_label = out.get("pred", "INDETERMINATE")
 
1117
 
1118
  attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
1119
 
 
1120
  radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
1121
  radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
1122
  radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1123
 
 
1124
  if consensus_label == "DISAGREE":
1125
  next_step = "✅ Next step: Treat as **indeterminate** → radiologist review + microbiology if clinically suspected."
1126
  elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
 
1130
  else:
1131
  next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present."
1132
 
 
1133
  state_badge_tb = f"""
1134
  <span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(tb_state)}; font-weight:800;">
1135
  {pretty_state(tb_state)}
 
1221
  # Detailed collapsible report (per image)
1222
  # ------------------------
1223
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
 
1224
 
1225
  details_md.append(
1226
  f"""
 
1267
 
1268
 
1269
  # ============================================================
1270
+ # UI (HF Spaces Welcome Screen + Main App)
1271
  # ============================================================
1272
  def build_ui():
1273
  css = """
 
1275
  .subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;}
1276
  .warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;}
1277
  .legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;}
1278
+ .card {border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:14px; margin:10px 0;}
1279
  """
1280
 
1281
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
1282
 
1283
+ # ---------------------------
1284
+ # Welcome screen (shown first)
1285
+ # ---------------------------
1286
+ with gr.Column(visible=True) as welcome_screen:
1287
+ gr.Markdown('<div class="title">Welcome TB X-ray Assistant (HF Spaces)</div>')
1288
+
1289
+ gr.Markdown(
1290
+ f"""
1291
+ <div class="card">
1292
+ <div style="font-size:16px; font-weight:900; margin-bottom:8px;">What this Space does</div>
1293
+ <div style="opacity:0.92;">
1294
+ This app analyzes chest X-rays for <b>TB-like patterns</b> and shows results from:
1295
+ <ul>
1296
+ <li><b>{MODEL_NAME_TBNET}</b> with <b>Grad-CAM</b> explainability</li>
1297
+ <li><b>Auto Lung Mask</b> (Lung U-Net) + a <b>fail-safe</b> to prevent unsafe scoring on bad masks</li>
1298
+ <li>Optional <b>{MODEL_NAME_RADIO}</b> with <b>RAW / MASKED</b> scoring and heatmaps</li>
1299
+ <li>A final <b>consensus</b>: ✅ LOW / ⚠️ INDET / ⚠️ SCREEN+ / 🚩 TB+</li>
1300
+ </ul>
1301
+ </div>
1302
+ </div>
1303
 
1304
+ <div class="warnbox">
1305
+ <b>Clinical disclaimer:</b> Decision support only (not diagnostic). TB can be subtle (including miliary TB).
1306
+ If TB is clinically suspected, pursue microbiology (CBNAAT/GeneXpert, sputum) and/or CT chest regardless of AI output.
1307
+ </div>
1308
 
1309
+ <div class="card">
1310
+ <div style="font-size:16px; font-weight:900; margin-bottom:8px;">Special feature: Phone / WhatsApp Mode</div>
1311
+ <div style="opacity:0.92;">
1312
+ Many users upload:
1313
+ <ul>
1314
+ <li><b>WhatsApp-forwarded X-rays</b> (compressed, low contrast)</li>
1315
+ <li><b>Phone photos</b> of printed films or monitor screens (borders, glare, blur)</li>
1316
+ <li><b>Screenshots</b> with large margins / UI elements</li>
1317
+ </ul>
1318
+ <b>Phone / WhatsApp Mode</b> is designed for these cases. When enabled, it applies:
1319
+ <ul>
1320
+ <li><b>Safe border crop</b> (reduces margins / screenshot framing)</li>
1321
+ <li><b>Conditional CLAHE</b> (boosts local contrast when underexposed / low-detail)</li>
1322
+ <li><b>Quality warnings</b> (blur, over/underexposure, heavy borders) to flag reduced reliability</li>
1323
+ </ul>
1324
+
1325
+ <div style="margin-top:10px; padding:10px 12px; border-left:6px solid rgba(96,165,250,0.9); background: rgba(96,165,250,0.10); border-radius:12px;">
1326
+ <b>Tip:</b> Enable Phone/WhatsApp Mode if your image is a phone photo, WhatsApp-forwarded, or has big borders / low contrast.
1327
+ Keep it OFF for clean digital exports to avoid unnecessary preprocessing.
1328
+ </div>
1329
+ </div>
1330
+ </div>
1331
 
1332
+ <div class="card">
1333
+ <div style="font-size:16px; font-weight:900; margin-bottom:8px;">Explainability & reliability</div>
1334
+ <ul>
1335
+ <li><b>Grad-CAM</b> (TBNet) highlights regions that influenced the TB score.</li>
1336
+ <li><b>RADIO heatmaps</b> show where the visual model is focusing (RAW and sometimes MASKED).</li>
1337
+ <li><b>Fail-safe</b>: if lung segmentation looks unreliable, TBNet scoring is disabled (shown as indeterminate).</li>
1338
+ <li><b>Quality scoring</b> warns when results may be less reliable (blur, compression, non-standard view).</li>
1339
+ </ul>
1340
+ </div>
1341
 
1342
+ <div class="card">
1343
+ <div style="font-size:16px; font-weight:900; margin-bottom:8px;">How to use</div>
1344
+ <ol>
1345
+ <li>Click <b>Continue</b> to open the interface.</li>
1346
+ <li>Upload one or multiple X-ray images.</li>
1347
+ <li>If your images come from <b>WhatsApp / phone camera / screenshots</b>, enable <b>Phone/WhatsApp Mode</b>.</li>
1348
+ <li>(Optional) Enable <b>{MODEL_NAME_RADIO}</b> for a second independent model + heatmaps.</li>
1349
+ <li>Click <b>Run Analysis</b>.</li>
1350
+ </ol>
1351
+ </div>
1352
 
1353
+ <div class="card">
1354
+ <div style="font-size:16px; font-weight:900; margin-bottom:8px;">Privacy / processing note (HF Spaces)</div>
1355
+ <div style="opacity:0.92;">
1356
+ Images are processed by this Space runtime. Avoid uploading personally identifiable medical data.
1357
+ Use anonymized images when possible.
1358
+ </div>
1359
+ </div>
1360
 
1361
+ <div class="subtitle">
1362
+ Device policy: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})
1363
+ </div>
1364
+ """
1365
+ )
1366
+ continue_btn = gr.Button("Continue →", variant="primary")
1367
+
1368
+ # ---------------------------
1369
+ # Main app UI (hidden initially)
1370
+ # ---------------------------
1371
+ with gr.Column(visible=False) as main_app:
1372
+ gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1373
+ gr.Markdown(
1374
+ f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1375
+ f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Clear per-model results + consensus</div>"
1376
+ )
1377
 
1378
+ gr.Markdown(
1379
+ "<div class='warnbox'><b>Clinical disclaimer:</b> Decision support only (not diagnostic). "
1380
+ "If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.</div>"
1381
+ )
 
 
 
 
1382
 
1383
+ with gr.Row():
1384
+ with gr.Column(scale=1):
1385
+ gr.Markdown("#### Model settings")
1386
+
1387
+ tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS)
1388
+ lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
1389
+
1390
+ backbone = gr.Dropdown(choices=["efficientnet_b0"], value="efficientnet_b0", label="TBNet backbone")
1391
+
1392
+ threshold = gr.Slider(
1393
+ 0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1394
+ label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
1395
+ )
1396
+
1397
+ phone_mode = gr.Checkbox(
1398
+ value=False,
1399
+ label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
1400
+ )
1401
+ gr.Markdown(
1402
+ "<div class='subtitle'>Enable for WhatsApp images, phone photos, or screenshots. "
1403
+ "Leave off for clean digital exports.</div>"
1404
+ )
1405
+
1406
+ use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
1407
+ radio_gate = gr.Slider(
1408
+ 0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
1409
+ label="RADIO masked gate (run masked head if lung coverage ≥ gate)"
1410
+ )
1411
+
1412
+ gr.Markdown(
1413
+ "<div class='warnbox'><b>Fail-safe:</b> If lung segmentation is too small or looks unreliable, "
1414
+ f"{MODEL_NAME_TBNET} scoring is disabled to avoid unsafe outputs.</div>"
1415
+ )
1416
+
1417
+ gr.Markdown(
1418
+ f"<div class='subtitle'>Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})</div>"
1419
+ )
1420
+
1421
+ back_btn = gr.Button("← Back to Welcome", variant="secondary")
1422
+
1423
+ with gr.Column(scale=2):
1424
+ gr.Markdown("#### Upload images")
1425
+ files = gr.Files(
1426
+ label="Upload one or multiple X-ray images",
1427
+ file_types=[".png", ".jpg", ".jpeg", ".bmp"]
1428
+ )
1429
+ run_btn = gr.Button("Run Analysis", variant="primary")
1430
+ status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1431
+
1432
+ gr.Markdown("""
1433
  <div class='legend'><b>Gallery legend:</b><br/>
1434
  1) ORIGINAL &nbsp;•&nbsp; 2) INPUT / PHONE-PROC &nbsp;•&nbsp; 3) Lung mask overlay &nbsp;•&nbsp;
1435
  4) Masked model input &nbsp;•&nbsp; 5) TBNet Grad-CAM &nbsp;•&nbsp; 6) RADIO heatmaps</div>
1436
  """)
1437
 
1438
+ gr.Markdown("#### Summary (per image)")
1439
+ summary = gr.Markdown("Upload images and click <b>Run Analysis</b>.")
1440
+ gallery = gr.Gallery(label="Visual outputs", columns=3, height=560)
1441
+
1442
+ with gr.Row():
1443
+ with gr.Column(scale=1):
1444
+ disclaimer_box = gr.Markdown(CLINICAL_DISCLAIMER)
1445
+ with gr.Column(scale=2):
1446
+ gr.Markdown("#### Detailed report (expand per image)")
1447
+ details = gr.Markdown("")
1448
+
1449
+ run_btn.click(
1450
+ fn=run_analysis,
1451
+ inputs=[
1452
+ files,
1453
+ tb_weights,
1454
+ lung_weights,
1455
+ backbone,
1456
+ threshold,
1457
+ phone_mode,
1458
+ use_radio,
1459
+ radio_gate,
1460
+ ],
1461
+ outputs=[summary, gallery, details, disclaimer_box, status]
1462
  )
1463
 
1464
+ # ---------------------------
1465
+ # Transitions
1466
+ # ---------------------------
1467
+ continue_btn.click(
1468
+ fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
1469
+ inputs=[],
1470
+ outputs=[welcome_screen, main_app],
1471
+ )
1472
+
1473
+ back_btn.click(
1474
+ fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
1475
+ inputs=[],
1476
+ outputs=[welcome_screen, main_app],
 
 
 
 
 
 
 
1477
  )
1478
 
1479
  return demo
 
1481
 
1482
  if __name__ == "__main__":
1483
  demo = build_ui()
1484
+ # HF Spaces: let the platform manage host/port. queue() helps stability for longer runs.
1485
+ demo.queue(concurrency_count=1)
1486
+ demo.launch(show_error=True)