# app.py
# Gradio — TBNet + Lung U-Net Auto Mask + Grad-CAM + RADIO
# + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE
# + 3-STATE CONSENSUS (LOW / INDET / SCREEN+ / TB+)
#
# UX UPDATE (this version):
# - Removes the results table
# - Shows per-image, easy-to-read "cards":
# 1) TBNet result (alone)
# 2) RADIO result (alone)
# 3) Final consensus (comparison + next step)
# - Adds collapsible detailed report per image
# - Keeps gallery, adds legend, better labels
# - Fixes welcome screen rendering (uses gr.HTML + inline styles; no raw HTML shown)
#
# HF Spaces: use relative weight paths (edit below if needed)
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import timm
import gradio as gr
from torchvision import transforms
from typing import List, Tuple, Dict, Any, Optional
from transformers import AutoModel, CLIPImageProcessor
from einops import rearrange
from PIL import Image
# ============================================================
# USER CONFIG
# ============================================================
# ---- Friendly names (UI) ----
MODEL_NAME_TBNET = "TBNet (CNN model)"
MODEL_NAME_RADIO = "Nvidia/C-RADIOv4-SO400M (visual model)"
# ---- Default TB/Lung weights (HF-friendly relative paths) ----
DEFAULT_TB_WEIGHTS = "weights/best.pt"
DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt"
# ---- RADIO config (same env as TB) ----
RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M"
RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8"
RADIO_RAW_HEAD_PATH = "weights/best_raw.pt"
RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt"
RADIO_IMG_SIZE = 320
RADIO_PATCH_SIZE = 16
RADIO_THR_SCREEN = 0.05
RADIO_THR_RED = 0.23
RADIO_MASKED_MIN_COV = 0.15
RADIO_GATE_DEFAULT = 0.21
# ---- Consensus logic thresholds ----
TBNET_SCREEN_THR = 0.30
TBNET_MARGIN = 0.03 # kept for compatibility / future use
RADIO_SCREEN_THR = RADIO_THR_SCREEN
RADIO_MARGIN = 0.02 # kept for compatibility / future use
# ---- Mask fail-safes ----
FAIL_COV = 0.10
WARN_COV = 0.18
FAILSAFE_ON_BAD_MASK = True
# ---- Device policy ----
FORCE_CPU = True
DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu"))
# ============================================================
# CLINICAL DISCLAIMER / REPORT TEXT
# ============================================================
CLINICAL_DISCLAIMER = """
⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only)
This AI system is for **research/decision support** and is NOT a diagnostic device.
It may NOT reliably detect early/subtle tuberculosis, including **MILIARY TB**,
which can appear near-normal or subtle on chest X-ray (especially on phone photos / WhatsApp images).
If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure),
recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output.
"""
REPORT_LABELS = {
"GREEN": {
"title": "LOW TB LIKELIHOOD / Pulmonary T.B not detected by A.I",
"summary": (
f"✅ **{MODEL_NAME_TBNET}** did not find patterns that strongly suggest pulmonary tuberculosis.\n\n"
"**What to do next:** If symptoms or TB risk factors are present, please seek clinician/radiologist review."
),
},
"YELLOW": {
"title": "INDETERMINATE — REVIEW RECOMMENDED BY A RADIOLOGIST",
"summary": (
f"⚠️ **{MODEL_NAME_TBNET}** result is **not definitive**.\n\n"
"**Common reasons:** image quality limitations, non-standard/cropped view, or non-focal attention.\n\n"
"**What to do next:** Radiologist/clinician review is recommended. "
"If TB is clinically suspected, consider microbiological tests (CBNAAT/GeneXpert, sputum)."
),
},
"RED": {
"title": "TB FEATURES SUSPECTED",
"summary": (
f"🚩 **{MODEL_NAME_TBNET}** detected lung patterns that can be seen with pulmonary tuberculosis.\n\n"
"**Important:** This is not a diagnosis.\n\n"
"**What to do next:** Urgent clinician/radiologist review and microbiological confirmation "
"(CBNAAT/GeneXpert, sputum) are recommended."
),
},
}
CLINICAL_GUIDANCE = (
"If clinical suspicion for tuberculosis exists, further evaluation "
"(e.g., CBNAAT / GeneXpert, sputum studies, CT chest) is recommended "
"regardless of AI output."
)
# ============================================================
# WELCOME HTML (minimal + main features only)
# IMPORTANT: rendered with gr.HTML (not gr.Markdown)
# ============================================================
WELCOME_HTML = f"""
TB X-ray Assistant (research / decision support)
Upload chest X-rays to get an AI screening score, heatmaps, and a simple consensus output.
{MODEL_NAME_TBNET} + Grad-CAM
Auto lung mask + fail-safe
{MODEL_NAME_RADIO} (optional)
Consensus: ✅ LOW · ⚠️ INDET · ⚠️ SCREEN+ · 🚩 TB+
Phone/WhatsApp Mode
Tip: Turn on Phone/WhatsApp Mode for phone photos, WhatsApp-forwards, or screenshots with borders.
Clinical disclaimer: Not diagnostic. If TB is suspected clinically, pursue CBNAAT/GeneXpert/sputum and/or CT chest regardless of AI output.
"""
# ============================================================
# UX HELPERS
# ============================================================
def pretty_state(s: str) -> str:
return {
"LOW": "✅ LOW",
"INDET": "⚠️ INDETERMINATE",
"SCREEN+": "⚠️ SCREEN+",
"TB+": "🚩 TB+",
"N/A": "⚠️ N/A",
}.get(s, f"⚠️ {s}")
def html_escape(s: str) -> str:
return (s or "").replace("&", "&").replace("<", "<").replace(">", ">")
def badge_color_for_state(state: str) -> str:
if state == "TB+":
return "rgba(239,68,68,0.18)" # red
if state == "SCREEN+":
return "rgba(245,158,11,0.18)" # amber
if state == "INDET":
return "rgba(245,158,11,0.12)" # amber lighter
if state == "LOW":
return "rgba(34,197,94,0.14)" # green
return "rgba(148,163,184,0.12)" # gray
# ============================================================
# LUNG U-NET (INFERENCE)
# ============================================================
class DoubleConv(nn.Module):
def __init__(self, in_c, out_c):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(in_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, out_c, 3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.net(x)
class LungUNet(nn.Module):
def __init__(self):
super().__init__()
self.d1 = DoubleConv(1, 64)
self.d2 = DoubleConv(64, 128)
self.d3 = DoubleConv(128, 256)
self.d4 = DoubleConv(256, 512)
self.pool = nn.MaxPool2d(2)
self.mid = DoubleConv(512, 1024)
self.u4 = nn.ConvTranspose2d(1024, 512, 2, 2)
self.u3 = nn.ConvTranspose2d(512, 256, 2, 2)
self.u2 = nn.ConvTranspose2d(256, 128, 2, 2)
self.u1 = nn.ConvTranspose2d(128, 64, 2, 2)
self.c4 = DoubleConv(1024, 512)
self.c3 = DoubleConv(512, 256)
self.c2 = DoubleConv(256, 128)
self.c1 = DoubleConv(128, 64)
self.out = nn.Conv2d(64, 1, 1)
def forward(self, x):
d1 = self.d1(x)
d2 = self.d2(self.pool(d1))
d3 = self.d3(self.pool(d2))
d4 = self.d4(self.pool(d3))
m = self.mid(self.pool(d4))
x = self.c4(torch.cat([self.u4(m), d4], 1))
x = self.c3(torch.cat([self.u3(x), d3], 1))
x = self.c2(torch.cat([self.u2(x), d2], 1))
x = self.c1(torch.cat([self.u1(x), d1], 1))
return self.out(x)
# ============================================================
# TB MODEL + GRAD-CAM
# ============================================================
class TBNet(nn.Module):
def __init__(self, backbone="efficientnet_b0"):
super().__init__()
self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg")
self.fc = nn.Linear(self.backbone.num_features, 1)
def forward(self, x):
return self.fc(self.backbone(x)).view(-1)
def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
sd = torch.load(ckpt_path, map_location=device)
model.load_state_dict(sd, strict=True)
class GradCAM:
def __init__(self, model: nn.Module, target_layer: nn.Module):
self.model = model
self.activ = None
self.grad = None
target_layer.register_forward_hook(self._fwd)
target_layer.register_full_backward_hook(self._bwd)
def _fwd(self, _, __, out):
self.activ = out
def _bwd(self, _, grad_in, grad_out):
self.grad = grad_out[0]
def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]:
with torch.enable_grad():
self.model.zero_grad(set_to_none=True)
logits = self.model(x)
score = logits[0]
score.backward()
A = self.activ[0]
G = self.grad[0]
w = G.mean(dim=(1, 2))
cam = (w[:, None, None] * A).sum(dim=0)
cam = torch.relu(cam)
cam = cam - cam.min()
cam = cam / (cam.max() + 1e-8)
logit = float(logits.detach().cpu()[0].item())
prob = float(torch.sigmoid(logits.detach().cpu())[0].item())
return cam.detach().cpu().numpy(), prob, logit
# ============================================================
# PREPROCESS HELPERS + QUALITY
# ============================================================
def preprocess_for_lung_unet(gray_u8: np.ndarray) -> torch.Tensor:
g = gray_u8.astype(np.float32)
g = cv2.resize(g, (256, 256), interpolation=cv2.INTER_AREA)
lo, hi = np.percentile(g, (1, 99))
g = np.clip(g, lo, hi)
g = (g - lo) / (hi - lo + 1e-8)
return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float()
def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
gray = gray_u8.astype(np.float32)
lo, hi = np.percentile(gray, (1, 99))
gray = np.clip(gray, lo, hi)
gray = (gray - lo) / (hi - lo + 1e-8)
return gray
def laplacian_sharpness(gray_u8: np.ndarray) -> float:
g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA)
g = cv2.GaussianBlur(g, (3, 3), 0)
return float(cv2.Laplacian(g, cv2.CV_64F).var())
def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]:
lo = float((gray_u8 < 10).mean())
hi = float((gray_u8 > 245).mean())
return lo, hi
def border_fraction(gray_u8: np.ndarray) -> float:
h, w = gray_u8.shape
b = max(5, int(0.06 * min(h, w)))
top = gray_u8[:b, :]
bot = gray_u8[-b:, :]
left = gray_u8[:, :b]
right = gray_u8[:, -b:]
def frac_border(x):
return float(((x < 15) | (x > 240)).mean())
return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
warnings: List[str] = []
h, w = gray_u8.shape
score = 100.0
if min(h, w) < 400:
warnings.append("Low resolution (image may be downsampled).")
score -= 8
sharp = laplacian_sharpness(gray_u8)
lo_clip, hi_clip = exposure_scores(gray_u8)
border = border_fraction(gray_u8)
likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05)
if likely_phone:
if sharp < 40:
score -= 25
warnings.append("Blurry / motion blur detected (likely phone capture).")
elif sharp < 80:
score -= 12
warnings.append("Slight blur detected.")
else:
if sharp < 30:
score -= 8
warnings.append("Low fine detail (possible downsampling).")
if hi_clip > 0.05:
score -= 15
warnings.append("Overexposed highlights (washed-out areas).")
if lo_clip > 0.10:
score -= 12
warnings.append("Underexposed shadows (very dark areas).")
if border > 0.55:
score -= 18
warnings.append("Large border/margins detected (possible screenshot/phone framing).")
elif border > 0.35:
score -= 10
warnings.append("Some border/margins detected.")
return float(np.clip(score, 0, 100)), warnings
def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
g = gray_u8.copy()
g_blur = cv2.GaussianBlur(g, (5, 5), 0)
_, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
if th.mean() > 127:
th = 255 - th
k = max(3, int(0.01 * min(g.shape)))
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2)
contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return gray_u8
c = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(c)
H, W = gray_u8.shape
if w * h < 0.20 * (H * W):
return gray_u8
pad = int(0.03 * min(H, W))
x1 = max(0, x - pad)
y1 = max(0, y - pad)
x2 = min(W, x + w + pad)
y2 = min(H, y + h + pad)
return gray_u8[y1:y2, x1:x2]
def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
return clahe.apply(gray_u8)
def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
sharp = laplacian_sharpness(gray_u8)
lo_clip, _hi_clip = exposure_scores(gray_u8)
border = border_fraction(gray_u8)
g = gray_u8
if border > 0.35:
cropped = auto_border_crop(g)
if cropped.size >= 0.70 * g.size:
g = cropped
if lo_clip > 0.10 or sharp < 80:
g = apply_clahe(g)
return g
def cam_entropy(cam: np.ndarray) -> float:
cam = cam.astype(np.float32)
cam = cam / (cam.sum() + 1e-8)
return float(-np.sum(cam * np.log(cam + 1e-8)))
def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
if quality_score < 55:
return False
if prob_tb < 0.05:
return False
ent = cam_entropy(cam_up)
return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
if prob_tb < 0.01 and quality_score >= 45:
return ("GREEN", "✅ Very low TB signal detected.")
if quality_score < 55:
return ("YELLOW", "⚠️ Image quality is low; treat as indeterminate.")
if diffuse:
return ("YELLOW", "⚠️ Attention is non-focal; treat as indeterminate.")
if prob_tb >= TBNET_SCREEN_THR:
return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
return ("GREEN", "✅ No strong TB signal detected.")
def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET)
return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0)
def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
m = (binary_u8 * 255).astype(np.uint8)
h, w = m.shape
flood = m.copy()
mask = np.zeros((h + 2, w + 2), np.uint8)
cv2.floodFill(flood, mask, (0, 0), 255)
holes = cv2.bitwise_not(flood)
filled = cv2.bitwise_or(m, holes)
return (filled > 0).astype(np.uint8)
def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
m = (binary_u8 > 0).astype(np.uint8)
n, labels = cv2.connectedComponents(m)
if n <= 1:
return m
areas = []
for i in range(1, n):
areas.append((i, int((labels == i).sum())))
areas.sort(key=lambda x: x[1], reverse=True)
keep_ids = set([i for i, _ in areas[:k]])
out = np.zeros_like(m)
for i in keep_ids:
out[labels == i] = 1
return out
def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
m = (mask_full_u8 > 0).astype(np.uint8)
n, labels = cv2.connectedComponents(m)
warns = []
if n <= 2:
warns.append("Only one lung region detected (possible crop/segmentation failure).")
return warns
areas = []
for i in range(1, n):
areas.append(int((labels == i).sum()))
areas.sort(reverse=True)
total = int(m.sum())
top1 = areas[0]
top2 = areas[1] if len(areas) > 1 else 0
if total > 0 and top1 / total > 0.80:
warns.append("Mask dominated by a single region (possible cropped/partial lung view).")
border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]])
if border.mean() > 0.05:
warns.append("Lung mask touches image border (possible cropped/non-standard CXR).")
if total > 0 and (top1 + top2) / total < 0.90:
warns.append("Mask appears fragmented (may reduce reliability).")
return warns
def recommendation_for_band(band: Optional[str]) -> str:
if band in (None, "YELLOW"):
return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)."
if band == "RED":
return "✅ Recommendation: **Urgent** clinician/radiologist review + microbiological confirmation (CBNAAT/GeneXpert, sputum)."
return "✅ Recommendation: If symptoms/risk factors exist, clinician/radiologist correlation is advised."
# ============================================================
# CONSENSUS LOGIC (TBNet vs RADIO)
# ============================================================
def tbnet_state(tb_prob: float, tb_band: str) -> str:
if tb_band == "RED":
return "TB+"
if tb_band == "YELLOW":
return "INDET"
if tb_prob >= TBNET_SCREEN_THR:
return "SCREEN+"
return "LOW"
def radio_state_from_prob(radio_prob: float) -> str:
if radio_prob >= RADIO_THR_RED:
return "TB+"
if radio_prob >= RADIO_THR_SCREEN:
return "SCREEN+"
return "LOW"
def build_consensus(
tb_prob: Optional[float],
tb_band: Optional[str],
radio_raw: Optional[float],
radio_masked: Optional[float],
radio_band: Optional[str] = None
) -> Tuple[str, str, str, str]:
"""
Returns:
consensus_label, consensus_detail, tb_state, radio_state
"""
if tb_prob is None or tb_band is None:
return (
"N/A",
f"{MODEL_NAME_TBNET} unavailable (lung segmentation failed / fail-safe).",
"N/A",
"N/A",
)
if radio_masked is not None:
radio_primary = radio_masked
radio_used = "MASKED"
else:
radio_primary = radio_raw
radio_used = "RAW"
tb_state = tbnet_state(tb_prob, tb_band)
if radio_primary is None:
return (
"TBNet only",
f"{MODEL_NAME_RADIO} unavailable → TBNet state={tb_state}, p={tb_prob:.4f} (band={tb_band}).",
tb_state,
"N/A",
)
radio_state = radio_state_from_prob(radio_primary)
rb = f" (RADIO band={radio_band})" if radio_band else ""
if tb_state == "INDET" and radio_state == "INDET":
return (
"AGREE: INDET",
f"Both models are indeterminate. TBNet p={tb_prob:.4f} (band={tb_band}), RADIO({radio_used})={radio_primary:.4f}{rb}.",
tb_state,
radio_state,
)
if tb_state == "INDET" and radio_state in ("LOW", "SCREEN+", "TB+"):
return (
"MIXED/INDET",
f"TBNet is indeterminate (band={tb_band}); RADIO suggests {radio_state} ({radio_used})={radio_primary:.4f}{rb}.",
tb_state,
radio_state,
)
if radio_state == "INDET" and tb_state in ("LOW", "SCREEN+", "TB+"):
return (
"MIXED/INDET",
f"RADIO is indeterminate; TBNet suggests {tb_state} (band={tb_band}) p={tb_prob:.4f}.",
tb_state,
radio_state,
)
if tb_state == radio_state:
return (
f"AGREE: {tb_state}",
f"Both models agree: {tb_state}. TBNet p={tb_prob:.4f}, RADIO({radio_used})={radio_primary:.4f}{rb}.",
tb_state,
radio_state,
)
if (tb_state in ("SCREEN+", "TB+") and radio_state == "LOW") or (radio_state in ("SCREEN+", "TB+") and tb_state == "LOW"):
return (
"DISAGREE",
f"Strong disagreement: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.",
tb_state,
radio_state,
)
return (
"MIXED",
f"Mixed: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.",
tb_state,
radio_state,
)
# ============================================================
# TB + LUNG MODEL BUNDLE (cached)
# ============================================================
class ModelBundle:
def __init__(self):
self.device = DEVICE
self.tb = None
self.cammer = None
self.lung = None
self.tb_path = None
self.lung_path = None
self.backbone = "efficientnet_b0"
self.tfm = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
])
def load(self, tb_weights: str, lung_weights: str, backbone: str = "efficientnet_b0"):
if (self.tb_path != tb_weights) or (self.tb is None) or (self.cammer is None) or (self.backbone != backbone):
tb = TBNet(backbone=backbone).to(self.device)
load_tb_weights(tb, tb_weights, self.device)
tb.eval()
self.tb = tb
self.cammer = GradCAM(tb, tb.backbone.conv_head)
self.tb_path = tb_weights
self.backbone = backbone
if (self.lung_path != lung_weights) or (self.lung is None):
lung = LungUNet().to(self.device)
lung.load_state_dict(torch.load(lung_weights, map_location=self.device))
lung.eval()
self.lung = lung
self.lung_path = lung_weights
BUNDLE = ModelBundle()
# ============================================================
# RADIO BUNDLE (cached)
# ============================================================
class RadioMLPHead(nn.Module):
def __init__(self, dim: int, hidden: int = 512, dropout: float = 0.2):
super().__init__()
self.net = nn.Sequential(
nn.LayerNorm(dim),
nn.Dropout(dropout),
nn.Linear(dim, hidden),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x).squeeze(1)
class RadioBundle:
def __init__(self):
self.loaded = False
self.processor = None
self.radio = None
self.raw_head = None
self.masked_head = None
self.summary_dim = None
self.device_str = None
def load(self, device: torch.device):
dev_str = str(device)
if self.loaded and self.device_str == dev_str:
return
if not os.path.exists(RADIO_RAW_HEAD_PATH):
raise FileNotFoundError(f"RADIO raw head not found: {RADIO_RAW_HEAD_PATH}")
if not os.path.exists(RADIO_MASKED_HEAD_PATH):
raise FileNotFoundError(f"RADIO masked head not found: {RADIO_MASKED_HEAD_PATH}")
self.processor = CLIPImageProcessor.from_pretrained(RADIO_HF_REPO, revision=RADIO_REVISION)
dtype = torch.float16 if device.type == "cuda" else torch.float32
self.radio = AutoModel.from_pretrained(
RADIO_HF_REPO,
revision=RADIO_REVISION,
trust_remote_code=True,
dtype=dtype,
).eval().to(device)
with torch.no_grad():
dummy = torch.zeros((1, 3, RADIO_IMG_SIZE, RADIO_IMG_SIZE), device=device, dtype=dtype)
summary, _ = self.radio(dummy)
self.summary_dim = int(summary.shape[-1])
def _load_head(path: str) -> nn.Module:
ckpt = torch.load(path, map_location="cpu")
dim = int(ckpt.get("dim", self.summary_dim))
head = RadioMLPHead(dim=dim).to(device).eval()
head.load_state_dict(ckpt["head_state"], strict=True)
return head
self.raw_head = _load_head(RADIO_RAW_HEAD_PATH)
self.masked_head = _load_head(RADIO_MASKED_HEAD_PATH)
self.device_str = dev_str
self.loaded = True
RADIO_BUNDLE = RadioBundle()
def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray:
ht = in_h // patch_size
wt = in_w // patch_size
feat = rearrange(spatial_tokens, "b (h w) d -> b d h w", h=ht, w=wt)
energy = torch.sqrt(torch.clamp((feat ** 2).sum(dim=1), min=1e-8))[0]
energy = (energy - energy.min()) / (energy.max() - energy.min() + 1e-8)
hm = energy.detach().float().cpu().numpy().astype(np.float32)
hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR)
return np.array(hm_img, dtype=np.float32) / 255.0
def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray:
img = rgb_u8.astype(np.float32) / 255.0
hm = np.clip(heatmap01, 0, 1).astype(np.float32)
out = img.copy()
out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
return (out * 255).astype(np.uint8)
@torch.inference_mode()
def radio_predict_from_arrays(
gray_vis_u8: np.ndarray,
lung_mask_u8: np.ndarray,
coverage: float,
device: torch.device,
gate_threshold: float
) -> Dict[str, Any]:
RADIO_BUNDLE.load(device=device)
dtype = torch.float16 if device.type == "cuda" else torch.float32
raw_rgb = cv2.cvtColor(gray_vis_u8, cv2.COLOR_GRAY2RGB)
px = RADIO_BUNDLE.processor(
images=Image.fromarray(raw_rgb),
return_tensors="pt",
do_resize=True,
size={"shortest_edge": RADIO_IMG_SIZE},
do_center_crop=True,
).pixel_values.to(device).to(dtype)
summary, spatial = RADIO_BUNDLE.radio(px)
logit_raw = RADIO_BUNDLE.raw_head(summary)
prob_raw = float(torch.sigmoid(logit_raw)[0].item())
hm_raw = radio_heatmap_from_spatial(spatial, px.shape[-2], px.shape[-1], RADIO_PATCH_SIZE)
raw_overlay = radio_overlay_heatmap(
cv2.resize(raw_rgb, (px.shape[-1], px.shape[-2])),
hm_raw,
alpha=0.35
)
masked_prob = None
masked_overlay = None
masked_ran = False
if lung_mask_u8 is not None and coverage >= RADIO_MASKED_MIN_COV and coverage >= gate_threshold:
masked_ran = True
masked_u8 = (gray_vis_u8 * lung_mask_u8).astype(np.uint8)
masked_rgb = cv2.cvtColor(masked_u8, cv2.COLOR_GRAY2RGB)
pxm = RADIO_BUNDLE.processor(
images=Image.fromarray(masked_rgb),
return_tensors="pt",
do_resize=True,
size={"shortest_edge": RADIO_IMG_SIZE},
do_center_crop=True,
).pixel_values.to(device).to(dtype)
summary_m, spatial_m = RADIO_BUNDLE.radio(pxm)
logit_m = RADIO_BUNDLE.masked_head(summary_m)
masked_prob = float(torch.sigmoid(logit_m)[0].item())
hm_m = radio_heatmap_from_spatial(spatial_m, pxm.shape[-2], pxm.shape[-1], RADIO_PATCH_SIZE)
masked_overlay = radio_overlay_heatmap(
cv2.resize(masked_rgb, (pxm.shape[-1], pxm.shape[-2])),
hm_m,
alpha=0.35
)
prob_primary = masked_prob if masked_prob is not None else prob_raw
if prob_primary >= RADIO_THR_RED:
band = "RED"
pred = "HIGH TB-LIKE PATTERN SCORE (RADIO)"
elif prob_primary >= RADIO_THR_SCREEN:
band = "YELLOW"
pred = "SCREEN-POSITIVE RANGE (RADIO)"
else:
band = "GREEN"
pred = "LOW TB-LIKE PATTERN SCORE (RADIO)"
return {
"prob_raw": prob_raw,
"prob_primary": prob_primary,
"pred": pred,
"band": band,
"raw_overlay": raw_overlay,
"masked_prob": masked_prob,
"masked_overlay": masked_overlay,
"masked_ran": masked_ran,
"gate_threshold": float(gate_threshold),
}
# ============================================================
# TB CORE ANALYSIS
# ============================================================
def analyze_one_image(
img_bgr: np.ndarray,
tb_weights: str,
lung_weights: str,
backbone: str,
threshold: float,
phone_mode: bool,
img_size: int = 224,
fail_cov: float = FAIL_COV,
warn_cov: float = WARN_COV,
) -> Dict[str, Any]:
BUNDLE.load(tb_weights, lung_weights, backbone)
device = BUNDLE.device
gray = img_bgr if img_bgr.ndim == 2 else cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY)
q_score, q_warn = phone_quality_report(gray)
gray_vis = phone_preprocess(gray) if phone_mode else gray
if gray_vis.dtype != np.uint8:
gray_vis = np.clip(gray_vis, 0, 255).astype(np.uint8)
with torch.no_grad():
x_lung = preprocess_for_lung_unet(gray_vis).to(device)
mask_logits = BUNDLE.lung(x_lung)
mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy()
mask256_bin = (mask256 > 0.5).astype(np.uint8)
mask256_bin = keep_top_k_components(mask256_bin, k=2)
k = max(3, int(0.02 * 256))
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
mask256_bin = cv2.morphologyEx(mask256_bin, cv2.MORPH_CLOSE, kernel, iterations=1)
mask256_bin = fill_holes(mask256_bin)
coverage = float(mask256_bin.mean())
mask_full = cv2.resize(mask256_bin, (gray_vis.shape[1], gray_vis.shape[0]), interpolation=cv2.INTER_NEAREST)
if coverage < fail_cov:
overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
return {
"prob": None,
"logit": None,
"pred": "INDETERMINATE",
"band": "YELLOW",
"band_text": "Lung segmentation failed. TB scoring disabled (fail-safe).",
"quality_score": float(q_score),
"diffuse_risk": False,
"warnings": (
["Lung segmentation failed (<10% lung area).", f"Lung coverage: {coverage*100:.1f}%"]
+ (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
+ q_warn
),
"lung_coverage": coverage,
"orig_gray": gray,
"vis_gray": gray_vis,
"masked_gray": None,
"proc_gray": None,
"lung_mask": mask_full,
"mask_overlay": make_mask_overlay(gray_vis, mask_full),
"overlay": overlay_rgb,
"overlay_clean": overlay_rgb,
}
sanity = mask_sanity_warnings(mask_full.astype(np.uint8))
if FAILSAFE_ON_BAD_MASK and sanity:
overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB)
return {
"prob": None,
"logit": None,
"pred": "INDETERMINATE",
"band": "YELLOW",
"band_text": "Non-standard/cropped view or unreliable lung segmentation. TB scoring disabled (fail-safe).",
"quality_score": float(q_score),
"diffuse_risk": False,
"warnings": (
sanity
+ [f"Lung coverage: {coverage*100:.1f}%"]
+ (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else [])
+ q_warn
),
"lung_coverage": coverage,
"orig_gray": gray,
"vis_gray": gray_vis,
"masked_gray": None,
"proc_gray": None,
"lung_mask": mask_full,
"mask_overlay": make_mask_overlay(gray_vis, mask_full),
"overlay": overlay_rgb,
"overlay_clean": overlay_rgb,
}
masked = (gray_vis * mask_full).astype(np.uint8)
masked_f01 = tb_training_preprocess(masked)
masked_u8 = (masked_f01 * 255).astype(np.uint8)
masked_u8_rs = cv2.resize(masked_u8, (img_size, img_size), interpolation=cv2.INTER_AREA)
rgb = cv2.cvtColor(masked_u8_rs, cv2.COLOR_GRAY2RGB)
x = BUNDLE.tfm(rgb).unsqueeze(0).to(device)
cam, prob_tb, logit = BUNDLE.cammer.generate(x)
cam_u8 = (np.clip(cam, 0, 1) * 255).astype(np.uint8)
cam_u8 = cv2.resize(cam_u8, (img_size, img_size), interpolation=cv2.INTER_CUBIC)
cam_up = cam_u8.astype(np.float32) / 255.0
diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score)
band_base, _ = confidence_band(prob_tb, q_score, diffuse)
allow_red = (prob_tb >= 0.70 and q_score >= 55 and not diffuse and coverage >= warn_cov)
band = "RED" if allow_red else band_base
pred = REPORT_LABELS[band]["title"]
band_text = REPORT_LABELS[band]["summary"]
heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET)
overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0)
overlay_annotated = overlay_clean.copy()
text1 = f"{band}: {pred}"
text2 = f"TB p={prob_tb:.3f} | Quality={q_score:.0f}/100 | LungCov={coverage*100:.1f}%"
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2)
cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1)
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2)
cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1)
warnings = []
if phone_mode:
warnings.append("Phone/WhatsApp mode enabled; artifacts possible.")
if q_score < 55:
warnings.append("Suboptimal image quality limits AI reliability.")
if coverage < warn_cov:
warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).")
if diffuse:
warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.")
warnings.extend(q_warn)
return {
"prob": float(prob_tb),
"logit": float(logit),
"pred": pred,
"band": band,
"band_text": band_text,
"quality_score": float(q_score),
"diffuse_risk": bool(diffuse),
"warnings": warnings,
"lung_coverage": coverage,
"orig_gray": gray,
"vis_gray": gray_vis,
"masked_gray": masked,
"proc_gray": masked_u8_rs,
"lung_mask": mask_full,
"mask_overlay": make_mask_overlay(gray_vis, mask_full),
"overlay": overlay_annotated,
"overlay_clean": overlay_clean,
}
# ============================================================
# GRADIO CALLBACK
# ============================================================
def run_analysis(
files: List[gr.File],
tb_weights: str,
lung_weights: str,
backbone: str,
threshold: float,
phone_mode: bool,
use_radio: bool,
radio_gate: float,
):
if not files:
return "Please upload at least one image.", [], "", CLINICAL_DISCLAIMER, "Please upload at least one image."
if not os.path.exists(tb_weights):
msg = f"TB weights not found: {tb_weights}"
return msg, [], "", CLINICAL_DISCLAIMER, msg
if not os.path.exists(lung_weights):
msg = f"Lung U-Net weights not found: {lung_weights}"
return msg, [], "", CLINICAL_DISCLAIMER, msg
summary_md: List[str] = []
gallery_items = []
details_md: List[str] = []
# Top banner
summary_md.append(f"""
Results
Each image shows: {MODEL_NAME_TBNET} result, {MODEL_NAME_RADIO} result (optional), then a final consensus.
Key: ✅ LOW | ⚠️ INDETERMINATE | ⚠️ SCREEN+ | 🚩 TB+
""")
for f in files:
path = f.name if hasattr(f, "name") else str(f)
name = os.path.basename(path)
img = cv2.imread(path, cv2.IMREAD_COLOR)
if img is None:
summary_md.append(f"""
{html_escape(name)}
⚠️ Could not read this image. Please re-export it as PNG/JPG.
""")
continue
out = analyze_one_image(
img_bgr=img,
tb_weights=tb_weights,
lung_weights=lung_weights,
backbone=backbone,
threshold=threshold,
phone_mode=phone_mode,
img_size=224,
)
# RADIO (optional)
radio_text_long = f"{MODEL_NAME_RADIO} disabled."
radio_raw_overlay = None
radio_masked_overlay = None
radio_raw_val: Optional[float] = None
radio_masked_val: Optional[float] = None
radio_primary_val: Optional[float] = None
radio_band: Optional[str] = None
radio_result_short = "Disabled"
radio_masked_ran = False
if use_radio and out["prob"] is not None:
try:
r = radio_predict_from_arrays(
gray_vis_u8=out["vis_gray"],
lung_mask_u8=out["lung_mask"].astype(np.uint8),
coverage=float(out["lung_coverage"]),
device=BUNDLE.device,
gate_threshold=float(radio_gate),
)
radio_raw_val = float(r["prob_raw"])
radio_primary_val = float(r["prob_primary"])
radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"])
radio_band = str(r["band"])
radio_result_short = str(r["pred"])
radio_masked_ran = bool(r["masked_ran"])
radio_text_long = (
f"**{MODEL_NAME_RADIO}:** {r['pred']} \n"
f"- PRIMARY={radio_primary_val:.4f} \n"
f"- RAW={radio_raw_val:.4f} \n"
+ (f"- MASKED={radio_masked_val:.4f} \n" if radio_masked_val is not None else "- MASKED=Not run \n")
+ (f"- Band={radio_band} \n" if radio_band else "")
+ (f"- Masked gate={float(radio_gate):.2f} | LungCov={float(out['lung_coverage']):.2f} | Masked ran={radio_masked_ran}\n")
)
radio_raw_overlay = r["raw_overlay"]
radio_masked_overlay = r["masked_overlay"]
except Exception as e:
radio_text_long = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}"
radio_result_short = "Error"
radio_raw_val = None
radio_masked_val = None
radio_primary_val = None
radio_band = None
radio_masked_ran = False
# Consensus
consensus_label, consensus_detail, tb_state, radio_state = build_consensus(
tb_prob=out["prob"],
tb_band=out["band"],
radio_raw=radio_raw_val,
radio_masked=radio_masked_val,
radio_band=radio_band,
)
tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
tb_label = out.get("pred", "INDETERMINATE")
q = float(out.get("quality_score", 0.0))
cov = float(out.get("lung_coverage", 0.0))
attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
warns = out.get("warnings", [])
top_warns = warns[:3] if warns else []
top_warn_line = " • ".join([html_escape(w) for w in top_warns]) if top_warns else "None"
radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
if consensus_label == "DISAGREE":
next_step = "✅ Next step: Treat as indeterminate → radiologist review + microbiology if clinically suspected."
elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
next_step = "✅ Next step: Prompt clinician/radiologist review; consider microbiological confirmation if clinically suspected."
elif consensus_label in ("AGREE: LOW",):
next_step = "✅ Next step: If symptoms/risk factors exist, still correlate clinically and consider further testing."
else:
next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present."
state_badge_tb = f"""
{pretty_state(tb_state)}
"""
state_badge_radio = f"""
{pretty_state(radio_state)}
"""
tb_card = f"""
{html_escape(name)} — {MODEL_NAME_TBNET}
State: {state_badge_tb}
Result: {html_escape(tb_label)} (p={tb_prob_line})
Reliability: Quality {q:.0f}/100 | Lung mask {cov*100:.1f}% | Attention {attention}
Top notes: {top_warn_line}
{recommendation_for_band(out.get("band"))}
"""
if not use_radio:
radio_card = f"""
{html_escape(name)} — {MODEL_NAME_RADIO}
RADIO is disabled. Enable it to view its independent output and heatmaps.
"""
else:
gate_info = f"Masked gate={float(radio_gate):.2f} | LungCov={cov:.2f} | Masked ran={radio_masked_ran}"
radio_card = f"""
{html_escape(name)} — {MODEL_NAME_RADIO}
State: {state_badge_radio}
Result: {html_escape(radio_result_short)}
Scores: PRIMARY {radio_primary_line} | RAW {radio_raw_line} | MASKED {radio_masked_line}
{html_escape(gate_info)}
"""
consensus_card = f"""
{html_escape(name)} — Final consensus
Comparison: TBNet {pretty_state(tb_state)} vs RADIO {pretty_state(radio_state)}
Consensus: {html_escape(consensus_label)}
{html_escape(consensus_detail)}
{next_step}
"""
summary_md.append(tb_card)
summary_md.append(radio_card)
summary_md.append(consensus_card)
# Gallery
orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
mask_overlay = cv2.resize(out["mask_overlay"], (512, 512))
overlay_big = cv2.resize(out["overlay"], (512, 512))
gallery_items.append((orig_rgb, f"{name} • ORIGINAL"))
gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT"))
gallery_items.append((mask_overlay, f"{name} • Lung mask overlay"))
if out["proc_gray"] is not None:
proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB)
gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)"))
gallery_items.append((overlay_big, f"{name} • TBNet Grad-CAM overlay"))
if radio_raw_overlay is not None:
gallery_items.append((cv2.resize(radio_raw_overlay, (512, 512)), f"{name} • RADIO RAW heatmap"))
if radio_masked_overlay is not None:
gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap"))
# Details (collapsible per image) — FIXED (no welcome HTML here)
warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
details_md.append(
f"""
{html_escape(name)} — detailed report
**TBNet**
- Result: **{html_escape(tb_label)}**
- Probability: {tb_prob_line}
- Band: {out.get("band", "YELLOW")}
- Quality: {q:.0f}/100
- Lung mask coverage: {cov*100:.1f}%
- Attention: {attention}
**Consensus**
- TBNet state: {pretty_state(tb_state)}
- RADIO state: {pretty_state(radio_state)}
- Consensus label: **{html_escape(consensus_label)}**
- Detail: {html_escape(consensus_detail)}
**Warnings**
{warn_txt}
**RADIO (full)**
{radio_text_long}
---
"""
)
return "\n".join(summary_md), gallery_items, "\n".join(details_md), CLINICAL_DISCLAIMER, "Done."
# ============================================================
# UI (HF Spaces Welcome Screen + Main App)
# ============================================================
def build_ui():
css = """
.title {font-size: 28px; font-weight: 900; margin-bottom: 6px;}
.subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;}
.warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;}
.legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;}
.card {border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:14px; margin:10px 0;}
"""
with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
# ---------------------------
# Welcome screen (shown first)
# ---------------------------
with gr.Column(visible=True) as welcome_screen:
gr.Markdown('Welcome — TB X-ray Assistant (HF Spaces)
')
gr.HTML(WELCOME_HTML)
continue_btn = gr.Button("Continue →", variant="primary")
# ---------------------------
# Main app UI (hidden initially)
# ---------------------------
with gr.Column(visible=False) as main_app:
gr.Markdown('TB X-ray Assistant (Auto Lung Mask • Research Use)
')
gr.Markdown(
f"Auto lung mask → {MODEL_NAME_TBNET} + Grad-CAM • "
f"Optional {MODEL_NAME_RADIO} (C-RADIOv4 + heads) • Clear per-model results + consensus
"
)
gr.Markdown(
"Clinical disclaimer: Decision support only (not diagnostic). "
"If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.
"
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("#### Model settings")
tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS)
lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
backbone = gr.Dropdown(
choices=["efficientnet_b0"],
value="efficientnet_b0",
label="TBNet backbone"
)
threshold = gr.Slider(
0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
)
phone_mode = gr.Checkbox(
value=False,
label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
)
gr.Markdown(
"Enable for WhatsApp images, phone photos, or screenshots. "
"Leave off for clean digital exports.
"
)
use_radio = gr.Checkbox(value=True, label=f"Enable {MODEL_NAME_RADIO}")
radio_gate = gr.Slider(
0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
label="RADIO masked gate (run masked head if lung coverage ≥ gate)"
)
gr.Markdown(
"Fail-safe: If lung segmentation is too small or looks unreliable, "
f"{MODEL_NAME_TBNET} scoring is disabled to avoid unsafe outputs.
"
)
gr.Markdown(
f"Device: {DEVICE} (FORCE_CPU={FORCE_CPU})
"
)
back_btn = gr.Button("← Back to Welcome", variant="secondary")
with gr.Column(scale=2):
gr.Markdown("#### Upload images")
files = gr.Files(
label="Upload one or multiple X-ray images",
file_types=[".png", ".jpg", ".jpeg", ".bmp"]
)
run_btn = gr.Button("Run Analysis", variant="primary")
status = gr.Textbox(label="Status", value="Ready.", interactive=False)
gr.Markdown("""
Gallery legend:
1) ORIGINAL • 2) INPUT / PHONE-PROC • 3) Lung mask overlay •
4) Masked model input • 5) TBNet Grad-CAM • 6) RADIO heatmaps
""")
gr.Markdown("#### Summary (per image)")
summary = gr.Markdown("Upload images and click Run Analysis.")
gallery = gr.Gallery(label="Visual outputs", columns=3, height=560)
with gr.Row():
with gr.Column(scale=1):
disclaimer_box = gr.Markdown(CLINICAL_DISCLAIMER)
with gr.Column(scale=2):
gr.Markdown("#### Detailed report (expand per image)")
details = gr.Markdown("")
run_btn.click(
fn=run_analysis,
inputs=[
files,
tb_weights,
lung_weights,
backbone,
threshold,
phone_mode,
use_radio,
radio_gate,
],
outputs=[summary, gallery, details, disclaimer_box, status]
)
# ---------------------------
# Transitions
# ---------------------------
continue_btn.click(
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
inputs=[],
outputs=[welcome_screen, main_app],
)
back_btn.click(
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
inputs=[],
outputs=[welcome_screen, main_app],
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)