drvikasgaur commited on
Commit
8c9cbe1
·
verified ·
1 Parent(s): 27745a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -88
app.py CHANGED
@@ -3,8 +3,16 @@
3
  # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
4
  # + 3-STATE AGREEMENT (LOW / SCREEN+ / TB+)
5
  #
6
- # Requirements:
7
- # pip install gradio timm torchvision opencv-python pillow transformers einops torch
 
 
 
 
 
 
 
 
8
  #
9
  # HF Spaces notes:
10
  # - weights are expected in ./weights/
@@ -32,19 +40,15 @@ from PIL import Image
32
  # USER CONFIG (HF Spaces friendly)
33
  # ============================================================
34
 
35
- # ---- Friendly model names for UI ----
36
  MODEL_NAME_TBNET = "TBNet (CNN model)"
37
  MODEL_NAME_RADIO = "RADIO (visual model)"
38
 
39
- # ---- Default TB/Lung weights ----
40
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
41
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
42
 
43
- # ---- RADIO config (same env as TB) ----
44
  RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
45
  RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
46
 
47
- # Your trained heads stored in this Space repo
48
  RADIO_RAW_HEAD_PATH = "weights/best_raw.pt"
49
  RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt"
50
 
@@ -55,17 +59,14 @@ RADIO_THR_RED = 0.23
55
  RADIO_MASKED_MIN_COV = 0.15
56
  RADIO_GATE_DEFAULT = 0.21
57
 
58
- # ---- Consensus logic thresholds ----
59
  TBNET_SCREEN_THR = 0.30
60
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
61
 
62
- # ---- Mask fail-safes ----
63
- FAIL_COV = 0.10 # <10% -> segmentation fail
64
- WARN_COV = 0.18 # <18% -> warn
65
- FAILSAFE_ON_BAD_MASK = True # fail-safe on suspicious/cropped masks
66
 
67
- # ---- Device policy ----
68
- FORCE_CPU = True # HF CPU space: keep True
69
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
70
 
71
 
@@ -118,6 +119,63 @@ CLINICAL_GUIDANCE = (
118
  )
119
 
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  # ============================================================
122
  # LUNG U-NET (INFERENCE)
123
  # ============================================================
@@ -332,12 +390,6 @@ def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
332
  return clahe.apply(gray_u8)
333
 
334
  def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
335
- """
336
- Safer phone preprocessing:
337
- - only crop if border artifacts suggest a framed/screenshot input
338
- - only apply CLAHE if underexposed or low sharpness
339
- - crop sanity check to avoid destroying clean digital CXRs
340
- """
341
  sharp = laplacian_sharpness(gray_u8)
342
  lo_clip, _hi_clip = exposure_scores(gray_u8)
343
  border = border_fraction(gray_u8)
@@ -362,23 +414,21 @@ def cam_entropy(cam: np.ndarray) -> float:
362
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
363
  if quality_score < 55:
364
  return False
365
- # Only apply diffuse-risk heuristic in near-threshold negatives
366
  if prob_tb < 0.05:
367
  return False
368
  ent = cam_entropy(cam_up)
369
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
370
 
371
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
372
- # Very low probability should not be forced to YELLOW just because attention is diffuse
373
  if prob_tb < 0.01 and quality_score >= 45:
374
- return ("GREEN", "✅ Very low TB signal detected by the CNN model.")
375
  if quality_score < 55:
376
- return ("YELLOW", "⚠️ Image quality is low; treat the result as indeterminate.")
377
  if diffuse:
378
- return ("YELLOW", "⚠️ Attention is non-focal; treat the result as indeterminate.")
379
  if prob_tb >= TBNET_SCREEN_THR:
380
  return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
381
- return ("GREEN", "✅ No strong TB signal detected by the CNN model.")
382
 
383
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
384
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
@@ -441,14 +491,14 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
441
 
442
  def recommendation_for_band(band: Optional[str]) -> str:
443
  if band in (None, "YELLOW"):
444
- return "✅ Recommendation: Radiologist/clinician review is recommended (result is indeterminate)."
445
  if band == "RED":
446
- return "✅ Recommendation: Urgent clinical review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
447
- return "✅ Recommendation: If symptoms/risk factors exist, clinical correlation is advised."
448
 
449
 
450
  # ============================================================
451
- # AGREEMENT LOGIC (TBNet vs RADIO) — 3-state
452
  # ============================================================
