drvikasgaur commited on
Commit
42c02f1
·
verified ·
1 Parent(s): ae85b0c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +185 -109
app.py CHANGED
@@ -1,3 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import cv2
3
  import numpy as np
@@ -20,6 +32,10 @@ from PIL import Image
20
  # USER CONFIG (HF Spaces friendly)
21
  # ============================================================
22
 
 
 
 
 
23
  # ---- Default TB/Lung weights ----
24
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
25
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
@@ -32,7 +48,6 @@ RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
32
  RADIO_RAW_HEAD_PATH = "weights/best_raw.pt"
33
  RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt"
34
 
35
-
36
  RADIO_IMG_SIZE = 320
37
  RADIO_PATCH_SIZE = 16
38
  RADIO_THR_SCREEN = 0.05
@@ -42,19 +57,15 @@ RADIO_GATE_DEFAULT = 0.21
42
 
43
  # ---- Consensus logic thresholds ----
44
  TBNET_SCREEN_THR = 0.30
45
- TBNET_MARGIN = 0.03 # 3% margin around threshold → INDET zone
46
-
47
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
48
- RADIO_MARGIN = 0.02 # 2% margin around radio screen threshold
49
 
50
  # ---- Mask fail-safes ----
51
  FAIL_COV = 0.10 # <10% -> segmentation fail
52
  WARN_COV = 0.18 # <18% -> warn
53
- # if mask looks like a single lung / cropped, we fail-safe and do not output TB score
54
- FAILSAFE_ON_BAD_MASK = True
55
 
56
  # ---- Device policy ----
57
- FORCE_CPU = True # set True if you want TB+RADIO to always run CPU
58
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
59
 
60
 
@@ -63,18 +74,41 @@ DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available
63
  # ============================================================
64
  CLINICAL_DISCLAIMER = """
65
  ⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
66
- This AI system is for **research/decision support** and is NOT a diagnostic device.
67
- It may NOT reliably detect early/subtle tuberculosis, including **MILIARY TB**,
68
- which can appear near-normal or subtle on chest X-ray (especially on phone photos / WhatsApp images).
 
69
 
70
  If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
71
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
72
  """
73
 
74
  REPORT_LABELS = {
75
- "GREEN": {"title": "LIKELY NORMAL", "summary": "No radiographic features suggestive of pulmonary tuberculosis detected by AI."},
76
- "YELLOW": {"title": "RADIOLOGIST INTERPRETATION RECOMMENDED", "summary": "This AI output is indeterminate / not definitive.A qualified radiologist review is recommended to confirm findings and correlate clinically."},
77
- "RED": {"title": "LIKELY TB", "summary": "AI detected focal lung patterns commonly associated with pulmonary tuberculosis. This is not a diagnosis; microbiological confirmation is required."},
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  }
79
 
80
  CLINICAL_GUIDANCE = (
@@ -98,7 +132,9 @@ class DoubleConv(nn.Module):
98
  nn.BatchNorm2d(out_c),
99
  nn.ReLU(inplace=True),
100
  )
101
- def forward(self, x): return self.net(x)
 
 
102
 
103
  class LungUNet(nn.Module):
104
  def __init__(self):
@@ -140,7 +176,9 @@ class TBNet(nn.Module):
140
  super().__init__()
141
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
142
  self.fc = nn.Linear(self.backbone.num_features, 1)
143
- def forward(self, x): return self.fc(self.backbone(x)).view(-1)
 
 
144
 
145
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
146
  sd = torch.load(ckpt_path, map_location=device)
@@ -154,8 +192,11 @@ class GradCAM:
154
  target_layer.register_forward_hook(self._fwd)
155
  target_layer.register_full_backward_hook(self._bwd)
156
 
157
- def _fwd(self, _, __, out): self.activ = out
158
- def _bwd(self, _, grad_in, grad_out): self.grad = grad_out[0]
 
 
 
159
 
160
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
161
  with torch.enable_grad():
@@ -213,17 +254,17 @@ def border_fraction(gray_u8: np.ndarray) -> float:
213
  bot = gray_u8[-b:, :]
214
  left = gray_u8[:, :b]
215
  right = gray_u8[:, -b:]
216
- def frac_border(x): return float(((x < 15) | (x > 240)).mean())
 
217
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
218
 
219
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
220
  warnings: List[str] = []
221
  h, w = gray_u8.shape
222
-
223
  score = 100.0
224
 
225
  if min(h, w) < 400:
226
- warnings.append("Low resolution (may reduce detection reliability).")
227
  score -= 8
228
 
229
  sharp = laplacian_sharpness(gray_u8)
@@ -234,22 +275,29 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
234
 
235
  if likely_phone:
236
  if sharp < 40:
237
- score -= 25; warnings.append("Blurry / motion blur detected (phone capture).")
 
238
  elif sharp < 80:
239
- score -= 12; warnings.append("Slight blur detected.")
 
240
  else:
241
  if sharp < 30:
242
- score -= 8; warnings.append("Low fine-detail / mild blur (digital CXR or downsample).")
 
243
 
244
  if hi_clip > 0.05:
245
- score -= 15; warnings.append("Overexposed highlights (washed out areas).")
 
