Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,22 +1,20 @@
|
|
| 1 |
# app.py
|
| 2 |
-
# Gradio — TBNet
|
| 3 |
# + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
|
| 4 |
-
# + 3-STATE
|
| 5 |
#
|
| 6 |
-
#
|
| 7 |
-
# gradio
|
| 8 |
-
# torch
|
| 9 |
-
# torchvision
|
| 10 |
-
# timm
|
| 11 |
-
# opencv-python
|
| 12 |
-
# pillow
|
| 13 |
-
# transformers
|
| 14 |
-
# einops
|
| 15 |
-
# open_clip_torch
|
| 16 |
#
|
| 17 |
-
#
|
| 18 |
-
#
|
| 19 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
import os
|
| 22 |
import cv2
|
|
@@ -30,22 +28,24 @@ import gradio as gr
|
|
| 30 |
from torchvision import transforms
|
| 31 |
from typing import List, Tuple, Dict, Any, Optional
|
| 32 |
|
| 33 |
-
# RADIO deps (same env as TBNet)
|
| 34 |
from transformers import AutoModel, CLIPImageProcessor
|
| 35 |
from einops import rearrange
|
| 36 |
from PIL import Image
|
| 37 |
|
| 38 |
|
| 39 |
# ============================================================
|
| 40 |
-
# USER CONFIG
|
| 41 |
# ============================================================
|
| 42 |
|
|
|
|
| 43 |
MODEL_NAME_TBNET = "TBNet (CNN model)"
|
| 44 |
MODEL_NAME_RADIO = "RADIO (visual model)"
|
| 45 |
|
|
|
|
| 46 |
DEFAULT_TB_WEIGHTS = "weights/best.pt"
|
| 47 |
DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
|
| 48 |
|
|
|
|
| 49 |
RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
|
| 50 |
RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
|
| 51 |
|
|
@@ -59,13 +59,19 @@ RADIO_THR_RED = 0.23
|
|
| 59 |
RADIO_MASKED_MIN_COV = 0.15
|
| 60 |
RADIO_GATE_DEFAULT = 0.21
|
| 61 |
|
|
|
|
| 62 |
TBNET_SCREEN_THR = 0.30
|
|
|
|
|
|
|
| 63 |
RADIO_SCREEN_THR = RADIO_THR_SCREEN
|
|
|
|
| 64 |
|
|
|
|
| 65 |
FAIL_COV = 0.10
|
| 66 |
WARN_COV = 0.18
|
| 67 |
FAILSAFE_ON_BAD_MASK = True
|
| 68 |
|
|
|
|
| 69 |
FORCE_CPU = True
|
| 70 |
DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 71 |
|
|
@@ -75,15 +81,15 @@ DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available
|
|
| 75 |
# ============================================================
|
| 76 |
CLINICAL_DISCLAIMER = """
|
| 77 |
⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
Phone photos / screenshots / downsampled images can reduce reliability.
|
| 82 |
|
| 83 |
If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
|
| 84 |
recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
|
| 85 |
"""
|
| 86 |
|
|
|
|
| 87 |
REPORT_LABELS = {
|
| 88 |
"GREEN": {
|
| 89 |
"title": "LOW TB LIKELIHOOD",
|
|
@@ -120,60 +126,20 @@ CLINICAL_GUIDANCE = (
|
|
| 120 |
|
| 121 |
|
| 122 |
# ============================================================
|
| 123 |
-
#
|
| 124 |
# ============================================================
|
| 125 |
-
def
|
| 126 |
-
tb_prob: Optional[float],
|
| 127 |
-
radio_primary: Optional[float],
|
| 128 |
-
radio_band: Optional[str],
|
| 129 |
-
consensus_label: str,
|
| 130 |
-
q_score: float,
|
| 131 |
-
cov: float,
|
| 132 |
-
warnings: List[str]) -> str:
|
| 133 |
-
# Overall label from agreement (keeps it simple for users)
|
| 134 |
if tb_prob is None:
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
overall_title = "SCREEN-POSITIVE — REVIEW RECOMMENDED"
|
| 146 |
-
icon = "⚠️"
|
| 147 |
-
else:
|
| 148 |
-
overall_title = "INDETERMINATE — REVIEW RECOMMENDED"
|
| 149 |
-
icon = "⚠️"
|
| 150 |
-
|
| 151 |
-
reliability = "Good" if (q_score >= 70 and cov >= WARN_COV) else "Limited"
|
| 152 |
-
rel_icon = "🟢" if reliability == "Good" else "🟡"
|
| 153 |
-
|
| 154 |
-
warn_line = "None" if not warnings else f"{len(warnings)} note(s) below"
|
| 155 |
-
|
| 156 |
-
tb_prob_str = "N/A" if tb_prob is None else f"{tb_prob:.4f}"
|
| 157 |
-
radio_str = "N/A" if radio_primary is None else f"{radio_primary:.4f}"
|
| 158 |
-
|
| 159 |
-
return f"""
|
| 160 |
-
## {icon} Overall screening result: **{overall_title}**
|
| 161 |
-
|
| 162 |
-
**Reliability:** {rel_icon} **{reliability}** (Quality: {q_score:.0f}/100 • Lung coverage: {cov*100:.1f}% • Notes: {warn_line})
|
| 163 |
-
|
| 164 |
-
### What this means
|
| 165 |
-
- This is a **screening support tool**, not a diagnosis.
|
| 166 |
-
- Two models analyze the same image: a **CNN model** (TBNet) and a **visual model** (RADIO).
|
| 167 |
-
|
| 168 |
-
### Model agreement
|
| 169 |
-
- **{consensus_label}**
|
| 170 |
-
- {MODEL_NAME_TBNET} probability: **{tb_prob_str}**
|
| 171 |
-
- {MODEL_NAME_RADIO} probability: **{radio_str}** {f"(band={radio_band})" if radio_band else ""}
|
| 172 |
-
|
| 173 |
-
### What to do next
|
| 174 |
-
- If you have symptoms/risk factors, seek clinician/radiologist review.
|
| 175 |
-
- If TB is clinically suspected, consider **CBNAAT/GeneXpert** and sputum testing regardless of AI output.
|
| 176 |
-
"""
|
| 177 |
|
| 178 |
|
| 179 |
# ============================================================
|
|
@@ -190,9 +156,7 @@ class DoubleConv(nn.Module):
|
|
| 190 |
nn.BatchNorm2d(out_c),
|
| 191 |
nn.ReLU(inplace=True),
|
| 192 |
)
|
| 193 |
-
|
| 194 |
-
def forward(self, x):
|
| 195 |
-
return self.net(x)
|
| 196 |
|
| 197 |
class LungUNet(nn.Module):
|
| 198 |
def __init__(self):
|
|
@@ -234,9 +198,7 @@ class TBNet(nn.Module):
|
|
| 234 |
super().__init__()
|
| 235 |
self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
|
| 236 |
self.fc = nn.Linear(self.backbone.num_features, 1)
|
| 237 |
-
|
| 238 |
-
def forward(self, x):
|
| 239 |
-
return self.fc(self.backbone(x)).view(-1)
|
| 240 |
|
| 241 |
def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
|
| 242 |
sd = torch.load(ckpt_path, map_location=device)
|
|
@@ -250,11 +212,8 @@ class GradCAM:
|
|
| 250 |
target_layer.register_forward_hook(self._fwd)
|
| 251 |
target_layer.register_full_backward_hook(self._bwd)
|
| 252 |
|
| 253 |
-
def _fwd(self, _, __, out):
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
def _bwd(self, _, grad_in, grad_out):
|
| 257 |
-
self.grad = grad_out[0]
|
| 258 |
|
| 259 |
def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
|
| 260 |
with torch.enable_grad():
|
|
@@ -312,8 +271,7 @@ def border_fraction(gray_u8: np.ndarray) -> float:
|
|
| 312 |
bot = gray_u8[-b:, :]
|
| 313 |
left = gray_u8[:, :b]
|
| 314 |
right = gray_u8[:, -b:]
|
| 315 |
-
def frac_border(x):
|
| 316 |
-
return float(((x < 15) | (x > 240)).mean())
|
| 317 |
return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
|
| 318 |
|
| 319 |
def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
|
@@ -328,34 +286,26 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
|
| 328 |
sharp = laplacian_sharpness(gray_u8)
|
| 329 |
lo_clip, hi_clip = exposure_scores(gray_u8)
|
| 330 |
border = border_fraction(gray_u8)
|
| 331 |
-
|
| 332 |
likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05)
|
| 333 |
|
| 334 |
if likely_phone:
|
| 335 |
if sharp < 40:
|
| 336 |
-
score -= 25
|
| 337 |
-
warnings.append("Blurry / motion blur detected (likely phone capture).")
|
| 338 |
elif sharp < 80:
|
| 339 |
-
score -= 12
|
| 340 |
-
warnings.append("Slight blur detected.")
|
| 341 |
else:
|
| 342 |
if sharp < 30:
|
| 343 |
-
score -= 8
|
| 344 |
-
warnings.append("Low fine detail (possible downsampling).")
|
| 345 |
|
| 346 |
if hi_clip > 0.05:
|
| 347 |
-
score -= 15
|
| 348 |
-
warnings.append("Overexposed highlights (washed-out areas).")
|
| 349 |
if lo_clip > 0.10:
|
| 350 |
-
score -= 12
|
| 351 |
-
warnings.append("Underexposed shadows (very dark areas).")
|
| 352 |
|
| 353 |
if border > 0.55:
|
| 354 |
-
score -= 18
|
| 355 |
-
warnings.append("Large border/margins detected (possible screenshot/phone framing).")
|
| 356 |
elif border > 0.35:
|
| 357 |
-
score -= 10
|
| 358 |
-
warnings.append("Some border/margins detected.")
|
| 359 |
|
| 360 |
return float(np.clip(score, 0, 100)), warnings
|
| 361 |
|
|
@@ -363,22 +313,19 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
|
|
| 363 |
g = gray_u8.copy()
|
| 364 |
g_blur = cv2.GaussianBlur(g, (5, 5), 0)
|
| 365 |
_, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 366 |
-
if th.mean() > 127:
|
| 367 |
-
th = 255 - th
|
| 368 |
|
| 369 |
k = max(3, int(0.01 * min(g.shape)))
|
| 370 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
| 371 |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 372 |
|
| 373 |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 374 |
-
if not contours:
|
| 375 |
-
return gray_u8
|
| 376 |
|
| 377 |
c = max(contours, key=cv2.contourArea)
|
| 378 |
x, y, w, h = cv2.boundingRect(c)
|
| 379 |
H, W = gray_u8.shape
|
| 380 |
-
if w * h < 0.20 * (H * W):
|
| 381 |
-
return gray_u8
|
| 382 |
|
| 383 |
pad = int(0.03 * min(H, W))
|
| 384 |
x1 = max(0, x - pad); y1 = max(0, y - pad)
|
|
@@ -395,7 +342,6 @@ def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
|
| 395 |
border = border_fraction(gray_u8)
|
| 396 |
|
| 397 |
g = gray_u8
|
| 398 |
-
|
| 399 |
if border > 0.35:
|
| 400 |
cropped = auto_border_crop(g)
|
| 401 |
if cropped.size >= 0.70 * g.size:
|
|
@@ -439,7 +385,7 @@ def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
|
|
| 439 |
m = (binary_u8 * 255).astype(np.uint8)
|
| 440 |
h, w = m.shape
|
| 441 |
flood = m.copy()
|
| 442 |
-
mask = np.zeros((h
|
| 443 |
cv2.floodFill(flood, mask, (0, 0), 255)
|
| 444 |
holes = cv2.bitwise_not(flood)
|
| 445 |
filled = cv2.bitwise_or(m, holes)
|
|
@@ -482,7 +428,7 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
|
|
| 482 |
|
| 483 |
border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
|
| 484 |
if border.mean() > 0.05:
|
| 485 |
-
warns.append("Lung mask touches
|
| 486 |
|
| 487 |
if total > 0 and (top1 + top2) / total < 0.90:
|
| 488 |
warns.append("Mask appears fragmented (may reduce reliability).")
|
|
@@ -491,14 +437,14 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
|
|
| 491 |
|
| 492 |
def recommendation_for_band(band: Optional[str]) -> str:
|
| 493 |
if band in (None, "YELLOW"):
|
| 494 |
-
return "Radiologist/clinician review is recommended (result is indeterminate)."
|
| 495 |
if band == "RED":
|
| 496 |
-
return "Urgent
|
| 497 |
-
return "If symptoms/risk factors exist,
|
| 498 |
|
| 499 |
|
| 500 |
# ============================================================
|
| 501 |
-
#
|
| 502 |
# ============================================================
|
| 503 |
def tbnet_state(tb_prob: float, tb_band: str) -> str:
|
| 504 |
if tb_band == "RED":
|
|
@@ -523,7 +469,7 @@ def build_consensus(
|
|
| 523 |
) -> Tuple[str, str]:
|
| 524 |
|
| 525 |
if tb_prob is None or tb_band is None:
|
| 526 |
-
return ("N/A", f"{MODEL_NAME_TBNET}
|
| 527 |
|
| 528 |
if radio_masked is not None:
|
| 529 |
radio_primary = radio_masked
|
|
@@ -533,27 +479,27 @@ def build_consensus(
|
|
| 533 |
radio_used = "RAW"
|
| 534 |
|
| 535 |
if radio_primary is None:
|
| 536 |
-
return ("TBNet only", f"{MODEL_NAME_RADIO}
|
| 537 |
|
| 538 |
t = tbnet_state(tb_prob, tb_band)
|
| 539 |
r = radio_state_from_prob(radio_primary)
|
| 540 |
-
rb = f" (band={radio_band})" if radio_band else ""
|
| 541 |
|
| 542 |
if t == r:
|
| 543 |
return (
|
| 544 |
f"AGREE: {t}",
|
| 545 |
-
f"Both
|
| 546 |
)
|
| 547 |
|
| 548 |
if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
|
| 549 |
return (
|
| 550 |
"DISAGREE",
|
| 551 |
-
f"
|
| 552 |
)
|
| 553 |
|
| 554 |
return (
|
| 555 |
"MIXED/INDET",
|
| 556 |
-
f"Mixed
|
| 557 |
)
|
| 558 |
|
| 559 |
|
|
@@ -612,7 +558,6 @@ class RadioMLPHead(nn.Module):
|
|
| 612 |
nn.Dropout(dropout),
|
| 613 |
nn.Linear(hidden, 1),
|
| 614 |
)
|
| 615 |
-
|
| 616 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 617 |
return self.net(x).squeeze(1)
|
| 618 |
|
|
@@ -796,6 +741,7 @@ def analyze_one_image(
|
|
| 796 |
mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
|
| 797 |
|
| 798 |
mask256_bin = (mask256 > 0.5).astype(np.uint8)
|
|
|
|
| 799 |
mask256_bin = keep_top_k_components(mask256_bin, k=2)
|
| 800 |
k = max(3, int(0.02 * 256))
|
| 801 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
|
@@ -812,14 +758,11 @@ def analyze_one_image(
|
|
| 812 |
"logit": None,
|
| 813 |
"pred": "INDETERMINATE",
|
| 814 |
"band": "YELLOW",
|
| 815 |
-
"band_text": (
|
| 816 |
-
"⚠️ Lung segmentation looks unreliable, so the TBNet screening score is disabled for safety.\n\n"
|
| 817 |
-
"Please use a clearer standard frontal CXR (PA/AP) or seek radiologist review."
|
| 818 |
-
),
|
| 819 |
"quality_score": float(q_score),
|
| 820 |
"diffuse_risk": False,
|
| 821 |
"warnings": (
|
| 822 |
-
[
|
| 823 |
+ (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
|
| 824 |
+ q_warn
|
| 825 |
),
|
|
@@ -842,11 +785,7 @@ def analyze_one_image(
|
|
| 842 |
"logit": None,
|
| 843 |
"pred": "INDETERMINATE",
|
| 844 |
"band": "YELLOW",
|
| 845 |
-
"band_text": (
|
| 846 |
-
"⚠️ The image appears cropped/non-standard (mask sanity check). "
|
| 847 |
-
"TBNet screening score is disabled for safety.\n\n"
|
| 848 |
-
"Please use a standard frontal CXR (PA/AP) or seek radiologist review."
|
| 849 |
-
),
|
| 850 |
"quality_score": float(q_score),
|
| 851 |
"diffuse_risk": False,
|
| 852 |
"warnings": (
|
|
@@ -883,38 +822,28 @@ def analyze_one_image(
|
|
| 883 |
diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
|
| 884 |
band_base, _ = confidence_band(prob_tb, q_score, diffuse)
|
| 885 |
|
| 886 |
-
allow_red = (prob_tb >= 0.70 and q_score >= 55 and
|
| 887 |
band = "RED" if allow_red else band_base
|
| 888 |
|
| 889 |
pred = REPORT_LABELS[band]["title"]
|
| 890 |
band_text = REPORT_LABELS[band]["summary"]
|
| 891 |
|
| 892 |
-
if band == "YELLOW" and prob_tb < 0.05:
|
| 893 |
-
band_text = (
|
| 894 |
-
"⚠️ TB probability is very low, but the result is marked **indeterminate** because reliability is limited.\n\n"
|
| 895 |
-
+ band_text
|
| 896 |
-
)
|
| 897 |
-
|
| 898 |
heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 899 |
overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
|
| 900 |
|
| 901 |
overlay_annotated = overlay_clean.copy()
|
| 902 |
text1 = f"{band}: {pred}"
|
| 903 |
-
text2 = f"
|
| 904 |
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
|
| 905 |
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
|
| 906 |
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
|
| 907 |
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
|
| 908 |
|
| 909 |
warnings = []
|
| 910 |
-
if phone_mode:
|
| 911 |
-
|
| 912 |
-
if
|
| 913 |
-
|
| 914 |
-
if coverage < warn_cov:
|
| 915 |
-
warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
|
| 916 |
-
if diffuse:
|
| 917 |
-
warnings.append("Non-focal attention pattern; result treated cautiously.")
|
| 918 |
warnings.extend(q_warn)
|
| 919 |
|
| 920 |
return {
|
|
@@ -969,7 +898,7 @@ def run_analysis(
|
|
| 969 |
|
| 970 |
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 971 |
if img is None:
|
| 972 |
-
rows.append([name, "", "SKIP", "", "Unreadable image", "", "", "", ""
|
| 973 |
continue
|
| 974 |
|
| 975 |
out = analyze_one_image(
|
|
@@ -983,16 +912,16 @@ def run_analysis(
|
|
| 983 |
)
|
| 984 |
|
| 985 |
# RADIO (optional)
|
| 986 |
-
radio_text = f"{MODEL_NAME_RADIO}
|
| 987 |
radio_raw_overlay = None
|
| 988 |
radio_masked_overlay = None
|
|
|
|
| 989 |
radio_raw_val: Optional[float] = None
|
| 990 |
radio_masked_val: Optional[float] = None
|
| 991 |
radio_primary_val: Optional[float] = None
|
| 992 |
radio_band: Optional[str] = None
|
| 993 |
|
| 994 |
-
|
| 995 |
-
radio_masked_str = ""
|
| 996 |
|
| 997 |
if use_radio and out["prob"] is not None:
|
| 998 |
try:
|
|
@@ -1007,20 +936,18 @@ def run_analysis(
|
|
| 1007 |
radio_primary_val = float(r["prob_primary"])
|
| 1008 |
radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
|
| 1009 |
radio_band = str(r["band"])
|
| 1010 |
-
|
| 1011 |
-
radio_raw_str = f"{radio_raw_val:.4f}"
|
| 1012 |
-
radio_masked_str = "" if radio_masked_val is None else f"{radio_masked_val:.4f}"
|
| 1013 |
|
| 1014 |
radio_text = (
|
| 1015 |
-
f"**{MODEL_NAME_RADIO}
|
| 1016 |
-
f"PRIMARY={radio_primary_val:.4f} | RAW={radio_raw_val:.4f}"
|
| 1017 |
+ (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
|
| 1018 |
-
+ f" | Band={radio_band}"
|
| 1019 |
)
|
| 1020 |
radio_raw_overlay = r["raw_overlay"]
|
| 1021 |
radio_masked_overlay = r["masked_overlay"]
|
| 1022 |
except Exception as e:
|
| 1023 |
radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
|
|
|
|
| 1024 |
radio_raw_val = None
|
| 1025 |
radio_masked_val = None
|
| 1026 |
radio_primary_val = None
|
|
@@ -1034,24 +961,24 @@ def run_analysis(
|
|
| 1034 |
radio_band=radio_band,
|
| 1035 |
)
|
| 1036 |
|
| 1037 |
-
|
| 1038 |
-
|
| 1039 |
-
|
|
|
|
| 1040 |
|
| 1041 |
rows.append([
|
| 1042 |
name,
|
| 1043 |
-
|
| 1044 |
out["pred"],
|
| 1045 |
-
|
| 1046 |
-
|
|
|
|
| 1047 |
f"{out['quality_score']:.0f}",
|
| 1048 |
-
|
| 1049 |
-
radio_raw_str,
|
| 1050 |
-
radio_masked_str,
|
| 1051 |
consensus_label,
|
| 1052 |
])
|
| 1053 |
|
| 1054 |
-
#
|
| 1055 |
orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 1056 |
vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 1057 |
mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
|
|
@@ -1060,11 +987,9 @@ def run_analysis(
|
|
| 1060 |
gallery_items.append((orig_rgb, f"{name} • ORIGINAL"))
|
| 1061 |
gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT"))
|
| 1062 |
gallery_items.append((mask_overlay, f"{name} • Lung mask overlay"))
|
| 1063 |
-
|
| 1064 |
if out["proc_gray"] is not None:
|
| 1065 |
proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 1066 |
gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
|
| 1067 |
-
|
| 1068 |
gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay ({MODEL_NAME_TBNET})"))
|
| 1069 |
|
| 1070 |
if radio_raw_overlay is not None:
|
|
@@ -1072,56 +997,35 @@ def run_analysis(
|
|
| 1072 |
if radio_masked_overlay is not None:
|
| 1073 |
gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
|
| 1074 |
|
| 1075 |
-
# Details
|
| 1076 |
-
summary_md = overall_summary(
|
| 1077 |
-
tb_band=out.get("band"),
|
| 1078 |
-
tb_prob=out.get("prob"),
|
| 1079 |
-
radio_primary=radio_primary_val,
|
| 1080 |
-
radio_band=radio_band,
|
| 1081 |
-
consensus_label=consensus_label,
|
| 1082 |
-
q_score=float(out["quality_score"]),
|
| 1083 |
-
cov=float(out.get("lung_coverage", 0.0)),
|
| 1084 |
-
warnings=out.get("warnings", []),
|
| 1085 |
-
)
|
| 1086 |
-
|
| 1087 |
warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
|
| 1088 |
-
tb_line = "N/A (
|
| 1089 |
rec_line = recommendation_for_band(out.get("band"))
|
| 1090 |
|
| 1091 |
details_md.append(
|
| 1092 |
-
f"""{
|
| 1093 |
|
| 1094 |
-
|
|
|
|
|
|
|
| 1095 |
|
| 1096 |
-
|
| 1097 |
-
<summary><b>{MODEL_NAME_TBNET} details</b></summary>
|
| 1098 |
|
| 1099 |
-
|
| 1100 |
-
|
| 1101 |
-
- **Probability (screening score):** {tb_line}
|
| 1102 |
-
- **Attention pattern:** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
|
| 1103 |
|
| 1104 |
-
|
| 1105 |
-
|
| 1106 |
-
|
| 1107 |
-
<summary><b>{MODEL_NAME_RADIO} details</b></summary>
|
| 1108 |
-
|
| 1109 |
-
{radio_text}
|
| 1110 |
-
|
| 1111 |
-
</details>
|
| 1112 |
-
|
| 1113 |
-
<details>
|
| 1114 |
-
<summary><b>Image quality & segmentation</b></summary>
|
| 1115 |
-
|
| 1116 |
-
- **Quality score:** {out['quality_score']:.0f}/100
|
| 1117 |
-
- **Lung mask coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
|
| 1118 |
|
| 1119 |
**Notes that may affect reliability**
|
| 1120 |
{warn_txt}
|
| 1121 |
|
| 1122 |
-
|
|
|
|
| 1123 |
|
| 1124 |
-
|
|
|
|
| 1125 |
|
| 1126 |
**Clinical guidance**
|
| 1127 |
{CLINICAL_GUIDANCE}
|
|
@@ -1144,10 +1048,10 @@ def build_ui():
|
|
| 1144 |
"""
|
| 1145 |
|
| 1146 |
with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
|
| 1147 |
-
gr.Markdown('<div class="title">TB X-ray Assistant (Research Use)</div>')
|
| 1148 |
gr.Markdown(
|
| 1149 |
f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
|
| 1150 |
-
f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) •
|
| 1151 |
)
|
| 1152 |
|
| 1153 |
with gr.Row():
|
|
@@ -1161,7 +1065,7 @@ def build_ui():
|
|
| 1161 |
|
| 1162 |
threshold = gr.Slider(
|
| 1163 |
0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
|
| 1164 |
-
label=f"Reference threshold (TBNet
|
| 1165 |
)
|
| 1166 |
|
| 1167 |
phone_mode = gr.Checkbox(
|
|
@@ -1186,10 +1090,8 @@ def build_ui():
|
|
| 1186 |
|
| 1187 |
with gr.Column(scale=2):
|
| 1188 |
gr.Markdown("#### Upload images")
|
| 1189 |
-
files = gr.Files(
|
| 1190 |
-
|
| 1191 |
-
file_types=[".png", ".jpg", ".jpeg", ".bmp"]
|
| 1192 |
-
)
|
| 1193 |
run_btn = gr.Button("Run Analysis", variant="primary")
|
| 1194 |
status = gr.Textbox(label="Status", value="Ready.", interactive=False)
|
| 1195 |
|
|
@@ -1197,17 +1099,16 @@ def build_ui():
|
|
| 1197 |
table = gr.Dataframe(
|
| 1198 |
headers=[
|
| 1199 |
"Image",
|
| 1200 |
-
"
|
| 1201 |
"TBNet Result",
|
| 1202 |
-
"
|
| 1203 |
-
"
|
|
|
|
| 1204 |
"Quality",
|
| 1205 |
"LungCov",
|
| 1206 |
-
"
|
| 1207 |
-
"RADIO MASKED",
|
| 1208 |
-
"AGREEMENT",
|
| 1209 |
],
|
| 1210 |
-
datatype=["str","str","str","str","str","str","str","str","str"
|
| 1211 |
interactive=False,
|
| 1212 |
label="Results"
|
| 1213 |
)
|
|
|
|
| 1 |
# app.py
|
| 2 |
+
# Gradio — TBNet + Lung U-Net Auto Mask + Grad-CAM + RADIO
|
| 3 |
# + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
|
| 4 |
+
# + 3-STATE CONSENSUS (LOW / INDET / SCREEN+)
|
| 5 |
#
|
| 6 |
+
# HF Spaces: use relative weight paths (edit below if needed)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
#
|
| 8 |
+
# Requirements (requirements.txt):
|
| 9 |
+
# gradio
|
| 10 |
+
# torch
|
| 11 |
+
# torchvision
|
| 12 |
+
# timm
|
| 13 |
+
# opencv-python
|
| 14 |
+
# pillow
|
| 15 |
+
# transformers
|
| 16 |
+
# einops
|
| 17 |
+
# open_clip_torch
|
| 18 |
|
| 19 |
import os
|
| 20 |
import cv2
|
|
|
|
| 28 |
from torchvision import transforms
|
| 29 |
from typing import List, Tuple, Dict, Any, Optional
|
| 30 |
|
|
|
|
| 31 |
from transformers import AutoModel, CLIPImageProcessor
|
| 32 |
from einops import rearrange
|
| 33 |
from PIL import Image
|
| 34 |
|
| 35 |
|
| 36 |
# ============================================================
|
| 37 |
+
# USER CONFIG
|
| 38 |
# ============================================================
|
| 39 |
|
| 40 |
+
# ---- Friendly names (UI) ----
|
| 41 |
MODEL_NAME_TBNET = "TBNet (CNN model)"
|
| 42 |
MODEL_NAME_RADIO = "RADIO (visual model)"
|
| 43 |
|
| 44 |
+
# ---- Default TB/Lung weights (HF-friendly relative paths) ----
|
| 45 |
DEFAULT_TB_WEIGHTS = "weights/best.pt"
|
| 46 |
DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
|
| 47 |
|
| 48 |
+
# ---- RADIO config (same env as TB) ----
|
| 49 |
RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
|
| 50 |
RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
|
| 51 |
|
|
|
|
| 59 |
RADIO_MASKED_MIN_COV = 0.15
|
| 60 |
RADIO_GATE_DEFAULT = 0.21
|
| 61 |
|
| 62 |
+
# ---- Consensus logic thresholds ----
|
| 63 |
TBNET_SCREEN_THR = 0.30
|
| 64 |
+
TBNET_MARGIN = 0.03 # (kept for compatibility / future use)
|
| 65 |
+
|
| 66 |
RADIO_SCREEN_THR = RADIO_THR_SCREEN
|
| 67 |
+
RADIO_MARGIN = 0.02 # (kept for compatibility / future use)
|
| 68 |
|
| 69 |
+
# ---- Mask fail-safes ----
|
| 70 |
FAIL_COV = 0.10
|
| 71 |
WARN_COV = 0.18
|
| 72 |
FAILSAFE_ON_BAD_MASK = True
|
| 73 |
|
| 74 |
+
# ---- Device policy ----
|
| 75 |
FORCE_CPU = True
|
| 76 |
DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| 77 |
|
|
|
|
| 81 |
# ============================================================
|
| 82 |
CLINICAL_DISCLAIMER = """
|
| 83 |
⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
|
| 84 |
+
This AI system is for **research/decision support** and is NOT a diagnostic device.
|
| 85 |
+
It may NOT reliably detect early/subtle tuberculosis, including **MILIARY TB**,
|
| 86 |
+
which can appear near-normal or subtle on chest X-ray (especially on phone photos / WhatsApp images).
|
|
|
|
| 87 |
|
| 88 |
If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
|
| 89 |
recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
|
| 90 |
"""
|
| 91 |
|
| 92 |
+
# Friendly labels still map to GREEN/YELLOW/RED logic
|
| 93 |
REPORT_LABELS = {
|
| 94 |
"GREEN": {
|
| 95 |
"title": "LOW TB LIKELIHOOD",
|
|
|
|
| 126 |
|
| 127 |
|
| 128 |
# ============================================================
|
| 129 |
+
# NEW: OVERALL LABEL FOR TABLE (user-friendly)
|
| 130 |
# ============================================================
|
| 131 |
+
def overall_label_from_consensus(consensus_label: str, tb_prob: Optional[float]) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
if tb_prob is None:
|
| 133 |
+
return "⚠️ INDETERMINATE"
|
| 134 |
+
if "AGREE: LOW" in consensus_label:
|
| 135 |
+
return "✅ LOW"
|
| 136 |
+
if "AGREE: SCREEN+" in consensus_label:
|
| 137 |
+
return "⚠️ SCREEN+"
|
| 138 |
+
if "AGREE: TB+" in consensus_label:
|
| 139 |
+
return "🚩 TB+"
|
| 140 |
+
if "DISAGREE" in consensus_label:
|
| 141 |
+
return "⚠️ DISAGREE"
|
| 142 |
+
return "⚠️ INDET"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
|
| 145 |
# ============================================================
|
|
|
|
| 156 |
nn.BatchNorm2d(out_c),
|
| 157 |
nn.ReLU(inplace=True),
|
| 158 |
)
|
| 159 |
+
def forward(self, x): return self.net(x)
|
|
|
|
|
|
|
| 160 |
|
| 161 |
class LungUNet(nn.Module):
|
| 162 |
def __init__(self):
|
|
|
|
| 198 |
super().__init__()
|
| 199 |
self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
|
| 200 |
self.fc = nn.Linear(self.backbone.num_features, 1)
|
| 201 |
+
def forward(self, x): return self.fc(self.backbone(x)).view(-1)
|
|
|
|
|
|
|
| 202 |
|
| 203 |
def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
|
| 204 |
sd = torch.load(ckpt_path, map_location=device)
|
|
|
|
| 212 |
target_layer.register_forward_hook(self._fwd)
|
| 213 |
target_layer.register_full_backward_hook(self._bwd)
|
| 214 |
|
| 215 |
+
def _fwd(self, _, __, out): self.activ = out
|
| 216 |
+
def _bwd(self, _, grad_in, grad_out): self.grad = grad_out[0]
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
|
| 219 |
with torch.enable_grad():
|
|
|
|
| 271 |
bot = gray_u8[-b:, :]
|
| 272 |
left = gray_u8[:, :b]
|
| 273 |
right = gray_u8[:, -b:]
|
| 274 |
+
def frac_border(x): return float(((x < 15) | (x > 240)).mean())
|
|
|
|
| 275 |
return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
|
| 276 |
|
| 277 |
def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
|
|
|
| 286 |
sharp = laplacian_sharpness(gray_u8)
|
| 287 |
lo_clip, hi_clip = exposure_scores(gray_u8)
|
| 288 |
border = border_fraction(gray_u8)
|
|
|
|
| 289 |
likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05)
|
| 290 |
|
| 291 |
if likely_phone:
|
| 292 |
if sharp < 40:
|
| 293 |
+
score -= 25; warnings.append("Blurry / motion blur detected (likely phone capture).")
|
|
|
|
| 294 |
elif sharp < 80:
|
| 295 |
+
score -= 12; warnings.append("Slight blur detected.")
|
|
|
|
| 296 |
else:
|
| 297 |
if sharp < 30:
|
| 298 |
+
score -= 8; warnings.append("Low fine detail (possible downsampling).")
|
|
|
|
| 299 |
|
| 300 |
if hi_clip > 0.05:
|
| 301 |
+
score -= 15; warnings.append("Overexposed highlights (washed-out areas).")
|
|
|
|
| 302 |
if lo_clip > 0.10:
|
| 303 |
+
score -= 12; warnings.append("Underexposed shadows (very dark areas).")
|
|
|
|
| 304 |
|
| 305 |
if border > 0.55:
|
| 306 |
+
score -= 18; warnings.append("Large border/margins detected (possible screenshot/phone framing).")
|
|
|
|
| 307 |
elif border > 0.35:
|
| 308 |
+
score -= 10; warnings.append("Some border/margins detected.")
|
|
|
|
| 309 |
|
| 310 |
return float(np.clip(score, 0, 100)), warnings
|
| 311 |
|
|
|
|
| 313 |
g = gray_u8.copy()
|
| 314 |
g_blur = cv2.GaussianBlur(g, (5, 5), 0)
|
| 315 |
_, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 316 |
+
if th.mean() > 127: th = 255 - th
|
|
|
|
| 317 |
|
| 318 |
k = max(3, int(0.01 * min(g.shape)))
|
| 319 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
| 320 |
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
|
| 321 |
|
| 322 |
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
| 323 |
+
if not contours: return gray_u8
|
|
|
|
| 324 |
|
| 325 |
c = max(contours, key=cv2.contourArea)
|
| 326 |
x, y, w, h = cv2.boundingRect(c)
|
| 327 |
H, W = gray_u8.shape
|
| 328 |
+
if w * h < 0.20 * (H * W): return gray_u8
|
|
|
|
| 329 |
|
| 330 |
pad = int(0.03 * min(H, W))
|
| 331 |
x1 = max(0, x - pad); y1 = max(0, y - pad)
|
|
|
|
| 342 |
border = border_fraction(gray_u8)
|
| 343 |
|
| 344 |
g = gray_u8
|
|
|
|
| 345 |
if border > 0.35:
|
| 346 |
cropped = auto_border_crop(g)
|
| 347 |
if cropped.size >= 0.70 * g.size:
|
|
|
|
| 385 |
m = (binary_u8 * 255).astype(np.uint8)
|
| 386 |
h, w = m.shape
|
| 387 |
flood = m.copy()
|
| 388 |
+
mask = np.zeros((h+2, w+2), np.uint8)
|
| 389 |
cv2.floodFill(flood, mask, (0, 0), 255)
|
| 390 |
holes = cv2.bitwise_not(flood)
|
| 391 |
filled = cv2.bitwise_or(m, holes)
|
|
|
|
| 428 |
|
| 429 |
border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
|
| 430 |
if border.mean() > 0.05:
|
| 431 |
+
warns.append("Lung mask touches image border (possible cropped/non-standard CXR).")
|
| 432 |
|
| 433 |
if total > 0 and (top1 + top2) / total < 0.90:
|
| 434 |
warns.append("Mask appears fragmented (may reduce reliability).")
|
|
|
|
| 437 |
|
| 438 |
def recommendation_for_band(band: Optional[str]) -> str:
|
| 439 |
if band in (None, "YELLOW"):
|
| 440 |
+
return "✅ Recommendation: Radiologist/clinician review is recommended (result is indeterminate)."
|
| 441 |
if band == "RED":
|
| 442 |
+
return "✅ Recommendation: Urgent clinician/radiologist review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
|
| 443 |
+
return "✅ Recommendation: If symptoms/risk factors exist, clinician/radiologist correlation is advised."
|
| 444 |
|
| 445 |
|
| 446 |
# ============================================================
|
| 447 |
+
# CONSENSUS LOGIC (TBNet vs RADIO) — 3-state
|
| 448 |
# ============================================================
|
| 449 |
def tbnet_state(tb_prob: float, tb_band: str) -> str:
|
| 450 |
if tb_band == "RED":
|
|
|
|
| 469 |
) -> Tuple[str, str]:
|
| 470 |
|
| 471 |
if tb_prob is None or tb_band is None:
|
| 472 |
+
return ("N/A", f"{MODEL_NAME_TBNET} unavailable (lung segmentation failed / fail-safe).")
|
| 473 |
|
| 474 |
if radio_masked is not None:
|
| 475 |
radio_primary = radio_masked
|
|
|
|
| 479 |
radio_used = "RAW"
|
| 480 |
|
| 481 |
if radio_primary is None:
|
| 482 |
+
return ("TBNet only", f"{MODEL_NAME_RADIO} unavailable → {MODEL_NAME_TBNET}={tb_prob:.4f} (band={tb_band}).")
|
| 483 |
|
| 484 |
t = tbnet_state(tb_prob, tb_band)
|
| 485 |
r = radio_state_from_prob(radio_primary)
|
| 486 |
+
rb = f" (RADIO band={radio_band})" if radio_band else ""
|
| 487 |
|
| 488 |
if t == r:
|
| 489 |
return (
|
| 490 |
f"AGREE: {t}",
|
| 491 |
+
f"Both: {t}. {MODEL_NAME_TBNET}={tb_prob:.4f}, {MODEL_NAME_RADIO}({radio_used})={radio_primary:.4f}{rb}."
|
| 492 |
)
|
| 493 |
|
| 494 |
if (t in ("SCREEN+", "TB+") and r == "LOW") or (r in ("SCREEN+", "TB+") and t == "LOW"):
|
| 495 |
return (
|
| 496 |
"DISAGREE",
|
| 497 |
+
f"Strong disagreement: {MODEL_NAME_TBNET}={t} (band={tb_band}) vs {MODEL_NAME_RADIO}={r} ({radio_used})={radio_primary:.4f}{rb}."
|
| 498 |
)
|
| 499 |
|
| 500 |
return (
|
| 501 |
"MIXED/INDET",
|
| 502 |
+
f"Mixed/uncertain: {MODEL_NAME_TBNET}={t} (band={tb_band}) vs {MODEL_NAME_RADIO}={r} ({radio_used})={radio_primary:.4f}{rb}."
|
| 503 |
)
|
| 504 |
|
| 505 |
|
|
|
|
| 558 |
nn.Dropout(dropout),
|
| 559 |
nn.Linear(hidden, 1),
|
| 560 |
)
|
|
|
|
| 561 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 562 |
return self.net(x).squeeze(1)
|
| 563 |
|
|
|
|
| 741 |
mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
|
| 742 |
|
| 743 |
mask256_bin = (mask256 > 0.5).astype(np.uint8)
|
| 744 |
+
|
| 745 |
mask256_bin = keep_top_k_components(mask256_bin, k=2)
|
| 746 |
k = max(3, int(0.02 * 256))
|
| 747 |
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
|
|
|
| 758 |
"logit": None,
|
| 759 |
"pred": "INDETERMINATE",
|
| 760 |
"band": "YELLOW",
|
| 761 |
+
"band_text": "Lung segmentation failed. TB scoring disabled (fail-safe).",
|
|
|
|
|
|
|
|
|
|
| 762 |
"quality_score": float(q_score),
|
| 763 |
"diffuse_risk": False,
|
| 764 |
"warnings": (
|
| 765 |
+
["Lung segmentation failed (<10% lung area).", f"Lung coverage: {coverage*100:.1f}%"]
|
| 766 |
+ (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
|
| 767 |
+ q_warn
|
| 768 |
),
|
|
|
|
| 785 |
"logit": None,
|
| 786 |
"pred": "INDETERMINATE",
|
| 787 |
"band": "YELLOW",
|
| 788 |
+
"band_text": "Non-standard/cropped view or unreliable lung segmentation. TB scoring disabled (fail-safe).",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 789 |
"quality_score": float(q_score),
|
| 790 |
"diffuse_risk": False,
|
| 791 |
"warnings": (
|
|
|
|
| 822 |
diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
|
| 823 |
band_base, _ = confidence_band(prob_tb, q_score, diffuse)
|
| 824 |
|
| 825 |
+
allow_red = (prob_tb >= 0.70 and q_score >= 55 and not diffuse and coverage >= warn_cov)
|
| 826 |
band = "RED" if allow_red else band_base
|
| 827 |
|
| 828 |
pred = REPORT_LABELS[band]["title"]
|
| 829 |
band_text = REPORT_LABELS[band]["summary"]
|
| 830 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 831 |
heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 832 |
overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
|
| 833 |
|
| 834 |
overlay_annotated = overlay_clean.copy()
|
| 835 |
text1 = f"{band}: {pred}"
|
| 836 |
+
text2 = f"TB prob={prob_tb:.3f} | Quality={q_score:.0f}/100 | Lung coverage={coverage*100:.1f}%"
|
| 837 |
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
|
| 838 |
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
|
| 839 |
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
|
| 840 |
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
|
| 841 |
|
| 842 |
warnings = []
|
| 843 |
+
if phone_mode: warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
|
| 844 |
+
if q_score < 55: warnings.append("Suboptimal image quality limits AI reliability.")
|
| 845 |
+
if coverage < warn_cov: warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
|
| 846 |
+
if diffuse: warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 847 |
warnings.extend(q_warn)
|
| 848 |
|
| 849 |
return {
|
|
|
|
| 898 |
|
| 899 |
img = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 900 |
if img is None:
|
| 901 |
+
rows.append([name, "⚠️", "SKIP", "", "Unreadable image", "", "", "", ""])
|
| 902 |
continue
|
| 903 |
|
| 904 |
out = analyze_one_image(
|
|
|
|
| 912 |
)
|
| 913 |
|
| 914 |
# RADIO (optional)
|
| 915 |
+
radio_text = f"{MODEL_NAME_RADIO} disabled."
|
| 916 |
radio_raw_overlay = None
|
| 917 |
radio_masked_overlay = None
|
| 918 |
+
|
| 919 |
radio_raw_val: Optional[float] = None
|
| 920 |
radio_masked_val: Optional[float] = None
|
| 921 |
radio_primary_val: Optional[float] = None
|
| 922 |
radio_band: Optional[str] = None
|
| 923 |
|
| 924 |
+
radio_result_short = "Disabled"
|
|
|
|
| 925 |
|
| 926 |
if use_radio and out["prob"] is not None:
|
| 927 |
try:
|
|
|
|
| 936 |
radio_primary_val = float(r["prob_primary"])
|
| 937 |
radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
|
| 938 |
radio_band = str(r["band"])
|
| 939 |
+
radio_result_short = str(r["pred"])
|
|
|
|
|
|
|
| 940 |
|
| 941 |
radio_text = (
|
| 942 |
+
f"**{MODEL_NAME_RADIO}:** {r['pred']} | PRIMARY={radio_primary_val:.4f} | RAW={radio_raw_val:.4f}"
|
|
|
|
| 943 |
+ (f" | MASKED={radio_masked_val:.4f}" if radio_masked_val is not None else "")
|
| 944 |
+
+ (f" | Band={radio_band}" if radio_band else "")
|
| 945 |
)
|
| 946 |
radio_raw_overlay = r["raw_overlay"]
|
| 947 |
radio_masked_overlay = r["masked_overlay"]
|
| 948 |
except Exception as e:
|
| 949 |
radio_text = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
|
| 950 |
+
radio_result_short = "Error"
|
| 951 |
radio_raw_val = None
|
| 952 |
radio_masked_val = None
|
| 953 |
radio_primary_val = None
|
|
|
|
| 961 |
radio_band=radio_band,
|
| 962 |
)
|
| 963 |
|
| 964 |
+
overall = overall_label_from_consensus(consensus_label, out["prob"])
|
| 965 |
+
|
| 966 |
+
tb_prob_str = "" if out["prob"] is None else f"{out['prob']:.4f}"
|
| 967 |
+
radio_prob_primary_str = "" if radio_primary_val is None else f"{radio_primary_val:.4f}"
|
| 968 |
|
| 969 |
rows.append([
|
| 970 |
name,
|
| 971 |
+
overall,
|
| 972 |
out["pred"],
|
| 973 |
+
tb_prob_str,
|
| 974 |
+
radio_result_short,
|
| 975 |
+
radio_prob_primary_str,
|
| 976 |
f"{out['quality_score']:.0f}",
|
| 977 |
+
f"{out.get('lung_coverage', 0.0) * 100:.1f}%",
|
|
|
|
|
|
|
| 978 |
consensus_label,
|
| 979 |
])
|
| 980 |
|
| 981 |
+
# Visual outputs
|
| 982 |
orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 983 |
vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 984 |
mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
|
|
|
|
| 987 |
gallery_items.append((orig_rgb, f"{name} • ORIGINAL"))
|
| 988 |
gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT"))
|
| 989 |
gallery_items.append((mask_overlay, f"{name} • Lung mask overlay"))
|
|
|
|
| 990 |
if out["proc_gray"] is not None:
|
| 991 |
proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
|
| 992 |
gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
|
|
|
|
| 993 |
gallery_items.append((overlay_big, f"{name} • Grad-CAM overlay ({MODEL_NAME_TBNET})"))
|
| 994 |
|
| 995 |
if radio_raw_overlay is not None:
|
|
|
|
| 997 |
if radio_masked_overlay is not None:
|
| 998 |
gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
|
| 999 |
|
| 1000 |
+
# Details panel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1001 |
warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
|
| 1002 |
+
tb_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
|
| 1003 |
rec_line = recommendation_for_band(out.get("band"))
|
| 1004 |
|
| 1005 |
details_md.append(
|
| 1006 |
+
f"""### {name}
|
| 1007 |
|
| 1008 |
+
**Overall:** {overall}
|
| 1009 |
+
**{MODEL_NAME_TBNET} result:** **{out['pred']}**
|
| 1010 |
+
{rec_line}
|
| 1011 |
|
| 1012 |
+
**{MODEL_NAME_TBNET} probability:** {tb_line}
|
|
|
|
| 1013 |
|
| 1014 |
+
**Interpretation**
|
| 1015 |
+
{out['band_text']}
|
|
|
|
|
|
|
| 1016 |
|
| 1017 |
+
**Image quality:** {out['quality_score']:.0f}/100
|
| 1018 |
+
**Lung mask coverage:** {out.get('lung_coverage', 0.0) * 100:.1f}%
|
| 1019 |
+
**Attention pattern (TBNet):** {"Diffuse / non-focal" if out["diffuse_risk"] else "Focal / localized"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1020 |
|
| 1021 |
**Notes that may affect reliability**
|
| 1022 |
{warn_txt}
|
| 1023 |
|
| 1024 |
+
**{MODEL_NAME_RADIO} output**
|
| 1025 |
+
{radio_text}
|
| 1026 |
|
| 1027 |
+
**Agreement between models:** **{consensus_label}**
|
| 1028 |
+
- {consensus_detail}
|
| 1029 |
|
| 1030 |
**Clinical guidance**
|
| 1031 |
{CLINICAL_GUIDANCE}
|
|
|
|
| 1048 |
"""
|
| 1049 |
|
| 1050 |
with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
|
| 1051 |
+
gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
|
| 1052 |
gr.Markdown(
|
| 1053 |
f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
|
| 1054 |
+
f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Agreement summary</div>"
|
| 1055 |
)
|
| 1056 |
|
| 1057 |
with gr.Row():
|
|
|
|
| 1065 |
|
| 1066 |
threshold = gr.Slider(
|
| 1067 |
0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
|
| 1068 |
+
label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
|
| 1069 |
)
|
| 1070 |
|
| 1071 |
phone_mode = gr.Checkbox(
|
|
|
|
| 1090 |
|
| 1091 |
with gr.Column(scale=2):
|
| 1092 |
gr.Markdown("#### Upload images")
|
| 1093 |
+
files = gr.Files(label="Upload one or multiple X-ray images",
|
| 1094 |
+
file_types=[".png", ".jpg", ".jpeg", ".bmp"])
|
|
|
|
|
|
|
| 1095 |
run_btn = gr.Button("Run Analysis", variant="primary")
|
| 1096 |
status = gr.Textbox(label="Status", value="Ready.", interactive=False)
|
| 1097 |
|
|
|
|
| 1099 |
table = gr.Dataframe(
|
| 1100 |
headers=[
|
| 1101 |
"Image",
|
| 1102 |
+
"OVERALL",
|
| 1103 |
"TBNet Result",
|
| 1104 |
+
"TBNet Prob",
|
| 1105 |
+
"RADIO Result",
|
| 1106 |
+
"RADIO Prob (Primary)",
|
| 1107 |
"Quality",
|
| 1108 |
"LungCov",
|
| 1109 |
+
"Agreement",
|
|
|
|
|
|
|
| 1110 |
],
|
| 1111 |
+
datatype=["str","str","str","str","str","str","str","str","str"],
|
| 1112 |
interactive=False,
|
| 1113 |
label="Results"
|
| 1114 |
)
|