drvikasgaur commited on
Commit
a95eba6
·
verified ·
1 Parent(s): 8c9cbe1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -220
app.py CHANGED
@@ -1,22 +1,20 @@
1
  # app.py
2
- # Gradio — TBNet (CNN model) + Lung U-Net Auto Mask + Grad-CAM + RADIO (visual model)
3
  # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
4
- # + 3-STATE AGREEMENT (LOW / SCREEN+ / TB+)
5
  #
6
- # Requirements (requirements.txt):
7
- # gradio
8
- # torch
9
- # torchvision
10
- # timm
11
- # opencv-python
12
- # pillow
13
- # transformers
14
- # einops
15
- # open_clip_torch
16
  #
17
- # HF Spaces notes:
18
- # - weights are expected in ./weights/
19
- # - app launches on 0.0.0.0:7860
 
 
 
 
 
 
 
20
 
21
  import os
22
  import cv2
@@ -30,22 +28,24 @@ import gradio as gr
30
  from torchvision import transforms
31
  from typing import List, Tuple, Dict, Any, Optional
32
 
33
- # RADIO deps (same env as TBNet)
34
  from transformers import AutoModel, CLIPImageProcessor
35
  from einops import rearrange
36
  from PIL import Image
37
 
38
 
39
  # ============================================================
40
- # USER CONFIG (HF Spaces friendly)
41
  # ============================================================
42
 
 
43
  MODEL_NAME_TBNET = "TBNet (CNN model)"
44
  MODEL_NAME_RADIO = "RADIO (visual model)"
45
 
 
46
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
47
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
48
 
 
49
  RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
50
  RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
51
 
@@ -59,13 +59,19 @@ RADIO_THR_RED = 0.23
59
  RADIO_MASKED_MIN_COV = 0.15
60
  RADIO_GATE_DEFAULT = 0.21
61
 
 
62
  TBNET_SCREEN_THR = 0.30
 
 
63
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
 
64
 
 
65
  FAIL_COV = 0.10
66
  WARN_COV = 0.18
67
  FAILSAFE_ON_BAD_MASK = True
68
 
 
69
  FORCE_CPU = True
70
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
71
 
@@ -75,15 +81,15 @@ DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available
75
  # ============================================================
76
  CLINICAL_DISCLAIMER = """
77
  ⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
78
-
79
- This system is for **research/decision support** and is **NOT** a diagnostic device.
80
- It may **miss early/subtle tuberculosis**, including **miliary TB**.
81
- Phone photos / screenshots / downsampled images can reduce reliability.
82
 
83
  If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
84
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
85
  """