246
  if lo_clip > 0.10:
247
- score -= 12; warnings.append("Underexposed shadows (very dark areas).")
 
248
 
249
  if border > 0.55:
250
- score -= 18; warnings.append("Large border/margins detected (screenshot/phone framing).")
 
251
  elif border > 0.35:
252
- score -= 10; warnings.append("Some border/margins detected.")
 
253
 
254
  return float(np.clip(score, 0, 100)), warnings
255
 
@@ -257,19 +305,22 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
257
  g = gray_u8.copy()
258
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
259
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
260
- if th.mean() > 127: th = 255 - th
 
261
 
262
  k = max(3, int(0.01 * min(g.shape)))
263
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
264
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
265
 
266
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
267
- if not contours: return gray_u8
 
268
 
269
  c = max(contours, key=cv2.contourArea)
270
  x, y, w, h = cv2.boundingRect(c)
271
  H, W = gray_u8.shape
272
- if w * h < 0.20 * (H * W): return gray_u8
 
273
 
274
  pad = int(0.03 * min(H, W))
275
  x1 = max(0, x - pad); y1 = max(0, y - pad)
@@ -311,25 +362,23 @@ def cam_entropy(cam: np.ndarray) -> float:
311
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
312
  if quality_score < 55:
313
  return False
314
-
315
- # Only apply diffuse-risk heuristic in the "near-threshold negative" zone
316
  if prob_tb < 0.05:
317
  return False
318
-
319
  ent = cam_entropy(cam_up)
320
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
321
 
322
-
323
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
324
- if prob_tb < 0.05 and quality_score >= 45 and not diffuse:
325
- return ("GREEN", "LIKELY NORMAL very low AI TB signal (image quality suboptimal)")
 
326
  if quality_score < 55:
327
- return ("YELLOW", "NO DEFINITE TB FEATURES low image quality, treat as indeterminate")
328
  if diffuse:
329
- return ("YELLOW", "NO DEFINITE TB FEATURES — non-focal/diffuse attention pattern")
330
  if prob_tb >= TBNET_SCREEN_THR:
331
- return ("YELLOW", "NO DEFINITE TB FEATURES — AI confidence limited")
332
- return ("GREEN", "LIKELY NORMAL — no strong AI signal for TB")
333
 
334
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
335
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
@@ -340,7 +389,7 @@ def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
340
  m = (binary_u8 * 255).astype(np.uint8)
341
  h, w = m.shape
342
  flood = m.copy()
343
- mask = np.zeros((h+2, w+2), np.uint8)
344
  cv2.floodFill(flood, mask, (0, 0), 255)
345
  holes = cv2.bitwise_not(flood)
346
  filled = cv2.bitwise_or(m, holes)
@@ -367,7 +416,7 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
367
  warns = []
368
 
369
  if n <= 2:
370
- warns.append("Only one lung component detected (possible crop/segmentation failure).")
371
  return warns
372
 
373
  areas = []
@@ -379,27 +428,27 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
379
  top2 = areas[1] if len(areas) > 1 else 0
380
 
381
  if total > 0 and top1 / total > 0.80:
382
- warns.append("Mask dominated by a single component (likely one lung / cropped view).")
383
 
384
  border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
385
  if border.mean() > 0.05:
386
- warns.append("Lung mask touches image border (possible cropped/non-standard CXR).")
387
 
388
  if total > 0 and (top1 + top2) / total < 0.90:
389
- warns.append("Significant mask fragmentation/holes (post-processing may be insufficient).")
390
 
391
  return warns
392
 
393
  def recommendation_for_band(band: Optional[str]) -> str:
394
  if band in (None, "YELLOW"):
395
- return "✅ Recommendation: Radiologist interpretation recommended (AI result is indeterminate / not definitive)."
396
  if band == "RED":
397
- return "✅ Recommendation: Urgent clinician/radiologist review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
398
- return "✅ Recommendation: If symptoms/risk factors exist, clinician/radiologist correlation is still advised."
399
 
400
 
401
  # ============================================================
402
- # CONSENSUS LOGIC (TBNet vs RADIO) — 3-state
403
  # ============================================================
404
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
405
  if tb_band == "RED":
@@ -424,7 +473,7 @@ def build_consensus(
424
  ) -> Tuple[str, str]:
425
 
426
  if tb_prob is None or tb_band is None:
427
- return ("N/A", "TBNet unavailable (lung segmentation failed / fail-safe).")
428
 
429
  # PRIMARY = masked if available else raw
430
  if radio_masked is not None:
@@ -435,7 +484,7 @@ def build_consensus(
435
  radio_used = "RAW"
436
 
437
  if radio_primary is None:
438
- return ("TBNet only", f"RADIO unavailableTBNet={tb_prob:.4f} (band={tb_band}).")
439
 
440
  t = tbnet_state(tb_prob, tb_band)
441
  r = radio_state_from_prob(radio_primary)
@@ -445,18 +494,19 @@ def build_consensus(
445
  if t == r:
446
  return (
447
  f"AGREE: {t}",
448
- f"Both: {t}. TBNet={tb_prob:.4f}, RADIO({radio_used})={radio_primary:.4f}{rb}."
449
  )
450
 
 
451
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
452
  return (
453
  "DISAGREE",
454
- f"Strong disagreement: TBNet={t} (band={tb_band}) vs RADIO={r} ({radio_used})={radio_primary:.4f}{rb}."
455
  )
456
 
457
  return (
458
  "MIXED/INDET",
459
- f"Mixed/uncertain: TBNet={t} (band={tb_band}) vs RADIO={r} ({radio_used})={radio_primary:.4f}{rb}."
460
  )
461
 
462
 
@@ -487,6 +537,7 @@ class ModelBundle:
487
  load_tb_weights(tb, tb_weights, self.device)
488
  tb.eval()
489
  self.tb = tb
 
490
  self.cammer = GradCAM(tb, tb.backbone.conv_head)
491
  self.tb_path = tb_weights
492
  self.backbone = backbone
@@ -515,6 +566,7 @@ class RadioMLPHead(nn.Module):
515
  nn.Dropout(dropout),
516
  nn.Linear(hidden, 1),
517
  )
 
518
  def forward(self, x: torch.Tensor) -> torch.Tensor:
519
  return self.net(x).squeeze(1)
520
 
@@ -582,6 +634,7 @@ def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: floa
582
  img = rgb_u8.astype(np.float32) / 255.0
583
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
584
  out = img.copy()
 
585
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
586
  return (out * 255).astype(np.uint8)
587
 
@@ -649,10 +702,10 @@ def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
649
 
650
  if prob_primary >= RADIO_THR_RED:
651
  band = "RED"
652
- pred = "LIKELY TB (RADIO)"
653
  elif prob_primary >= RADIO_THR_SCREEN:
654
  band = "YELLOW"
655
- pred = "SCREEN-POSITIVE / INDETERMINATE (RADIO)"
656
  else:
657
  band = "GREEN"
658
  pred = "LOW TB LIKELIHOOD (RADIO)"
@@ -678,7 +731,7 @@ def analyze_one_image(
678
  tb_weights: str,
679
  lung_weights: str,
680
  backbone: str,
681
- threshold: float,
682
  phone_mode: bool,
683
  img_size: int = 224,
684
  fail_cov: float = FAIL_COV,
@@ -720,11 +773,14 @@ def analyze_one_image(
720
  "logit": None,
721
  "pred": "INDETERMINATE",
722
  "band": "YELLOW",
723
- "band_text": "Lung segmentation failed. AI TB assessment cannot be performed reliably on this image.",
 
 
 
724
  "quality_score": float(q_score),
725
  "diffuse_risk": False,
726
  "warnings": (
727
- ["Lung segmentation failed (<10% lung area).", f"Lung coverage: {coverage*100:.1f}%"]
728
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
729
  + q_warn
730
  ),
@@ -748,7 +804,11 @@ def analyze_one_image(
748
  "logit": None,
749
  "pred": "INDETERMINATE",
750
  "band": "YELLOW",
751
- "band_text": "Non-standard/cropped view or unreliable lung segmentation. TB scoring disabled (fail-safe).",
 
 
 
 
752
  "quality_score": float(q_score),
753
  "diffuse_risk": False,
754
  "warnings": (
@@ -783,39 +843,42 @@ def analyze_one_image(
783
  cam_up = cam_u8.astype(np.float32) / 255.0
784
 
785
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
786
- band_base, _ = confidence_band(prob_tb, q_score, diffuse)
787
 
788
- allow_red = (prob_tb >= 0.70 and q_score >= 55 and not diffuse and coverage >= warn_cov)
789
  band = "RED" if allow_red else band_base
790
 
791
  pred = REPORT_LABELS[band]["title"]
792
  band_text = REPORT_LABELS[band]["summary"]
793
 
794
- abnormal_non_tb = (prob_tb >= 0.60 and q_score < 55 and band != "RED")
795
- if abnormal_non_tb:
796
  band_text = (
797
- "Significant abnormal lung findings detected. "
798
- "Findings are non-specific and not characteristic of pulmonary tuberculosis. "
799
- "Image quality may affect AI reliability."
800
  )
801
 
 
802
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
803
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
804
 
805
  overlay_annotated = overlay_clean.copy()
806
  text1 = f"{band}: {pred}"
807
- text2 = f"TB prob={prob_tb:.3f} | Quality={q_score:.0f}/100 | Lung coverage={coverage*100:.1f}%"
808
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
809
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
810
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
811
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
812
 
813
  warnings = []
814
- if phone_mode: warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
815
- if q_score < 55: warnings.append("Suboptimal image quality limits AI reliability.")
816
- if coverage < warn_cov: warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
817
- if diffuse: warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
818
- if abnormal_non_tb: warnings.append("Abnormal lung findings detected; pattern not specific for tuberculosis.")
 
 
 
819
  warnings.extend(q_warn)
820
 
821
  return {
@@ -886,7 +949,7 @@ def run_analysis(
886
  # -------------------------
887
  # RADIO (optional)
888
  # -------------------------
889
- radio_text = "RADIO disabled."
890
  radio_raw_overlay = None
891
  radio_masked_overlay = None
892
  radio_raw_val: Optional[float] = None
@@ -914,14 +977,14 @@ def run_analysis(
914
  radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
915
 
916
  radio_text = (
917
- f"**RADIO:** {r['pred']} | RAW={radio_raw_val:.4f}"
918
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
919
  + f" | Band={radio_band}"
920
  )
921
  radio_raw_overlay = r["raw_overlay"]
922
  radio_masked_overlay = r["masked_overlay"]
923
  except Exception as e:
924
- radio_text = f"RADIO error: {type(e).__name__}: {e}"
925
  radio_raw_str = ""
926
  radio_masked_str = ""
927
  radio_raw_val = None
@@ -929,7 +992,7 @@ def run_analysis(
929
  radio_band = None
930
 
931
  # -------------------------
932
- # Consensus
933
  # -------------------------
934
  consensus_label, consensus_detail = build_consensus(
935
  tb_prob=out["prob"],
@@ -974,7 +1037,7 @@ def run_analysis(
974
  proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
975
  gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
976
 
977
- gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay (TBNet)"))
978
 
979
  if radio_raw_overlay is not None:
980
  gallery_items.append((cv2.resize(radio_raw_overlay, (512, 512)), f"{name} • RADIO RAW heatmap"))
@@ -982,37 +1045,37 @@ def run_analysis(
982
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
983
 
984
  # -------------------------
985
- # Details panel
986
  # -------------------------
987
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
988
- tb_line = "N/A (segmentation failed / fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
989
  rec_line = recommendation_for_band(out.get("band"))
990
 
991
  details_md.append(
992
  f"""### {name}
993
 
994
- **AI Assessment (TBNet):** **{out['pred']}**
995
  {rec_line}
996
 
997
- **TB Probability (screening model):** {tb_line}
998
-
999
- **Interpretation**
1000
  {out['band_text']}
1001
 
1002
- **Image Quality:** {out['quality_score']:.0f}/100
1003
- **Lung Mask Coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
1004
- **AI Attention Pattern (TBNet):** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
 
 
1005
 
1006
- **Warnings**
1007
  {warn_txt}
1008
 
1009
- **RADIO Output**
1010
  {radio_text}
1011
 
1012
- **Final consensus (TBNet vs RADIO):** **{consensus_label}**
1013
  - {consensus_detail}
1014
 
1015
- **Clinical Guidance**
1016
  {CLINICAL_GUIDANCE}
1017
 
1018
  ---
@@ -1034,39 +1097,52 @@ def build_ui():
1034
 
1035
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1036
  gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1037
- gr.Markdown('<div class="subtitle">Lung U-Net masking → EfficientNet TBNet + Grad-CAM • Optional RADIO (C-RADIOv4 + heads) • 3-state consensus</div>')
 
 
 
1038
 
1039
  with gr.Row():
1040
  with gr.Column(scale=1):
1041
  gr.Markdown("#### Model settings")
1042
 
1043
- tb_weights = gr.Textbox(label="TB Weights (.pt)", value=DEFAULT_TB_WEIGHTS)
1044
- lung_weights = gr.Textbox(label="Lung U-Net Weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
1045
 
1046
- backbone = gr.Dropdown(choices=["efficientnet_b0"], value="efficientnet_b0", label="Backbone")
1047
 
1048
- threshold = gr.Slider(0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1049
- label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}")
 
 
1050
 
1051
- phone_mode = gr.Checkbox(value=False,
1052
- label="Phone/WhatsApp Mode (SAFE: conditional crop + conditional CLAHE)")
 
 
1053
 
1054
  # RADIO
1055
- use_radio = gr.Checkbox(value=False, label="Enable RADIO layer (C-RADIOv4 + heads)")
1056
- radio_gate = gr.Slider(0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
1057
- label="RADIO masked gate (run masked head if lung coverage ≥ gate)")
 
 
1058
 
1059
  gr.Markdown(
1060
- '<div class="warnbox"><b>Fail-safe:</b> If lung segmentation is too small or looks like a cropped/single-lung mask, TB scoring is disabled to avoid false positives.</div>'
 
1061
  )
1062
 
1063
  gr.Markdown(
1064
- f"<div class='subtitle'>Device for TB+RADIO: <b>{DEVICE}</b> (set FORCE_CPU=True to force CPU)</div>"
1065
  )
1066
 
1067
  with gr.Column(scale=2):
1068
  gr.Markdown("#### Upload images")
1069
- files = gr.Files(label="Upload one or multiple X-ray images", file_types=[".png", ".jpg", ".jpeg", ".bmp"])
 
 
 
1070
  run_btn = gr.Button("Run Analysis", variant="primary")
1071
  status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1072
 
@@ -1074,15 +1150,15 @@ def build_ui():
1074
  table = gr.Dataframe(
1075
  headers=[
1076
  "Image",
1077
- "TB Probability",
1078
- "AI Assessment",
1079
  "Band",
1080
- "Band meaning",
1081
  "Quality",
1082
  "LungCov",
1083
  "RADIO RAW",
1084
  "RADIO MASKED",
1085
- "CONSENSUS",
1086
  ],
1087
  datatype=["str","str","str","str","str","str","str","str","str","str"],
1088
  interactive=False,
 
1
+ # app.py
2
+ # Gradio — TBNet (CNN model) + Lung U-Net Auto Mask + Grad-CAM + RADIO (visual model)
3
+ # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
4
+ # + 3-STATE AGREEMENT (LOW / SCREEN+ / TB+)
5
+ #
6
+ # Requirements:
7
+ # pip install gradio timm torchvision opencv-python pillow transformers einops torch
8
+ #
9
+ # HF Spaces notes:
10
+ # - weights are expected in ./weights/
11
+ # - app launches on 0.0.0.0:7860
12
+
13
  import os
14
  import cv2
15
  import numpy as np
 
32
  # USER CONFIG (HF Spaces friendly)
33
  # ============================================================
34
 
35
+ # ---- Friendly model names for UI ----
36
+ MODEL_NAME_TBNET = "TBNet (CNN model)"
37
+ MODEL_NAME_RADIO = "RADIO (visual model)"
38
+
39
  # ---- Default TB/Lung weights ----
40
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
41
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
 
48
  RADIO_RAW_HEAD_PATH = "weights/best_raw.pt"
49
  RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt"
50
 
 
51
  RADIO_IMG_SIZE = 320
52
  RADIO_PATCH_SIZE = 16
53
  RADIO_THR_SCREEN = 0.05
 
57
 
58
  # ---- Consensus logic thresholds ----
59
  TBNET_SCREEN_THR = 0.30
 
 
60
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
 
61
 
62
  # ---- Mask fail-safes ----
63
  FAIL_COV = 0.10 # <10% -> segmentation fail
64
  WARN_COV = 0.18 # <18% -> warn
65
+ FAILSAFE_ON_BAD_MASK = True # fail-safe on suspicious/cropped masks
 
66
 
67
  # ---- Device policy ----
68
+ FORCE_CPU = True # HF CPU space: keep True
69
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
70
 
71
 
 
74
  # ============================================================
75
  CLINICAL_DISCLAIMER = """
76
  ⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
77
+
78
+ This system is for **research/decision support** and is **NOT** a diagnostic device.
79
+ It may **miss early/subtle tuberculosis**, including **miliary TB**.
80
+ Phone photos / screenshots / downsampled images can reduce reliability.
81
 
82
  If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
83
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
84
  """
85
 
86
  REPORT_LABELS = {
87
+ "GREEN": {
88
+ "title": "LOW TB LIKELIHOOD",
89
+ "summary": (
90
+ f"✅ **{MODEL_NAME_TBNET}** did not find patterns that strongly suggest pulmonary tuberculosis.\n\n"
91
+ "**What to do next:** If symptoms or TB risk factors are present, please seek clinician/radiologist review."
92
+ ),
93
+ },
94
+ "YELLOW": {
95
+ "title": "INDETERMINATE — REVIEW RECOMMENDED",
96
+ "summary": (
97
+ f"⚠️ **{MODEL_NAME_TBNET}** result is **not definitive**.\n\n"
98
+ "**Common reasons:** image quality limitations, non-standard/cropped view, or non-focal attention.\n\n"
99
+ "**What to do next:** Radiologist/clinician review is recommended. "
100
+ "If TB is clinically suspected, consider microbiological tests (CBNAAT/GeneXpert, sputum)."
101
+ ),
102
+ },
103
+ "RED": {
104
+ "title": "TB FEATURES SUSPECTED",
105
+ "summary": (
106
+ f"🚩 **{MODEL_NAME_TBNET}** detected lung patterns that can be seen with pulmonary tuberculosis.\n\n"
107
+ "**Important:** This is not a diagnosis.\n\n"
108
+ "**What to do next:** Urgent clinician/radiologist review and microbiological confirmation "
109
+ "(CBNAAT/GeneXpert, sputum) are recommended."
110
+ ),
111
+ },
112
  }
113
 
114
  CLINICAL_GUIDANCE = (
 
132
  nn.BatchNorm2d(out_c),
133
  nn.ReLU(inplace=True),
134
  )
135
+
136
+ def forward(self, x):
137
+ return self.net(x)
138
 
139
  class LungUNet(nn.Module):
140
  def __init__(self):
 
176
  super().__init__()
177
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
178
  self.fc = nn.Linear(self.backbone.num_features, 1)
179
+
180
+ def forward(self, x):
181
+ return self.fc(self.backbone(x)).view(-1)
182
 
183
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
184
  sd = torch.load(ckpt_path, map_location=device)
 
192
  target_layer.register_forward_hook(self._fwd)
193
  target_layer.register_full_backward_hook(self._bwd)
194
 
195
+ def _fwd(self, _, __, out):
196
+ self.activ = out
197
+
198
+ def _bwd(self, _, grad_in, grad_out):
199
+ self.grad = grad_out[0]
200
 
201
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
202
  with torch.enable_grad():
 
254
  bot = gray_u8[-b:, :]
255
  left = gray_u8[:, :b]
256
  right = gray_u8[:, -b:]
257
+ def frac_border(x):
258
+ return float(((x < 15) | (x > 240)).mean())
259
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
260
 
261
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
262
  warnings: List[str] = []
263
  h, w = gray_u8.shape
 
264
  score = 100.0
265
 
266
  if min(h, w) < 400:
267
+ warnings.append("Low resolution (image may be downsampled).")
268
  score -= 8
269
 
270
  sharp = laplacian_sharpness(gray_u8)
 
275
 
276
  if likely_phone:
277
  if sharp < 40:
278
+ score -= 25
279
+ warnings.append("Blurry / motion blur detected (likely phone capture).")
280
  elif sharp < 80:
281
+ score -= 12
282
+ warnings.append("Slight blur detected.")
283
  else:
284
  if sharp < 30:
285
+ score -= 8
286
+ warnings.append("Low fine detail (possible downsampling).")
287
 
288
  if hi_clip > 0.05:
289
+ score -= 15
290
+ warnings.append("Overexposed highlights (washed-out areas).")
291
  if lo_clip > 0.10:
292
+ score -= 12
293
+ warnings.append("Underexposed shadows (very dark areas).")
294
 
295
  if border > 0.55:
296
+ score -= 18
297
+ warnings.append("Large border/margins detected (possible screenshot/phone framing).")
298
  elif border > 0.35:
299
+ score -= 10
300
+ warnings.append("Some border/margins detected.")
301
 
302
  return float(np.clip(score, 0, 100)), warnings
303
 
 
305
  g = gray_u8.copy()
306
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
307
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
308
+ if th.mean() > 127:
309
+ th = 255 - th
310
 
311
  k = max(3, int(0.01 * min(g.shape)))
312
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
313
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
314
 
315
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
316
+ if not contours:
317
+ return gray_u8
318
 
319
  c = max(contours, key=cv2.contourArea)
320
  x, y, w, h = cv2.boundingRect(c)
321
  H, W = gray_u8.shape
322
+ if w * h < 0.20 * (H * W):
323
+ return gray_u8
324
 
325
  pad = int(0.03 * min(H, W))
326
  x1 = max(0, x - pad); y1 = max(0, y - pad)
 
362
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
363
  if quality_score < 55:
364
  return False
365
+ # Only apply diffuse-risk heuristic in near-threshold negatives
 
366
  if prob_tb < 0.05:
367
  return False
 
368
  ent = cam_entropy(cam_up)
369
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
370
 
 
371
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
372
+ # Very low probability should not be forced to YELLOW just because attention is diffuse
373
+ if prob_tb < 0.01 and quality_score >= 45:
374
+ return ("GREEN", "✅ Very low TB signal detected by the CNN model.")
375
  if quality_score < 55:
376
+ return ("YELLOW", "⚠️ Image quality is low; treat the result as indeterminate.")
377
  if diffuse:
378
+ return ("YELLOW", "⚠️ Attention is non-focal; treat the result as indeterminate.")
379
  if prob_tb >= TBNET_SCREEN_THR:
380
+ return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
381
+ return ("GREEN", " No strong TB signal detected by the CNN model.")
382
 
383
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
384
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
 
389
  m = (binary_u8 * 255).astype(np.uint8)
390
  h, w = m.shape
391
  flood = m.copy()
392
+ mask = np.zeros((h + 2, w + 2), np.uint8)
393
  cv2.floodFill(flood, mask, (0, 0), 255)
394
  holes = cv2.bitwise_not(flood)
395
  filled = cv2.bitwise_or(m, holes)
 
416
  warns = []
417
 
418
  if n <= 2:
419
+ warns.append("Only one lung region detected (possible crop/segmentation failure).")
420
  return warns
421
 
422
  areas = []
 
428
  top2 = areas[1] if len(areas) > 1 else 0
429
 
430
  if total > 0 and top1 / total > 0.80:
431
+ warns.append("Mask dominated by a single region (possible cropped/partial lung view).")
432
 
433
  border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
434
  if border.mean() > 0.05:
435
+ warns.append("Lung mask touches the image border (possible cropped/non-standard CXR).")
436
 
437
  if total > 0 and (top1 + top2) / total < 0.90:
438
+ warns.append("Mask appears fragmented (may reduce reliability).")
439
 
440
  return warns
441
 
442
  def recommendation_for_band(band: Optional[str]) -> str:
443
  if band in (None, "YELLOW"):
444
+ return "✅ Recommendation: Radiologist/clinician review is recommended (result is indeterminate)."
445
  if band == "RED":
446
+ return "✅ Recommendation: Urgent clinical review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
447
+ return "✅ Recommendation: If symptoms/risk factors exist, clinical correlation is advised."
448
 
449
 
450
  # ============================================================
451
+ # AGREEMENT LOGIC (TBNet vs RADIO) — 3-state
452
  # ============================================================
453
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
454
  if tb_band == "RED":
 
473
  ) -> Tuple[str, str]:
474
 
475
  if tb_prob is None or tb_band is None:
476
+ return ("N/A", f"{MODEL_NAME_TBNET} not available (lung segmentation failed / fail-safe).")
477
 
478
  # PRIMARY = masked if available else raw
479
  if radio_masked is not None:
 
484
  radio_used = "RAW"
485
 
486
  if radio_primary is None:
487
+ return ("TBNet only", f"{MODEL_NAME_RADIO} not available {MODEL_NAME_TBNET} probability={tb_prob:.4f} (band={tb_band}).")
488
 
489
  t = tbnet_state(tb_prob, tb_band)
490
  r = radio_state_from_prob(radio_primary)
 
494
  if t == r:
495
  return (
496
  f"AGREE: {t}",
497
+ f"Both models indicate **{t}**. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
498
  )
499
 
500
+ # strong disagreement: one says SCREEN+/TB+ and the other says LOW
501
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
502
  return (
503
  "DISAGREE",
504
+ f"Models disagree: {MODEL_NAME_TBNET} suggests **{t}** (band={tb_band}) vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
505
  )
506
 
507
  return (
508
  "MIXED/INDET",
509
+ f"Mixed signals: {MODEL_NAME_TBNET} suggests **{t}** vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
510
  )
511
 
512
 
 
537
  load_tb_weights(tb, tb_weights, self.device)
538
  tb.eval()
539
  self.tb = tb
540
+ # EfficientNet in timm has conv_head on effb0
541
  self.cammer = GradCAM(tb, tb.backbone.conv_head)
542
  self.tb_path = tb_weights
543
  self.backbone = backbone
 
566
  nn.Dropout(dropout),
567
  nn.Linear(hidden, 1),
568
  )
569
+
570
  def forward(self, x: torch.Tensor) -> torch.Tensor:
571
  return self.net(x).squeeze(1)
572
 
 
634
  img = rgb_u8.astype(np.float32) / 255.0
635
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
636
  out = img.copy()
637
+ # subtle red overlay
638
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
639
  return (out * 255).astype(np.uint8)
640
 
 
702
 
703
  if prob_primary >= RADIO_THR_RED:
704
  band = "RED"
705
+ pred = "TB FEATURES SUSPECTED (RADIO)"
706
  elif prob_primary >= RADIO_THR_SCREEN:
707
  band = "YELLOW"
708
+ pred = "SCREEN-POSITIVE RANGE (RADIO)"
709
  else:
710
  band = "GREEN"
711
  pred = "LOW TB LIKELIHOOD (RADIO)"
 
731
  tb_weights: str,
732
  lung_weights: str,
733
  backbone: str,
734
+ threshold: float, # kept for UI compatibility; main logic uses TBNET_SCREEN_THR
735
  phone_mode: bool,
736
  img_size: int = 224,
737
  fail_cov: float = FAIL_COV,
 
773
  "logit": None,
774
  "pred": "INDETERMINATE",
775
  "band": "YELLOW",
776
+ "band_text": (
777
+ "⚠️ The lung segmentation looks unreliable, so the TB screening score is disabled for safety.\n\n"
778
+ "Please use a clearer frontal chest X-ray (PA/AP) or seek radiologist review."
779
+ ),
780
  "quality_score": float(q_score),
781
  "diffuse_risk": False,
782
  "warnings": (
783
+ [f"Lung segmentation coverage is too small ({coverage*100:.1f}%)."]
784
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
785
  + q_warn
786
  ),
 
804
  "logit": None,
805
  "pred": "INDETERMINATE",
806
  "band": "YELLOW",
807
+ "band_text": (
808
+ "⚠️ The image appears cropped/non-standard (mask sanity check). "
809
+ "TB screening score is disabled for safety.\n\n"
810
+ "Please use a standard frontal chest X-ray (PA/AP) or seek radiologist review."
811
+ ),
812
  "quality_score": float(q_score),
813
  "diffuse_risk": False,
814
  "warnings": (
 
843
  cam_up = cam_u8.astype(np.float32) / 255.0
844
 
845
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
846
+ band_base, _band_hint = confidence_band(prob_tb, q_score, diffuse)
847
 
848
+ allow_red = (prob_tb >= 0.70 and q_score >= 55 and (not diffuse) and coverage >= warn_cov)
849
  band = "RED" if allow_red else band_base
850
 
851
  pred = REPORT_LABELS[band]["title"]
852
  band_text = REPORT_LABELS[band]["summary"]
853
 
854
+ # Extra helpful line if YELLOW but probability is very low
855
+ if band == "YELLOW" and prob_tb < 0.05:
856
  band_text = (
857
+ "���️ TB probability is very low, but the result is marked **indeterminate** because reliability is limited.\n\n"
858
+ + band_text
 
859
  )
860
 
861
+ # Build Grad-CAM overlay
862
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
863
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
864
 
865
  overlay_annotated = overlay_clean.copy()
866
  text1 = f"{band}: {pred}"
867
+ text2 = f"TBNet prob={prob_tb:.3f} | Quality={q_score:.0f}/100 | Lung coverage={coverage*100:.1f}%"
868
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
869
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
870
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
871
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
872
 
873
  warnings = []
874
+ if phone_mode:
875
+ warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
876
+ if q_score < 55:
877
+ warnings.append("Image quality is low; reliability may be reduced.")
878
+ if coverage < warn_cov:
879
+ warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
880
+ if diffuse:
881
+ warnings.append("Non-focal attention pattern; result treated cautiously.")
882
  warnings.extend(q_warn)
883
 
884
  return {
 
949
  # -------------------------
950
  # RADIO (optional)
951
  # -------------------------
952
+ radio_text = f"{MODEL_NAME_RADIO} is disabled."
953
  radio_raw_overlay = None
954
  radio_masked_overlay = None
955
  radio_raw_val: Optional[float] = None
 
977
  radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
978
 
979
  radio_text = (
980
+ f"**{MODEL_NAME_RADIO} result:** {r['pred']} | RAW={radio_raw_val:.4f}"
981
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
982
  + f" | Band={radio_band}"
983
  )
984
  radio_raw_overlay = r["raw_overlay"]
985
  radio_masked_overlay = r["masked_overlay"]
986
  except Exception as e:
987
+ radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
988
  radio_raw_str = ""
989
  radio_masked_str = ""
990
  radio_raw_val = None
 
992
  radio_band = None
993
 
994
  # -------------------------
995
+ # Agreement between models
996
  # -------------------------
997
  consensus_label, consensus_detail = build_consensus(
998
  tb_prob=out["prob"],
 
1037
  proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1038
  gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
1039
 
1040
+ gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay ({MODEL_NAME_TBNET})"))
1041
 
1042
  if radio_raw_overlay is not None:
1043
  gallery_items.append((cv2.resize(radio_raw_overlay, (512, 512)), f"{name} • RADIO RAW heatmap"))
 
1045
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1046
 
1047
  # -------------------------
1048
+ # Details panel (user-friendly)
1049
  # -------------------------
1050
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1051
+ tb_line = "N/A (disabled by fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1052
  rec_line = recommendation_for_band(out.get("band"))
1053
 
1054
  details_md.append(
1055
  f"""### {name}
1056
 
1057
+ **{MODEL_NAME_TBNET} result:** **{out['pred']}**
1058
  {rec_line}
1059
 
1060
+ **What this means**
 
 
1061
  {out['band_text']}
1062
 
1063
+ **Why it decided this**
1064
+ - {MODEL_NAME_TBNET} probability: {tb_line}
1065
+ - Image quality: {out['quality_score']:.0f}/100
1066
+ - Lung mask coverage: {out.get('lung_coverage', 0.0) * 100:.1f}%
1067
+ - Attention pattern (TBNet): {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
1068
 
1069
+ **Notes that may affect reliability**
1070
  {warn_txt}
1071
 
1072
+ **{MODEL_NAME_RADIO} output**
1073
  {radio_text}
1074
 
1075
+ **Agreement between models (TBNet vs RADIO):** **{consensus_label}**
1076
  - {consensus_detail}
1077
 
1078
+ **Clinical guidance**
1079
  {CLINICAL_GUIDANCE}
1080
 
1081
  ---
 
1097
 
1098
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1099
  gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1100
+ gr.Markdown(
1101
+ f"<div class='subtitle'>Lung U-Net masking → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1102
+ f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Agreement summary</div>"
1103
+ )
1104
 
1105
  with gr.Row():
1106
  with gr.Column(scale=1):
1107
  gr.Markdown("#### Model settings")
1108
 
1109
+ tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS)
1110
+ lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
1111
 
1112
+ backbone = gr.Dropdown(choices=["efficientnet_b0"], value="efficientnet_b0", label="TBNet backbone")
1113
 
1114
+ threshold = gr.Slider(
1115
+ 0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1116
+ label=f"Reference threshold (TBNet SCREEN+) = {TBNET_SCREEN_THR:.2f}"
1117
+ )
1118
 
1119
+ phone_mode = gr.Checkbox(
1120
+ value=False,
1121
+ label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
1122
+ )
1123
 
1124
  # RADIO
1125
+ use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
1126
+ radio_gate = gr.Slider(
1127
+ 0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
1128
+ label="RADIO masked gate (run masked head if lung coverage ≥ gate)"
1129
+ )
1130
 
1131
  gr.Markdown(
1132
+ "<div class='warnbox'><b>Fail-safe:</b> If lung segmentation is too small or looks unreliable, "
1133
+ "TBNet scoring is disabled to avoid unsafe outputs.</div>"
1134
  )
1135
 
1136
  gr.Markdown(
1137
+ f"<div class='subtitle'>Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})</div>"
1138
  )
1139
 
1140
  with gr.Column(scale=2):
1141
  gr.Markdown("#### Upload images")
1142
+ files = gr.Files(
1143
+ label="Upload one or multiple chest X-ray images",
1144
+ file_types=[".png", ".jpg", ".jpeg", ".bmp"]
1145
+ )
1146
  run_btn = gr.Button("Run Analysis", variant="primary")
1147
  status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1148
 
 
1150
  table = gr.Dataframe(
1151
  headers=[
1152
  "Image",
1153
+ "TBNet Probability",
1154
+ "TBNet Result",
1155
  "Band",
1156
+ "Meaning",
1157
  "Quality",
1158
  "LungCov",
1159
  "RADIO RAW",
1160
  "RADIO MASKED",
1161
+ "AGREEMENT",
1162
  ],
1163
  datatype=["str","str","str","str","str","str","str","str","str","str"],
1164
  interactive=False,