Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -136,7 +136,7 @@ CLINICAL_GUIDANCE = (
|
|
| 136 |
|
| 137 |
|
| 138 |
# ============================================================
|
| 139 |
-
# UX HELPERS
|
| 140 |
# ============================================================
|
| 141 |
def pretty_state(s: str) -> str:
|
| 142 |
return {
|
|
@@ -147,10 +147,12 @@ def pretty_state(s: str) -> str:
|
|
| 147 |
"N/A": "⚠️ N/A",
|
| 148 |
}.get(s, f"⚠️ {s}")
|
| 149 |
|
|
|
|
| 150 |
def html_escape(s: str) -> str:
|
| 151 |
# minimal escaping for safety in HTML blocks
|
| 152 |
return (s or "").replace("&", "&").replace("<", "<").replace(">", ">")
|
| 153 |
|
|
|
|
| 154 |
def badge_color_for_state(state: str) -> str:
|
| 155 |
# soft visual cue; works in both dark/light
|
| 156 |
if state == "TB+":
|
|
@@ -181,6 +183,7 @@ class DoubleConv(nn.Module):
|
|
| 181 |
|
| 182 |
def forward(self, x): return self.net(x)
|
| 183 |
|
|
|
|
| 184 |
class LungUNet(nn.Module):
|
| 185 |
def __init__(self):
|
| 186 |
super().__init__()
|
|
@@ -224,10 +227,12 @@ class TBNet(nn.Module):
|
|
| 224 |
|
| 225 |
def forward(self, x): return self.fc(self.backbone(x)).view(-1)
|
| 226 |
|
|
|
|
| 227 |
def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
|
| 228 |
sd = torch.load(ckpt_path, map_location=device)
|
| 229 |
model.load_state_dict(sd, strict=True)
|
| 230 |
|
|
|
|
| 231 |
class GradCAM:
|
| 232 |
def __init__(self, model: nn.Module, target_layer: nn.Module):
|
| 233 |
self.model = model
|
|
@@ -271,6 +276,7 @@ def preprocess_for_lung_unet(gray_u8: np.ndarray) -> torch.Tensor:
|
|
| 271 |
g = (g - lo) / (hi - lo + 1e-8)
|
| 272 |
return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float()
|
| 273 |
|
|
|
|
| 274 |
def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
| 275 |
gray = gray_u8.astype(np.float32)
|
| 276 |
lo, hi = np.percentile(gray, (1, 99))
|
|
@@ -278,16 +284,19 @@ def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
|
| 278 |
gray = (gray - lo) / (hi - lo + 1e-8)
|
| 279 |
return gray
|
| 280 |
|
|
|
|
| 281 |
def laplacian_sharpness(gray_u8: np.ndarray) -> float:
|
| 282 |
g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA)
|
| 283 |
g = cv2.GaussianBlur(g, (3, 3), 0)
|
| 284 |
return float(cv2.Laplacian(g, cv2.CV_64F).var())
|
| 285 |
|
|
|
|
| 286 |
def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]:
|
| 287 |
lo = float((gray_u8 < 10).mean())
|
| 288 |
hi = float((gray_u8 > 245).mean())
|
| 289 |
return lo, hi
|
| 290 |
|
|
|
|
| 291 |
def border_fraction(gray_u8: np.ndarray) -> float:
|
| 292 |
h, w = gray_u8.shape
|
| 293 |
b = max(5, int(0.06 * min(h, w)))
|
|
@@ -299,6 +308,7 @@ def border_fraction(gray_u8: np.ndarray) -> float:
|
|
| 299 |
def frac_border(x): return float(((x < 15) | (x > 240)).mean())
|
| 300 |
return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
|
| 301 |
|
|
|
|
| 302 |
def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
| 303 |
warnings: List[str] = []
|
| 304 |
h, w = gray_u8.shape
|
|
@@ -334,6 +344,7 @@ def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
|
| 334 |
|
| 335 |
return float(np.clip(score, 0, 100)), warnings
|
| 336 |
|
|
|
|
| 337 |
def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
|
| 338 |
g = gray_u8.copy()
|
| 339 |
g_blur = cv2.GaussianBlur(g, (5, 5), 0)
|
|
@@ -357,10 +368,12 @@ def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
|
|
| 357 |
x2 = min(W, x + w + pad); y2 = min(H, y + h + pad)
|
| 358 |
return gray_u8[y1:y2, x1:x2]
|
| 359 |
|
|
|
|
| 360 |
def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
|
| 361 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 362 |
return clahe.apply(gray_u8)
|
| 363 |
|
|
|
|
| 364 |
def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
| 365 |
sharp = laplacian_sharpness(gray_u8)
|
| 366 |
lo_clip, _hi_clip = exposure_scores(gray_u8)
|
|
@@ -377,11 +390,13 @@ def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
|
| 377 |
|
| 378 |
return g
|
| 379 |
|
|
|
|
| 380 |
def cam_entropy(cam: np.ndarray) -> float:
|
| 381 |
cam = cam.astype(np.float32)
|
| 382 |
cam = cam / (cam.sum() + 1e-8)
|
| 383 |
return float(-np.sum(cam * np.log(cam + 1e-8)))
|
| 384 |
|
|
|
|
| 385 |
def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
|
| 386 |
if quality_score < 55:
|
| 387 |
return False
|
|
@@ -390,6 +405,7 @@ def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float
|
|
| 390 |
ent = cam_entropy(cam_up)
|
| 391 |
return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
|
| 392 |
|
|
|
|
| 393 |
def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
|
| 394 |
if prob_tb < 0.01 and quality_score >= 45:
|
| 395 |
return ("GREEN", "✅ Very low TB signal detected.")
|
|
@@ -401,21 +417,24 @@ def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
|
|
| 401 |
return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
|
| 402 |
return ("GREEN", "✅ No strong TB signal detected.")
|
| 403 |
|
|
|
|
| 404 |
def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
|
| 405 |
base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
|
| 406 |
mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 407 |
return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0)
|
| 408 |
|
|
|
|
| 409 |
def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
|
| 410 |
m = (binary_u8 * 255).astype(np.uint8)
|
| 411 |
h, w = m.shape
|
| 412 |
flood = m.copy()
|
| 413 |
-
mask = np.zeros((h+2, w+2), np.uint8)
|
| 414 |
cv2.floodFill(flood, mask, (0, 0), 255)
|
| 415 |
holes = cv2.bitwise_not(flood)
|
| 416 |
filled = cv2.bitwise_or(m, holes)
|
| 417 |
return (filled > 0).astype(np.uint8)
|
| 418 |
|
|
|
|
| 419 |
def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
|
| 420 |
m = (binary_u8 > 0).astype(np.uint8)
|
| 421 |
n, labels = cv2.connectedComponents(m)
|
|
@@ -431,6 +450,7 @@ def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
|
|
| 431 |
out[labels == i] = 1
|
| 432 |
return out
|
| 433 |
|
|
|
|
| 434 |
def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
|
| 435 |
m = (mask_full_u8 > 0).astype(np.uint8)
|
| 436 |
n, labels = cv2.connectedComponents(m)
|
|
@@ -460,6 +480,7 @@ def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
|
|
| 460 |
|
| 461 |
return warns
|
| 462 |
|
|
|
|
| 463 |
def recommendation_for_band(band: Optional[str]) -> str:
|
| 464 |
if band in (None, "YELLOW"):
|
| 465 |
return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)."
|
|
@@ -482,6 +503,7 @@ def tbnet_state(tb_prob: float, tb_band: str) -> str:
|
|
| 482 |
return "SCREEN+"
|
| 483 |
return "LOW"
|
| 484 |
|
|
|
|
| 485 |
def radio_state_from_prob(radio_prob: float) -> str:
|
| 486 |
if radio_prob >= RADIO_THR_RED:
|
| 487 |
return "TB+"
|
|
@@ -489,6 +511,7 @@ def radio_state_from_prob(radio_prob: float) -> str:
|
|
| 489 |
return "SCREEN+"
|
| 490 |
return "LOW"
|
| 491 |
|
|
|
|
| 492 |
def build_consensus(
|
| 493 |
tb_prob: Optional[float],
|
| 494 |
tb_band: Optional[str],
|
|
@@ -622,6 +645,7 @@ class ModelBundle:
|
|
| 622 |
self.lung = lung
|
| 623 |
self.lung_path = lung_weights
|
| 624 |
|
|
|
|
| 625 |
BUNDLE = ModelBundle()
|
| 626 |
|
| 627 |
|
|
@@ -643,6 +667,7 @@ class RadioMLPHead(nn.Module):
|
|
| 643 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 644 |
return self.net(x).squeeze(1)
|
| 645 |
|
|
|
|
| 646 |
class RadioBundle:
|
| 647 |
def __init__(self):
|
| 648 |
self.loaded = False
|
|
@@ -691,8 +716,10 @@ class RadioBundle:
|
|
| 691 |
self.device_str = dev_str
|
| 692 |
self.loaded = True
|
| 693 |
|
|
|
|
| 694 |
RADIO_BUNDLE = RadioBundle()
|
| 695 |
|
|
|
|
| 696 |
def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray:
|
| 697 |
ht = in_h // patch_size
|
| 698 |
wt = in_w // patch_size
|
|
@@ -703,6 +730,7 @@ def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: in
|
|
| 703 |
hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR)
|
| 704 |
return np.array(hm_img, dtype=np.float32) / 255.0
|
| 705 |
|
|
|
|
| 706 |
def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray:
|
| 707 |
img = rgb_u8.astype(np.float32) / 255.0
|
| 708 |
hm = np.clip(heatmap01, 0, 1).astype(np.float32)
|
|
@@ -710,6 +738,7 @@ def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: floa
|
|
| 710 |
out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
|
| 711 |
return (out * 255).astype(np.uint8)
|
| 712 |
|
|
|
|
| 713 |
@torch.inference_mode()
|
| 714 |
def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
|
| 715 |
lung_mask_u8: np.ndarray,
|
|
@@ -1076,7 +1105,6 @@ def run_analysis(
|
|
| 1076 |
# ------------------------
|
| 1077 |
# Build descriptive cards (TBNet, RADIO, Consensus)
|
| 1078 |
# ------------------------
|
| 1079 |
-
# TBNet block values
|
| 1080 |
tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
|
| 1081 |
tb_band = out.get("band", "YELLOW")
|
| 1082 |
tb_label = out.get("pred", "INDETERMINATE")
|
|
@@ -1089,12 +1117,10 @@ def run_analysis(
|
|
| 1089 |
|
| 1090 |
attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
|
| 1091 |
|
| 1092 |
-
# RADIO block values
|
| 1093 |
radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
|
| 1094 |
radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
|
| 1095 |
radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
|
| 1096 |
|
| 1097 |
-
# Recommended next step (consensus-aware)
|
| 1098 |
if consensus_label == "DISAGREE":
|
| 1099 |
next_step = "✅ Next step: Treat as **indeterminate** → radiologist review + microbiology if clinically suspected."
|
| 1100 |
elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
|
|
@@ -1104,7 +1130,6 @@ def run_analysis(
|
|
| 1104 |
else:
|
| 1105 |
next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present."
|
| 1106 |
|
| 1107 |
-
# Card containers (HTML)
|
| 1108 |
state_badge_tb = f"""
|
| 1109 |
<span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(tb_state)}; font-weight:800;">
|
| 1110 |
{pretty_state(tb_state)}
|
|
@@ -1196,7 +1221,6 @@ def run_analysis(
|
|
| 1196 |
# Detailed collapsible report (per image)
|
| 1197 |
# ------------------------
|
| 1198 |
warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
|
| 1199 |
-
rec_line = recommendation_for_band(out.get("band"))
|
| 1200 |
|
| 1201 |
details_md.append(
|
| 1202 |
f"""
|
|
@@ -1243,7 +1267,7 @@ def run_analysis(
|
|
| 1243 |
|
| 1244 |
|
| 1245 |
# ============================================================
|
| 1246 |
-
# UI
|
| 1247 |
# ============================================================
|
| 1248 |
def build_ui():
|
| 1249 |
css = """
|
|
@@ -1251,101 +1275,205 @@ def build_ui():
|
|
| 1251 |
.subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;}
|
| 1252 |
.warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;}
|
| 1253 |
.legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;}
|
|
|
|
| 1254 |
"""
|
| 1255 |
|
| 1256 |
with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
|
| 1257 |
-
gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
|
| 1258 |
-
gr.Markdown(
|
| 1259 |
-
f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
|
| 1260 |
-
f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Clear per-model results + consensus</div>"
|
| 1261 |
-
)
|
| 1262 |
-
|
| 1263 |
-
# Disclaimer shown early (so users see it before running)
|
| 1264 |
-
gr.Markdown(
|
| 1265 |
-
"<div class='warnbox'><b>Clinical disclaimer:</b> Decision support only (not diagnostic). "
|
| 1266 |
-
"If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.</div>"
|
| 1267 |
-
)
|
| 1268 |
|
| 1269 |
-
|
| 1270 |
-
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
|
| 1274 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1275 |
|
| 1276 |
-
|
|
|
|
|
|
|
|
|
|
| 1277 |
|
| 1278 |
-
|
| 1279 |
-
|
| 1280 |
-
|
| 1281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
|
| 1283 |
-
|
| 1284 |
-
|
| 1285 |
-
|
| 1286 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1287 |
|
| 1288 |
-
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1293 |
|
| 1294 |
-
|
| 1295 |
-
|
| 1296 |
-
|
| 1297 |
-
|
|
|
|
|
|
|
|
|
|
| 1298 |
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1302 |
|
| 1303 |
-
|
| 1304 |
-
|
| 1305 |
-
|
| 1306 |
-
|
| 1307 |
-
file_types=[".png", ".jpg", ".jpeg", ".bmp"]
|
| 1308 |
-
)
|
| 1309 |
-
run_btn = gr.Button("Run Analysis", variant="primary")
|
| 1310 |
-
status = gr.Textbox(label="Status", value="Ready.", interactive=False)
|
| 1311 |
|
| 1312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1313 |
<div class='legend'><b>Gallery legend:</b><br/>
|
| 1314 |
1) ORIGINAL • 2) INPUT / PHONE-PROC • 3) Lung mask overlay •
|
| 1315 |
4) Masked model input • 5) TBNet Grad-CAM • 6) RADIO heatmaps</div>
|
| 1316 |
""")
|
| 1317 |
|
| 1318 |
-
|
| 1319 |
-
|
| 1320 |
-
|
| 1321 |
-
|
| 1322 |
-
|
| 1323 |
-
|
| 1324 |
-
|
| 1325 |
-
|
| 1326 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1327 |
)
|
| 1328 |
|
| 1329 |
-
|
| 1330 |
-
|
| 1331 |
-
|
| 1332 |
-
|
| 1333 |
-
|
| 1334 |
-
|
| 1335 |
-
|
| 1336 |
-
|
| 1337 |
-
|
| 1338 |
-
|
| 1339 |
-
|
| 1340 |
-
|
| 1341 |
-
|
| 1342 |
-
backbone,
|
| 1343 |
-
threshold,
|
| 1344 |
-
phone_mode,
|
| 1345 |
-
use_radio,
|
| 1346 |
-
radio_gate,
|
| 1347 |
-
],
|
| 1348 |
-
outputs=[summary, gallery, details, disclaimer_box, status]
|
| 1349 |
)
|
| 1350 |
|
| 1351 |
return demo
|
|
@@ -1353,4 +1481,6 @@ def build_ui():
|
|
| 1353 |
|
| 1354 |
if __name__ == "__main__":
|
| 1355 |
demo = build_ui()
|
| 1356 |
-
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
|
| 138 |
# ============================================================
|
| 139 |
+
# UX HELPERS
|
| 140 |
# ============================================================
|
| 141 |
def pretty_state(s: str) -> str:
|
| 142 |
return {
|
|
|
|
| 147 |
"N/A": "⚠️ N/A",
|
| 148 |
}.get(s, f"⚠️ {s}")
|
| 149 |
|
| 150 |
+
|
| 151 |
def html_escape(s: str) -> str:
|
| 152 |
# minimal escaping for safety in HTML blocks
|
| 153 |
return (s or "").replace("&", "&").replace("<", "<").replace(">", ">")
|
| 154 |
|
| 155 |
+
|
| 156 |
def badge_color_for_state(state: str) -> str:
|
| 157 |
# soft visual cue; works in both dark/light
|
| 158 |
if state == "TB+":
|
|
|
|
| 183 |
|
| 184 |
def forward(self, x): return self.net(x)
|
| 185 |
|
| 186 |
+
|
| 187 |
class LungUNet(nn.Module):
|
| 188 |
def __init__(self):
|
| 189 |
super().__init__()
|
|
|
|
| 227 |
|
| 228 |
def forward(self, x): return self.fc(self.backbone(x)).view(-1)
|
| 229 |
|
| 230 |
+
|
| 231 |
def load_tb_weights(model: nn.Module, ckpt_path: str, device: torch.device):
|
| 232 |
sd = torch.load(ckpt_path, map_location=device)
|
| 233 |
model.load_state_dict(sd, strict=True)
|
| 234 |
|
| 235 |
+
|
| 236 |
class GradCAM:
|
| 237 |
def __init__(self, model: nn.Module, target_layer: nn.Module):
|
| 238 |
self.model = model
|
|
|
|
| 276 |
g = (g - lo) / (hi - lo + 1e-8)
|
| 277 |
return torch.from_numpy(g).unsqueeze(0).unsqueeze(0).float()
|
| 278 |
|
| 279 |
+
|
| 280 |
def tb_training_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
| 281 |
gray = gray_u8.astype(np.float32)
|
| 282 |
lo, hi = np.percentile(gray, (1, 99))
|
|
|
|
| 284 |
gray = (gray - lo) / (hi - lo + 1e-8)
|
| 285 |
return gray
|
| 286 |
|
| 287 |
+
|
| 288 |
def laplacian_sharpness(gray_u8: np.ndarray) -> float:
|
| 289 |
g = cv2.resize(gray_u8, (512, 512), interpolation=cv2.INTER_AREA)
|
| 290 |
g = cv2.GaussianBlur(g, (3, 3), 0)
|
| 291 |
return float(cv2.Laplacian(g, cv2.CV_64F).var())
|
| 292 |
|
| 293 |
+
|
| 294 |
def exposure_scores(gray_u8: np.ndarray) -> Tuple[float, float]:
|
| 295 |
lo = float((gray_u8 < 10).mean())
|
| 296 |
hi = float((gray_u8 > 245).mean())
|
| 297 |
return lo, hi
|
| 298 |
|
| 299 |
+
|
| 300 |
def border_fraction(gray_u8: np.ndarray) -> float:
|
| 301 |
h, w = gray_u8.shape
|
| 302 |
b = max(5, int(0.06 * min(h, w)))
|
|
|
|
| 308 |
def frac_border(x): return float(((x < 15) | (x > 240)).mean())
|
| 309 |
return float(np.mean([frac_border(top), frac_border(bot), frac_border(left), frac_border(right)]))
|
| 310 |
|
| 311 |
+
|
| 312 |
def phone_quality_report(gray_u8: np.ndarray) -> Tuple[float, List[str]]:
|
| 313 |
warnings: List[str] = []
|
| 314 |
h, w = gray_u8.shape
|
|
|
|
| 344 |
|
| 345 |
return float(np.clip(score, 0, 100)), warnings
|
| 346 |
|
| 347 |
+
|
| 348 |
def auto_border_crop(gray_u8: np.ndarray) -> np.ndarray:
|
| 349 |
g = gray_u8.copy()
|
| 350 |
g_blur = cv2.GaussianBlur(g, (5, 5), 0)
|
|
|
|
| 368 |
x2 = min(W, x + w + pad); y2 = min(H, y + h + pad)
|
| 369 |
return gray_u8[y1:y2, x1:x2]
|
| 370 |
|
| 371 |
+
|
| 372 |
def apply_clahe(gray_u8: np.ndarray) -> np.ndarray:
|
| 373 |
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
| 374 |
return clahe.apply(gray_u8)
|
| 375 |
|
| 376 |
+
|
| 377 |
def phone_preprocess(gray_u8: np.ndarray) -> np.ndarray:
|
| 378 |
sharp = laplacian_sharpness(gray_u8)
|
| 379 |
lo_clip, _hi_clip = exposure_scores(gray_u8)
|
|
|
|
| 390 |
|
| 391 |
return g
|
| 392 |
|
| 393 |
+
|
| 394 |
def cam_entropy(cam: np.ndarray) -> float:
|
| 395 |
cam = cam.astype(np.float32)
|
| 396 |
cam = cam / (cam.sum() + 1e-8)
|
| 397 |
return float(-np.sum(cam * np.log(cam + 1e-8)))
|
| 398 |
|
| 399 |
+
|
| 400 |
def detect_diffuse_risk(prob_tb: float, cam_up: np.ndarray, quality_score: float) -> bool:
|
| 401 |
if quality_score < 55:
|
| 402 |
return False
|
|
|
|
| 405 |
ent = cam_entropy(cam_up)
|
| 406 |
return (prob_tb < TBNET_SCREEN_THR) and (ent > 6.5)
|
| 407 |
|
| 408 |
+
|
| 409 |
def confidence_band(prob_tb: float, quality_score: float, diffuse: bool):
|
| 410 |
if prob_tb < 0.01 and quality_score >= 45:
|
| 411 |
return ("GREEN", "✅ Very low TB signal detected.")
|
|
|
|
| 417 |
return ("YELLOW", "⚠️ Screening-positive range; review recommended.")
|
| 418 |
return ("GREEN", "✅ No strong TB signal detected.")
|
| 419 |
|
| 420 |
+
|
| 421 |
def make_mask_overlay(gray_u8: np.ndarray, mask_u8: np.ndarray) -> np.ndarray:
|
| 422 |
base = cv2.cvtColor(gray_u8, cv2.COLOR_GRAY2RGB)
|
| 423 |
mask_color = cv2.applyColorMap((mask_u8 * 255).astype(np.uint8), cv2.COLORMAP_JET)
|
| 424 |
return cv2.addWeighted(base, 0.75, mask_color, 0.25, 0)
|
| 425 |
|
| 426 |
+
|
| 427 |
def fill_holes(binary_u8: np.ndarray) -> np.ndarray:
|
| 428 |
m = (binary_u8 * 255).astype(np.uint8)
|
| 429 |
h, w = m.shape
|
| 430 |
flood = m.copy()
|
| 431 |
+
mask = np.zeros((h + 2, w + 2), np.uint8)
|
| 432 |
cv2.floodFill(flood, mask, (0, 0), 255)
|
| 433 |
holes = cv2.bitwise_not(flood)
|
| 434 |
filled = cv2.bitwise_or(m, holes)
|
| 435 |
return (filled > 0).astype(np.uint8)
|
| 436 |
|
| 437 |
+
|
| 438 |
def keep_top_k_components(binary_u8: np.ndarray, k: int = 2) -> np.ndarray:
|
| 439 |
m = (binary_u8 > 0).astype(np.uint8)
|
| 440 |
n, labels = cv2.connectedComponents(m)
|
|
|
|
| 450 |
out[labels == i] = 1
|
| 451 |
return out
|
| 452 |
|
| 453 |
+
|
| 454 |
def mask_sanity_warnings(mask_full_u8: np.ndarray) -> List[str]:
|
| 455 |
m = (mask_full_u8 > 0).astype(np.uint8)
|
| 456 |
n, labels = cv2.connectedComponents(m)
|
|
|
|
| 480 |
|
| 481 |
return warns
|
| 482 |
|
| 483 |
+
|
| 484 |
def recommendation_for_band(band: Optional[str]) -> str:
|
| 485 |
if band in (None, "YELLOW"):
|
| 486 |
return "✅ Recommendation: Radiologist/clinician review is recommended (**indeterminate**)."
|
|
|
|
| 503 |
return "SCREEN+"
|
| 504 |
return "LOW"
|
| 505 |
|
| 506 |
+
|
| 507 |
def radio_state_from_prob(radio_prob: float) -> str:
|
| 508 |
if radio_prob >= RADIO_THR_RED:
|
| 509 |
return "TB+"
|
|
|
|
| 511 |
return "SCREEN+"
|
| 512 |
return "LOW"
|
| 513 |
|
| 514 |
+
|
| 515 |
def build_consensus(
|
| 516 |
tb_prob: Optional[float],
|
| 517 |
tb_band: Optional[str],
|
|
|
|
| 645 |
self.lung = lung
|
| 646 |
self.lung_path = lung_weights
|
| 647 |
|
| 648 |
+
|
| 649 |
BUNDLE = ModelBundle()
|
| 650 |
|
| 651 |
|
|
|
|
| 667 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 668 |
return self.net(x).squeeze(1)
|
| 669 |
|
| 670 |
+
|
| 671 |
class RadioBundle:
|
| 672 |
def __init__(self):
|
| 673 |
self.loaded = False
|
|
|
|
| 716 |
self.device_str = dev_str
|
| 717 |
self.loaded = True
|
| 718 |
|
| 719 |
+
|
| 720 |
RADIO_BUNDLE = RadioBundle()
|
| 721 |
|
| 722 |
+
|
| 723 |
def radio_heatmap_from_spatial(spatial_tokens: torch.Tensor, in_h: int, in_w: int, patch_size: int = 16) -> np.ndarray:
|
| 724 |
ht = in_h // patch_size
|
| 725 |
wt = in_w // patch_size
|
|
|
|
| 730 |
hm_img = Image.fromarray((hm * 255).astype(np.uint8)).resize((in_w, in_h), resample=Image.BILINEAR)
|
| 731 |
return np.array(hm_img, dtype=np.float32) / 255.0
|
| 732 |
|
| 733 |
+
|
| 734 |
def radio_overlay_heatmap(rgb_u8: np.ndarray, heatmap01: np.ndarray, alpha: float = 0.35) -> np.ndarray:
|
| 735 |
img = rgb_u8.astype(np.float32) / 255.0
|
| 736 |
hm = np.clip(heatmap01, 0, 1).astype(np.float32)
|
|
|
|
| 738 |
out[..., 0] = np.clip(out[..., 0] * (1 - alpha) + hm * alpha, 0, 1)
|
| 739 |
return (out * 255).astype(np.uint8)
|
| 740 |
|
| 741 |
+
|
| 742 |
@torch.inference_mode()
|
| 743 |
def radio_predict_from_arrays(gray_vis_u8: np.ndarray,
|
| 744 |
lung_mask_u8: np.ndarray,
|
|
|
|
| 1105 |
# ------------------------
|
| 1106 |
# Build descriptive cards (TBNet, RADIO, Consensus)
|
| 1107 |
# ------------------------
|
|
|
|
| 1108 |
tb_prob_line = "N/A (fail-safe)" if out["prob"] is None else f"{out['prob']:.4f}"
|
| 1109 |
tb_band = out.get("band", "YELLOW")
|
| 1110 |
tb_label = out.get("pred", "INDETERMINATE")
|
|
|
|
| 1117 |
|
| 1118 |
attention = "Diffuse / non-focal" if out.get("diffuse_risk", False) else "Focal / localized"
|
| 1119 |
|
|
|
|
| 1120 |
radio_primary_line = "N/A" if radio_primary_val is None else f"{radio_primary_val:.4f}"
|
| 1121 |
radio_raw_line = "N/A" if radio_raw_val is None else f"{radio_raw_val:.4f}"
|
| 1122 |
radio_masked_line = "Not run" if radio_masked_val is None else f"{radio_masked_val:.4f}"
|
| 1123 |
|
|
|
|
| 1124 |
if consensus_label == "DISAGREE":
|
| 1125 |
next_step = "✅ Next step: Treat as **indeterminate** → radiologist review + microbiology if clinically suspected."
|
| 1126 |
elif consensus_label in ("AGREE: TB+", "AGREE: SCREEN+"):
|
|
|
|
| 1130 |
else:
|
| 1131 |
next_step = "✅ Next step: Correlate clinically; radiologist review recommended if uncertainty or symptoms present."
|
| 1132 |
|
|
|
|
| 1133 |
state_badge_tb = f"""
|
| 1134 |
<span style="padding:4px 10px; border-radius:999px; background:{badge_color_for_state(tb_state)}; font-weight:800;">
|
| 1135 |
{pretty_state(tb_state)}
|
|
|
|
| 1221 |
# Detailed collapsible report (per image)
|
| 1222 |
# ------------------------
|
| 1223 |
warn_txt = "\n".join([f"- {w}" for w in out["warnings"]]) if out["warnings"] else "- None"
|
|
|
|
| 1224 |
|
| 1225 |
details_md.append(
|
| 1226 |
f"""
|
|
|
|
| 1267 |
|
| 1268 |
|
| 1269 |
# ============================================================
|
| 1270 |
+
# UI (HF Spaces Welcome Screen + Main App)
|
| 1271 |
# ============================================================
|
| 1272 |
def build_ui():
|
| 1273 |
css = """
|
|
|
|
| 1275 |
.subtitle {font-size: 14px; opacity: 0.88; margin-bottom: 14px;}
|
| 1276 |
.warnbox {border-left: 6px solid #f59e0b; padding: 10px 12px; background: rgba(245,158,11,0.08); border-radius: 10px;}
|
| 1277 |
.legend {border-left: 6px solid rgba(148,163,184,0.7); padding: 10px 12px; background: rgba(148,163,184,0.08); border-radius: 10px;}
|
| 1278 |
+
.card {border:1px solid rgba(255,255,255,0.12); border-radius:14px; padding:14px; margin:10px 0;}
|
| 1279 |
"""
|
| 1280 |
|
| 1281 |
with gr.Blocks(title="TB X-ray Assistant (TBNet + RADIO)", css=css) as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1282 |
|
| 1283 |
+
# ---------------------------
|
| 1284 |
+
# Welcome screen (shown first)
|
| 1285 |
+
# ---------------------------
|
| 1286 |
+
with gr.Column(visible=True) as welcome_screen:
|
| 1287 |
+
gr.Markdown('<div class="title">Welcome — TB X-ray Assistant (HF Spaces)</div>')
|
| 1288 |
+
|
| 1289 |
+
gr.Markdown(
|
| 1290 |
+
f"""
|
| 1291 |
+
<div class="card">
|
| 1292 |
+
<div style="font-size:16px; font-weight:900; margin-bottom:8px;">What this Space does</div>
|
| 1293 |
+
<div style="opacity:0.92;">
|
| 1294 |
+
This app analyzes chest X-rays for <b>TB-like patterns</b> and shows results from:
|
| 1295 |
+
<ul>
|
| 1296 |
+
<li><b>{MODEL_NAME_TBNET}</b> with <b>Grad-CAM</b> explainability</li>
|
| 1297 |
+
<li><b>Auto Lung Mask</b> (Lung U-Net) + a <b>fail-safe</b> to prevent unsafe scoring on bad masks</li>
|
| 1298 |
+
<li>Optional <b>{MODEL_NAME_RADIO}</b> with <b>RAW / MASKED</b> scoring and heatmaps</li>
|
| 1299 |
+
<li>A final <b>consensus</b>: ✅ LOW / ⚠️ INDET / ⚠️ SCREEN+ / 🚩 TB+</li>
|
| 1300 |
+
</ul>
|
| 1301 |
+
</div>
|
| 1302 |
+
</div>
|
| 1303 |
|
| 1304 |
+
<div class="warnbox">
|
| 1305 |
+
<b>Clinical disclaimer:</b> Decision support only (not diagnostic). TB can be subtle (including miliary TB).
|
| 1306 |
+
If TB is clinically suspected, pursue microbiology (CBNAAT/GeneXpert, sputum) and/or CT chest regardless of AI output.
|
| 1307 |
+
</div>
|
| 1308 |
|
| 1309 |
+
<div class="card">
|
| 1310 |
+
<div style="font-size:16px; font-weight:900; margin-bottom:8px;">Special feature: Phone / WhatsApp Mode</div>
|
| 1311 |
+
<div style="opacity:0.92;">
|
| 1312 |
+
Many users upload:
|
| 1313 |
+
<ul>
|
| 1314 |
+
<li><b>WhatsApp-forwarded X-rays</b> (compressed, low contrast)</li>
|
| 1315 |
+
<li><b>Phone photos</b> of printed films or monitor screens (borders, glare, blur)</li>
|
| 1316 |
+
<li><b>Screenshots</b> with large margins / UI elements</li>
|
| 1317 |
+
</ul>
|
| 1318 |
+
<b>Phone / WhatsApp Mode</b> is designed for these cases. When enabled, it applies:
|
| 1319 |
+
<ul>
|
| 1320 |
+
<li><b>Safe border crop</b> (reduces margins / screenshot framing)</li>
|
| 1321 |
+
<li><b>Conditional CLAHE</b> (boosts local contrast when underexposed / low-detail)</li>
|
| 1322 |
+
<li><b>Quality warnings</b> (blur, over/underexposure, heavy borders) to flag reduced reliability</li>
|
| 1323 |
+
</ul>
|
| 1324 |
+
|
| 1325 |
+
<div style="margin-top:10px; padding:10px 12px; border-left:6px solid rgba(96,165,250,0.9); background: rgba(96,165,250,0.10); border-radius:12px;">
|
| 1326 |
+
<b>Tip:</b> Enable Phone/WhatsApp Mode if your image is a phone photo, WhatsApp-forwarded, or has big borders / low contrast.
|
| 1327 |
+
Keep it OFF for clean digital exports to avoid unnecessary preprocessing.
|
| 1328 |
+
</div>
|
| 1329 |
+
</div>
|
| 1330 |
+
</div>
|
| 1331 |
|
| 1332 |
+
<div class="card">
|
| 1333 |
+
<div style="font-size:16px; font-weight:900; margin-bottom:8px;">Explainability & reliability</div>
|
| 1334 |
+
<ul>
|
| 1335 |
+
<li><b>Grad-CAM</b> (TBNet) highlights regions that influenced the TB score.</li>
|
| 1336 |
+
<li><b>RADIO heatmaps</b> show where the visual model is focusing (RAW and sometimes MASKED).</li>
|
| 1337 |
+
<li><b>Fail-safe</b>: if lung segmentation looks unreliable, TBNet scoring is disabled (shown as indeterminate).</li>
|
| 1338 |
+
<li><b>Quality scoring</b> warns when results may be less reliable (blur, compression, non-standard view).</li>
|
| 1339 |
+
</ul>
|
| 1340 |
+
</div>
|
| 1341 |
|
| 1342 |
+
<div class="card">
|
| 1343 |
+
<div style="font-size:16px; font-weight:900; margin-bottom:8px;">How to use</div>
|
| 1344 |
+
<ol>
|
| 1345 |
+
<li>Click <b>Continue</b> to open the interface.</li>
|
| 1346 |
+
<li>Upload one or multiple X-ray images.</li>
|
| 1347 |
+
<li>If your images come from <b>WhatsApp / phone camera / screenshots</b>, enable <b>Phone/WhatsApp Mode</b>.</li>
|
| 1348 |
+
<li>(Optional) Enable <b>{MODEL_NAME_RADIO}</b> for a second independent model + heatmaps.</li>
|
| 1349 |
+
<li>Click <b>Run Analysis</b>.</li>
|
| 1350 |
+
</ol>
|
| 1351 |
+
</div>
|
| 1352 |
|
| 1353 |
+
<div class="card">
|
| 1354 |
+
<div style="font-size:16px; font-weight:900; margin-bottom:8px;">Privacy / processing note (HF Spaces)</div>
|
| 1355 |
+
<div style="opacity:0.92;">
|
| 1356 |
+
Images are processed by this Space runtime. Avoid uploading personally identifiable medical data.
|
| 1357 |
+
Use anonymized images when possible.
|
| 1358 |
+
</div>
|
| 1359 |
+
</div>
|
| 1360 |
|
| 1361 |
+
<div class="subtitle">
|
| 1362 |
+
Device policy: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})
|
| 1363 |
+
</div>
|
| 1364 |
+
"""
|
| 1365 |
+
)
|
| 1366 |
+
continue_btn = gr.Button("Continue →", variant="primary")
|
| 1367 |
+
|
| 1368 |
+
# ---------------------------
|
| 1369 |
+
# Main app UI (hidden initially)
|
| 1370 |
+
# ---------------------------
|
| 1371 |
+
with gr.Column(visible=False) as main_app:
|
| 1372 |
+
gr.Markdown('<div class="title">TB X-ray Assistant (Auto Lung Mask • Research Use)</div>')
|
| 1373 |
+
gr.Markdown(
|
| 1374 |
+
f"<div class='subtitle'>Auto lung mask → <b>{MODEL_NAME_TBNET}</b> + Grad-CAM • "
|
| 1375 |
+
f"Optional <b>{MODEL_NAME_RADIO}</b> (C-RADIOv4 + heads) • Clear per-model results + consensus</div>"
|
| 1376 |
+
)
|
| 1377 |
|
| 1378 |
+
gr.Markdown(
|
| 1379 |
+
"<div class='warnbox'><b>Clinical disclaimer:</b> Decision support only (not diagnostic). "
|
| 1380 |
+
"If TB is clinically suspected, pursue microbiology / CT as appropriate regardless of AI output.</div>"
|
| 1381 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1382 |
|
| 1383 |
+
with gr.Row():
|
| 1384 |
+
with gr.Column(scale=1):
|
| 1385 |
+
gr.Markdown("#### Model settings")
|
| 1386 |
+
|
| 1387 |
+
tb_weights = gr.Textbox(label="TBNet weights (.pt)", value=DEFAULT_TB_WEIGHTS)
|
| 1388 |
+
lung_weights = gr.Textbox(label="Lung U-Net weights (.pt)", value=DEFAULT_LUNG_WEIGHTS)
|
| 1389 |
+
|
| 1390 |
+
backbone = gr.Dropdown(choices=["efficientnet_b0"], value="efficientnet_b0", label="TBNet backbone")
|
| 1391 |
+
|
| 1392 |
+
threshold = gr.Slider(
|
| 1393 |
+
0.01, 0.99, value=TBNET_SCREEN_THR, step=0.01,
|
| 1394 |
+
label=f"Reference threshold (TBNet screen+) = {TBNET_SCREEN_THR:.2f}"
|
| 1395 |
+
)
|
| 1396 |
+
|
| 1397 |
+
phone_mode = gr.Checkbox(
|
| 1398 |
+
value=False,
|
| 1399 |
+
label="Phone/WhatsApp Mode (safe: conditional crop + conditional CLAHE)"
|
| 1400 |
+
)
|
| 1401 |
+
gr.Markdown(
|
| 1402 |
+
"<div class='subtitle'>Enable for WhatsApp images, phone photos, or screenshots. "
|
| 1403 |
+
"Leave off for clean digital exports.</div>"
|
| 1404 |
+
)
|
| 1405 |
+
|
| 1406 |
+
use_radio = gr.Checkbox(value=False, label=f"Enable {MODEL_NAME_RADIO}")
|
| 1407 |
+
radio_gate = gr.Slider(
|
| 1408 |
+
0.10, 0.40, value=RADIO_GATE_DEFAULT, step=0.01,
|
| 1409 |
+
label="RADIO masked gate (run masked head if lung coverage ≥ gate)"
|
| 1410 |
+
)
|
| 1411 |
+
|
| 1412 |
+
gr.Markdown(
|
| 1413 |
+
"<div class='warnbox'><b>Fail-safe:</b> If lung segmentation is too small or looks unreliable, "
|
| 1414 |
+
f"{MODEL_NAME_TBNET} scoring is disabled to avoid unsafe outputs.</div>"
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
gr.Markdown(
|
| 1418 |
+
f"<div class='subtitle'>Device: <b>{DEVICE}</b> (FORCE_CPU={FORCE_CPU})</div>"
|
| 1419 |
+
)
|
| 1420 |
+
|
| 1421 |
+
back_btn = gr.Button("← Back to Welcome", variant="secondary")
|
| 1422 |
+
|
| 1423 |
+
with gr.Column(scale=2):
|
| 1424 |
+
gr.Markdown("#### Upload images")
|
| 1425 |
+
files = gr.Files(
|
| 1426 |
+
label="Upload one or multiple X-ray images",
|
| 1427 |
+
file_types=[".png", ".jpg", ".jpeg", ".bmp"]
|
| 1428 |
+
)
|
| 1429 |
+
run_btn = gr.Button("Run Analysis", variant="primary")
|
| 1430 |
+
status = gr.Textbox(label="Status", value="Ready.", interactive=False)
|
| 1431 |
+
|
| 1432 |
+
gr.Markdown("""
|
| 1433 |
<div class='legend'><b>Gallery legend:</b><br/>
|
| 1434 |
1) ORIGINAL • 2) INPUT / PHONE-PROC • 3) Lung mask overlay •
|
| 1435 |
4) Masked model input • 5) TBNet Grad-CAM • 6) RADIO heatmaps</div>
|
| 1436 |
""")
|
| 1437 |
|
| 1438 |
+
gr.Markdown("#### Summary (per image)")
|
| 1439 |
+
summary = gr.Markdown("Upload images and click <b>Run Analysis</b>.")
|
| 1440 |
+
gallery = gr.Gallery(label="Visual outputs", columns=3, height=560)
|
| 1441 |
+
|
| 1442 |
+
with gr.Row():
|
| 1443 |
+
with gr.Column(scale=1):
|
| 1444 |
+
disclaimer_box = gr.Markdown(CLINICAL_DISCLAIMER)
|
| 1445 |
+
with gr.Column(scale=2):
|
| 1446 |
+
gr.Markdown("#### Detailed report (expand per image)")
|
| 1447 |
+
details = gr.Markdown("")
|
| 1448 |
+
|
| 1449 |
+
run_btn.click(
|
| 1450 |
+
fn=run_analysis,
|
| 1451 |
+
inputs=[
|
| 1452 |
+
files,
|
| 1453 |
+
tb_weights,
|
| 1454 |
+
lung_weights,
|
| 1455 |
+
backbone,
|
| 1456 |
+
threshold,
|
| 1457 |
+
phone_mode,
|
| 1458 |
+
use_radio,
|
| 1459 |
+
radio_gate,
|
| 1460 |
+
],
|
| 1461 |
+
outputs=[summary, gallery, details, disclaimer_box, status]
|
| 1462 |
)
|
| 1463 |
|
| 1464 |
+
# ---------------------------
|
| 1465 |
+
# Transitions
|
| 1466 |
+
# ---------------------------
|
| 1467 |
+
continue_btn.click(
|
| 1468 |
+
fn=lambda: (gr.update(visible=False), gr.update(visible=True)),
|
| 1469 |
+
inputs=[],
|
| 1470 |
+
outputs=[welcome_screen, main_app],
|
| 1471 |
+
)
|
| 1472 |
+
|
| 1473 |
+
back_btn.click(
|
| 1474 |
+
fn=lambda: (gr.update(visible=True), gr.update(visible=False)),
|
| 1475 |
+
inputs=[],
|
| 1476 |
+
outputs=[welcome_screen, main_app],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1477 |
)
|
| 1478 |
|
| 1479 |
return demo
|
|
|
|
| 1481 |
|
| 1482 |
if __name__ == "__main__":
|
| 1483 |
demo = build_ui()
|
| 1484 |
+
# HF Spaces: let the platform manage host/port. queue() helps stability for longer runs.
|
| 1485 |
+
demo.queue(concurrency_count=1)
|
| 1486 |
+
demo.launch(show_error=True)
|