86
 
 
87
  REPORT_LABELS = {
88
  "GREEN": {
89
  "title": "LOW TB LIKELIHOOD",
@@ -120,60 +126,20 @@ CLINICAL_GUIDANCE = (
120
 
121
 
122
  # ============================================================
123
- # USER-FRIENDLY SUMMARY BUILDER
124
  # ============================================================
125
- def overall_summary(tb_band: Optional[str],
126
- tb_prob: Optional[float],
127
- radio_primary: Optional[float],
128
- radio_band: Optional[str],
129
- consensus_label: str,
130
- q_score: float,
131
- cov: float,
132
- warnings: List[str]) -> str:
133
- # Overall label from agreement (keeps it simple for users)
134
  if tb_prob is None:
135
- overall_title = "INDETERMINATE — NEEDS REVIEW"
136
- icon = "⚠️"
137
- else:
138
- if "AGREE: LOW" in consensus_label:
139
- overall_title = "LOW TB LIKELIHOOD"
140
- icon = "✅"
141
- elif "AGREE: TB+" in consensus_label:
142
- overall_title = "TB FEATURES SUSPECTED"
143
- icon = "🚩"
144
- elif "AGREE: SCREEN+" in consensus_label:
145
- overall_title = "SCREEN-POSITIVE — REVIEW RECOMMENDED"
146
- icon = "⚠️"
147
- else:
148
- overall_title = "INDETERMINATE — REVIEW RECOMMENDED"
149
- icon = "⚠️"
150
-
151
- reliability = "Good" if (q_score >= 70 and cov >= WARN_COV) else "Limited"
152
- rel_icon = "🟢" if reliability == "Good" else "🟡"
153
-
154
- warn_line = "None" if not warnings else f"{len(warnings)} note(s) below"
155
-
156
- tb_prob_str = "N/A" if tb_prob is None else f"{tb_prob:.4f}"
157
- radio_str = "N/A" if radio_primary is None else f"{radio_primary:.4f}"
158
-
159
- return f"""
160
- ## {icon} Overall screening result: **{overall_title}**
161
-
162
- **Reliability:** {rel_icon} **{reliability}** (Quality: {q_score:.0f}/100 • Lung coverage: {cov*100:.1f}% • Notes: {warn_line})
163
-
164
- ### What this means
165
- - This is a **screening support tool**, not a diagnosis.
166
- - Two models analyze the same image: a **CNN model** (TBNet) and a **visual model** (RADIO).
167
-
168
- ### Model agreement
169
- - **{consensus_label}**
170
- - {MODEL_NAME_TBNET} probability: **{tb_prob_str}**
171
- - {MODEL_NAME_RADIO} probability: **{radio_str}** {f"(band={radio_band})" if radio_band else ""}
172
-
173
- ### What to do next
174
- - If you have symptoms/risk factors, seek clinician/radiologist review.
175
- - If TB is clinically suspected, consider **CBNAAT/GeneXpert** and sputum testing regardless of AI output.
176
- """
177
 
178
 
179
  # ============================================================
@@ -190,9 +156,7 @@ class DoubleConv(nn.Module):
190
  nn.BatchNorm2d(out_c),
191
  nn.ReLU(inplace=True),
192
  )
193
-
194
- def forward(self, x):
195
- return self.net(x)
196
 
197
  class LungUNet(nn.Module):
198
  def __init__(self):
@@ -234,9 +198,7 @@ class TBNet(nn.Module):
234
  super().__init__()
235
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
236
  self.fc = nn.Linear(self.backbone.num_features, 1)
237
-
238
- def forward(self, x):
239
- return self.fc(self.backbone(x)).view(-1)
240
 
241
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
242
  sd = torch.load(ckpt_path, map_location=device)
@@ -250,11 +212,8 @@ class GradCAM:
250
  target_layer.register_forward_hook(self._fwd)
251
  target_layer.register_full_backward_hook(self._bwd)
252
 
253
- def _fwd(self, _, __, out):
254
- self.activ = out
255
-
256
- def _bwd(self, _, grad_in, grad_out):
257
- self.grad = grad_out[0]
258
 
259
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
260
  with torch.enable_grad():
@@ -312,8 +271,7 @@ def border_fraction(gray_u8: np.ndarray) -> float:
312
  bot = gray_u8[-b:, :]
313
  left = gray_u8[:, :b]
314
  right = gray_u8[:, -b:]
315
- def frac_border(x):
316
- return float(((x < 15) | (x > 240)).mean())
317
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
318
 
319
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
@@ -328,34 +286,26 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
328
  sharp = laplacian_sharpness(gray_u8)
329
  lo_clip, hi_clip = exposure_scores(gray_u8)
330
  border = border_fraction(gray_u8)
331
-
332
  likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05)
333
 
334
  if likely_phone:
335
  if sharp < 40:
336
- score -= 25
337
- warnings.append("Blurry / motion blur detected (likely phone capture).")
338
  elif sharp < 80:
339
- score -= 12
340
- warnings.append("Slight blur detected.")
341
  else:
342
  if sharp < 30:
343
- score -= 8
344
- warnings.append("Low fine detail (possible downsampling).")
345
 
346
  if hi_clip > 0.05:
347
- score -= 15
348
- warnings.append("Overexposed highlights (washed-out areas).")
349
  if lo_clip > 0.10:
350
- score -= 12
351
- warnings.append("Underexposed shadows (very dark areas).")
352
 
353
  if border > 0.55:
354
- score -= 18
355
- warnings.append("Large border/margins detected (possible screenshot/phone framing).")
356
  elif border > 0.35:
357
- score -= 10
358
- warnings.append("Some border/margins detected.")
359
 
360
  return float(np.clip(score, 0, 100)), warnings
361
 
@@ -363,22 +313,19 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
363
  g = gray_u8.copy()
364
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
365
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
366
- if th.mean() > 127:
367
- th = 255 - th
368
 
369
  k = max(3, int(0.01 * min(g.shape)))
370
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
371
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
372
 
373
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
374
- if not contours:
375
- return gray_u8
376
 
377
  c = max(contours, key=cv2.contourArea)
378
  x, y, w, h = cv2.boundingRect(c)
379
  H, W = gray_u8.shape
380
- if w * h < 0.20 * (H * W):
381
- return gray_u8
382
 
383
  pad = int(0.03 * min(H, W))
384
  x1 = max(0, x - pad); y1 = max(0, y - pad)
@@ -395,7 +342,6 @@ def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
395
  border = border_fraction(gray_u8)
396
 
397
  g = gray_u8
398
-
399
  if border > 0.35:
400
  cropped = auto_border_crop(g)
401
  if cropped.size >= 0.70 * g.size:
@@ -439,7 +385,7 @@ def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
439
  m = (binary_u8 * 255).astype(np.uint8)
440
  h, w = m.shape
441
  flood = m.copy()
442
- mask = np.zeros((h + 2, w + 2), np.uint8)
443
  cv2.floodFill(flood, mask, (0, 0), 255)
444
  holes = cv2.bitwise_not(flood)
445
  filled = cv2.bitwise_or(m, holes)
@@ -482,7 +428,7 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
482
 
483
  border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
484
  if border.mean() > 0.05:
485
- warns.append("Lung mask touches the image border (possible cropped/non-standard CXR).")
486
 
487
  if total > 0 and (top1 + top2) / total < 0.90:
488
  warns.append("Mask appears fragmented (may reduce reliability).")
@@ -491,14 +437,14 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
491
 
492
  def recommendation_for_band(band: Optional[str]) -> str:
493
  if band in (None, "YELLOW"):
494
- return "Radiologist/clinician review is recommended (result is indeterminate)."
495
  if band == "RED":
496
- return "Urgent clinical review + microbiological confirmation (CBNAAT/GeneXpert, sputum) recommended."
497
- return "If symptoms/risk factors exist, clinical correlation is advised."
498
 
499
 
500
  # ============================================================
501
- # AGREEMENT LOGIC (TBNet vs RADIO)
502
  # ============================================================
503
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
504
  if tb_band == "RED":
@@ -523,7 +469,7 @@ def build_consensus(
523
  ) -> Tuple[str, str]:
524
 
525
  if tb_prob is None or tb_band is None:
526
- return ("N/A", f"{MODEL_NAME_TBNET} not available (lung segmentation failed / fail-safe).")
527
 
528
  if radio_masked is not None:
529
  radio_primary = radio_masked
@@ -533,27 +479,27 @@ def build_consensus(
533
  radio_used = "RAW"
534
 
535
  if radio_primary is None:
536
- return ("TBNet only", f"{MODEL_NAME_RADIO} not available → {MODEL_NAME_TBNET}={tb_prob:.4f} (band={tb_band}).")
537
 
538
  t = tbnet_state(tb_prob, tb_band)
539
  r = radio_state_from_prob(radio_primary)
540
- rb = f" (band={radio_band})" if radio_band else ""
541
 
542
  if t == r:
543
  return (
544
  f"AGREE: {t}",
545
- f"Both models indicate **{t}**. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
546
  )
547
 
548
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
549
  return (
550
  "DISAGREE",
551
- f"Models disagree: {MODEL_NAME_TBNET} suggests **{t}** vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
552
  )
553
 
554
  return (
555
  "MIXED/INDET",
556
- f"Mixed signals: {MODEL_NAME_TBNET} suggests **{t}** vs {MODEL_NAME_RADIO} suggests **{r}** ({radio_used})={radio_primary:.4f}{rb}."
557
  )
558
 
559
 
@@ -612,7 +558,6 @@ class RadioMLPHead(nn.Module):
612
  nn.Dropout(dropout),
613
  nn.Linear(hidden, 1),
614
  )
615
-
616
  def forward(self, x: torch.Tensor) -> torch.Tensor:
617
  return self.net(x).squeeze(1)
618
 
@@ -796,6 +741,7 @@ def analyze_one_image(
796
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
797
 
798
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
 
799
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
800
  k = max(3, int(0.02 * 256))
801
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
@@ -812,14 +758,11 @@ def analyze_one_image(
812
  "logit": None,
813
  "pred": "INDETERMINATE",
814
  "band": "YELLOW",
815
- "band_text": (
816
- "⚠️ Lung segmentation looks unreliable, so the TBNet screening score is disabled for safety.\n\n"
817
- "Please use a clearer standard frontal CXR (PA/AP) or seek radiologist review."
818
- ),
819
  "quality_score": float(q_score),
820
  "diffuse_risk": False,
821
  "warnings": (
822
- [f"Lung segmentation coverage too small ({coverage*100:.1f}%)."]
823
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
824
  + q_warn
825
  ),
@@ -842,11 +785,7 @@ def analyze_one_image(
842
  "logit": None,
843
  "pred": "INDETERMINATE",
844
  "band": "YELLOW",
845
- "band_text": (
846
- "⚠️ The image appears cropped/non-standard (mask sanity check). "
847
- "TBNet screening score is disabled for safety.\n\n"
848
- "Please use a standard frontal CXR (PA/AP) or seek radiologist review."
849
- ),
850
  "quality_score": float(q_score),
851
  "diffuse_risk": False,
852
  "warnings": (
@@ -883,38 +822,28 @@ def analyze_one_image(
883
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
884
  band_base, _ = confidence_band(prob_tb, q_score, diffuse)
885
 
886
- allow_red = (prob_tb >= 0.70 and q_score >= 55 and (not diffuse) and coverage >= warn_cov)
887
  band = "RED" if allow_red else band_base
888
 
889
  pred = REPORT_LABELS[band]["title"]
890
  band_text = REPORT_LABELS[band]["summary"]
891
 
892
- if band == "YELLOW" and prob_tb < 0.05:
893
- band_text = (
894
- "⚠️ TB probability is very low, but the result is marked **indeterminate** because reliability is limited.\n\n"
895
- + band_text
896
- )
897
-
898
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
899
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
900
 
901
  overlay_annotated = overlay_clean.copy()
902
  text1 = f"{band}: {pred}"
903
- text2 = f"TBNet prob={prob_tb:.3f} | Quality={q_score:.0f}/100 | Lung coverage={coverage*100:.1f}%"
904
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
905
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
906
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
907
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
908
 
909
  warnings = []
910
- if phone_mode:
911
- warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
912
- if q_score < 55:
913
- warnings.append("Image quality is low; reliability may be reduced.")
914
- if coverage < warn_cov:
915
- warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
916
- if diffuse:
917
- warnings.append("Non-focal attention pattern; result treated cautiously.")
918
  warnings.extend(q_warn)
919
 
920
  return {
@@ -969,7 +898,7 @@ def run_analysis(
969
 
970
  img = cv2.imread(path, cv2.IMREAD_COLOR)
971
  if img is None:
972
- rows.append([name, "", "SKIP", "", "Unreadable image", "", "", "", "", ""])
973
  continue
974
 
975
  out = analyze_one_image(
@@ -983,16 +912,16 @@ def run_analysis(
983
  )
984
 
985
  # RADIO (optional)
986
- radio_text = f"{MODEL_NAME_RADIO} is disabled."
987
  radio_raw_overlay = None
988
  radio_masked_overlay = None
 
989
  radio_raw_val: Optional[float] = None
990
  radio_masked_val: Optional[float] = None
991
  radio_primary_val: Optional[float] = None
992
  radio_band: Optional[str] = None
993
 
994
- radio_raw_str = ""
995
- radio_masked_str = ""
996
 
997
  if use_radio and out["prob"] is not None:
998
  try:
@@ -1007,20 +936,18 @@ def run_analysis(
1007
  radio_primary_val = float(r["prob_primary"])
1008
  radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
1009
  radio_band = str(r["band"])
1010
-
1011
- radio_raw_str = f"{radio_raw_val:.4f}"
1012
- radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
1013
 
1014
  radio_text = (
1015
- f"**{MODEL_NAME_RADIO} result:** {r['pred']} | "
1016
- f"PRIMARY={radio_primary_val:.4f} | RAW={radio_raw_val:.4f}"
1017
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
1018
- + f" | Band={radio_band}"
1019
  )
1020
  radio_raw_overlay = r["raw_overlay"]
1021
  radio_masked_overlay = r["masked_overlay"]
1022
  except Exception as e:
1023
  radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
 
1024
  radio_raw_val = None
1025
  radio_masked_val = None
1026
  radio_primary_val = None
@@ -1034,24 +961,24 @@ def run_analysis(
1034
  radio_band=radio_band,
1035
  )
1036
 
1037
- # Table row
1038
- prob_str = "" if out["prob"] is None else f"{out['prob']:.4f}"
1039
- cov_str = f"{out.get('lung_coverage', 0.0) * 100:.1f}%"
 
1040
 
1041
  rows.append([
1042
  name,
1043
- prob_str,
1044
  out["pred"],
1045
- out["band"],
1046
- out["band_text"],
 
1047
  f"{out['quality_score']:.0f}",
1048
- cov_str,
1049
- radio_raw_str,
1050
- radio_masked_str,
1051
  consensus_label,
1052
  ])
1053
 
1054
- # Gallery
1055
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1056
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1057
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
@@ -1060,11 +987,9 @@ def run_analysis(
1060
  gallery_items.append((orig_rgb, f"{name} • ORIGINAL"))
1061
  gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT"))
1062
  gallery_items.append((mask_overlay, f"{name} • Lung mask overlay"))
1063
-
1064
  if out["proc_gray"] is not None:
1065
  proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
1066
  gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
1067
-
1068
  gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay ({MODEL_NAME_TBNET})"))
1069
 
1070
  if radio_raw_overlay is not None:
@@ -1072,56 +997,35 @@ def run_analysis(
1072
  if radio_masked_overlay is not None:
1073
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
1074
 
1075
- # Details (dashboard style)
1076
- summary_md = overall_summary(
1077
- tb_band=out.get("band"),
1078
- tb_prob=out.get("prob"),
1079
- radio_primary=radio_primary_val,
1080
- radio_band=radio_band,
1081
- consensus_label=consensus_label,
1082
- q_score=float(out["quality_score"]),
1083
- cov=float(out.get("lung_coverage", 0.0)),
1084
- warnings=out.get("warnings", []),
1085
- )
1086
-
1087
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1088
- tb_line = "N/A (disabled by fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1089
  rec_line = recommendation_for_band(out.get("band"))
1090
 
1091
  details_md.append(
1092
- f"""{summary_md}
1093
 
1094
- ---
 
 
1095
 
1096
- <details>
1097
- <summary><b>{MODEL_NAME_TBNET} details</b></summary>
1098
 
1099
- - **Result:** {out['pred']} ({out['band']})
1100
- - **Recommendation:** {rec_line}
1101
- - **Probability (screening score):** {tb_line}
1102
- - **Attention pattern:** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
1103
 
1104
- </details>
1105
-
1106
- <details>
1107
- <summary><b>{MODEL_NAME_RADIO} details</b></summary>
1108
-
1109
- {radio_text}
1110
-
1111
- </details>
1112
-
1113
- <details>
1114
- <summary><b>Image quality & segmentation</b></summary>
1115
-
1116
- - **Quality score:** {out['quality_score']:.0f}/100
1117
- - **Lung mask coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
1118
 
1119
  **Notes that may affect reliability**
1120
  {warn_txt}
1121
 
1122
- </details>
 
1123
 
1124
- ---
 
1125
 
1126
  **Clinical guidance**
1127
  {CLINICAL_GUIDANCE}
@@ -1144,10 +1048,10 @@ def build_ui():
1144
  """
1145
 
1146
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1147
- gr.Markdown('<div class="title">TB X-ray Assistant (Research Use)</div>')
1148
  gr.Markdown(
1149
  f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1150
- f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • User-friendly summary</div>"
1151
  )
1152
 
1153
  with gr.Row():
@@ -1161,7 +1065,7 @@ def build_ui():
1161
 
1162
  threshold = gr.Slider(
1163
  0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1164
- label=f"Reference threshold (TBNet SCREEN+) = {TBNET_SCREEN_THR:.2f}"
1165
  )
1166
 
1167
  phone_mode = gr.Checkbox(
@@ -1186,10 +1090,8 @@ def build_ui():
1186
 
1187
  with gr.Column(scale=2):
1188
  gr.Markdown("#### Upload images")
1189
- files = gr.Files(
1190
- label="Upload one or multiple chest X-ray images",
1191
- file_types=[".png", ".jpg", ".jpeg", ".bmp"]
1192
- )
1193
  run_btn = gr.Button("Run Analysis", variant="primary")
1194
  status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1195
 
@@ -1197,17 +1099,16 @@ def build_ui():
1197
  table = gr.Dataframe(
1198
  headers=[
1199
  "Image",
1200
- "TBNet Probability",
1201
  "TBNet Result",
1202
- "Band",
1203
- "Meaning",
 
1204
  "Quality",
1205
  "LungCov",
1206
- "RADIO RAW",
1207
- "RADIO MASKED",
1208
- "AGREEMENT",
1209
  ],
1210
- datatype=["str","str","str","str","str","str","str","str","str","str"],
1211
  interactive=False,
1212
  label="Results"
1213
  )
 
1
  # app.py
2
+ # Gradio — TBNet + Lung U-Net Auto Mask + Grad-CAM + RADIO
3
  # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
4
+ # + 3-STATE CONSENSUS (LOW / INDET / SCREEN+)
5
  #
6
+ # HF Spaces: use relative weight paths (edit below if needed)
 
 
 
 
 
 
 
 
 
7
  #
8
+ # Requirements (requirements.txt):
9
+ # gradio
10
+ # torch
11
+ # torchvision
12
+ # timm
13
+ # opencv-python
14
+ # pillow
15
+ # transformers
16
+ # einops
17
+ # open_clip_torch
18
 
19
  import os
20
  import cv2
 
28
  from torchvision import transforms
29
  from typing import List, Tuple, Dict, Any, Optional
30
 
 
31
  from transformers import AutoModel, CLIPImageProcessor
32
  from einops import rearrange
33
  from PIL import Image
34
 
35
 
36
  # ============================================================
37
+ # USER CONFIG
38
  # ============================================================
39
 
40
+ # ---- Friendly names (UI) ----
41
  MODEL_NAME_TBNET = "TBNet (CNN model)"
42
  MODEL_NAME_RADIO = "RADIO (visual model)"
43
 
44
+ # ---- Default TB/Lung weights (HF-friendly relative paths) ----
45
  DEFAULT_TB_WEIGHTS = "weights/best.pt"
46
  DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
47
 
48
+ # ---- RADIO config (same env as TB) ----
49
  RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
50
  RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
51
 
 
59
  RADIO_MASKED_MIN_COV = 0.15
60
  RADIO_GATE_DEFAULT = 0.21
61
 
62
+ # ---- Consensus logic thresholds ----
63
  TBNET_SCREEN_THR = 0.30
64
+ TBNET_MARGIN = 0.03 # (kept for compatibility / future use)
65
+
66
  RADIO_SCREEN_THR = RADIO_THR_SCREEN
67
+ RADIO_MARGIN = 0.02 # (kept for compatibility / future use)
68
 
69
+ # ---- Mask fail-safes ----
70
  FAIL_COV = 0.10
71
  WARN_COV = 0.18
72
  FAILSAFE_ON_BAD_MASK = True
73
 
74
+ # ---- Device policy ----
75
  FORCE_CPU = True
76
  DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
77
 
 
81
  # ============================================================
82
  CLINICAL_DISCLAIMER = """
83
  ⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
84
+ This AI system is for **research/decision support** and is NOT a diagnostic device.
85
+ It may NOT reliably detect early/subtle tuberculosis, including **MILIARY TB**,
86
+ which can appear near-normal or subtle on chest X-ray (especially on phone photos / WhatsApp images).
 
87
 
88
  If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
89
  recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
90
  """
91
 
92
+ # Friendly labels still map to GREEN/YELLOW/RED logic
93
  REPORT_LABELS = {
94
  "GREEN": {
95
  "title": "LOW TB LIKELIHOOD",
 
126
 
127
 
128
  # ============================================================
129
+ # NEW: OVERALL LABEL FOR TABLE (user-friendly)
130
  # ============================================================
131
+ def overall_label_from_consensus(consensus_label: str, tb_prob: Optional[float]) -> str:
 
 
 
 
 
 
 
 
132
  if tb_prob is None:
133
+ return "⚠️ INDETERMINATE"
134
+ if "AGREE: LOW" in consensus_label:
135
+ return "✅ LOW"
136
+ if "AGREE: SCREEN+" in consensus_label:
137
+ return "⚠️ SCREEN+"
138
+ if "AGREE: TB+" in consensus_label:
139
+ return "🚩 TB+"
140
+ if "DISAGREE" in consensus_label:
141
+ return "⚠️ DISAGREE"
142
+ return "⚠️ INDET"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
 
145
  # ============================================================
 
156
  nn.BatchNorm2d(out_c),
157
  nn.ReLU(inplace=True),
158
  )
159
+ def forward(self, x): return self.net(x)
 
 
160
 
161
  class LungUNet(nn.Module):
162
  def __init__(self):
 
198
  super().__init__()
199
  self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
200
  self.fc = nn.Linear(self.backbone.num_features, 1)
201
+ def forward(self, x): return self.fc(self.backbone(x)).view(-1)
 
 
202
 
203
  def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
204
  sd = torch.load(ckpt_path, map_location=device)
 
212
  target_layer.register_forward_hook(self._fwd)
213
  target_layer.register_full_backward_hook(self._bwd)
214
 
215
+ def _fwd(self, _, __, out): self.activ = out
216
+ def _bwd(self, _, grad_in, grad_out): self.grad = grad_out[0]
 
 
 
217
 
218
  def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
219
  with torch.enable_grad():
 
271
  bot = gray_u8[-b:, :]
272
  left = gray_u8[:, :b]
273
  right = gray_u8[:, -b:]
274
+ def frac_border(x): return float(((x < 15) | (x > 240)).mean())
 
275
  return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
276
 
277
  def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
 
286
  sharp = laplacian_sharpness(gray_u8)
287
  lo_clip, hi_clip = exposure_scores(gray_u8)
288
  border = border_fraction(gray_u8)
 
289
  likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05)
290
 
291
  if likely_phone:
292
  if sharp < 40:
293
+ score -= 25; warnings.append("Blurry / motion blur detected (likely phone capture).")
 
294
  elif sharp < 80:
295
+ score -= 12; warnings.append("Slight blur detected.")
 
296
  else:
297
  if sharp < 30:
298
+ score -= 8; warnings.append("Low fine detail (possible downsampling).")
 
299
 
300
  if hi_clip > 0.05:
301
+ score -= 15; warnings.append("Overexposed highlights (washed-out areas).")
 
302
  if lo_clip > 0.10:
303
+ score -= 12; warnings.append("Underexposed shadows (very dark areas).")
 
304
 
305
  if border > 0.55:
306
+ score -= 18; warnings.append("Large border/margins detected (possible screenshot/phone framing).")
 
307
  elif border > 0.35:
308
+ score -= 10; warnings.append("Some border/margins detected.")
 
309
 
310
  return float(np.clip(score, 0, 100)), warnings
311
 
 
313
  g = gray_u8.copy()
314
  g_blur = cv2.GaussianBlur(g, (5, 5), 0)
315
  _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
316
+ if th.mean() > 127: th = 255 - th
 
317
 
318
  k = max(3, int(0.01 * min(g.shape)))
319
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
320
  th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
321
 
322
  contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
323
+ if not contours: return gray_u8
 
324
 
325
  c = max(contours, key=cv2.contourArea)
326
  x, y, w, h = cv2.boundingRect(c)
327
  H, W = gray_u8.shape
328
+ if w * h < 0.20 * (H * W): return gray_u8
 
329
 
330
  pad = int(0.03 * min(H, W))
331
  x1 = max(0, x - pad); y1 = max(0, y - pad)
 
342
  border = border_fraction(gray_u8)
343
 
344
  g = gray_u8
 
345
  if border > 0.35:
346
  cropped = auto_border_crop(g)
347
  if cropped.size >= 0.70 * g.size:
 
385
  m = (binary_u8 * 255).astype(np.uint8)
386
  h, w = m.shape
387
  flood = m.copy()
388
+ mask = np.zeros((h+2, w+2), np.uint8)
389
  cv2.floodFill(flood, mask, (0, 0), 255)
390
  holes = cv2.bitwise_not(flood)
391
  filled = cv2.bitwise_or(m, holes)
 
428
 
429
  border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
430
  if border.mean() > 0.05:
431
+ warns.append("Lung mask touches image border (possible cropped/non-standard CXR).")
432
 
433
  if total > 0 and (top1 + top2) / total < 0.90:
434
  warns.append("Mask appears fragmented (may reduce reliability).")
 
437
 
438
  def recommendation_for_band(band: Optional[str]) -> str:
439
  if band in (None, "YELLOW"):
440
+ return "✅ Recommendation: Radiologist/clinician review is recommended (result is indeterminate)."
441
  if band == "RED":
442
+ return "✅ Recommendation: Urgent clinician/radiologist review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
443
+ return "✅ Recommendation: If symptoms/risk factors exist, clinician/radiologist correlation is advised."
444
 
445
 
446
  # ============================================================
447
+ # CONSENSUS LOGIC (TBNet vs RADIO) — 3-state
448
  # ============================================================
449
  def tbnet_state(tb_prob: float, tb_band: str) -> str:
450
  if tb_band == "RED":
 
469
  ) -> Tuple[str, str]:
470
 
471
  if tb_prob is None or tb_band is None:
472
+ return ("N/A", f"{MODEL_NAME_TBNET} unavailable (lung segmentation failed / fail-safe).")
473
 
474
  if radio_masked is not None:
475
  radio_primary = radio_masked
 
479
  radio_used = "RAW"
480
 
481
  if radio_primary is None:
482
+ return ("TBNet only", f"{MODEL_NAME_RADIO} unavailable → {MODEL_NAME_TBNET}={tb_prob:.4f} (band={tb_band}).")
483
 
484
  t = tbnet_state(tb_prob, tb_band)
485
  r = radio_state_from_prob(radio_primary)
486
+ rb = f" (RADIO band={radio_band})" if radio_band else ""
487
 
488
  if t == r:
489
  return (
490
  f"AGREE: {t}",
491
+ f"Both: {t}. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
492
  )
493
 
494
  if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
495
  return (
496
  "DISAGREE",
497
+ f"Strong disagreement: {MODEL_NAME_TBNET}={t} (band={tb_band}) vs {MODEL_NAME_RADIO}={r} ({radio_used})={radio_primary:.4f}{rb}."
498
  )
499
 
500
  return (
501
  "MIXED/INDET",
502
+ f"Mixed/uncertain: {MODEL_NAME_TBNET}={t} (band={tb_band}) vs {MODEL_NAME_RADIO}={r} ({radio_used})={radio_primary:.4f}{rb}."
503
  )
504
 
505
 
 
558
  nn.Dropout(dropout),
559
  nn.Linear(hidden, 1),
560
  )
 
561
  def forward(self, x: torch.Tensor) -> torch.Tensor:
562
  return self.net(x).squeeze(1)
563
 
 
741
  mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
742
 
743
  mask256_bin = (mask256 > 0.5).astype(np.uint8)
744
+
745
  mask256_bin = keep_top_k_components(mask256_bin, k=2)
746
  k = max(3, int(0.02 * 256))
747
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
 
758
  "logit": None,
759
  "pred": "INDETERMINATE",
760
  "band": "YELLOW",
761
+ "band_text": "Lung segmentation failed. TB scoring disabled (fail-safe).",
 
 
 
762
  "quality_score": float(q_score),
763
  "diffuse_risk": False,
764
  "warnings": (
765
+ ["Lung segmentation failed (<10% lung area).", f"Lung coverage: {coverage*100:.1f}%"]
766
  + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
767
  + q_warn
768
  ),
 
785
  "logit": None,
786
  "pred": "INDETERMINATE",
787
  "band": "YELLOW",
788
+ "band_text": "Non-standard/cropped view or unreliable lung segmentation. TB scoring disabled (fail-safe).",
 
 
 
 
789
  "quality_score": float(q_score),
790
  "diffuse_risk": False,
791
  "warnings": (
 
822
  diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
823
  band_base, _ = confidence_band(prob_tb, q_score, diffuse)
824
 
825
+ allow_red = (prob_tb >= 0.70 and q_score >= 55 and not diffuse and coverage >= warn_cov)
826
  band = "RED" if allow_red else band_base
827
 
828
  pred = REPORT_LABELS[band]["title"]
829
  band_text = REPORT_LABELS[band]["summary"]
830
 
 
 
 
 
 
 
831
  heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
832
  overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
833
 
834
  overlay_annotated = overlay_clean.copy()
835
  text1 = f"{band}: {pred}"
836
+ text2 = f"TB prob={prob_tb:.3f} | Quality={q_score:.0f}/100 | Lung coverage={coverage*100:.1f}%"
837
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
838
  cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
839
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
840
  cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
841
 
842
  warnings = []
843
+ if phone_mode: warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
844
+ if q_score < 55: warnings.append("Suboptimal image quality limits AI reliability.")
845
+ if coverage < warn_cov: warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
846
+ if diffuse: warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
 
 
 
 
847
  warnings.extend(q_warn)
848
 
849
  return {
 
898
 
899
  img = cv2.imread(path, cv2.IMREAD_COLOR)
900
  if img is None:
901
+ rows.append([name, "⚠️", "SKIP", "", "Unreadable image", "", "", "", ""])
902
  continue
903
 
904
  out = analyze_one_image(
 
912
  )
913
 
914
  # RADIO (optional)
915
+ radio_text = f"{MODEL_NAME_RADIO} disabled."
916
  radio_raw_overlay = None
917
  radio_masked_overlay = None
918
+
919
  radio_raw_val: Optional[float] = None
920
  radio_masked_val: Optional[float] = None
921
  radio_primary_val: Optional[float] = None
922
  radio_band: Optional[str] = None
923
 
924
+ radio_result_short = "Disabled"
 
925
 
926
  if use_radio and out["prob"] is not None:
927
  try:
 
936
  radio_primary_val = float(r["prob_primary"])
937
  radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
938
  radio_band = str(r["band"])
939
+ radio_result_short = str(r["pred"])
 
 
940
 
941
  radio_text = (
942
+ f"**{MODEL_NAME_RADIO}:** {r['pred']} | PRIMARY={radio_primary_val:.4f} | RAW={radio_raw_val:.4f}"
 
943
  + (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
944
+ + (f" | Band={radio_band}" if radio_band else "")
945
  )
946
  radio_raw_overlay = r["raw_overlay"]
947
  radio_masked_overlay = r["masked_overlay"]
948
  except Exception as e:
949
  radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
950
+ radio_result_short = "Error"
951
  radio_raw_val = None
952
  radio_masked_val = None
953
  radio_primary_val = None
 
961
  radio_band=radio_band,
962
  )
963
 
964
+ overall = overall_label_from_consensus(consensus_label, out["prob"])
965
+
966
+ tb_prob_str = "" if out["prob"] is None else f"{out['prob']:.4f}"
967
+ radio_prob_primary_str = "" if radio_primary_val is None else f"{radio_primary_val:.4f}"
968
 
969
  rows.append([
970
  name,
971
+ overall,
972
  out["pred"],
973
+ tb_prob_str,
974
+ radio_result_short,
975
+ radio_prob_primary_str,
976
  f"{out['quality_score']:.0f}",
977
+ f"{out.get('lung_coverage', 0.0) * 100:.1f}%",
 
 
978
  consensus_label,
979
  ])
980
 
981
+ # Visual outputs
982
  orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
983
  vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
984
  mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
 
987
  gallery_items.append((orig_rgb, f"{name} • ORIGINAL"))
988
  gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT"))
989
  gallery_items.append((mask_overlay, f"{name} • Lung mask overlay"))
 
990
  if out["proc_gray"] is not None:
991
  proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
992
  gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
 
993
  gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay ({MODEL_NAME_TBNET})"))
994
 
995
  if radio_raw_overlay is not None:
 
997
  if radio_masked_overlay is not None:
998
  gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
999
 
1000
+ # Details panel
 
 
 
 
 
 
 
 
 
 
 
1001
  warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
1002
+ tb_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
1003
  rec_line = recommendation_for_band(out.get("band"))
1004
 
1005
  details_md.append(
1006
+ f"""### {name}
1007
 
1008
+ **Overall:** {overall}
1009
+ **{MODEL_NAME_TBNET} result:** **{out['pred']}**
1010
+ {rec_line}
1011
 
1012
+ **{MODEL_NAME_TBNET} probability:** {tb_line}
 
1013
 
1014
+ **Interpretation**
1015
+ {out['band_text']}
 
 
1016
 
1017
+ **Image quality:** {out['quality_score']:.0f}/100
1018
+ **Lung mask coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
1019
+ **Attention pattern (TBNet):** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
 
 
 
 
 
 
 
 
 
 
 
1020
 
1021
  **Notes that may affect reliability**
1022
  {warn_txt}
1023
 
1024
+ **{MODEL_NAME_RADIO} output**
1025
+ {radio_text}
1026
 
1027
+ **Agreement between models:** **{consensus_label}**
1028
+ - {consensus_detail}
1029
 
1030
  **Clinical guidance**
1031
  {CLINICAL_GUIDANCE}
 
1048
  """
1049
 
1050
  with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
1051
+ gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
1052
  gr.Markdown(
1053
  f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
1054
+ f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Agreement summary</div>"
1055
  )
1056
 
1057
  with gr.Row():
 
1065
 
1066
  threshold = gr.Slider(
1067
  0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
1068
+ label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
1069
  )
1070
 
1071
  phone_mode = gr.Checkbox(
 
1090
 
1091
  with gr.Column(scale=2):
1092
  gr.Markdown("#### Upload images")
1093
+ files = gr.Files(label="Upload one or multiple X-ray images",
1094
+ file_types=[".png", ".jpg", ".jpeg", ".bmp"])
 
 
1095
  run_btn = gr.Button("Run Analysis", variant="primary")
1096
  status = gr.Textbox(label="Status", value="Ready.", interactive=False)
1097
 
 
1099
  table = gr.Dataframe(
1100
  headers=[
1101
  "Image",
1102
+ "OVERALL",
1103
  "TBNet Result",
1104
+ "TBNet Prob",
1105
+ "RADIO Result",
1106
+ "RADIO Prob (Primary)",
1107
  "Quality",
1108
  "LungCov",
1109
+ "Agreement",
 
 
1110
  ],
1111
+ datatype=["str","str","str","str","str","str","str","str","str"],
1112
  interactive=False,
1113
  label="Results"
1114
  )