Spaces:
Running
Running
| # app.py | |
| # Gradio — TBNet + Lung U-Net Auto Mask + Grad-CAM + RADIO | |
| # + SAFER PHONE MODE + MASK POST-PROCESSING + MASK SANITY FAILSAFE | |
| # + 3-STATE CONSENSUS (LOW / INDET / SCREEN+ / TB+) | |
| # | |
| # UX UPDATE (this version): | |
| # - Removes the results table | |
| # - Shows per-image, easy-to-read "cards": | |
| # 1) TBNet result (alone) | |
| # 2) RADIO result (alone) | |
| # 3) Final consensus (comparison + next step) | |
| # - Adds collapsible detailed report per image | |
| # - Keeps gallery, adds legend, better labels | |
| # - Fixes welcome screen rendering (uses gr.HTML + inline styles; no raw HTML shown) | |
| # | |
| # HF Spaces: use relative weight paths (edit below if needed) | |
| import os | |
| import cv2 | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import timm | |
| import gradio as gr | |
| from torchvision import transforms | |
| from typing import List, Tuple, Dict, Any, Optional | |
| from transformers import AutoModel, CLIPImageProcessor | |
| from einops import rearrange | |
| from PIL import Image | |
| # ============================================================ | |
| # USER CONFIG | |
| # ============================================================ | |
| # ---- Friendly names (UI) ---- | |
| MODEL_NAME_TBNET = "TBNet (CNN model)" | |
| MODEL_NAME_RADIO = "Nvidia/C-RADIOv4-SO400M (visual model)" | |
| # ---- Default TB/Lung weights (HF-friendly relative paths) ---- | |
| DEFAULT_TB_WEIGHTS = "weights/best.pt" | |
| DEFAULT_LUNG_WEIGHTS = "weights/lung_unet_mont_shenzhen.pt" | |
| # ---- RADIO config (same env as TB) ---- | |
| RADIO_HF_REPO = "nvidia/C-RADIOv4-SO400M" | |
| RADIO_REVISION = "c0457f5dc26ca145f954cd4fc5bb6114e5705ad8" | |
| RADIO_RAW_HEAD_PATH = "weights/best_raw.pt" | |
| RADIO_MASKED_HEAD_PATH = "weights/best_masked.pt" | |
| RADIO_IMG_SIZE = 320 | |
| RADIO_PATCH_SIZE = 16 | |
| RADIO_THR_SCREEN = 0.05 | |
| RADIO_THR_RED = 0.23 | |
| RADIO_MASKED_MIN_COV = 0.15 | |
| RADIO_GATE_DEFAULT = 0.21 | |
| # ---- Consensus logic thresholds ---- | |
| TBNET_SCREEN_THR = 0.30 | |
| TBNET_MARGIN = 0.03 # kept for compatibility / future use | |
| RADIO_SCREEN_THR = RADIO_THR_SCREEN | |
| RADIO_MARGIN = 0.02 # kept for compatibility / future use | |
| # ---- Mask fail-safes ---- | |
| FAIL_COV = 0.10 | |
| WARN_COV = 0.18 | |
| FAILSAFE_ON_BAD_MASK = True | |
| # ---- Device policy ---- | |
| FORCE_CPU = True | |
| DEVICE = torch.device("cpu" if FORCE_CPU else ("cuda" if torch.cuda.is_available() else "cpu")) | |
| # ============================================================ | |
| # CLINICAL DISCLAIMER / REPORT TEXT | |
| # ============================================================ | |
| CLINICAL_DISCLAIMER = """ | |
| ⚠️ IMPORTANT CLINICAL NOTICE (Decision Support Only) | |
| This AI system is for **research/decision support** and is NOT a diagnostic device. | |
| It may NOT reliably detect early/subtle tuberculosis, including **MILIARY TB**, | |
| which can appear near-normal or subtle on chest X-ray (especially on phone photos / WhatsApp images). | |
| If clinical suspicion exists (fever, weight loss, immunosuppression, known exposure), | |
| recommend **CBNAAT / GeneXpert**, sputum studies, and/or **CT chest** regardless of AI output. | |
| """ | |
| REPORT_LABELS = { | |
| "GREEN": { | |
| "title": "LOW TB LIKELIHOOD / Pulmonary T.B not detected by A.I", | |
| "summary": ( | |
| f"✅ **{MODEL_NAME_TBNET}** did not find patterns that strongly suggest pulmonary tuberculosis.\n\n" | |
| "**What to do next:** If symptoms or TB risk factors are present, please seek clinician/radiologist review." | |
| ), | |
| }, | |
| "YELLOW": { | |
| "title": "INDETERMINATE — REVIEW RECOMMENDED BY A RADIOLOGIST", | |
| "summary": ( | |
| f"⚠️ **{MODEL_NAME_TBNET}** result is **not definitive**.\n\n" | |
| "**Common reasons:** image quality limitations, non-standard/cropped view, or non-focal attention.\n\n" | |
| "**What to do next:** Radiologist/clinician review is recommended. " | |
| "If TB is clinically suspected, consider microbiological tests (CBNAAT/GeneXpert, sputum)." | |
| ), | |
| }, | |
| "RED": { | |
| "title": "TB FEATURES SUSPECTED", | |
| "summary": ( | |
| f"🚩 **{MODEL_NAME_TBNET}** detected lung patterns that can be seen with pulmonary tuberculosis.\n\n" | |
| "**Important:** This is not a diagnosis.\n\n" | |
| "**What to do next:** Urgent clinician/radiologist review and microbiological confirmation " | |
| "(CBNAAT/GeneXpert, sputum) are recommended." | |
| ), | |
| }, | |
| } | |
| CLINICAL_GUIDANCE = ( | |
| "If clinical suspicion for tuberculosis exists, further evaluation " | |
| "(e.g., CBNAAT / GeneXpert, sputum studies, CT chest) is recommended " | |
| "regardless of AI output." | |
| ) | |
| # ============================================================ | |
| # WELCOME HTML (minimal + main features only) | |
| # IMPORTANT: rendered with gr.HTML (not gr.Markdown) | |
| # ============================================================ | |
| WELCOME_HTML = f""" | |
| <div style="max-width:980px;margin:0 auto;"> | |
| <div style="padding:14px 16px;border-radius:14px;background:rgba(255,255,255,0.04);border:1px solid rgba(255,255,255,0.10);"> | |
| <div style="font-size:18px;font-weight:900;margin-bottom:6px;"> | |
| TB X-ray Assistant <span style="opacity:0.75;font-weight:700;font-size:13px;">(research / decision support)</span> | |
| </div> | |
| <div style="opacity:0.9;font-size:13px;line-height:1.35;"> | |
| Upload chest X-rays to get an AI screening score, heatmaps, and a simple consensus output. | |
| </div> | |
| <div style="margin-top:10px;display:flex;flex-wrap:wrap;gap:8px;"> | |
| <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);"> | |
| <b>{MODEL_NAME_TBNET}</b> + Grad-CAM | |
| </span> | |
| <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);"> | |
| Auto lung mask + fail-safe | |
| </span> | |
| <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);"> | |
| <b>{MODEL_NAME_RADIO}</b> (optional) | |
| </span> | |
| <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);"> | |
| Consensus: ✅ LOW · ⚠️ INDET · ⚠️ SCREEN+ · 🚩 TB+ | |
| </span> | |
| <span style="font-size:12px;padding:6px 10px;border-radius:999px;background:rgba(255,255,255,0.06);border:1px solid rgba(255,255,255,0.10);"> | |
| Phone/WhatsApp Mode | |
| </span> | |
| </div> | |
| <div style="margin-top:10px;opacity:0.85;font-size:12.5px;"> | |
| <b>Tip:</b> Turn on Phone/WhatsApp Mode for phone photos, WhatsApp-forwards, or screenshots with borders. | |
| </div> | |
| </div> | |
| <div style="margin-top:12px;padding:10px 12px;border-left:5px solid #f59e0b;border-radius:12px;background:rgba(245,158,11,0.10);font-size:12.5px;line-height:1.35;"> | |
| <b>Clinical disclaimer:</b> Not diagnostic. If TB is suspected clinically, pursue CBNAAT/GeneXpert/sputum and/or CT chest regardless of AI output. | |
| </div> | |
| </div> | |
| """ | |
| # ============================================================ | |
| # UX HELPERS | |
| # ============================================================ | |
| def pretty_state(s: str) -> str: | |
| return { | |
| "LOW": "✅ LOW", | |
| "INDET": "⚠️ INDETERMINATE", | |
| "SCREEN+": "⚠️ SCREEN+", | |
| "TB+": "🚩 TB+", | |
| "N/A": "⚠️ N/A", | |
| }.get(s, f"⚠️ {s}") | |
| def html_escape(s: str) -> str: | |
| return (s or "").replace("&", "&").replace("<", "<").replace(">", ">") | |
| def badge_color_for_state(state: str) -> str: | |
| if state == "TB+": | |
| return "rgba(239,68,68,0.18)" # red | |
| if state == "SCREEN+": | |
| return "rgba(245,158,11,0.18)" # amber | |
| if state == "INDET": | |
| return "rgba(245,158,11,0.12)" # amber lighter | |
| if state == "LOW": | |
| return "rgba(34,197,94,0.14)" # green | |
| return "rgba(148,163,184,0.12)" # gray | |
| # ============================================================ | |
| # LUNG U-NET (INFERENCE) | |
| # ============================================================ | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_c, out_c): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.Conv2d(in_c, out_c, 3, padding=1), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_c, out_c, 3, padding=1), | |
| nn.BatchNorm2d(out_c), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class LungUNet(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.d1 = DoubleConv(1, 64) | |
| self.d2 = DoubleConv(64, 128) | |
| self.d3 = DoubleConv(128, 256) | |
| self.d4 = DoubleConv(256, 512) | |
| self.pool = nn.MaxPool2d(2) | |
| self.mid = DoubleConv(512, 1024) | |
| self.u4 = nn.ConvTranspose2d(1024, 512, 2, 2) | |
| self.u3 = nn.ConvTranspose2d(512, 256, 2, 2) | |
| self.u2 = nn.ConvTranspose2d(256, 128, 2, 2) | |
| self.u1 = nn.ConvTranspose2d(128, 64, 2, 2) | |
| self.c4 = DoubleConv(1024, 512) | |
| self.c3 = DoubleConv(512, 256) | |
| self.c2 = DoubleConv(256, 128) | |
| self.c1 = DoubleConv(128, 64) | |
| self.out = nn.Conv2d(64, 1, 1) | |
| def forward(self, x): | |
| d1 = self.d1(x) | |
| d2 = self.d2(self.pool(d1)) | |
| d3 = self.d3(self.pool(d2)) | |
| d4 = self.d4(self.pool(d3)) | |
| m = self.mid(self.pool(d4)) | |
| x = self.c4(torch.cat([self.u4(m), d4], 1)) | |
| x = self.c3(torch.cat([self.u3(x), d3], 1)) | |
| x = self.c2(torch.cat([self.u2(x), d2], 1)) | |
| x = self.c1(torch.cat([self.u1(x), d1], 1)) | |
| return self.out(x) | |
| # ============================================================ | |
| # TB MODEL + GRAD-CAM | |
| # ============================================================ | |
| class TBNet(nn.Module): | |
| def __init__(self, backbone="efficientnet_b0"): | |
| super().__init__() | |
| self.backbone = timm.create_model(backbone, pretrained=False, num_classes=0, global_pool="avg") | |
| self.fc = nn.Linear(self.backbone.num_features, 1) | |
| def forward(self, x): | |
| return self.fc(self.backbone(x)).view(-1) | |
| def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device): | |
| sd = torch.load(ckpt_path, map_location=device) | |
| model.load_state_dict(sd, strict=True) | |
| class GradCAM: | |
| def __init__(self, model: nn.Module, target_layer: nn.Module): | |
| self.model = model | |
| self.activ = None | |
| self.grad = None | |
| target_layer.register_forward_hook(self._fwd) | |
| target_layer.register_full_backward_hook(self._bwd) | |
| def _fwd(self, _, __, out): | |
| self.activ = out | |
| def _bwd(self, _, grad_in, grad_out): | |
| self.grad = grad_out[0] | |
| def generate(self, x: torch.Tensor) -> Tuple[np.ndarray, float, float]: | |
| with torch.enable_grad(): | |
| self.model.zero_grad(set_to_none=True) | |
| logits = self.model(x) | |
| score = logits[0] | |
| score.backward() | |
| A = self.activ[0] | |
| G = self.grad[0] | |
| w = G.mean(dim=(1, 2)) | |
| cam = (w[:, None, None] * A).sum(dim=0) | |
| cam = torch.relu(cam) | |
| cam = cam - cam.min() | |
| cam = cam / (cam.max() + 1e-8) | |
| logit = float(logits.detach().cpu()[0].item()) | |
| prob = float(torch.sigmoid(logits.detach().cpu())[0].item()) | |
| return cam.detach().cpu().numpy(), prob, logit | |
| # ============================================================ | |
| # PREPROCESS HELPERS + QUALITY | |
| # ============================================================ | |
| def preprocess_for_lung_unet(gray_u8: np.ndarray) -> torch.Tensor: | |
| g = gray_u8.astype(np.float32) | |
| g = cv2.resize(g, (256, 256), interpolation=cv2.INTER_AREA) | |
| lo, hi = np.percentile(g, (1, 99)) | |
| g = np.clip(g, lo, hi) | |
| g = (g - lo) / (hi - lo + 1e-8) | |
| return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float() | |
| def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray: | |
| gray = gray_u8.astype(np.float32) | |
| lo, hi = np.percentile(gray, (1, 99)) | |
| gray = np.clip(gray, lo, hi) | |
| gray = (gray - lo) / (hi - lo + 1e-8) | |
| return gray | |
| def laplacian_sharpness(gray_u8: np.ndarray) -> float: | |
| g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA) | |
| g = cv2.GaussianBlur(g, (3, 3), 0) | |
| return float(cv2.Laplacian(g, cv2.CV_64F).var()) | |
| def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]: | |
| lo = float((gray_u8 < 10).mean()) | |
| hi = float((gray_u8 > 245).mean()) | |
| return lo, hi | |
| def border_fraction(gray_u8: np.ndarray) -> float: | |
| h, w = gray_u8.shape | |
| b = max(5, int(0.06 * min(h, w))) | |
| top = gray_u8[:b, :] | |
| bot = gray_u8[-b:, :] | |
| left = gray_u8[:, :b] | |
| right = gray_u8[:, -b:] | |
| def frac_border(x): | |
| return float(((x < 15) | (x > 240)).mean()) | |
| return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)])) | |
| def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]: | |
| warnings: List[str] = [] | |
| h, w = gray_u8.shape | |
| score = 100.0 | |
| if min(h, w) < 400: | |
| warnings.append("Low resolution (image may be downsampled).") | |
| score -= 8 | |
| sharp = laplacian_sharpness(gray_u8) | |
| lo_clip, hi_clip = exposure_scores(gray_u8) | |
| border = border_fraction(gray_u8) | |
| likely_phone = (border > 0.35) or (lo_clip > 0.10) or (hi_clip > 0.05) | |
| if likely_phone: | |
| if sharp < 40: | |
| score -= 25 | |
| warnings.append("Blurry / motion blur detected (likely phone capture).") | |
| elif sharp < 80: | |
| score -= 12 | |
| warnings.append("Slight blur detected.") | |
| else: | |
| if sharp < 30: | |
| score -= 8 | |
| warnings.append("Low fine detail (possible downsampling).") | |
| if hi_clip > 0.05: | |
| score -= 15 | |
| warnings.append("Overexposed highlights (washed-out areas).") | |
| if lo_clip > 0.10: | |
| score -= 12 | |
| warnings.append("Underexposed shadows (very dark areas).") | |
| if border > 0.55: | |
| score -= 18 | |
| warnings.append("Large border/margins detected (possible screenshot/phone framing).") | |
| elif border > 0.35: | |
| score -= 10 | |
| warnings.append("Some border/margins detected.") | |
| return float(np.clip(score, 0, 100)), warnings | |
| def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray: | |
| g = gray_u8.copy() | |
| g_blur = cv2.GaussianBlur(g, (5, 5), 0) | |
| _, th = cv2.threshold(g_blur, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU) | |
| if th.mean() > 127: | |
| th = 255 - th | |
| k = max(3, int(0.01 * min(g.shape))) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) | |
| th = cv2.morphologyEx(th, cv2.MORPH_CLOSE, kernel, iterations=2) | |
| contours, _ = cv2.findContours(th, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| if not contours: | |
| return gray_u8 | |
| c = max(contours, key=cv2.contourArea) | |
| x, y, w, h = cv2.boundingRect(c) | |
| H, W = gray_u8.shape | |
| if w * h < 0.20 * (H * W): | |
| return gray_u8 | |
| pad = int(0.03 * min(H, W)) | |
| x1 = max(0, x - pad) | |
| y1 = max(0, y - pad) | |
| x2 = min(W, x + w + pad) | |
| y2 = min(H, y + h + pad) | |
| return gray_u8[y1:y2, x1:x2] | |
| def apply_clahe(gray_u8: np.ndarray) -> np.ndarray: | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| return clahe.apply(gray_u8) | |
| def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray: | |
| sharp = laplacian_sharpness(gray_u8) | |
| lo_clip, _hi_clip = exposure_scores(gray_u8) | |
| border = border_fraction(gray_u8) | |
| g = gray_u8 | |
| if border > 0.35: | |
| cropped = auto_border_crop(g) | |
| if cropped.size >= 0.70 * g.size: | |
| g = cropped | |
| if lo_clip > 0.10 or sharp < 80: | |
| g = apply_clahe(g) | |
| return g | |
| def cam_entropy(cam: np.ndarray) -> float: | |
| cam = cam.astype(np.float32) | |
| cam = cam / (cam.sum() + 1e-8) | |
| return float(-np.sum(cam * np.log(cam + 1e-8))) | |
| def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool: | |
| if quality_score < 55: | |
| return False | |
| if prob_tb < 0.05: | |
| return False | |
| ent = cam_entropy(cam_up) | |
| return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5) | |
| def confidence_band(prob_tb: float, quality_score: float, diffuse: bool): | |
| if prob_tb < 0.01 and quality_score >= 45: | |
| return ("GREEN", "✅ Very low TB signal detected.") | |
| if quality_score < 55: | |
| return ("YELLOW", "⚠️ Image quality is low; treat as indeterminate.") | |
| if diffuse: | |
| return ("YELLOW", "⚠️ Attention is non-focal; treat as indeterminate.") | |
| if prob_tb >= TBNET_SCREEN_THR: | |
| return ("YELLOW", "⚠️ Screening-positive range; review recommended.") | |
| return ("GREEN", "✅ No strong TB signal detected.") | |
| def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray: | |
| base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB) | |
| mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0) | |
| def fill_holes(binary_u8: np.ndarray) -> np.ndarray: | |
| m = (binary_u8 * 255).astype(np.uint8) | |
| h, w = m.shape | |
| flood = m.copy() | |
| mask = np.zeros((h + 2, w + 2), np.uint8) | |
| cv2.floodFill(flood, mask, (0, 0), 255) | |
| holes = cv2.bitwise_not(flood) | |
| filled = cv2.bitwise_or(m, holes) | |
| return (filled > 0).astype(np.uint8) | |
| def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray: | |
| m = (binary_u8 > 0).astype(np.uint8) | |
| n, labels = cv2.connectedComponents(m) | |
| if n <= 1: | |
| return m | |
| areas = [] | |
| for i in range(1, n): | |
| areas.append((i, int((labels == i).sum()))) | |
| areas.sort(key=lambda x: x[1], reverse=True) | |
| keep_ids = set([i for i, _ in areas[:k]]) | |
| out = np.zeros_like(m) | |
| for i in keep_ids: | |
| out[labels == i] = 1 | |
| return out | |
| def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]: | |
| m = (mask_full_u8 > 0).astype(np.uint8) | |
| n, labels = cv2.connectedComponents(m) | |
| warns = [] | |
| if n <= 2: | |
| warns.append("Only one lung region detected (possible crop/segmentation failure).") | |
| return warns | |
| areas = [] | |
| for i in range(1, n): | |
| areas.append(int((labels == i).sum())) | |
| areas.sort(reverse=True) | |
| total = int(m.sum()) | |
| top1 = areas[0] | |
| top2 = areas[1] if len(areas) > 1 else 0 | |
| if total > 0 and top1 / total > 0.80: | |
| warns.append("Mask dominated by a single region (possible cropped/partial lung view).") | |
| border = np.concatenate([m[0, :], m[-1, :], m[:, 0], m[:, -1]]) | |
| if border.mean() > 0.05: | |
| warns.append("Lung mask touches image border (possible cropped/non-standard CXR).") | |
| if total > 0 and (top1 + top2) / total < 0.90: | |
| warns.append("Mask appears fragmented (may reduce reliability).") | |
| return warns | |
| def recommendation_for_band(band: Optional[str]) -> str: | |
| if band in (None, "YELLOW"): | |
| return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)." | |
| if band == "RED": | |
| return "✅ Recommendation: **Urgent** clinician/radiologist review + microbiological confirmation (CBNAAT/GeneXpert, sputum)." | |
| return "✅ Recommendation: If symptoms/risk factors exist, clinician/radiologist correlation is advised." | |
| # ============================================================ | |
| # CONSENSUS LOGIC (TBNet vs RADIO) | |
| # ============================================================ | |
| def tbnet_state(tb_prob: float, tb_band: str) -> str: | |
| if tb_band == "RED": | |
| return "TB+" | |
| if tb_band == "YELLOW": | |
| return "INDET" | |
| if tb_prob >= TBNET_SCREEN_THR: | |
| return "SCREEN+" | |
| return "LOW" | |
| def radio_state_from_prob(radio_prob: float) -> str: | |
| if radio_prob >= RADIO_THR_RED: | |
| return "TB+" | |
| if radio_prob >= RADIO_THR_SCREEN: | |
| return "SCREEN+" | |
| return "LOW" | |
| def build_consensus( | |
| tb_prob: Optional[float], | |
| tb_band: Optional[str], | |
| radio_raw: Optional[float], | |
| radio_masked: Optional[float], | |
| radio_band: Optional[str] = None | |
| ) -> Tuple[str, str, str, str]: | |
| """ | |
| Returns: | |
| consensus_label, consensus_detail, tb_state, radio_state | |
| """ | |
| if tb_prob is None or tb_band is None: | |
| return ( | |
| "N/A", | |
| f"{MODEL_NAME_TBNET} unavailable (lung segmentation failed / fail-safe).", | |
| "N/A", | |
| "N/A", | |
| ) | |
| if radio_masked is not None: | |
| radio_primary = radio_masked | |
| radio_used = "MASKED" | |
| else: | |
| radio_primary = radio_raw | |
| radio_used = "RAW" | |
| tb_state = tbnet_state(tb_prob, tb_band) | |
| if radio_primary is None: | |
| return ( | |
| "TBNet only", | |
| f"{MODEL_NAME_RADIO} unavailable → TBNet state={tb_state}, p={tb_prob:.4f} (band={tb_band}).", | |
| tb_state, | |
| "N/A", | |
| ) | |
| radio_state = radio_state_from_prob(radio_primary) | |
| rb = f" (RADIO band={radio_band})" if radio_band else "" | |
| if tb_state == "INDET" and radio_state == "INDET": | |
| return ( | |
| "AGREE: INDET", | |
| f"Both models are indeterminate. TBNet p={tb_prob:.4f} (band={tb_band}), RADIO({radio_used})={radio_primary:.4f}{rb}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| if tb_state == "INDET" and radio_state in ("LOW", "SCREEN+", "TB+"): | |
| return ( | |
| "MIXED/INDET", | |
| f"TBNet is indeterminate (band={tb_band}); RADIO suggests {radio_state} ({radio_used})={radio_primary:.4f}{rb}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| if radio_state == "INDET" and tb_state in ("LOW", "SCREEN+", "TB+"): | |
| return ( | |
| "MIXED/INDET", | |
| f"RADIO is indeterminate; TBNet suggests {tb_state} (band={tb_band}) p={tb_prob:.4f}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| if tb_state == radio_state: | |
| return ( | |
| f"AGREE: {tb_state}", | |
| f"Both models agree: {tb_state}. TBNet p={tb_prob:.4f}, RADIO({radio_used})={radio_primary:.4f}{rb}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| if (tb_state in ("SCREEN+", "TB+") and radio_state == "LOW") or (radio_state in ("SCREEN+", "TB+") and tb_state == "LOW"): | |
| return ( | |
| "DISAGREE", | |
| f"Strong disagreement: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| return ( | |
| "MIXED", | |
| f"Mixed: TBNet={tb_state} (band={tb_band}, p={tb_prob:.4f}) vs RADIO={radio_state} ({radio_used})={radio_primary:.4f}{rb}.", | |
| tb_state, | |
| radio_state, | |
| ) | |
| # ============================================================ | |
| # TB + LUNG MODEL BUNDLE (cached) | |
| # ============================================================ | |
| class ModelBundle: | |
| def __init__(self): | |
| self.device = DEVICE | |
| self.tb = None | |
| self.cammer = None | |
| self.lung = None | |
| self.tb_path = None | |
| self.lung_path = None | |
| self.backbone = "efficientnet_b0" | |
| self.tfm = transforms.Compose([ | |
| transforms.ToPILImage(), | |
| transforms.Resize((224, 224)), | |
| transforms.ToTensor(), | |
| transforms.Normalize( | |
| mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225] | |
| ), | |
| ]) | |
| def load(self, tb_weights: str, lung_weights: str, backbone: str = "efficientnet_b0"): | |
| if (self.tb_path != tb_weights) or (self.tb is None) or (self.cammer is None) or (self.backbone != backbone): | |
| tb = TBNet(backbone=backbone).to(self.device) | |
| load_tb_weights(tb, tb_weights, self.device) | |
| tb.eval() | |
| self.tb = tb | |
| self.cammer = GradCAM(tb, tb.backbone.conv_head) | |
| self.tb_path = tb_weights | |
| self.backbone = backbone | |
| if (self.lung_path != lung_weights) or (self.lung is None): | |
| lung = LungUNet().to(self.device) | |
| lung.load_state_dict(torch.load(lung_weights, map_location=self.device)) | |
| lung.eval() | |
| self.lung = lung | |
| self.lung_path = lung_weights | |
| BUNDLE = ModelBundle() | |
| # ============================================================ | |
| # RADIO BUNDLE (cached) | |
| # ============================================================ | |
| class RadioMLPHead(nn.Module): | |
| def __init__(self, dim: int, hidden: int = 512, dropout: float = 0.2): | |
| super().__init__() | |
| self.net = nn.Sequential( | |
| nn.LayerNorm(dim), | |
| nn.Dropout(dropout), | |
| nn.Linear(dim, hidden), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(hidden, 1), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.net(x).squeeze(1) | |
| class RadioBundle: | |
| def __init__(self): | |
| self.loaded = False | |
| self.processor = None | |
| self.radio = None | |
| self.raw_head = None | |
| self.masked_head = None | |
| self.summary_dim = None | |
| self.device_str = None | |
| def load(self, device: torch.device): | |
| dev_str = str(device) | |
| if self.loaded and self.device_str == dev_str: | |
| return | |
| if not os.path.exists(RADIO_RAW_HEAD_PATH): | |
| raise FileNotFoundError(f"RADIO raw head not found: {RADIO_RAW_HEAD_PATH}") | |
| if not os.path.exists(RADIO_MASKED_HEAD_PATH): | |
| raise FileNotFoundError(f"RADIO masked head not found: {RADIO_MASKED_HEAD_PATH}") | |
| self.processor = CLIPImageProcessor.from_pretrained(RADIO_HF_REPO, revision=RADIO_REVISION) | |
| dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| self.radio = AutoModel.from_pretrained( | |
| RADIO_HF_REPO, | |
| revision=RADIO_REVISION, | |
| trust_remote_code=True, | |
| dtype=dtype, | |
| ).eval().to(device) | |
| with torch.no_grad(): | |
| dummy = torch.zeros((1, 3, RADIO_IMG_SIZE, RADIO_IMG_SIZE), device=device, dtype=dtype) | |
| summary, _ = self.radio(dummy) | |
| self.summary_dim = int(summary.shape[-1]) | |
| def _load_head(path: str) -> nn.Module: | |
| ckpt = torch.load(path, map_location="cpu") | |
| dim = int(ckpt.get("dim", self.summary_dim)) | |
| head = RadioMLPHead(dim=dim).to(device).eval() | |
| head.load_state_dict(ckpt["head_state"], strict=True) | |
| return head | |
| self.raw_head = _load_head(RADIO_RAW_HEAD_PATH) | |
| self.masked_head = _load_head(RADIO_MASKED_HEAD_PATH) | |
| self.device_str = dev_str | |
| self.loaded = True | |
| RADIO_BUNDLE = RadioBundle() | |
| def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray: | |
| ht = in_h // patch_size | |
| wt = in_w // patch_size | |
| feat = rearrange(spatial_tokens, "b (h w) d -> b d h w", h=ht, w=wt) | |
| energy = torch.sqrt(torch.clamp((feat ** 2).sum(dim=1), min=1e-8))[0] | |
| energy = (energy - energy.min()) / (energy.max() - energy.min() + 1e-8) | |
| hm = energy.detach().float().cpu().numpy().astype(np.float32) | |
| hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR) | |
| return np.array(hm_img, dtype=np.float32) / 255.0 | |
| def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray: | |
| img = rgb_u8.astype(np.float32) / 255.0 | |
| hm = np.clip(heatmap01, 0, 1).astype(np.float32) | |
| out = img.copy() | |
| out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1) | |
| return (out * 255).astype(np.uint8) | |
| def radio_predict_from_arrays( | |
| gray_vis_u8: np.ndarray, | |
| lung_mask_u8: np.ndarray, | |
| coverage: float, | |
| device: torch.device, | |
| gate_threshold: float | |
| ) -> Dict[str, Any]: | |
| RADIO_BUNDLE.load(device=device) | |
| dtype = torch.float16 if device.type == "cuda" else torch.float32 | |
| raw_rgb = cv2.cvtColor(gray_vis_u8, cv2.COLOR_GRAY2RGB) | |
| px = RADIO_BUNDLE.processor( | |
| images=Image.fromarray(raw_rgb), | |
| return_tensors="pt", | |
| do_resize=True, | |
| size={"shortest_edge": RADIO_IMG_SIZE}, | |
| do_center_crop=True, | |
| ).pixel_values.to(device).to(dtype) | |
| summary, spatial = RADIO_BUNDLE.radio(px) | |
| logit_raw = RADIO_BUNDLE.raw_head(summary) | |
| prob_raw = float(torch.sigmoid(logit_raw)[0].item()) | |
| hm_raw = radio_heatmap_from_spatial(spatial, px.shape[-2], px.shape[-1], RADIO_PATCH_SIZE) | |
| raw_overlay = radio_overlay_heatmap( | |
| cv2.resize(raw_rgb, (px.shape[-1], px.shape[-2])), | |
| hm_raw, | |
| alpha=0.35 | |
| ) | |
| masked_prob = None | |
| masked_overlay = None | |
| masked_ran = False | |
| if lung_mask_u8 is not None and coverage >= RADIO_MASKED_MIN_COV and coverage >= gate_threshold: | |
| masked_ran = True | |
| masked_u8 = (gray_vis_u8 * lung_mask_u8).astype(np.uint8) | |
| masked_rgb = cv2.cvtColor(masked_u8, cv2.COLOR_GRAY2RGB) | |
| pxm = RADIO_BUNDLE.processor( | |
| images=Image.fromarray(masked_rgb), | |
| return_tensors="pt", | |
| do_resize=True, | |
| size={"shortest_edge": RADIO_IMG_SIZE}, | |
| do_center_crop=True, | |
| ).pixel_values.to(device).to(dtype) | |
| summary_m, spatial_m = RADIO_BUNDLE.radio(pxm) | |
| logit_m = RADIO_BUNDLE.masked_head(summary_m) | |
| masked_prob = float(torch.sigmoid(logit_m)[0].item()) | |
| hm_m = radio_heatmap_from_spatial(spatial_m, pxm.shape[-2], pxm.shape[-1], RADIO_PATCH_SIZE) | |
| masked_overlay = radio_overlay_heatmap( | |
| cv2.resize(masked_rgb, (pxm.shape[-1], pxm.shape[-2])), | |
| hm_m, | |
| alpha=0.35 | |
| ) | |
| prob_primary = masked_prob if masked_prob is not None else prob_raw | |
| if prob_primary >= RADIO_THR_RED: | |
| band = "RED" | |
| pred = "HIGH TB-LIKE PATTERN SCORE (RADIO)" | |
| elif prob_primary >= RADIO_THR_SCREEN: | |
| band = "YELLOW" | |
| pred = "SCREEN-POSITIVE RANGE (RADIO)" | |
| else: | |
| band = "GREEN" | |
| pred = "LOW TB-LIKE PATTERN SCORE (RADIO)" | |
| return { | |
| "prob_raw": prob_raw, | |
| "prob_primary": prob_primary, | |
| "pred": pred, | |
| "band": band, | |
| "raw_overlay": raw_overlay, | |
| "masked_prob": masked_prob, | |
| "masked_overlay": masked_overlay, | |
| "masked_ran": masked_ran, | |
| "gate_threshold": float(gate_threshold), | |
| } | |
| # ============================================================ | |
| # TB CORE ANALYSIS | |
| # ============================================================ | |
| def analyze_one_image( | |
| img_bgr: np.ndarray, | |
| tb_weights: str, | |
| lung_weights: str, | |
| backbone: str, | |
| threshold: float, | |
| phone_mode: bool, | |
| img_size: int = 224, | |
| fail_cov: float = FAIL_COV, | |
| warn_cov: float = WARN_COV, | |
| ) -> Dict[str, Any]: | |
| BUNDLE.load(tb_weights, lung_weights, backbone) | |
| device = BUNDLE.device | |
| gray = img_bgr if img_bgr.ndim == 2 else cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY) | |
| q_score, q_warn = phone_quality_report(gray) | |
| gray_vis = phone_preprocess(gray) if phone_mode else gray | |
| if gray_vis.dtype != np.uint8: | |
| gray_vis = np.clip(gray_vis, 0, 255).astype(np.uint8) | |
| with torch.no_grad(): | |
| x_lung = preprocess_for_lung_unet(gray_vis).to(device) | |
| mask_logits = BUNDLE.lung(x_lung) | |
| mask256 = torch.sigmoid(mask_logits)[0, 0].cpu().numpy() | |
| mask256_bin = (mask256 > 0.5).astype(np.uint8) | |
| mask256_bin = keep_top_k_components(mask256_bin, k=2) | |
| k = max(3, int(0.02 * 256)) | |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)) | |
| mask256_bin = cv2.morphologyEx(mask256_bin, cv2.MORPH_CLOSE, kernel, iterations=1) | |
| mask256_bin = fill_holes(mask256_bin) | |
| coverage = float(mask256_bin.mean()) | |
| mask_full = cv2.resize(mask256_bin, (gray_vis.shape[1], gray_vis.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| if coverage < fail_cov: | |
| overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB) | |
| return { | |
| "prob": None, | |
| "logit": None, | |
| "pred": "INDETERMINATE", | |
| "band": "YELLOW", | |
| "band_text": "Lung segmentation failed. TB scoring disabled (fail-safe).", | |
| "quality_score": float(q_score), | |
| "diffuse_risk": False, | |
| "warnings": ( | |
| ["Lung segmentation failed (<10% lung area).", f"Lung coverage: {coverage*100:.1f}%"] | |
| + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else []) | |
| + q_warn | |
| ), | |
| "lung_coverage": coverage, | |
| "orig_gray": gray, | |
| "vis_gray": gray_vis, | |
| "masked_gray": None, | |
| "proc_gray": None, | |
| "lung_mask": mask_full, | |
| "mask_overlay": make_mask_overlay(gray_vis, mask_full), | |
| "overlay": overlay_rgb, | |
| "overlay_clean": overlay_rgb, | |
| } | |
| sanity = mask_sanity_warnings(mask_full.astype(np.uint8)) | |
| if FAILSAFE_ON_BAD_MASK and sanity: | |
| overlay_rgb = cv2.cvtColor(cv2.resize(gray_vis, (img_size, img_size)), cv2.COLOR_GRAY2RGB) | |
| return { | |
| "prob": None, | |
| "logit": None, | |
| "pred": "INDETERMINATE", | |
| "band": "YELLOW", | |
| "band_text": "Non-standard/cropped view or unreliable lung segmentation. TB scoring disabled (fail-safe).", | |
| "quality_score": float(q_score), | |
| "diffuse_risk": False, | |
| "warnings": ( | |
| sanity | |
| + [f"Lung coverage: {coverage*100:.1f}%"] | |
| + (["Phone/WhatsApp mode enabled; artifacts possible."] if phone_mode else []) | |
| + q_warn | |
| ), | |
| "lung_coverage": coverage, | |
| "orig_gray": gray, | |
| "vis_gray": gray_vis, | |
| "masked_gray": None, | |
| "proc_gray": None, | |
| "lung_mask": mask_full, | |
| "mask_overlay": make_mask_overlay(gray_vis, mask_full), | |
| "overlay": overlay_rgb, | |
| "overlay_clean": overlay_rgb, | |
| } | |
| masked = (gray_vis * mask_full).astype(np.uint8) | |
| masked_f01 = tb_training_preprocess(masked) | |
| masked_u8 = (masked_f01 * 255).astype(np.uint8) | |
| masked_u8_rs = cv2.resize(masked_u8, (img_size, img_size), interpolation=cv2.INTER_AREA) | |
| rgb = cv2.cvtColor(masked_u8_rs, cv2.COLOR_GRAY2RGB) | |
| x = BUNDLE.tfm(rgb).unsqueeze(0).to(device) | |
| cam, prob_tb, logit = BUNDLE.cammer.generate(x) | |
| cam_u8 = (np.clip(cam, 0, 1) * 255).astype(np.uint8) | |
| cam_u8 = cv2.resize(cam_u8, (img_size, img_size), interpolation=cv2.INTER_CUBIC) | |
| cam_up = cam_u8.astype(np.float32) / 255.0 | |
| diffuse = detect_diffuse_risk(prob_tb, cam_up, q_score) | |
| band_base, _ = confidence_band(prob_tb, q_score, diffuse) | |
| allow_red = (prob_tb >= 0.70 and q_score >= 55 and not diffuse and coverage >= warn_cov) | |
| band = "RED" if allow_red else band_base | |
| pred = REPORT_LABELS[band]["title"] | |
| band_text = REPORT_LABELS[band]["summary"] | |
| heat = cv2.applyColorMap((cam_up * 255).astype(np.uint8), cv2.COLORMAP_JET) | |
| overlay_clean = cv2.addWeighted(rgb, 0.65, heat, 0.35, 0) | |
| overlay_annotated = overlay_clean.copy() | |
| text1 = f"{band}: {pred}" | |
| text2 = f"TB p={prob_tb:.3f} | Quality={q_score:.0f}/100 | LungCov={coverage*100:.1f}%" | |
| cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (255, 255, 255), 2) | |
| cv2.putText(overlay_annotated, text1, (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.52, (0, 0, 0), 1) | |
| cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (255, 255, 255), 2) | |
| cv2.putText(overlay_annotated, text2, (8, 42), cv2.FONT_HERSHEY_SIMPLEX, 0.50, (0, 0, 0), 1) | |
| warnings = [] | |
| if phone_mode: | |
| warnings.append("Phone/WhatsApp mode enabled; artifacts possible.") | |
| if q_score < 55: | |
| warnings.append("Suboptimal image quality limits AI reliability.") | |
| if coverage < warn_cov: | |
| warnings.append(f"Partial lung segmentation ({coverage*100:.1f}% coverage).") | |
| if diffuse: | |
| warnings.append("Diffuse, non-focal AI attention pattern; TB-specific features not identified.") | |
| warnings.extend(q_warn) | |
| return { | |
| "prob": float(prob_tb), | |
| "logit": float(logit), | |
| "pred": pred, | |
| "band": band, | |
| "band_text": band_text, | |
| "quality_score": float(q_score), | |
| "diffuse_risk": bool(diffuse), | |
| "warnings": warnings, | |
| "lung_coverage": coverage, | |
| "orig_gray": gray, | |
| "vis_gray": gray_vis, | |
| "masked_gray": masked, | |
| "proc_gray": masked_u8_rs, | |
| "lung_mask": mask_full, | |
| "mask_overlay": make_mask_overlay(gray_vis, mask_full), | |
| "overlay": overlay_annotated, | |
| "overlay_clean": overlay_clean, | |
| } | |
| # ============================================================ | |
| # GRADIO CALLBACK | |
| # ============================================================ | |
| def run_analysis( | |
| files: List[gr.File], | |
| tb_weights: str, | |
| lung_weights: str, | |
| backbone: str, | |
| threshold: float, | |
| phone_mode: bool, | |
| use_radio: bool, | |
| radio_gate: float, | |
| ): | |
| if not files: | |
| return "Please upload at least one image.", [], "", CLINICAL_DISCLAIMER, "Please upload at least one image." | |
| if not os.path.exists(tb_weights): | |
| msg = f"TB weights not found: {tb_weights}" | |
| return msg, [], "", CLINICAL_DISCLAIMER, msg | |
| if not os.path.exists(lung_weights): | |
| msg = f"Lung U-Net weights not found: {lung_weights}" | |
| return msg, [], "", CLINICAL_DISCLAIMER, msg | |
| summary_md: List[str] = [] | |
| gallery_items = [] | |
| details_md: List[str] = [] | |
| # Top banner | |
| summary_md.append(f""" | |
| <div style="border:1px solid rgba(255,255,255,0.10); border-radius:14px; padding:12px; margin:10px 0;"> | |
| <div style="font-size:18px; font-weight:900;">Results</div> | |
| <div style="margin-top:6px; opacity:0.92;"> | |
| Each image shows: <b>{MODEL_NAME_TBNET}</b> result, <b>{MODEL_NAME_RADIO}</b> result (optional), then a <b>final consensus</b>. | |
| <br/> | |
| <b>Key:</b> ✅ LOW | ⚠️ INDETERMINATE | ⚠️ SCREEN+ | 🚩 TB+ | |
| </div> | |
| </div> | |
| """) | |
| for f in files: | |
| path = f.name if hasattr(f, "name") else str(f) | |
| name = os.path.basename(path) | |
| img = cv2.imread(path, cv2.IMREAD_COLOR) | |
| if img is None: | |
| summary_md.append(f""" | |
| <div style="border:1px solid rgba(255,255,255,0.12); border-radius:12px; padding:12px; margin:10px 0;"> | |
| <div style="font-weight:800;">{html_escape(name)}</div> | |
| <div style="margin-top:6px;">⚠️ Could not read this image. Please re-export it as PNG/JPG.</div> | |
| </div> | |
| """) | |
| continue | |
| out = analyze_one_image( | |
| img_bgr=img, | |
| tb_weights=tb_weights, | |
| lung_weights=lung_weights, | |
| backbone=backbone, | |
| threshold=threshold, | |
| phone_mode=phone_mode, | |
| img_size=224, | |
| ) | |
| # RADIO (optional) | |
| radio_text_long = f"{MODEL_NAME_RADIO} disabled." | |
| radio_raw_overlay = None | |
| radio_masked_overlay = None | |
| radio_raw_val: Optional[float] = None | |
| radio_masked_val: Optional[float] = None | |
| radio_primary_val: Optional[float] = None | |
| radio_band: Optional[str] = None | |
| radio_result_short = "Disabled" | |
| radio_masked_ran = False | |
| if use_radio and out["prob"] is not None: | |
| try: | |
| r = radio_predict_from_arrays( | |
| gray_vis_u8=out["vis_gray"], | |
| lung_mask_u8=out["lung_mask"].astype(np.uint8), | |
| coverage=float(out["lung_coverage"]), | |
| device=BUNDLE.device, | |
| gate_threshold=float(radio_gate), | |
| ) | |
| radio_raw_val = float(r["prob_raw"]) | |
| radio_primary_val = float(r["prob_primary"]) | |
| radio_masked_val = None if r["masked_prob"] is None else float(r["masked_prob"]) | |
| radio_band = str(r["band"]) | |
| radio_result_short = str(r["pred"]) | |
| radio_masked_ran = bool(r["masked_ran"]) | |
| radio_text_long = ( | |
| f"**{MODEL_NAME_RADIO}:** {r['pred']} \n" | |
| f"- PRIMARY={radio_primary_val:.4f} \n" | |
| f"- RAW={radio_raw_val:.4f} \n" | |
| + (f"- MASKED={radio_masked_val:.4f} \n" if radio_masked_val is not None else "- MASKED=Not run \n") | |
| + (f"- Band={radio_band} \n" if radio_band else "") | |
| + (f"- Masked gate={float(radio_gate):.2f} | LungCov={float(out['lung_coverage']):.2f} | Masked ran={radio_masked_ran}\n") | |
| ) | |
| radio_raw_overlay = r["raw_overlay"] | |
| radio_masked_overlay = r["masked_overlay"] | |
| except Exception as e: | |
| radio_text_long = f"{MODEL_NAME_RADIO} error: {type(e).__name__}: {e}" | |
| radio_result_short = "Error" | |
| radio_raw_val = None | |
| radio_masked_val = None | |
| radio_primary_val = None | |
| radio_band = None | |
| radio_masked_ran = False | |
| # Consensus | |
| consensus_label, consensus_detail, tb_state, radio_state = build_consensus( | |
| tb_prob=out["prob"], | |
| tb_band=out["band"], | |
| radio_raw=radio_raw_val, | |
| radio_masked=radio_masked_val, | |
| radio_band=radio_band, | |
| ) | |
| tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}" | |
| tb_label = out.get("pred", "INDETERMINATE") | |
| q = float(out.get("quality_score", 0.0)) | |
| cov = float(out.get("lung_coverage", 0.0)) | |
| attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized" | |
| warns = out.get("warnings", []) | |
| top_warns = warns[:3] if warns else [] | |
| top_warn_line = " • ".join([html_escape(w) for w in top_warns]) if top_warns else "None" | |
| radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}" | |
| radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}" | |
| radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}" | |
| if consensus_label == "DISAGREE": | |
| next_step = "✅ Next step: Treat as <b>indeterminate</b> → radiologist review + microbiology if clinically suspected." | |
| elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"): | |
| next_step = "✅ Next step: Prompt clinician/radiologist review; consider microbiological confirmation if clinically suspected." | |
| elif consensus_label in ("AGREE: LOW",): | |
| next_step = "✅ Next step: If symptoms/risk factors exist, still correlate clinically and consider further testing." | |
| else: | |
| next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present." | |
| state_badge_tb = f""" | |
| <span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(tb_state)}; font-weight:800;"> | |
| {pretty_state(tb_state)} | |
| </span>""" | |
| state_badge_radio = f""" | |
| <span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(radio_state)}; font-weight:800;"> | |
| {pretty_state(radio_state)} | |
| </span>""" | |
| tb_card = f""" | |
| <div style="border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:12px; margin:10px 0;"> | |
| <div style="font-size:16px; font-weight:900; margin-bottom:6px;">{html_escape(name)} — {MODEL_NAME_TBNET}</div> | |
| <div style="margin-bottom:6px;"><b>State:</b> {state_badge_tb}</div> | |
| <div><b>Result:</b> {html_escape(tb_label)} <span style="opacity:0.9;">(p={tb_prob_line})</span></div> | |
| <div style="margin-top:6px; opacity:0.92;"> | |
| <b>Reliability:</b> Quality <b>{q:.0f}/100</b> | Lung mask <b>{cov*100:.1f}%</b> | Attention <b>{attention}</b> | |
| </div> | |
| <div style="margin-top:6px; opacity:0.92;"><b>Top notes:</b> {top_warn_line}</div> | |
| <div style="margin-top:10px; padding:10px 12px; border-left:6px solid rgba(96,165,250,0.9); background: rgba(96,165,250,0.10); border-radius:12px;"> | |
| {recommendation_for_band(out.get("band"))} | |
| </div> | |
| </div> | |
| """ | |
| if not use_radio: | |
| radio_card = f""" | |
| <div style="border:1px solid rgba(255,255,255,0.10); border-radius:14px; padding:12px; margin:10px 0; opacity:0.9;"> | |
| <div style="font-size:16px; font-weight:900; margin-bottom:6px;">{html_escape(name)} — {MODEL_NAME_RADIO}</div> | |
| <div>RADIO is disabled. Enable it to view its independent output and heatmaps.</div> | |
| </div> | |
| """ | |
| else: | |
| gate_info = f"Masked gate={float(radio_gate):.2f} | LungCov={cov:.2f} | Masked ran={radio_masked_ran}" | |
| radio_card = f""" | |
| <div style="border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:12px; margin:10px 0;"> | |
| <div style="font-size:16px; font-weight:900; margin-bottom:6px;">{html_escape(name)} — {MODEL_NAME_RADIO}</div> | |
| <div style="margin-bottom:6px;"><b>State:</b> {state_badge_radio}</div> | |
| <div><b>Result:</b> {html_escape(radio_result_short)}</div> | |
| <div style="margin-top:6px; opacity:0.92;"> | |
| <b>Scores:</b> PRIMARY <b>{radio_primary_line}</b> | RAW <b>{radio_raw_line}</b> | MASKED <b>{radio_masked_line}</b> | |
| </div> | |
| <div style="margin-top:6px; opacity:0.85;">{html_escape(gate_info)}</div> | |
| </div> | |
| """ | |
| consensus_card = f""" | |
| <div style="border:1px solid rgba(255,255,255,0.14); border-radius:16px; padding:12px; margin:10px 0;"> | |
| <div style="font-size:16px; font-weight:950; margin-bottom:6px;">{html_escape(name)} — Final consensus</div> | |
| <div style="margin-bottom:6px;"> | |
| <b>Comparison:</b> TBNet {pretty_state(tb_state)} vs RADIO {pretty_state(radio_state)} | |
| </div> | |
| <div style="margin-bottom:6px;"><b>Consensus:</b> {html_escape(consensus_label)}</div> | |
| <div style="opacity:0.9; margin-bottom:10px;">{html_escape(consensus_detail)}</div> | |
| <div style="padding:10px 12px; border-left:6px solid rgba(245,158,11,0.95); background: rgba(245,158,11,0.12); border-radius:12px;"> | |
| {next_step} | |
| </div> | |
| </div> | |
| """ | |
| summary_md.append(tb_card) | |
| summary_md.append(radio_card) | |
| summary_md.append(consensus_card) | |
| # Gallery | |
| orig_rgb = cv2.cvtColor(cv2.resize(out["orig_gray"], (512, 512)), cv2.COLOR_GRAY2RGB) | |
| vis_rgb = cv2.cvtColor(cv2.resize(out["vis_gray"], (512, 512)), cv2.COLOR_GRAY2RGB) | |
| mask_overlay = cv2.resize(out["mask_overlay"], (512, 512)) | |
| overlay_big = cv2.resize(out["overlay"], (512, 512)) | |
| gallery_items.append((orig_rgb, f"{name} • ORIGINAL")) | |
| gallery_items.append((vis_rgb, f"{name} • PHONE-PROC" if phone_mode else f"{name} • INPUT")) | |
| gallery_items.append((mask_overlay, f"{name} • Lung mask overlay")) | |
| if out["proc_gray"] is not None: | |
| proc_rgb = cv2.cvtColor(cv2.resize(out["proc_gray"], (512, 512)), cv2.COLOR_GRAY2RGB) | |
| gallery_items.append((proc_rgb, f"{name} • Masked model input (224x224)")) | |
| gallery_items.append((overlay_big, f"{name} • TBNet Grad-CAM overlay")) | |
| if radio_raw_overlay is not None: | |
| gallery_items.append((cv2.resize(radio_raw_overlay, (512, 512)), f"{name} • RADIO RAW heatmap")) | |
| if radio_masked_overlay is not None: | |
| gallery_items.append((cv2.resize(radio_masked_overlay, (512, 512)), f"{name} • RADIO MASKED heatmap")) | |
| # Details (collapsible per image) — FIXED (no welcome HTML here) | |
| warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None" | |
| details_md.append( | |
| f""" | |
| <details> | |
| <summary><b>{html_escape(name)}</b> — detailed report</summary> | |
| **TBNet** | |
| - Result: **{html_escape(tb_label)}** | |
| - Probability: {tb_prob_line} | |
| - Band: {out.get("band", "YELLOW")} | |
| - Quality: {q:.0f}/100 | |
| - Lung mask coverage: {cov*100:.1f}% | |
| - Attention: {attention} | |
| **Consensus** | |
| - TBNet state: {pretty_state(tb_state)} | |
| - RADIO state: {pretty_state(radio_state)} | |
| - Consensus label: **{html_escape(consensus_label)}** | |
| - Detail: {html_escape(consensus_detail)} | |
| **Warnings** | |
| {warn_txt} | |
| **RADIO (full)** | |
| {radio_text_long} | |
| </details> | |
| --- | |
| """ | |
| ) | |
| return "\n".join(summary_md), gallery_items, "\n".join(details_md), CLINICAL_DISCLAIMER, "Done." | |
| # ============================================================ | |
| # UI (HF Spaces Welcome Screen + Main App) | |
| # ============================================================ | |
| def build_ui(): | |
| css = """ | |
| .title {font-size: 28px; font-weight: 900; margin-bottom: 6px;} | |
| .subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;} | |
| .warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;} | |
| .legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;} | |
| .card {border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:14px; margin:10px 0;} | |
| """ | |
| with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo: | |
| # --------------------------- | |
| # Welcome screen (shown first) | |
| # --------------------------- | |
| with gr.Column(visible=True) as welcome_screen: | |
| gr.Markdown('<div class="title">Welcome — TB X-ray Assistant (HF Spaces)</div>') | |
| gr.HTML(WELCOME_HTML) | |
| continue_btn = gr.Button("Continue →", variant="primary") | |
| # --------------------------- | |
| # Main app UI (hidden initially) | |
| # --------------------------- | |
| with gr.Column(visible=False) as main_app: | |
| gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>') | |
| gr.Markdown( | |
| f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • " | |
| f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Clear per-model results + consensus</div>" | |
| ) | |
| gr.Markdown( | |
| "<div class='warnbox'><b>Clinical disclaimer:</b> Decision support only (not diagnostic). " | |
| "If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.</div>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### Model settings") | |
| tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS) | |
| lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS) | |
| backbone = gr.Dropdown( | |
| choices=["efficientnet_b0"], | |
| value="efficientnet_b0", | |
| label="TBNet backbone" | |
| ) | |
| threshold = gr.Slider( | |
| 0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01, | |
| label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}" | |
| ) | |
| phone_mode = gr.Checkbox( | |
| value=False, | |
| label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)" | |
| ) | |
| gr.Markdown( | |
| "<div class='subtitle'>Enable for WhatsApp images, phone photos, or screenshots. " | |
| "Leave off for clean digital exports.</div>" | |
| ) | |
| use_radio = gr.Checkbox(value=True, label=f"Enable {MODEL_NAME_RADIO}") | |
| radio_gate = gr.Slider( | |
| 0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01, | |
| label="RADIO masked gate (run masked head if lung coverage ≥ gate)" | |
| ) | |
| gr.Markdown( | |
| "<div class='warnbox'><b>Fail-safe:</b> If lung segmentation is too small or looks unreliable, " | |
| f"{MODEL_NAME_TBNET} scoring is disabled to avoid unsafe outputs.</div>" | |
| ) | |
| gr.Markdown( | |
| f"<div class='subtitle'>Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})</div>" | |
| ) | |
| back_btn = gr.Button("← Back to Welcome", variant="secondary") | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### Upload images") | |
| files = gr.Files( | |
| label="Upload one or multiple X-ray images", | |
| file_types=[".png", ".jpg", ".jpeg", ".bmp"] | |
| ) | |
| run_btn = gr.Button("Run Analysis", variant="primary") | |
| status = gr.Textbox(label="Status", value="Ready.", interactive=False) | |
| gr.Markdown(""" | |
| <div class='legend'><b>Gallery legend:</b><br/> | |
| 1) ORIGINAL • 2) INPUT / PHONE-PROC • 3) Lung mask overlay • | |
| 4) Masked model input • 5) TBNet Grad-CAM • 6) RADIO heatmaps</div> | |
| """) | |
| gr.Markdown("#### Summary (per image)") | |
| summary = gr.Markdown("Upload images and click <b>Run Analysis</b>.") | |
| gallery = gr.Gallery(label="Visual outputs", columns=3, height=560) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| disclaimer_box = gr.Markdown(CLINICAL_DISCLAIMER) | |
| with gr.Column(scale=2): | |
| gr.Markdown("#### Detailed report (expand per image)") | |
| details = gr.Markdown("") | |
| run_btn.click( | |
| fn=run_analysis, | |
| inputs=[ | |
| files, | |
| tb_weights, | |
| lung_weights, | |
| backbone, | |
| threshold, | |
| phone_mode, | |
| use_radio, | |
| radio_gate, | |
| ], | |
| outputs=[summary, gallery, details, disclaimer_box, status] | |
| ) | |
| # --------------------------- | |
| # Transitions | |
| # --------------------------- | |
| continue_btn.click( | |
| fn=lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| inputs=[], | |
| outputs=[welcome_screen, main_app], | |
| ) | |
| back_btn.click( | |
| fn=lambda: (gr.update(visible=True), gr.update(visible=False)), | |
| inputs=[], | |
| outputs=[welcome_screen, main_app], | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True) |