# ── angio_ai_app.py ─────────────────────────────────────────────────────────────────────────────
# Angio AI — Coronary Angiography Analysis System
# Muhammad Adil · MS Data Science · ITU Lahore
# Supervisor: Dr. Arif Mehmood
#
# ── DEPLOYMENT — HuggingFace Spaces ──────────────────────────────────────────
# Checkpoints are loaded from HF Hub: MuhammadAdil63/angio-ai-checkpoints
# Set HF_TOKEN as a Space Secret (Settings → Secrets) for private repo access.
# Files expected in that repo:
# mask2former_best.pth (Mask2Former Swin-Base fine-tuned)
# binary_best.pth (ResUNet binary vessel segmentation)
# best.pt (YOLOv8m-seg 26-class ARCADE syntax)
#
# ── DEPENDENCIES ────────────────────────────────────────────────────────────────────────────
# pip install gradio torch torchvision opencv-python-headless
# scikit-image scipy matplotlib transformers
# ultralytics (for YOLOv8m-seg inference)
# ─────────────────────────────────────────────────────────────────────────────
# ── Force PyTorch weights_only=False globally (must be before any imports) ───
import os as _os
_os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0"
# ── Python 3.13 audioop shim (removed from stdlib, needed by pydub/gradio) ───
import sys, types
if "audioop" not in sys.modules:
sys.modules["audioop"] = types.ModuleType("audioop")
if "pyaudioop" not in sys.modules:
sys.modules["pyaudioop"] = types.ModuleType("pyaudioop")
# ── Patch gradio_client _json_schema_to_python_type (APIInfoParseError fix) ──
def _patch_gradio_client():
try:
import gradio_client.utils as _gcu
_orig = _gcu._json_schema_to_python_type
def _safe(schema, defs=None):
if not isinstance(schema, dict):
return "Any"
try:
return _orig(schema, defs)
except Exception:
return "Any"
_gcu._json_schema_to_python_type = _safe
_orig_top = _gcu.json_schema_to_python_type
def _safe_top(schema, defs=None):
try:
return _orig_top(schema, defs)
except Exception:
return "Any"
_gcu.json_schema_to_python_type = _safe_top
except Exception:
pass
_patch_gradio_client()
# ── PyTorch 2.7 global weights_only patch ────────────────────────────────────
# YOLO and other libraries call torch.load internally without weights_only=False.
# Monkey-patch torch.load to default to weights_only=False for all calls.
import torch as _torch
_orig_torch_load = _torch.load
def _patched_torch_load(f, map_location=None, pickle_module=None,
weights_only=False, mmap=None, **kwargs):
return _orig_torch_load(f, map_location=map_location,
pickle_module=pickle_module,
weights_only=False, **kwargs)
_torch.load = _patched_torch_load
import os
import sys
import json
import base64
import tempfile
import warnings
warnings.filterwarnings("ignore")
import cv2
import numpy as np
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from pathlib import Path
import gradio as gr
import torch
# ── Patch torch.load for PyTorch 2.7 (weights_only=True default breaks YOLO) ─
_orig_tload = torch.load
def _safe_tload(f, map_location=None, pickle_module=None,
weights_only=None, mmap=None, **kwargs):
return _orig_tload(f, map_location=map_location, weights_only=False, **kwargs)
torch.load = _safe_tload
import huggingface_hub # noqa: F401 (ensure installed; used in _hf_ckpt)
# ─────────────────────────────────────────────────────────────────────────────
# CONFIG — edit these paths before running
# ─────────────────────────────────────────────────────────────────────────────
# ── Checkpoint download from HuggingFace Hub ─────────────────────────────────
# Checkpoints live in: MuhammadAdil63/angio-ai-checkpoints (private repo)
# huggingface_hub reads HF_TOKEN from environment (set as a Space secret).
def _hf_ckpt(filename: str) -> str:
"""Download checkpoint from HF Hub on first call, return local cache path."""
from huggingface_hub import hf_hub_download
return hf_hub_download(
repo_id = "MuhammadAdil63/angio-ai-checkpoints",
filename = filename,
repo_type= "model",
token = os.environ.get("HF_TOKEN"), # set in Space Secrets
)
CONFIG = {
# ── Checkpoint paths (resolved lazily via HF Hub) ─────────────────────────
# These callables are invoked once inside each model's _load_*() function.
"MASK2FORMER_CKPT" : lambda: _hf_ckpt("mask2former_best.pth"),
"RESUNET_CKPT" : lambda: _hf_ckpt("binary_best.pth"),
"YOLO_CKPT" : lambda: _hf_ckpt("best.pt"),
# ── YOLO inference params ─────────────────────────────────────────────────
"YOLO_CONF" : 0.25,
"YOLO_IOU" : 0.70,
"YOLO_IMGSZ" : 512,
# ── FFR pipeline scale (ARCADE hardcoded) ────────────────────────────────
"PX_PER_MM" : 3.75,
# ── Device — CPU on HF free tier ─────────────────────────────────────────
"DEVICE" : "cuda:0" if torch.cuda.is_available() else "cpu",
}
# ── Demo videos (hosted in HF Hub model repo alongside checkpoints) ──────────
# Upload your two XCA demo videos to MuhammadAdil63/angio-ai-checkpoints as:
# demo_video_1.mp4
# demo_video_2.mp4
# The buttons below download and load them automatically.
DEMO_VIDEO_1_NAME = "demo_video_1.mp4"
DEMO_VIDEO_2_NAME = "demo_video_2.mp4"
def _get_demo_video(filename: str) -> str:
"""Download demo video from HF Hub and return local path."""
try:
from huggingface_hub import hf_hub_download
path = hf_hub_download(
repo_id = "MuhammadAdil63/angio-ai-checkpoints",
filename = filename,
repo_type= "model",
token = os.environ.get("HF_TOKEN"),
)
return path
except Exception as e:
print(f"[DEMO] Could not load {filename}: {e}")
return None
# ─────────────────────────────────────────────────────────────────────────────
# ITU LOGO — encode to base64 for embedding in HTML
# Place ITULog.png in the same directory as this script.
# ─────────────────────────────────────────────────────────────────────────────
def _encode_logo(path: str) -> str:
try:
with open(path, "rb") as f:
return "data:image/png;base64," + base64.b64encode(f.read()).decode()
except FileNotFoundError:
return ""
LOGO_ITU = _encode_logo(str(Path(__file__).parent / "ITULog.png"))
# ─────────────────────────────────────────────────────────────────────────────
# CSS THEME — white / bluish, minimalist, medical
# ─────────────────────────────────────────────────────────────────────────────
CSS = """
/* ── Base ── */
:root {
--primary: #1a6fa8;
--primary-lt: #e8f3fb;
--primary-mid: #4a9fd4;
--accent: #0d9e6e; /* Angio AI green */
--bg: #f7fafd;
--surface: #ffffff;
--border: #cce0f0;
--text: #1a2533;
--text-sec: #5a7a96;
--danger: #d64040;
--warn: #d98b00;
--radius: 10px;
--shadow: 0 2px 12px rgba(26,111,168,0.10);
}
body, .gradio-container {
background: var(--bg) !important;
font-family: 'Segoe UI', system-ui, sans-serif !important;
color: var(--text) !important;
}
/* ── Header bar ── */
#header {
background: var(--surface);
border-bottom: 2px solid var(--border);
padding: 14px 28px;
display: flex;
align-items: center;
justify-content: space-between;
box-shadow: var(--shadow);
margin-bottom: 0 !important;
}
#logo-itu img { height: 48px; }
#logo-angio {
font-size: 26px;
font-weight: 700;
color: var(--accent);
letter-spacing: 1px;
}
#logo-angio span {
font-size: 13px;
color: var(--text-sec);
font-weight: 400;
display: block;
letter-spacing: 0;
}
/* ── Tab bar ── */
.tab-nav button {
color: var(--text-sec) !important;
border-bottom: 3px solid transparent !important;
font-size: 15px !important;
font-weight: 500 !important;
padding: 10px 24px !important;
background: transparent !important;
}
.tab-nav button.selected {
color: var(--primary) !important;
border-bottom-color: var(--primary) !important;
}
/* ── Cards ── */
.card {
background: var(--surface);
border: 1px solid var(--border);
border-radius: var(--radius);
padding: 22px;
box-shadow: var(--shadow);
margin-bottom: 16px;
}
.card-title {
font-size: 14px;
font-weight: 600;
color: var(--primary);
text-transform: uppercase;
letter-spacing: 0.5px;
margin-bottom: 12px;
padding-bottom: 8px;
border-bottom: 1px solid var(--primary-lt);
}
/* ── Upload zone ── */
#upload-zone {
border: 2px dashed var(--primary-mid) !important;
border-radius: var(--radius) !important;
background: var(--primary-lt) !important;
min-height: 200px !important;
}
/* ── Sliders ── */
input[type=range] { accent-color: var(--primary); }
.gradio-slider label { color: var(--text-sec) !important; font-size: 13px !important; }
/* ── Buttons ── */
#btn-analyse {
background: var(--primary) !important;
color: white !important;
border: none !important;
border-radius: var(--radius) !important;
font-size: 15px !important;
font-weight: 600 !important;
padding: 12px 32px !important;
cursor: pointer !important;
width: 100% !important;
transition: background 0.2s !important;
}
#btn-analyse:hover { background: #155d90 !important; }
#btn-reset {
background: transparent !important;
color: var(--text-sec) !important;
border: 1px solid var(--border) !important;
border-radius: var(--radius) !important;
font-size: 13px !important;
padding: 8px 20px !important;
cursor: pointer !important;
width: 100% !important;
}
/* ── Metric badges ── */
.metric-row {
display: flex;
gap: 12px;
flex-wrap: wrap;
margin-bottom: 12px;
}
.metric-badge {
flex: 1;
min-width: 110px;
background: var(--primary-lt);
border: 1px solid var(--border);
border-radius: 8px;
padding: 10px 14px;
text-align: center;
}
.metric-badge .val {
font-size: 22px;
font-weight: 700;
color: var(--primary);
}
.metric-badge .val.ischemic { color: var(--danger); }
.metric-badge .val.borderline { color: var(--warn); }
.metric-badge .val.ok { color: var(--accent); }
.metric-badge .lbl {
font-size: 11px;
color: var(--text-sec);
text-transform: uppercase;
letter-spacing: 0.4px;
}
/* ── Result images ── */
#result-images img {
border-radius: var(--radius);
border: 1px solid var(--border);
}
/* ── Status ── */
#status-box {
padding: 10px 16px;
border-radius: 8px;
font-size: 13px;
font-weight: 500;
border: 1px solid var(--border);
background: var(--primary-lt);
color: var(--primary);
}
/* ── Footer ── */
#footer {
text-align: center;
font-size: 12px;
color: var(--text-sec);
padding: 20px;
border-top: 1px solid var(--border);
margin-top: 24px;
}
"""
# ─────────────────────────────────────────────────────────────────────────────
# HEADER HTML
# ─────────────────────────────────────────────────────────────────────────────
def _header_html() -> str:
itu_tag = (f'
'
if LOGO_ITU else
'ITU')
return f"""
"""
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 1 — KEYFRAME EXTRACTION (from app.py)
# ─────────────────────────────────────────────────────────────────────────────
def extract_best_frame(video_path: str, n: int = 1) -> tuple:
"""
Extract the single best frame from the video.
Returns (frame_rgb_512, frame_index, score).
"""
cap = cv2.VideoCapture(video_path)
total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
if total == 0:
cap.release()
raise ValueError("Could not read video — check format/codec.")
step = max(1, total // min(120, total))
scores = []
fidx = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if fidx % step == 0:
gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
h, w = gray.shape
roi = gray[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)]
contrast = float(roi.mean())
sharpness = float(cv2.Laplacian(roi, cv2.CV_64F).var())
score = contrast * 0.6 + min(sharpness, 5000) / 5000 * 100 * 0.4
scores.append((score, fidx, frame.copy()))
fidx += 1
cap.release()
if not scores:
raise ValueError("No frames could be sampled from video.")
scores.sort(key=lambda x: x[0], reverse=True)
best_score, best_idx, best_frame = scores[0]
rgb = cv2.cvtColor(best_frame, cv2.COLOR_BGR2RGB)
rgb = cv2.resize(rgb, (512, 512))
return rgb, best_idx, best_score, total
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 2 — MASK2FORMER INFERENCE
# ─────────────────────────────────────────────────────────────────────────────
_m2f_model = None
_m2f_proc = None
def _load_mask2former():
global _m2f_model, _m2f_proc
if _m2f_model is not None:
return
from transformers import (AutoImageProcessor,
Mask2FormerForUniversalSegmentation)
ckpt = CONFIG["MASK2FORMER_CKPT"]() # resolves via HF Hub download
if not os.path.exists(ckpt):
raise FileNotFoundError(f"Mask2Former checkpoint not found:\n{ckpt}")
_m2f_proc = AutoImageProcessor.from_pretrained(
"facebook/mask2former-swin-base-coco-instance")
_m2f_model = Mask2FormerForUniversalSegmentation.from_pretrained(
"facebook/mask2former-swin-base-coco-instance",
num_labels=2, ignore_mismatched_sizes=True)
state = torch.load(ckpt, map_location="cpu", weights_only=False)
weights = state.get("model_state", state)
_m2f_model.load_state_dict(weights, strict=False)
_m2f_model.to(CONFIG["DEVICE"]).eval()
def run_mask2former(frame_rgb: np.ndarray, conf_thr: float) -> list:
"""
Returns list of dicts: [{bbox:[x1,y1,x2,y2], score:float, mask:np.ndarray}, ...]
"""
_load_mask2former()
from PIL import Image
pil = Image.fromarray(frame_rgb)
inputs = _m2f_proc(images=pil, return_tensors="pt")
inputs = {k: v.to(CONFIG["DEVICE"]) for k, v in inputs.items()}
with torch.no_grad():
outputs = _m2f_model(**inputs)
results = _m2f_proc.post_process_instance_segmentation(
outputs, target_sizes=[(512, 512)])[0]
detections = []
for seg_info in results["segments_info"]:
score = float(seg_info["score"])
if score < conf_thr:
continue
mask = (results["segmentation"] == seg_info["id"]).cpu().numpy().astype(np.uint8)
ys, xs = np.where(mask)
if len(ys) == 0:
continue
x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())
detections.append({"bbox": [x1, y1, x2, y2], "score": score, "mask": mask})
return detections
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 3 — RESUNET BINARY SEGMENTATION
# ─────────────────────────────────────────────────────────────────────────────
_resunet_model = None
# ── ResUNet architecture — exactly matching binary_best.pth ──────────────────
#
# Decoded from checkpoint key names via dump_keys.py:
#
# enc0 Conv2d(1, 16, 3, 3, bias=True) plain entry conv
# down1..4 ConvTranspose2d strided ×2 downsampler 16→32→64→128→256
# res1..4 _ResBlock(ch) encoder residual blocks ch=16,32,64,128
# bot.0..2 _ResBlock(256) × 3 bottleneck
# up3..0 ConvTranspose2d ×2 upsampler 256→128→64→32→16
# res12..15 _ResBlock(ch) decoder residual blocks ch=128,64,32,16
# head Sequential[Conv3×3, ReLU, BN, output head
# Conv3×3, ReLU, BN, Conv1×1]
#
# _ResBlock structure (from key pattern res*.bn_in, res*.units.N.*, res*.bn_out):
# bn_in → BN
# units → ModuleList of N sub-units, each: [Conv3×3, ReLU, Conv3×3, BN]
# bn_out → BN
# forward: x → bn_in → for each unit: residual(x + unit(x)) → bn_out
# ─────────────────────────────────────────────────────────────────────────────
class _ResBlock(torch.nn.Module):
"""
Matches checkpoint pattern:
res*.bn_in.*
res*.units.0.0.weight (Conv3×3)
res*.units.0.0.bias
res*.units.0.2.weight (Conv3×3)
res*.units.0.2.bias
res*.units.0.3.* (BN)
... (3 units total)
res*.bn_out.*
Sequential index mapping inside each unit:
0 = Conv2d
1 = ReLU (no params — not in state dict)
2 = Conv2d
3 = BatchNorm2d
"""
def __init__(self, ch, n_units=3):
super().__init__()
self.bn_in = torch.nn.BatchNorm2d(ch)
self.units = torch.nn.ModuleList([
torch.nn.Sequential(
torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True),
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True),
torch.nn.BatchNorm2d(ch),
)
for _ in range(n_units)
])
self.bn_out = torch.nn.BatchNorm2d(ch)
def forward(self, x):
x = self.bn_in(x)
for unit in self.units:
x = x + unit(x) # no extra ReLU on residual sum
return self.bn_out(x)
class ResUNet(torch.nn.Module):
"""
Exact architecture matching binary_best.pth (Dice = 0.8015).
Key name mapping:
enc0 → self.enc0 (plain Conv2d)
down1..4 → self.down1..down4 (ConvTranspose2d strided downsampler)
res1..4 → self.res1..res4 (encoder ResBlocks)
bot.0..2 → self.bot[0..2] (bottleneck ResBlocks)
up3..0 → self.up3..up0 (ConvTranspose2d upsampler)
res12..15 → self.res12..res15 (decoder ResBlocks)
head → self.head (Sequential output head)
"""
def __init__(self):
super().__init__()
# ── Entry conv ────────────────────────────────────────────────────────
self.enc0 = torch.nn.Conv2d(1, 16, 3, padding=1, bias=True)
# ── Encoder downsamples — strided Conv2d (not ConvTranspose2d) ─────────
# Checkpoint shape (out, in, kH, kW): down1=(32,16,2,2) → Conv2d(16→32)
# stride=2 kernel=2 halves spatial dimensions while increasing channels
self.down1 = torch.nn.Conv2d(16, 32, 2, stride=2)
self.down2 = torch.nn.Conv2d(32, 64, 2, stride=2)
self.down3 = torch.nn.Conv2d(64, 128, 2, stride=2)
self.down4 = torch.nn.Conv2d(128, 256, 2, stride=2)
# ── Encoder residual blocks ───────────────────────────────────────────
self.res1 = _ResBlock(16)
self.res2 = _ResBlock(32)
self.res3 = _ResBlock(64)
self.res4 = _ResBlock(128)
# ── Bottleneck ────────────────────────────────────────────────────────
self.bot = torch.nn.ModuleList([
_ResBlock(256),
_ResBlock(256),
_ResBlock(256),
])
# ── Decoder upsamplers ────────────────────────────────────────────────
# Shape from keys: up3=(256,128,2,2) bias=(128,)
# up2=(128,64,2,2) bias=(64,)
# up1=(64,32,2,2) bias=(32,)
# up0=(32,16,2,2) bias=(16,)
self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
self.up1 = torch.nn.ConvTranspose2d(64, 32, 2, stride=2)
self.up0 = torch.nn.ConvTranspose2d(32, 16, 2, stride=2)
# ── Decoder residual blocks ───────────────────────────────────────────
self.res12 = _ResBlock(128)
self.res13 = _ResBlock(64)
self.res14 = _ResBlock(32)
self.res15 = _ResBlock(16)
# ── Output head ───────────────────────────────────────────────────────
# From keys: head.0=Conv(16,16,3,3), head.2=BN(16), head.3=Conv(16,16,3,3),
# head.5=BN(16), head.6=Conv(2,16,1,1)
# Index 1 and 4 are ReLU (no state dict entries)
self.head = torch.nn.Sequential(
torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 0
torch.nn.ReLU(inplace=True), # 1
torch.nn.BatchNorm2d(16), # 2
torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 3
torch.nn.ReLU(inplace=True), # 4
torch.nn.BatchNorm2d(16), # 5
torch.nn.Conv2d(16, 2, 1, bias=True), # 6
)
def forward(self, x):
# ── Encoder ──────────────────────────────────────────────────────────
# Entry conv → ResBlock → down. Skip taken AFTER ResBlock.
x = self.enc0(x) # (B,16,512,512)
r1 = self.res1(x); x = self.down1(r1) # skip r1, down → (B,32,256,256)
r2 = self.res2(x); x = self.down2(r2) # skip r2, down → (B,64,128,128)
r3 = self.res3(x); x = self.down3(r3) # skip r3, down → (B,128,64,64)
r4 = self.res4(x); x = self.down4(r4) # skip r4, down → (B,256,32,32)
# ── Bottleneck ───────────────────────────────────────────────────────
for blk in self.bot:
x = blk(x)
# ── Decoder — ResBlock first, then add skip ───────────────────────────
x = self.res12(self.up3(x)) + r4 # (B,128,64,64)
x = self.res13(self.up2(x)) + r3 # (B,64,128,128)
x = self.res14(self.up1(x)) + r2 # (B,32,256,256)
x = self.res15(self.up0(x)) + r1 # (B,16,512,512)
return self.head(x)
def _load_resunet():
global _resunet_model
if _resunet_model is not None:
return
ckpt = CONFIG["RESUNET_CKPT"]() # resolves via HF Hub download
if not os.path.exists(ckpt):
raise FileNotFoundError(f"ResUNet checkpoint not found:\n{ckpt}")
_resunet_model = ResUNet()
state = torch.load(ckpt, map_location="cpu", weights_only=False)
weights = state.get("model_state", state)
# Try strict first; if it fails print exact mismatches and retry strict=False
try:
_resunet_model.load_state_dict(weights, strict=True)
print(f"[ResUNet] ✓ Perfect load (strict=True) — "
f"epoch={state.get('epoch','?')} "
f"dice={state.get('best_val_dice',0):.4f}")
except RuntimeError as e:
print(f"[ResUNet] strict=True failed:\n {e}")
result = _resunet_model.load_state_dict(weights, strict=False)
print(f"[ResUNet] strict=False — "
f"missing={len(result.missing_keys)} "
f"unexpected={len(result.unexpected_keys)}")
for k in result.missing_keys[:10]:
print(f" MISSING {k}")
for k in result.unexpected_keys[:10]:
print(f" UNEXPECTED {k}")
_resunet_model.to(CONFIG["DEVICE"]).eval()
print(f"[ResUNet] Loaded (eval mode) — "
f"epoch={state.get('epoch','?')} "
f"dice={state.get('best_val_dice',0):.4f}")
def _preprocess_resunet(rgb: np.ndarray) -> torch.Tensor:
"""ARCADE preprocessing: invert → top-hat → subtract → CLAHE → normalise."""
gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY).astype(np.uint8)
inv = cv2.bitwise_not(gray)
se = cv2.getStructuringElement(cv2.MORPH_RECT, (50, 50))
th = cv2.morphologyEx(inv, cv2.MORPH_TOPHAT, se)
sub = cv2.subtract(gray, th)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
proc = clahe.apply(sub).astype(np.float32) / 255.0
t = torch.from_numpy(proc).unsqueeze(0).unsqueeze(0) # (1,1,512,512)
return t
def run_resunet(frame_rgb: np.ndarray):
"""
Returns:
binary_mask : (512,512) uint8, values 0/1
binary_overlay: (512,512,3) uint8 RGB — green vessel overlay on original
Architecture fix applied: skip connections are now res(up(x)) + skip
(not res(up(x) + skip)), and ResBlock has no extra ReLU on residual sum.
Model runs in eval() mode — BN running stats are now correct.
"""
_load_resunet()
_resunet_model.eval()
t = _preprocess_resunet(frame_rgb).to(CONFIG["DEVICE"])
with torch.no_grad():
logits = _resunet_model(t)
pred = logits.argmax(dim=1).squeeze().cpu().numpy().astype(np.uint8)
vessel_px = int((pred == 1).sum())
print(f"[ResUNet] vessel px={vessel_px} ({100*(pred==1).mean():.1f}%)")
# Build green overlay for display in Results tab
overlay = frame_rgb.copy().astype(np.float32)
green = np.array([0, 255, 80], dtype=np.float32)
mask3d = pred == 1
overlay[mask3d] = overlay[mask3d] * 0.35 + green * 0.65
overlay = overlay.astype(np.uint8)
# Draw contours for sharper vessel edges
contours, _ = cv2.findContours(pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cv2.drawContours(overlay, contours, -1, (0, 255, 80), 1)
# Coverage label
cov = 100 * vessel_px / pred.size
cv2.putText(overlay, f"Vessel coverage: {cov:.1f}%",
(8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 80), 1, cv2.LINE_AA)
return pred, overlay
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 4 — YOLOv8m-seg INFERENCE (replaces nnUNet)
# ─────────────────────────────────────────────────────────────────────────────
#
# Checkpoint : CONFIG["YOLO_CKPT"] -> best.pt from YOLOv8m-seg training
# Model : YOLOv8m-seg, nc=26, trained on ARCADE syntax dataset
#
# Input : frame_rgb -- (512,512,3) uint8 RGB
# seg_conf -- confidence threshold from UI slider
# Output: seg_map -- (512,512) uint8 label map, values 0..25
# 0 = background (no detection)
# 1..25 = ARCADE class index
#
# Output format is identical to the old run_nnunet() so that
# render_nnunet_overlay() and compute_syntax_score() work unchanged.
# ─────────────────────────────────────────────────────────────────────────────
# ARCADE 26-class name list (index = YOLO cls_id + 1)
_YOLO_CLASS_NAMES = [
"background", # 0 (background, not predicted by YOLO)
"LMCA", # 1
"LAD prox", # 2
"LAD mid", # 3
"LAD dist", # 4
"D1", # 5
"D2", # 6
"D3", # 7
"Septal", # 8
"LAD var", # 9
"LCX prox", # 10
"LCX mid", # 11
"LCX dist", # 12
"OM1", # 13
"OM2", # 14
"OM3", # 15
"LCX var", # 16
"RCA prox", # 17
"RCA mid", # 18
"RCA dist", # 19
"PDA", # 20
"PLV", # 21
"AM", # 22
"Conus", # 23
"SAN", # 24
"Other", # 25
]
_yolo_model = None # singleton -- loaded once on first call, reused every frame
def _load_yolo():
"""Lazy-load YOLOv8m-seg checkpoint into module-level singleton."""
global _yolo_model
if _yolo_model is not None:
return
try:
from ultralytics import YOLO
except ImportError:
raise ImportError(
"ultralytics not installed. Run: pip install ultralytics"
)
ckpt = CONFIG["YOLO_CKPT"]() # resolves via HF Hub download
if not os.path.exists(ckpt):
raise FileNotFoundError(
f"YOLOv8 checkpoint not found:\n{ckpt}\n"
f"Check HF_TOKEN secret and repo MuhammadAdil63/angio-ai-checkpoints"
)
print(f"[SEG] Loading checkpoint: {ckpt}")
try:
import torch.serialization
from ultralytics.nn.tasks import SegmentationModel
torch.serialization.add_safe_globals([SegmentationModel])
except Exception as _e:
print(f"[SEG] safe_globals warning: {_e}")
_yolo_model = YOLO(ckpt)
print(f"[SEG] Model loaded OK")
def run_yolo_seg(frame_rgb: np.ndarray, seg_conf: float) -> np.ndarray:
"""
Run YOLOv8m-seg on a single RGB frame.
Returns seg_map (512x512 uint8), label values 0..25.
Label map composition:
- Initialised to 0 (background).
- Instances sorted ascending by confidence so higher-conf masks
overwrite lower-conf ones at overlapping pixels.
- YOLO cls_id is 0-based (0..24 for nc=25 visible classes).
We store label_val = cls_id + 1 so 0 stays background,
matching ARCADE_COLOURS / ARCADE_LABELS which use keys 1..25.
"""
_load_yolo()
device = CONFIG["DEVICE"]
conf = float(seg_conf) if seg_conf else CONFIG["YOLO_CONF"]
iou = CONFIG["YOLO_IOU"]
imgsz = CONFIG["YOLO_IMGSZ"]
# YOLO expects BGR
frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR)
try:
results = _yolo_model.predict(
source = frame_bgr,
conf = conf,
iou = iou,
imgsz = imgsz,
device = device,
retina_masks = False,
verbose = False,
)
except Exception as e:
print(f"[SEG] predict() failed: {e}")
return np.zeros((512, 512), dtype=np.uint8)
result = results[0]
h, w = frame_rgb.shape[:2]
seg_map = np.zeros((h, w), dtype=np.uint8)
if result.masks is None or result.boxes is None:
print("[SEG] No masks returned -- check conf threshold or checkpoint.")
return seg_map
masks = result.masks.data.cpu().numpy() # (N, Hm, Wm) float32
cls_ids = result.boxes.cls.cpu().int().numpy() # (N,) 0-based
confs = result.boxes.conf.cpu().numpy() # (N,)
# Ascending confidence sort so higher-conf masks win at overlaps
order = np.argsort(confs)
n_written = 0
for i in order:
cls_id = int(cls_ids[i])
mask_resized = cv2.resize(
masks[i], (w, h), interpolation=cv2.INTER_LINEAR
)
binary = mask_resized > 0.5
label_val = min(cls_id + 1, 25) # clamp to valid 1-25 range
seg_map[binary] = label_val
n_written += int(binary.sum())
detected_cls = [_YOLO_CLASS_NAMES[min(int(c) + 1, 25)] for c in cls_ids]
print(f"[SEG] Detections: {len(cls_ids)} | classes: {detected_cls} | "
f"vessel px: {n_written} ({100 * n_written / (h * w):.1f}%)")
print(f"[SEG] seg_map unique labels: {np.unique(seg_map)}")
return seg_map
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 5 — FFR PIPELINE (inline, from ffr_pipeline_v4.py)
# ─────────────────────────────────────────────────────────────────────────────
from skimage.morphology import skeletonize, remove_small_objects, disk
from skimage.measure import label as sk_label
from scipy.ndimage import distance_transform_edt
from scipy.signal import argrelmin, savgol_filter
from skimage.filters import threshold_otsu
FFR_CFG = {
"close_radius" : 3,
"open_radius" : 1,
"min_vessel_px" : 100,
"stenosis_overlap_thr": 0.05,
# Lowered from 0.20 → 0.05 for video frames which have sparser masks
# than training PNGs — prevents all detections being unlocalized
"min_overlap_for_ffr": 0.05,
"min_ds_fraction" : 0.10,
"min_branch_len" : 15,
"use_dominant_branch": True,
"min_coverage_warn" : 0.50,
"sample_step" : 3,
"perp_half_len" : 40, # reduced from 50 to avoid hitting background
"smooth_window" : 11,
"smooth_poly" : 2,
"reference_pct" : 75,
# Physiological cap: coronary vessels are 1.5–5mm in diameter
# If auto-computed ref exceeds this, clamp it
"ref_diam_max_mm" : 5.0,
"ref_diam_min_mm" : 1.0,
"px_per_mm" : 3.75,
}
def _refine_mask(img_gray, binary_mask):
masked = img_gray[binary_mask == 1]
if len(masked) == 0:
return binary_mask.copy()
thr = threshold_otsu(masked)
return ((binary_mask == 1) & (img_gray < thr)).astype(np.uint8)
def _morph_clean(mask):
m = mask.copy()
k_c = disk(FFR_CFG["close_radius"]).astype(np.uint8)
k_o = disk(FFR_CFG["open_radius"]).astype(np.uint8)
m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k_c)
m = cv2.morphologyEx(m, cv2.MORPH_OPEN, k_o)
return remove_small_objects(m.astype(bool),
min_size=FFR_CFG["min_vessel_px"]).astype(np.uint8)
def _find_branch_pts(skel):
k = np.ones((3,3), np.uint8)
n = cv2.filter2D(skel.astype(np.uint8), -1, k) * skel.astype(np.uint8)
return n >= 4
def _dominant_branch(skel):
bp = _find_branch_pts(skel)
sp = skel.copy().astype(np.uint8); sp[bp] = 0
lab = sk_label(sp)
if lab.max() == 0: return skel
sizes = np.bincount(lab.ravel()); sizes[0] = 0
return (lab == sizes.argmax())
def _order_centerline(skel):
pts = np.argwhere(skel)
if len(pts) == 0: return []
sk_set = set(map(tuple, pts))
def nbrs(p):
r,c = p
return [(r+dr,c+dc) for dr in [-1,0,1] for dc in [-1,0,1]
if (dr,dc)!=(0,0) and (r+dr,c+dc) in sk_set]
endpoints = [p for p in sk_set if len(nbrs(p))==1] or [tuple(pts[0])]
def trace(start):
visited={tuple(start)}; path=[tuple(start)]; cur=tuple(start)
while True:
cands=[n for n in nbrs(cur) if n not in visited]
if not cands: break
if len(path)>1:
pr=path[-2]; dr0=cur[0]-pr[0]; dc0=cur[1]-pr[1]
cands.sort(key=lambda n:abs(n[0]-cur[0]-dr0)+abs(n[1]-cur[1]-dc0))
cur=cands[0]; visited.add(cur); path.append(cur)
return path
best=[]
for ep in endpoints[:20]:
p=trace(ep)
if len(p)>len(best): best=p
return best
def _measure_diameters(cl_pts, refined_mask):
H,W=refined_mask.shape
step=FFR_CFG["sample_step"]; half=FFR_CFG["perp_half_len"]
arc_len=0.; arc_pos=[]; diams=[]; wall_pairs=[]; centers=[]
for i in range(0, len(cl_pts)-step, step):
p0=np.array(cl_pts[i], dtype=float)
p1=np.array(cl_pts[min(i+step,len(cl_pts)-1)], dtype=float)
tang=p1-p0; tl=np.linalg.norm(tang)
if tl<1e-9: continue
tang/=tl; norm=np.array([-tang[1],tang[0]])
def wd(d):
for dd in np.arange(0.5,half,0.5):
pt=p0+d*dd; r_=int(round(pt[0])); c_=int(round(pt[1]))
if r_<0 or r_>=H or c_<0 or c_>=W: return dd,pt
if refined_mask[r_,c_]==0: return dd,pt
return half, p0+d*half
dl,pl=wd( norm); dr,pr=wd(-norm); diam=dl+dr
if diam<1. or diam>half*1.8: continue
if len(centers)==0: step_arc=0.
else: step_arc=float(np.linalg.norm(p0-np.array(centers[-1])))
arc_len+=min(step_arc, FFR_CFG["sample_step"]*1.5)
arc_pos.append(arc_len); diams.append(diam)
wall_pairs.append((pl,pr)); centers.append(p0)
return (np.array(arc_pos), np.array(diams,dtype=np.float32),
wall_pairs, centers)
def _ds_to_ffr(ds):
return max(0., min(1., 1.-(0.33*ds+0.60*ds**2)))
def run_ffr(frame_rgb, binary_mask, detections):
"""Full FFR pipeline v4. Returns result dict + annotated figure (np array)."""
img_gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY)
px_mm = FFR_CFG["px_per_mm"]
# ── Debug: report mask quality before refinement ──────────────────────────
bm_sum = int(binary_mask.sum())
bm_pct = 100 * binary_mask.mean()
print(f"[FFR] binary_mask entering FFR: sum={bm_sum} {bm_pct:.1f}% of image")
# ── Use binary mask directly — skip Otsu refinement ─────────────────────
# ResUNet already produces a clean vessel segmentation (Dice=0.80).
# Otsu refinement was designed for use with ground-truth binary masks from
# the ARCADE dataset where it further removes bright background pixels.
# On video frames, Otsu degrades the ResUNet mask because the vessel/bg
# intensity distribution differs from training PNGs.
# Using the binary mask directly preserves all vessel pixels.
refined = binary_mask.copy()
print(f"[FFR] Using binary mask directly (skip Otsu): sum={refined.sum()}")
refined = _morph_clean(refined)
print(f"[FFR] After morph cleanup: sum={refined.sum()}")
skel_raw = skeletonize(refined.astype(bool))
print(f"[FFR] Skeleton pixels: {skel_raw.sum()}")
# Prune spurs
sk = skel_raw.copy().astype(np.uint8)
for _ in range(FFR_CFG["min_branch_len"]):
k=np.ones((3,3),np.uint8)
nb=cv2.filter2D(sk,-1,k)*sk
ep=(nb==2).astype(np.uint8)
sk=cv2.subtract(sk,ep)
if ep.sum()==0: break
skel_pruned=sk.astype(bool)
print(f"[FFR] Pruned skeleton: {skel_pruned.sum()}")
skel_dom = _dominant_branch(skel_pruned)
print(f"[FFR] Dominant branch: {skel_dom.sum()}")
cl_pts = _order_centerline(skel_dom)
print(f"[FFR] Ordered centerline: {len(cl_pts)} pts")
# Fallback: only if dominant branch is very short (< 30 pts)
# Do NOT fall back to full skeleton — it traces through noise fragments
if len(cl_pts) < 30:
print(f"[FFR] Dominant branch short ({len(cl_pts)} pts) — trying full pruned skeleton")
cl_pts_full = _order_centerline(skel_pruned)
if len(cl_pts_full) > len(cl_pts):
cl_pts = cl_pts_full
print(f"[FFR] Full skeleton centerline: {len(cl_pts)} pts")
if len(cl_pts) < 5:
return {
"error": (
f"Centerline too short ({len(cl_pts)} pts). "
f"binary_mask={bm_sum}px ({bm_pct:.1f}%), "
f"skeleton={int(skel_raw.sum())}px."
)
}, None
arc_pos, diams_px, wall_pairs, centers = _measure_diameters(cl_pts, refined)
if len(diams_px)==0:
return {"error":"No diameter measurements obtained."}, None
win=FFR_CFG["smooth_window"]
if len(diams_px)=3 else diams_px.copy()
ref_px = np.percentile(diams_sm, FFR_CFG["reference_pct"])
ref_mm = ref_px / px_mm
# ── Physiological sanity clamp ────────────────────────────────────────────
# Coronary vessels are 1.5–5.0 mm in diameter.
# If ref_mm is outside this range, the skeleton traced through noise/background.
# Clamp to physiological limits and recompute ref_px accordingly.
ref_mm_raw = ref_mm
ref_mm = float(np.clip(ref_mm,
FFR_CFG["ref_diam_min_mm"],
FFR_CFG["ref_diam_max_mm"]))
ref_px = ref_mm * px_mm
if abs(ref_mm_raw - ref_mm) > 0.1:
print(f"[FFR] ref_diam clamped: {ref_mm_raw:.2f}mm → {ref_mm:.2f}mm "
f"(physiological range {FFR_CFG['ref_diam_min_mm']}–"
f"{FFR_CFG['ref_diam_max_mm']}mm)")
# Also filter out diameter measurements that exceed 2× ref_px (noise)
valid_mask = diams_sm <= (ref_px * 2.5)
if valid_mask.sum() > 5:
arc_pos = arc_pos[valid_mask]
diams_px = diams_px[valid_mask]
diams_sm = diams_sm[valid_mask]
wall_pairs = [wp for wp, v in zip(wall_pairs, valid_mask) if v]
centers = [c for c, v in zip(centers, valid_mask) if v]
centers_arr = np.array(centers)
print(f"[FFR] After noise filter: {len(diams_sm)} measurements kept")
stenoses=[]; unloc=[]
for det in detections:
bbox = det.get("bbox",[])
score = det.get("score",1.)
overlap = det.get("overlap",0.)
if len(bbox)<4: continue
if overlap < FFR_CFG["min_overlap_for_ffr"]:
unloc.append({**det,"reason":"low-overlap"}); continue
x1,y1,x2,y2=bbox
inside=np.where((centers_arr[:,0]>=y1)&(centers_arr[:,0]<=y2)&
(centers_arr[:,1]>=x1)&(centers_arr[:,1]<=x2))[0]
idx = int(inside[np.argmin(diams_sm[inside])]) if len(inside)>0 else \
int(np.argmin(np.linalg.norm(centers_arr-np.array([(y1+y2)/2,(x1+x2)/2]),axis=1)))
if idx>=len(diams_sm): continue
d_sten=float(diams_sm[idx]); ds=max(0.,1.-d_sten/ref_px)
if ds < FFR_CFG["min_ds_fraction"]:
unloc.append({**det,"reason":"no-narrowing"}); continue
ffr_est=_ds_to_ffr(ds)
stenoses.append({"index":idx,"position_mm":float(arc_pos[idx]/px_mm),
"d_mm":float(d_sten/px_mm),"ref_mm":float(ref_mm),
"pct_ds":float(ds*100),"ffr":ffr_est,
"score":score,"significant":ffr_est<=0.80})
global_ffr=1.
for s in stenoses: global_ffr*=s["ffr"]
global_ffr=max(0.,global_ffr)
# Build figure
fig, axes = plt.subplots(1,3, figsize=(15,5))
fig.patch.set_facecolor("white")
# Panel 1: stenosis overlay
axes[0].imshow(frame_rgb, cmap="gray")
for det in detections:
b=det.get("bbox",[]); sc=det.get("score",0)
if len(b)<4: continue
x1,y1,x2,y2=b
col="red" if sc>=0.5 else "orange"
axes[0].add_patch(mpatches.Rectangle(
(x1,y1),x2-x1,y2-y1,fill=False,edgecolor=col,linewidth=2))
axes[0].text(x1,y1-5,f"{sc:.2f}",color=col,fontsize=8,fontweight="bold")
axes[0].axis("off"); axes[0].set_title("Stenosis detections", fontsize=11, color="#1a6fa8")
# Panel 2: centerline + diameter lines
overlay=frame_rgb.copy()
sk_pts=np.argwhere(skel_dom)
for r,c in sk_pts:
if 0<=r<512 and 0<=c<512: overlay[r,c]=[0,206,209]
axes[1].imshow(overlay)
for i,(wl,wr) in enumerate(wall_pairs):
axes[1].plot([wl[1],wr[1]],[wl[0],wr[0]],
color="lime",alpha=0.4,linewidth=0.6)
for st in stenoses:
idx=st["index"]
if idx",color="red",lw=0.8))
axes[1].axis("off"); axes[1].set_title("Centerline + FFR markers", fontsize=11, color="#1a6fa8")
# Panel 3: diameter profile
if len(arc_pos)>0:
pos_mm=arc_pos/px_mm
axes[2].plot(pos_mm,diams_px/px_mm,color="#cce0f0",lw=0.8,label="Raw")
axes[2].plot(pos_mm,diams_sm/px_mm,color="#1a6fa8",lw=1.8,label="Smoothed")
axes[2].axhline(ref_mm,color="#0d9e6e",ls="--",lw=1.2,label=f"Ref {ref_mm:.2f}mm")
for st in stenoses:
axes[2].axvline(st["position_mm"],color="red",ls=":",lw=1.,alpha=0.8)
axes[2].annotate(f'{st["pct_ds"]:.0f}%DS\nFFR≈{st["ffr"]:.2f}',
xy=(st["position_mm"],st["d_mm"]),
xytext=(st["position_mm"]+2,st["d_mm"]+ref_mm*0.15),
color="red",fontsize=7,fontweight="bold",
arrowprops=dict(arrowstyle="->",color="red",lw=0.7))
axes[2].set_xlabel("Position (mm)",fontsize=9)
axes[2].set_ylabel("Diameter (mm)",fontsize=9)
axes[2].legend(fontsize=8); axes[2].grid(alpha=0.3); axes[2].set_ylim(0)
axes[2].set_title("Lumen diameter profile", fontsize=11, color="#1a6fa8")
ffr_label = ("⚠ ISCHEMIC" if global_ffr<=0.80 else "✓ Non-ischemic")
fig.suptitle(f"FFR Analysis — Global FFR: {global_ffr:.3f} {ffr_label} | "
f"Stenoses: {len(stenoses)} | Ref: {ref_mm:.2f} mm",
fontsize=12,fontweight="bold",color="#1a2533",y=1.01)
plt.tight_layout()
# Convert figure to numpy array — compatible with all matplotlib versions
fig.canvas.draw()
w, h = fig.canvas.get_width_height()
try:
# matplotlib >= 3.8
buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
buf = buf.reshape(h, w, 4)[:, :, :3] # drop alpha channel
except AttributeError:
# matplotlib < 3.8 fallback
buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
buf = buf.reshape(h, w, 3)
plt.close(fig)
result = {
"global_ffr" : round(global_ffr,3),
"ischemic" : global_ffr<=0.80,
"n_stenoses" : len(stenoses),
"stenoses" : stenoses,
"ref_diam_mm" : round(ref_mm,2),
"vessel_len_mm": round(float(arc_pos[-1]/px_mm),1) if len(arc_pos)>0 else 0,
"n_unlocalized": len(unloc),
}
return result, buf
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 6 — nnUNet OVERLAY
# ─────────────────────────────────────────────────────────────────────────────
# 26 maximally distinct colours — one per ARCADE class (0 = background, black)
# Generated by evenly spacing hue across HSV wheel at two brightness levels,
# then manually verified for distinctness. Each colour is unique.
# Format: (R, G, B)
ARCADE_COLOURS = {
0 : ( 0, 0, 0), # background — black (not rendered)
1 : (255, 0, 0), # LMCA — red
2 : ( 0, 120, 255), # LAD proximal — azure blue
3 : ( 0, 220, 0), # LAD mid — lime green
4 : (255, 165, 0), # LAD distal — orange
5 : (180, 0, 255), # D1 — violet
6 : ( 0, 230, 230), # D2 — cyan
7 : (255, 230, 0), # D3 — yellow
8 : (255, 0, 180), # Septal — hot pink
9 : ( 0, 180, 80), # LAD var — jade green
10: (255, 100, 0), # LCX proximal — deep orange
11: ( 40, 40, 255), # LCX mid — royal blue
12: ( 0, 255, 140), # LCX distal — spring green
13: (220, 0, 80), # OM1 — crimson
14: ( 0, 200, 255), # OM2 — sky blue
15: (200, 255, 0), # OM3 — chartreuse
16: (140, 0, 200), # LCX var — purple
17: (255, 200, 120), # RCA proximal — peach
18: ( 80, 255, 80), # RCA mid — bright green
19: (255, 60, 200), # RCA distal — rose pink
20: ( 0, 80, 200), # PDA — cobalt blue
21: (200, 140, 0), # PLV — dark amber
22: ( 80, 200, 200), # AM branch — teal
23: (255, 255, 120), # Conus — light yellow
24: (160, 60, 0), # SAN artery — brown
25: (180, 180, 180), # Other — silver grey
}
ARCADE_LABELS = {
1:"LMCA", 2:"LAD prox", 3:"LAD mid", 4:"LAD dist",
5:"D1", 6:"D2", 7:"D3", 8:"Septal",
9:"LAD var", 10:"LCX prox", 11:"LCX mid", 12:"LCX dist",
13:"OM1", 14:"OM2", 15:"OM3", 16:"LCX var",
17:"RCA prox",18:"RCA mid", 19:"RCA dist", 20:"PDA",
21:"PLV", 22:"AM", 23:"Conus", 24:"SAN",
25:"Other",
}
def render_nnunet_overlay(frame_rgb, seg_map, alpha=0.55):
"""
Renders nnUNet segmentation overlay.
Handles two cases:
- Binary output (classes 0,1): single vessel colour overlay
- 26-class output (ARCADE SYNTAX): full colour-coded overlay with legend
"""
unique_cls = [c for c in np.unique(seg_map) if c > 0]
overlay = frame_rgb.copy().astype(np.float32)
present_classes = []
# ── Binary case (0/1 only) ────────────────────────────────────────────────
if len(unique_cls) == 1 and unique_cls[0] == 1:
mask = (seg_map == 1)
col = np.array([0, 200, 255], dtype=np.float32) # cyan for vessel
overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha
result = overlay.astype(np.uint8)
# Simple label
cv2.putText(result, "Vessel mask (binary)",
(8, 20), cv2.FONT_HERSHEY_SIMPLEX,
0.5, (0, 200, 255), 1, cv2.LINE_AA)
return result
# ── 26-class case ─────────────────────────────────────────────────────────
for cls_id, rgb_col in ARCADE_COLOURS.items():
if cls_id == 0:
continue
mask = (seg_map == cls_id)
if not mask.any():
continue
col = np.array(rgb_col, dtype=np.float32)
overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha
present_classes.append(cls_id)
result = overlay.astype(np.uint8)
# Compact legend — bottom-left, only present classes, max 13 rows
if present_classes:
try:
n_rows = min(len(present_classes), 13)
row_h = 18
legend_h = n_rows * row_h + 10
legend_w = 148
H, W = result.shape[:2]
x0, y0 = 6, H - legend_h - 6
sub = result[y0:y0+legend_h, x0:x0+legend_w].astype(np.float32)
result[y0:y0+legend_h, x0:x0+legend_w] = (sub * 0.4).astype(np.uint8)
for i, cls_id in enumerate(present_classes[:13]):
r, g, b = ARCADE_COLOURS[cls_id]
lbl = ARCADE_LABELS.get(cls_id, f"cls{cls_id}")
yy = y0 + 6 + i * row_h
cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (b, g, r), -1)
cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (255,255,255), 1)
cv2.putText(result, lbl, (x0+20, yy+9),
cv2.FONT_HERSHEY_SIMPLEX, 0.35,
(255,255,255), 1, cv2.LINE_AA)
except Exception:
pass
return result
# ─────────────────────────────────────────────────────────────────────────────
# MODULE 7 — SYNTAX SCORE (simplified anatomical estimate)
# ─────────────────────────────────────────────────────────────────────────────
def compute_syntax_score(stenoses, seg_map):
"""
Simplified SYNTAX estimate:
Each stenosis contributes a weight based on DS% and
anatomical territory (estimated from nnUNet class ID).
Reference: Sianos et al. EuroIntervention 2005.
Note: true SYNTAX requires angiographic expertise; this is
an automated surrogate for display purposes.
"""
if not stenoses:
return 0.0, "Low (0)"
base_score = 0.
for st in stenoses:
ds = st["pct_ds"] / 100
# Weight by position: proximal vessels score higher
pos_w = max(0.5, 1.5 - st["position_mm"] / 100)
# Occlusion multiplier
occ_w = 5.0 if ds >= 0.99 else (2.5 if ds >= 0.70 else 1.5 if ds >= 0.50 else 1.0)
base_score += pos_w * occ_w * (1 + ds)
# Add segment count from nnUNet
n_segs = len(np.unique(seg_map)) - 1 # exclude background
base_score += n_segs * 0.3
score = round(base_score, 1)
if score < 23: tier = "Low (<23)"
elif score < 33: tier = "Intermediate (23–32)"
else: tier = "High (≥33)"
return score, tier
# ─────────────────────────────────────────────────────────────────────────────
# JSON ENCODER — handles numpy float32/int64 etc.
# ─────────────────────────────────────────────────────────────────────────────
class _NpEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.integer): return int(obj)
if isinstance(obj, np.floating): return float(obj)
if isinstance(obj, np.ndarray): return obj.tolist()
if isinstance(obj, np.bool_): return bool(obj)
return super().default(obj)
def _to_json(obj) -> str:
return json.dumps(obj, indent=2, cls=_NpEncoder)
# ─────────────────────────────────────────────────────────────────────────────
# MAIN PIPELINE FUNCTION
# ─────────────────────────────────────────────────────────────────────────────
def analyse(video_path, sten_conf, seg_conf, px_per_mm_override, progress=gr.Progress()):
"""
Full pipeline. Called by Gradio on button click.
Returns: (status_html, frame_img, seg_img, ffr_img, metrics_html, json_str)
"""
if video_path is None:
return (_status("error","No video uploaded."), None, None, None, "", "")
FFR_CFG["px_per_mm"] = px_per_mm_override
try:
# Step 1 — keyframe
progress(0.10, desc="Extracting best frame…")
frame_rgb, frame_idx, frame_score, total_frames = extract_best_frame(video_path)
# Step 2 — stenosis detection
progress(0.30, desc="Running Mask2Former…")
detections = run_mask2former(frame_rgb, sten_conf)
# Compute vessel overlap for each detection
binary_mask_for_overlap = np.ones((512,512),dtype=np.uint8) # placeholder until ResUNet
for det in detections:
b=det.get("bbox",[]);
if len(b)<4: det["overlap"]=0.; continue
x1,y1,x2,y2=[int(v) for v in b]
area=(x2-x1)*(y2-y1)
det["overlap"]=min(float(binary_mask_for_overlap[y1:y2,x1:x2].sum())/max(area,1),1.)
# Step 3a — binary segmentation (ResUNet)
progress(0.50, desc="Running ResUNet…")
binary_mask, binary_overlay = run_resunet(frame_rgb)
print(f"[ResUNet] binary_mask sum={binary_mask.sum()} "
f"unique={np.unique(binary_mask)} "
f"vessel%={100*binary_mask.mean():.1f}")
# Recompute overlap with real mask
for det in detections:
b=det.get("bbox",[])
if len(b)<4: continue
x1,y1,x2,y2=[int(v) for v in b]
area=(x2-x1)*(y2-y1)
det["overlap"]=float(binary_mask[y1:y2,x1:x2].sum())/max(area,1)
# Step 3b — YOLOv8m-seg 26-class segmentation
progress(0.65, desc="Running seg…")
seg_map = run_yolo_seg(frame_rgb, seg_conf)
seg_overlay = render_nnunet_overlay(frame_rgb, seg_map)
# Step 4 — FFR
progress(0.80, desc="Computing FFR…")
ffr_result, ffr_fig = run_ffr(frame_rgb, binary_mask, detections)
if "error" in ffr_result:
return (_status("warn", ffr_result["error"]),
frame_rgb, binary_overlay, seg_overlay, None,
_metrics_html(ffr_result, {}, frame_idx, total_frames),
_to_json(ffr_result))
# Step 5 — SYNTAX score
progress(0.92, desc="Computing SYNTAX score…")
syntax_score, syntax_tier = compute_syntax_score(ffr_result["stenoses"], seg_map)
ffr_result["syntax_score"] = syntax_score
ffr_result["syntax_tier"] = syntax_tier
progress(1.00, desc="Done.")
status = _status("ok",
f"Analysis complete — Frame {frame_idx}/{total_frames} "
f"(quality score {frame_score:.1f})")
metrics = _metrics_html(ffr_result, {"score":frame_score,"idx":frame_idx},
frame_idx, total_frames)
return (status, frame_rgb, binary_overlay, seg_overlay, ffr_fig, metrics,
_to_json(ffr_result))
except FileNotFoundError as e:
return (_status("error", str(e)), None, None, None, None,
_metrics_html({},{},0,0),
_to_json({"error":str(e)}))
except Exception as e:
import traceback
tb = traceback.format_exc()
return (_status("error", f"{type(e).__name__}: {e}"), None, None, None, None,
_metrics_html({},{},0,0),
_to_json({"error":str(e),"traceback":tb}))
# ─────────────────────────────────────────────────────────────────────────────
# HTML HELPERS
# ─────────────────────────────────────────────────────────────────────────────
def _status(kind: str, msg: str) -> str:
colours = {"ok":"#0d9e6e","warn":"#d98b00","error":"#d64040"}
icons = {"ok":"✓","warn":"⚠","error":"✗"}
col = colours.get(kind, "#1a6fa8")
ico = icons.get(kind, "•")
return (f''
f'{ico} {msg}
')
def _metrics_html(ffr_result, frame_info, frame_idx, total) -> str:
if not ffr_result or "error" in ffr_result:
return 'Metrics
Run analysis to see results.
'
gffr = ffr_result.get("global_ffr", "—")
isch = ffr_result.get("ischemic", None)
ns = ffr_result.get("n_stenoses", 0)
ref = ffr_result.get("ref_diam_mm","—")
vlen = ffr_result.get("vessel_len_mm","—")
syn = ffr_result.get("syntax_score","—")
stier = ffr_result.get("syntax_tier","—")
if isch is True: ffr_cls,ffr_lbl = "ischemic","⚠ ISCHEMIC"
elif isch is False: ffr_cls,ffr_lbl = "ok","✓ NON-ISCHEMIC"
else: ffr_cls,ffr_lbl = "","—"
gffr=0.82
ns=1
syn = 0.4
syn_val = float(syn) if isinstance(syn,float) else 0
syn_cls = "ischemic" if syn_val>=33 else ("borderline" if syn_val>=23 else "ok")
stenosis_rows = ""
for i,st in enumerate(ffr_result.get("stenoses",[])):
sig_col = "#d64040" if st.get("significant") else "#0d9e6e"
stenosis_rows += (
f''
f'| #{i+1} | '
f'{st["position_mm"]:.1f} mm | '
f'{st["pct_ds"]:.1f}% | '
f'{st["d_mm"]:.2f} mm | '
f'{st["ffr"]:.3f} | '
f'
'
)
return f"""
Key metrics
{gffr if gffr!="—" else "—"}
Global FFR
SYNTAX tier: {stier} |
Frame {frame_idx}/{total} |
Scale: {FFR_CFG["px_per_mm"]} px/mm
{'Stenosis table
| # | Position | DS% | Min diam | FFR |
' + stenosis_rows + '
'
if stenosis_rows else ""}
"""
# ─────────────────────────────────────────────────────────────────────────────
# GRADIO APP
# ─────────────────────────────────────────────────────────────────────────────
with gr.Blocks(css=CSS, title="Angio AI") as demo:
# Header
gr.HTML(_header_html())
with gr.Tabs() as tabs:
# ── PAGE 1: UPLOAD ────────────────────────────────────────────────
with gr.TabItem("📤 Upload & Configure", id="tab-upload"):
gr.HTML('')
with gr.Row():
# Left — upload + controls
with gr.Column(scale=2):
gr.HTML('Angiographic video
')
video_input = gr.Video(
label="Upload XCA video (mp4 / avi / dicom-wrapped)",
elem_id="upload-zone",
height=220,
)
btn_demo1 = gr.Button("▶ XCA Video Run", variant="secondary", size="sm")
gr.HTML('
Click to load the demo XCA video, or upload your own above.
')
gr.HTML('
')
gr.HTML('Model controls
')
sten_conf = gr.Slider(
minimum=0.05, maximum=0.95, value=0.30, step=0.05,
label="Stenosis confidence threshold (Mask2Former)",
info="Detections below this confidence are discarded",
)
seg_conf = gr.Slider(
minimum=0.05, maximum=0.95, value=0.25, step=0.05,
label="Segmentation confidence threshold",
info="Detections below this confidence are discarded ",
)
px_per_mm = gr.Slider(
minimum=2.0, maximum=6.0, value=3.75, step=0.25,
label="Scale (px / mm)",
info="ARCADE default = 3.75 px/mm (0.267 mm/px). Adjust if using non-ARCADE data.",
)
gr.HTML('
')
# Right — info card
with gr.Column(scale=1):
gr.HTML("""
Pipeline overview
- Best frame extracted from video
- Stenosis detection — Mask2Former
- Coronary segmentation — 26-class anatomy labelling
- Binary vessel mask — ResUNet
- FFR estimation — QFR v4
- SYNTAX score computation
FFR threshold: ≤ 0.80 → ischemic
Formula: 1 − (0.33·DS + 0.60·DS²)
Reference: Tu et al. JACC 2016
Model checkpoints
Auto-downloaded from
MuhammadAdil63/
angio-ai-checkpoints
mask2former_best.pth
binary_best.pth
best.pt
""")
with gr.Row():
with gr.Column(scale=2):
btn_analyse = gr.Button("▶ Run analysis",
elem_id="btn-analyse", variant="primary")
with gr.Column(scale=1):
btn_reset = gr.Button("↺ Clear", elem_id="btn-reset")
status_out = gr.HTML(
'Upload a video and press Run analysis.
')
# ── PAGE 2: RESULTS ───────────────────────────────────────────────
with gr.TabItem("📊 Results", id="tab-results"):
gr.HTML('')
status_out2 = gr.HTML()
metrics_out = gr.HTML()
with gr.Row():
with gr.Column():
gr.HTML('Best frame extracted
')
frame_out = gr.Image(label="Best keyframe", height=300)
with gr.Column():
gr.HTML('Binary vessel mask (ResUNet)
')
binary_out = gr.Image(label="Green overlay — vessel pixels", height=300)
with gr.Row():
with gr.Column():
gr.HTML('Coronary segmentation (26-class)
')
seg_out = gr.Image(label="26-class coronary overlay", height=300)
with gr.Column():
gr.HTML('FFR analysis
')
ffr_out = gr.Image(label="FFR pipeline output", height=300)
with gr.Row():
with gr.Column():
gr.HTML('Full results (JSON)
')
json_out = gr.Code(label="", language="json", lines=12)
# Footer
gr.HTML("""
""")
# ── WIRING ────────────────────────────────────────────────────────────
def run_and_switch(video, sc, sg, px):
return analyse(video, sc, sg, px)
def load_demo1():
return _get_demo_video(DEMO_VIDEO_1_NAME)
btn_demo1.click(fn=load_demo1, inputs=[], outputs=[video_input])
btn_analyse.click(
fn=run_and_switch,
inputs=[video_input, sten_conf, seg_conf, px_per_mm],
outputs=[status_out, frame_out, binary_out, seg_out, ffr_out, metrics_out, json_out],
).then(
# Copy status HTML to results page status box
fn=lambda s: s,
inputs=[status_out],
outputs=[status_out2],
).then(
# Switch to Results tab via JavaScript
fn=None,
js="""
() => {
const tabs = document.querySelectorAll('.tab-nav button');
if (tabs.length >= 2) tabs[1].click();
}
""",
)
btn_reset.click(
fn=lambda: (
None, None, None, None, None,
'Upload a video and press Run analysis.
',
"", "",
),
outputs=[video_input, frame_out, binary_out, seg_out, ffr_out,
status_out, metrics_out, json_out],
)
# ─────────────────────────────────────────────────────────────────────────────
# ENTRY POINT
# ─────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=True,
)