453
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
454
  if tb_band == "RED":
@@ -475,7 +525,6 @@ def build_consensus(
475
  if tb_prob is None or tb_band is None:
476
  return ("N/A", f"{MODEL_NAME_TBNET} not available (lung segmentation failed / fail-safe).")
477
 
478
- # PRIMARY = masked if available else raw
479
  if radio_masked is not None:
480
  radio_primary = radio_masked
481
  radio_used = "MASKED"
@@ -484,12 +533,11 @@ def build_consensus(
484
  radio_used = "RAW"
485
 
486
  if radio_primary is None:
487
- return ("TBNet only", f"{MODEL_NAME_RADIO} not available → {MODEL_NAME_TBNET} probability={tb_prob:.4f} (band={tb_band}).")
488
 
489
  t = tbnet_state(tb_prob, tb_band)
490
  r = radio_state_from_prob(radio_primary)
491
-
492
- rb = f" (RADIO band={radio_band})" if radio_band else ""
493
 
494
  if t == r:
495
  return (
@@ -497,11 +545,10 @@ def build_consensus(
497
  f"Both models indicate **{t}**. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
498
  )
499
 
500
- # strong disagreement: one says SCREEN+/TB+ and the other says LOW
501
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
502
  return (
503
  "DISAGREE",
504
- f"Models disagree: {MODEL_NAME_TBNET} suggests **{t}** (band={tb_band}) vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
505
  )
506
 
507
  return (
@@ -537,7 +584,6 @@ class ModelBundle:
537
  load_tb_weights(tb, tb_weights, self.device)
538
  tb.eval()
539
  self.tb = tb
540
- # EfficientNet in timm has conv_head on effb0
541
  self.cammer = GradCAM(tb, tb.backbone.conv_head)
542
  self.tb_path = tb_weights
543
  self.backbone = backbone
@@ -634,7 +680,6 @@ def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: floa
634
  img = rgb_u8.astype(np.float32) / 255.0
635
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
636
  out = img.copy()
637
- # subtle red overlay
638
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
639
  return (out * 255).astype(np.uint8)
640
 
@@ -647,7 +692,6 @@ def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
647
  RADIO_BUNDLE.load(device=device)
648
  dtype = torch.float16 if device.type == "cuda" else torch.float32
649
 
650
- # ---------- RAW ----------
651
  raw_rgb = cv2.cvtColor(gray_vis_u8, cv2.COLOR_GRAY2RGB)
652
  px = RADIO_BUNDLE.processor(
653
  images=Image.fromarray(raw_rgb),
@@ -668,7 +712,6 @@ def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
668
  alpha=0.35
669
  )
670
 
671
- # ---------- MASKED (optional) ----------
672
  masked_prob = None
673
  masked_overlay = None
674
  masked_ran = False
@@ -697,7 +740,6 @@ def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
697
  alpha=0.35
698
  )
699
 
700
- # ---------- PRIMARY = masked if available else raw ----------
701
  prob_primary = masked_prob if masked_prob is not None else prob_raw
702
 
703
  if prob_primary >= RADIO_THR_RED:
@@ -731,7 +773,7 @@ def analyze_one_image(
731
  tb_weights: str,
732
  lung_weights: str,
733
  backbone: str,
734
- threshold: float, # kept for UI compatibility; main logic uses TBNET_SCREEN_THR
735
  phone_mode: bool,
736
  img_size: int = 224,
737
  fail_cov: float = FAIL_COV,
@@ -754,8 +796,6 @@ def analyze_one_image(
754
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
755
 
756
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
757
-
758
- # post-process: keep 2 lungs, close, fill holes
759
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
760
  k = max(3, int(0.02 * 256))
761
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
@@ -765,7 +805,6 @@ def analyze_one_image(
765
  coverage = float(mask256_bin.mean())
766
  mask_full = cv2.resize(mask256_bin, (gray_vis.shape[1], gray_vis.shape[0]), interpolation=cv2.INTER_NEAREST)
767
 
768
- # fail-safe if coverage too low
769
  if coverage < fail_cov:
770
  overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
771
  return {
@@ -774,13 +813,13 @@ def analyze_one_image(
774
  "pred": "INDETERMINATE",
775
  "band": "YELLOW",
776
  "band_text": (
777
- "⚠️ The lung segmentation looks unreliable, so the TB screening score is disabled for safety.\n\n"
778
- "Please use a clearer frontal chest X-ray (PA/AP) or seek radiologist review."
779
  ),
780
  "quality_score": float(q_score),
781
  "diffuse_risk": False,
782
  "warnings": (
783
- [f"Lung segmentation coverage is too small ({coverage*100:.1f}%)."]
784
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
785
  + q_warn
786
  ),
@@ -795,7 +834,6 @@ def analyze_one_image(
795
  "overlay_clean": overlay_rgb,
796
  }
797
 
798
- # fail-safe if mask looks like single lung / cropped
799
  sanity = mask_sanity_warnings(mask_full.astype(np.uint8))
800
  if FAILSAFE_ON_BAD_MASK and sanity:
801
  overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
@@ -806,8 +844,8 @@ def analyze_one_image(
806
  "band": "YELLOW",
807
  "band_text": (
808
  "⚠️ The image appears cropped/non-standard (mask sanity check). "
809
- "TB screening score is disabled for safety.\n\n"
810
- "Please use a standard frontal chest X-ray (PA/AP) or seek radiologist review."
811
  ),
812
  "quality_score": float(q_score),
813
  "diffuse_risk": False,
@@ -843,7 +881,7 @@ def analyze_one_image(
843
  cam_up = cam_u8.astype(np.float32) / 255.0
844
 
845
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
846
- band_base, _band_hint = confidence_band(prob_tb, q_score, diffuse)
847
 
848
  allow_red = (prob_tb >= 0.70 and q_score >= 55 and (not diffuse) and coverage >= warn_cov)
849
  band = "RED" if allow_red else band_base
@@ -851,14 +889,12 @@ def analyze_one_image(
851
  pred = REPORT_LABELS[band]["title"]
852
  band_text = REPORT_LABELS[band]["summary"]
853
 
854
- # Extra helpful line if YELLOW but probability is very low
855
  if band == "YELLOW" and prob_tb < 0.05:
856
  band_text = (
857
  "⚠️ TB probability is very low, but the result is marked **indeterminate** because reliability is limited.\n\n"
858
  + band_text
859
  )
860
 
861
- # Build Grad-CAM overlay
862
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
863
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
864
 
@@ -946,14 +982,13 @@ def run_analysis(
946
  img_size=224,
947
  )
948
 
949
- # -------------------------
950
  # RADIO (optional)
951
- # -------------------------
952
  radio_text = f"{MODEL_NAME_RADIO} is disabled."
953
  radio_raw_overlay = None
954
  radio_masked_overlay = None
955
  radio_raw_val: Optional[float] = None
956
  radio_masked_val: Optional[float] = None
 
957
  radio_band: Optional[str] = None
958
 
959
  radio_raw_str = ""
@@ -968,8 +1003,8 @@ def run_analysis(
968
  device=BUNDLE.device,
969
  gate_threshold=float(radio_gate),
970
  )
971
-
972
  radio_raw_val = float(r["prob_raw"])
 
973
  radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
974
  radio_band = str(r["band"])
975
 
@@ -977,7 +1012,8 @@ def run_analysis(
977
  radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
978
 
979
  radio_text = (
980
- f"**{MODEL_NAME_RADIO} result:** {r['pred']} | RAW={radio_raw_val:.4f}"
 
981
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
982
  + f" | Band={radio_band}"
983
  )
@@ -985,15 +1021,11 @@ def run_analysis(
985
  radio_masked_overlay = r["masked_overlay"]
986
  except Exception as e:
987
  radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
988
- radio_raw_str = ""
989
- radio_masked_str = ""
990
  radio_raw_val = None
991
  radio_masked_val = None
 
992
  radio_band = None
993
 
994
- # -------------------------
995
- # Agreement between models
996
- # -------------------------
997
  consensus_label, consensus_detail = build_consensus(
998
  tb_prob=out["prob"],
999
  tb_band=out["band"],
@@ -1002,9 +1034,7 @@ def run_analysis(
1002
  radio_band=radio_band,
1003
  )
1004
 
1005
- # -------------------------
1006
  # Table row
1007
- # -------------------------
1008
  prob_str = "" if out["prob"] is None else f"{out['prob']:.4f}"
1009
  cov_str = f"{out.get('lung_coverage', 0.0) * 100:.1f}%"
1010
 
@@ -1021,9 +1051,7 @@ def run_analysis(
1021
  consensus_label,
1022
  ])
1023
 
1024
- # -------------------------
1025
- # Visual outputs
1026
- # -------------------------
1027
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1028
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1029
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
@@ -1044,36 +1072,56 @@ def run_analysis(
1044
  if radio_masked_overlay is not None:
1045
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1046
 
1047
- # -------------------------
1048
- # Details panel (user-friendly)
1049
- # -------------------------
 
 
 
 
 
 
 
 
 
1050
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1051
  tb_line = "N/A (disabled by fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1052
  rec_line = recommendation_for_band(out.get("band"))
1053
 
1054
  details_md.append(
1055
- f"""### {name}
 
 
 
 
 
 
 
 
 
 
1056
 
1057
- **{MODEL_NAME_TBNET} result:** **{out['pred']}**
1058
- {rec_line}
1059
 
1060
- **What this means**
1061
- {out['band_text']}
1062
 
1063
- **Why it decided this**
1064
- - {MODEL_NAME_TBNET} probability: {tb_line}
1065
- - Image quality: {out['quality_score']:.0f}/100
1066
- - Lung mask coverage: {out.get('lung_coverage', 0.0) * 100:.1f}%
1067
- - Attention pattern (TBNet): {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
 
 
 
 
1068
 
1069
  **Notes that may affect reliability**
1070
  {warn_txt}
1071
 
1072
- **{MODEL_NAME_RADIO} output**
1073
- {radio_text}
1074
 
1075
- **Agreement between models (TBNet vs RADIO):** **{consensus_label}**
1076
- - {consensus_detail}
1077
 
1078
  **Clinical guidance**
1079
  {CLINICAL_GUIDANCE}
@@ -1096,10 +1144,10 @@ def build_ui():
1096
  """
1097
 
1098
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1099
- gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1100
  gr.Markdown(
1101
- f"<div class='subtitle'>Lung U-Net masking → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1102
- f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Agreement summary</div>"
1103
  )
1104
 
1105
  with gr.Row():
@@ -1121,7 +1169,6 @@ def build_ui():
1121
  label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
1122
  )
1123
 
1124
- # RADIO
1125
  use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
1126
  radio_gate = gr.Slider(
1127
  0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
 
3
  # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
4
  # + 3-STATE AGREEMENT (LOW / SCREEN+ / TB+)
5
  #
6
+ # Requirements (requirements.txt):
7
+ # gradio
8
+ # torch
9
+ # torchvision
10
+ # timm
11
+ # opencv-python
12
+ # pillow
13
+ # transformers
14
+ # einops
15
+ # open_clip_torch
16
  #
17
  # HF Spaces notes:
18
  # - weights are expected in ./weights/
 
40
  # USER CONFIG (HF Spaces friendly)
41
  # ============================================================
42
 
 
43
  MODEL_NAME_TBNET = "TBNet (CNN model)"
44
  MODEL_NAME_RADIO = "RADIO (visual model)"
45
 
 
46
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
47
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
48
 
 
49
  RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
50
  RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
51
 
 
52
  RADIO_RAW_HEAD_PATH = "weights/best_raw.pt"
53
  RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt"
54
 
 
59
  RADIO_MASKED_MIN_COV = 0.15
60
  RADIO_GATE_DEFAULT = 0.21
61
 
 
62
  TBNET_SCREEN_THR = 0.30
63
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
64
 
65
+ FAIL_COV = 0.10
66
+ WARN_COV = 0.18
67
+ FAILSAFE_ON_BAD_MASK = True
 
68
 
69
+ FORCE_CPU = True
 
70
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
71
 
72
 
 
119
  )
120
 
121
 
122
+ # ============================================================
123
+ # USER-FRIENDLY SUMMARY BUILDER
124
+ # ============================================================
125
+ def overall_summary(tb_band: Optional[str],
126
+ tb_prob: Optional[float],
127
+ radio_primary: Optional[float],
128
+ radio_band: Optional[str],
129
+ consensus_label: str,
130
+ q_score: float,
131
+ cov: float,
132
+ warnings: List[str]) -> str:
133
+ # Overall label from agreement (keeps it simple for users)
134
+ if tb_prob is None:
135
+ overall_title = "INDETERMINATE — NEEDS REVIEW"
136
+ icon = "⚠️"
137
+ else:
138
+ if "AGREE: LOW" in consensus_label:
139
+ overall_title = "LOW TB LIKELIHOOD"
140
+ icon = "✅"
141
+ elif "AGREE: TB+" in consensus_label:
142
+ overall_title = "TB FEATURES SUSPECTED"
143
+ icon = "🚩"
144
+ elif "AGREE: SCREEN+" in consensus_label:
145
+ overall_title = "SCREEN-POSITIVE — REVIEW RECOMMENDED"
146
+ icon = "⚠️"
147
+ else:
148
+ overall_title = "INDETERMINATE — REVIEW RECOMMENDED"
149
+ icon = "⚠️"
150
+
151
+ reliability = "Good" if (q_score >= 70 and cov >= WARN_COV) else "Limited"
152
+ rel_icon = "🟢" if reliability == "Good" else "🟡"
153
+
154
+ warn_line = "None" if not warnings else f"{len(warnings)} note(s) below"
155
+
156
+ tb_prob_str = "N/A" if tb_prob is None else f"{tb_prob:.4f}"
157
+ radio_str = "N/A" if radio_primary is None else f"{radio_primary:.4f}"
158
+
159
+ return f"""
160
+ ## {icon} Overall screening result: **{overall_title}**
161
+
162
+ **Reliability:** {rel_icon} **{reliability}** (Quality: {q_score:.0f}/100 • Lung coverage: {cov*100:.1f}% • Notes: {warn_line})
163
+
164
+ ### What this means
165
+ - This is a **screening support tool**, not a diagnosis.
166
+ - Two models analyze the same image: a **CNN model** (TBNet) and a **visual model** (RADIO).
167
+
168
+ ### Model agreement
169
+ - **{consensus_label}**
170
+ - {MODEL_NAME_TBNET} probability: **{tb_prob_str}**
171
+ - {MODEL_NAME_RADIO} probability: **{radio_str}** {f"(band={radio_band})" if radio_band else ""}
172
+
173
+ ### What to do next
174
+ - If you have symptoms/risk factors, seek clinician/radiologist review.
175
+ - If TB is clinically suspected, consider **CBNAAT/GeneXpert** and sputum testing regardless of AI output.
176
+ """
177
+
178
+
179
  # ============================================================
180
  # LUNG U-NET (INFERENCE)
181
  # ============================================================
 
390
  return clahe.apply(gray_u8)
391
 
392
  def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
 
 
 
 
 
 
393
  sharp = laplacian_sharpness(gray_u8)
394
  lo_clip, _hi_clip = exposure_scores(gray_u8)
395
  border = border_fraction(gray_u8)
 
414
  def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
415
  if quality_score < 55:
416
  return False
 
417
  if prob_tb < 0.05:
418
  return False
419
  ent = cam_entropy(cam_up)
420
  return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
421
 
422
  def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
 
423
  if prob_tb < 0.01 and quality_score >= 45:
424
+ return ("GREEN", "✅ Very low TB signal detected.")
425
  if quality_score < 55:
426
+ return ("YELLOW", "⚠️ Image quality is low; treat as indeterminate.")
427
  if diffuse:
428
+ return ("YELLOW", "⚠️ Attention is non-focal; treat as indeterminate.")
429
  if prob_tb >= TBNET_SCREEN_THR:
430
  return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
431
+ return ("GREEN", "✅ No strong TB signal detected.")
432
 
433
  def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
434
  base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
 
491
 
492
  def recommendation_for_band(band: Optional[str]) -> str:
493
  if band in (None, "YELLOW"):
494
+ return "Radiologist/clinician review is recommended (result is indeterminate)."
495
  if band == "RED":
496
+ return "Urgent clinical review + microbiological confirmation (CBNAAT/GeneXpert, sputum) recommended."
497
+ return "If symptoms/risk factors exist, clinical correlation is advised."
498
 
499
 
500
  # ============================================================
501
+ # AGREEMENT LOGIC (TBNet vs RADIO)
502
  # ============================================================
503
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
504
  if tb_band == "RED":
 
525
  if tb_prob is None or tb_band is None:
526
  return ("N/A", f"{MODEL_NAME_TBNET} not available (lung segmentation failed / fail-safe).")
527
 
 
528
  if radio_masked is not None:
529
  radio_primary = radio_masked
530
  radio_used = "MASKED"
 
533
  radio_used = "RAW"
534
 
535
  if radio_primary is None:
536
+ return ("TBNet only", f"{MODEL_NAME_RADIO} not available → {MODEL_NAME_TBNET}={tb_prob:.4f} (band={tb_band}).")
537
 
538
  t = tbnet_state(tb_prob, tb_band)
539
  r = radio_state_from_prob(radio_primary)
540
+ rb = f" (band={radio_band})" if radio_band else ""
 
541
 
542
  if t == r:
543
  return (
 
545
  f"Both models indicate **{t}**. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
546
  )
547
 
 
548
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
549
  return (
550
  "DISAGREE",
551
+ f"Models disagree: {MODEL_NAME_TBNET} suggests **{t}** vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
552
  )
553
 
554
  return (
 
584
  load_tb_weights(tb, tb_weights, self.device)
585
  tb.eval()
586
  self.tb = tb
 
587
  self.cammer = GradCAM(tb, tb.backbone.conv_head)
588
  self.tb_path = tb_weights
589
  self.backbone = backbone
 
680
  img = rgb_u8.astype(np.float32) / 255.0
681
  hm = np.clip(heatmap01, 0, 1).astype(np.float32)
682
  out = img.copy()
 
683
  out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
684
  return (out * 255).astype(np.uint8)
685
 
 
692
  RADIO_BUNDLE.load(device=device)
693
  dtype = torch.float16 if device.type == "cuda" else torch.float32
694
 
 
695
  raw_rgb = cv2.cvtColor(gray_vis_u8, cv2.COLOR_GRAY2RGB)
696
  px = RADIO_BUNDLE.processor(
697
  images=Image.fromarray(raw_rgb),
 
712
  alpha=0.35
713
  )
714
 
 
715
  masked_prob = None
716
  masked_overlay = None
717
  masked_ran = False
 
740
  alpha=0.35
741
  )
742
 
 
743
  prob_primary = masked_prob if masked_prob is not None else prob_raw
744
 
745
  if prob_primary >= RADIO_THR_RED:
 
773
  tb_weights: str,
774
  lung_weights: str,
775
  backbone: str,
776
+ threshold: float,
777
  phone_mode: bool,
778
  img_size: int = 224,
779
  fail_cov: float = FAIL_COV,
 
796
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
797
 
798
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
 
 
799
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
800
  k = max(3, int(0.02 * 256))
801
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
 
805
  coverage = float(mask256_bin.mean())
806
  mask_full = cv2.resize(mask256_bin, (gray_vis.shape[1], gray_vis.shape[0]), interpolation=cv2.INTER_NEAREST)
807
 
 
808
  if coverage < fail_cov:
809
  overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
810
  return {
 
813
  "pred": "INDETERMINATE",
814
  "band": "YELLOW",
815
  "band_text": (
816
+ "⚠️ Lung segmentation looks unreliable, so the TBNet screening score is disabled for safety.\n\n"
817
+ "Please use a clearer standard frontal CXR (PA/AP) or seek radiologist review."
818
  ),
819
  "quality_score": float(q_score),
820
  "diffuse_risk": False,
821
  "warnings": (
822
+ [f"Lung segmentation coverage too small ({coverage*100:.1f}%)."]
823
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
824
  + q_warn
825
  ),
 
834
  "overlay_clean": overlay_rgb,
835
  }
836
 
 
837
  sanity = mask_sanity_warnings(mask_full.astype(np.uint8))
838
  if FAILSAFE_ON_BAD_MASK and sanity:
839
  overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
 
844
  "band": "YELLOW",
845
  "band_text": (
846
  "⚠️ The image appears cropped/non-standard (mask sanity check). "
847
+ "TBNet screening score is disabled for safety.\n\n"
848
+ "Please use a standard frontal CXR (PA/AP) or seek radiologist review."
849
  ),
850
  "quality_score": float(q_score),
851
  "diffuse_risk": False,
 
881
  cam_up = cam_u8.astype(np.float32) / 255.0
882
 
883
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
884
+ band_base, _ = confidence_band(prob_tb, q_score, diffuse)
885
 
886
  allow_red = (prob_tb >= 0.70 and q_score >= 55 and (not diffuse) and coverage >= warn_cov)
887
  band = "RED" if allow_red else band_base
 
889
  pred = REPORT_LABELS[band]["title"]
890
  band_text = REPORT_LABELS[band]["summary"]
891
 
 
892
  if band == "YELLOW" and prob_tb < 0.05:
893
  band_text = (
894
  "⚠️ TB probability is very low, but the result is marked **indeterminate** because reliability is limited.\n\n"
895
  + band_text
896
  )
897
 
 
898
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
899
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
900
 
 
982
  img_size=224,
983
  )
984
 
 
985
  # RADIO (optional)
 
986
  radio_text = f"{MODEL_NAME_RADIO} is disabled."
987
  radio_raw_overlay = None
988
  radio_masked_overlay = None
989
  radio_raw_val: Optional[float] = None
990
  radio_masked_val: Optional[float] = None
991
+ radio_primary_val: Optional[float] = None
992
  radio_band: Optional[str] = None
993
 
994
  radio_raw_str = ""
 
1003
  device=BUNDLE.device,
1004
  gate_threshold=float(radio_gate),
1005
  )
 
1006
  radio_raw_val = float(r["prob_raw"])
1007
+ radio_primary_val = float(r["prob_primary"])
1008
  radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
1009
  radio_band = str(r["band"])
1010
 
 
1012
  radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1013
 
1014
  radio_text = (
1015
+ f"**{MODEL_NAME_RADIO} result:** {r['pred']} | "
1016
+ f"PRIMARY={radio_primary_val:.4f} | RAW={radio_raw_val:.4f}"
1017
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
1018
  + f" | Band={radio_band}"
1019
  )
 
1021
  radio_masked_overlay = r["masked_overlay"]
1022
  except Exception as e:
1023
  radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
 
 
1024
  radio_raw_val = None
1025
  radio_masked_val = None
1026
+ radio_primary_val = None
1027
  radio_band = None
1028
 
 
 
 
1029
  consensus_label, consensus_detail = build_consensus(
1030
  tb_prob=out["prob"],
1031
  tb_band=out["band"],
 
1034
  radio_band=radio_band,
1035
  )
1036
 
 
1037
  # Table row
 
1038
  prob_str = "" if out["prob"] is None else f"{out['prob']:.4f}"
1039
  cov_str = f"{out.get('lung_coverage', 0.0) * 100:.1f}%"
1040
 
 
1051
  consensus_label,
1052
  ])
1053
 
1054
+ # Gallery
 
 
1055
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1056
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1057
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
 
1072
  if radio_masked_overlay is not None:
1073
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1074
 
1075
+ # Details (dashboard style)
1076
+ summary_md = overall_summary(
1077
+ tb_band=out.get("band"),
1078
+ tb_prob=out.get("prob"),
1079
+ radio_primary=radio_primary_val,
1080
+ radio_band=radio_band,
1081
+ consensus_label=consensus_label,
1082
+ q_score=float(out["quality_score"]),
1083
+ cov=float(out.get("lung_coverage", 0.0)),
1084
+ warnings=out.get("warnings", []),
1085
+ )
1086
+
1087
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1088
  tb_line = "N/A (disabled by fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1089
  rec_line = recommendation_for_band(out.get("band"))
1090
 
1091
  details_md.append(
1092
+ f"""{summary_md}
1093
+
1094
+ ---
1095
+
1096
+ <details>
1097
+ <summary><b>{MODEL_NAME_TBNET} details</b></summary>
1098
+
1099
+ - **Result:** {out['pred']} ({out['band']})
1100
+ - **Recommendation:** {rec_line}
1101
+ - **Probability (screening score):** {tb_line}
1102
+ - **Attention pattern:** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
1103
 
1104
+ </details>
 
1105
 
1106
+ <details>
1107
+ <summary><b>{MODEL_NAME_RADIO} details</b></summary>
1108
 
1109
+ {radio_text}
1110
+
1111
+ </details>
1112
+
1113
+ <details>
1114
+ <summary><b>Image quality & segmentation</b></summary>
1115
+
1116
+ - **Quality score:** {out['quality_score']:.0f}/100
1117
+ - **Lung mask coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
1118
 
1119
  **Notes that may affect reliability**
1120
  {warn_txt}
1121
 
1122
+ </details>
 
1123
 
1124
+ ---
 
1125
 
1126
  **Clinical guidance**
1127
  {CLINICAL_GUIDANCE}
 
1144
  """
1145
 
1146
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1147
+ gr.Markdown('<div class="title">TB X-ray Assistant (Research Use)</div>')
1148
  gr.Markdown(
1149
+ f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1150
+ f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • User-friendly summary</div>"
1151
  )
1152
 
1153
  with gr.Row():
 
1169
  label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
1170
  )
1171
 
 
1172
  use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
1173
  radio_gate = gr.Slider(
1174
  0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,