# ── angio_ai_app.py ───────────────────────────────────────────────────────────────────────────── # Angio AI — Coronary Angiography Analysis System # Muhammad Adil · MS Data Science · ITU Lahore # Supervisor: Dr. Arif Mehmood # # ── DEPLOYMENT — HuggingFace Spaces ────────────────────────────────────────── # Checkpoints are loaded from HF Hub: MuhammadAdil63/angio-ai-checkpoints # Set HF_TOKEN as a Space Secret (Settings → Secrets) for private repo access. # Files expected in that repo: # mask2former_best.pth (Mask2Former Swin-Base fine-tuned) # binary_best.pth (ResUNet binary vessel segmentation) # best.pt (YOLOv8m-seg 26-class ARCADE syntax) # # ── DEPENDENCIES ──────────────────────────────────────────────────────────────────────────── # pip install gradio torch torchvision opencv-python-headless # scikit-image scipy matplotlib transformers # ultralytics (for YOLOv8m-seg inference) # ───────────────────────────────────────────────────────────────────────────── # ── Force PyTorch weights_only=False globally (must be before any imports) ─── import os as _os _os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0" # ── Python 3.13 audioop shim (removed from stdlib, needed by pydub/gradio) ─── import sys, types if "audioop" not in sys.modules: sys.modules["audioop"] = types.ModuleType("audioop") if "pyaudioop" not in sys.modules: sys.modules["pyaudioop"] = types.ModuleType("pyaudioop") # ── Patch gradio_client _json_schema_to_python_type (APIInfoParseError fix) ── def _patch_gradio_client(): try: import gradio_client.utils as _gcu _orig = _gcu._json_schema_to_python_type def _safe(schema, defs=None): if not isinstance(schema, dict): return "Any" try: return _orig(schema, defs) except Exception: return "Any" _gcu._json_schema_to_python_type = _safe _orig_top = _gcu.json_schema_to_python_type def _safe_top(schema, defs=None): try: return _orig_top(schema, defs) except Exception: return "Any" _gcu.json_schema_to_python_type = _safe_top except Exception: pass _patch_gradio_client() # ── PyTorch 2.7 global weights_only patch ──────────────────────────────────── # YOLO and other libraries call torch.load internally without weights_only=False. # Monkey-patch torch.load to default to weights_only=False for all calls. import torch as _torch _orig_torch_load = _torch.load def _patched_torch_load(f, map_location=None, pickle_module=None, weights_only=False, mmap=None, **kwargs): return _orig_torch_load(f, map_location=map_location, pickle_module=pickle_module, weights_only=False, **kwargs) _torch.load = _patched_torch_load import os import sys import json import base64 import tempfile import warnings warnings.filterwarnings("ignore") import cv2 import numpy as np import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import matplotlib.patches as mpatches from pathlib import Path import gradio as gr import torch # ── Patch torch.load for PyTorch 2.7 (weights_only=True default breaks YOLO) ─ _orig_tload = torch.load def _safe_tload(f, map_location=None, pickle_module=None, weights_only=None, mmap=None, **kwargs): return _orig_tload(f, map_location=map_location, weights_only=False, **kwargs) torch.load = _safe_tload import huggingface_hub # noqa: F401 (ensure installed; used in _hf_ckpt) # ───────────────────────────────────────────────────────────────────────────── # CONFIG — edit these paths before running # ───────────────────────────────────────────────────────────────────────────── # ── Checkpoint download from HuggingFace Hub ───────────────────────────────── # Checkpoints live in: MuhammadAdil63/angio-ai-checkpoints (private repo) # huggingface_hub reads HF_TOKEN from environment (set as a Space secret). def _hf_ckpt(filename: str) -> str: """Download checkpoint from HF Hub on first call, return local cache path.""" from huggingface_hub import hf_hub_download return hf_hub_download( repo_id = "MuhammadAdil63/angio-ai-checkpoints", filename = filename, repo_type= "model", token = os.environ.get("HF_TOKEN"), # set in Space Secrets ) CONFIG = { # ── Checkpoint paths (resolved lazily via HF Hub) ───────────────────────── # These callables are invoked once inside each model's _load_*() function. "MASK2FORMER_CKPT" : lambda: _hf_ckpt("mask2former_best.pth"), "RESUNET_CKPT" : lambda: _hf_ckpt("binary_best.pth"), "YOLO_CKPT" : lambda: _hf_ckpt("best.pt"), # ── YOLO inference params ───────────────────────────────────────────────── "YOLO_CONF" : 0.25, "YOLO_IOU" : 0.70, "YOLO_IMGSZ" : 512, # ── FFR pipeline scale (ARCADE hardcoded) ──────────────────────────────── "PX_PER_MM" : 3.75, # ── Device — CPU on HF free tier ───────────────────────────────────────── "DEVICE" : "cuda:0" if torch.cuda.is_available() else "cpu", } # ── Demo videos (hosted in HF Hub model repo alongside checkpoints) ────────── # Upload your two XCA demo videos to MuhammadAdil63/angio-ai-checkpoints as: # demo_video_1.mp4 # demo_video_2.mp4 # The buttons below download and load them automatically. DEMO_VIDEO_1_NAME = "demo_video_1.mp4" DEMO_VIDEO_2_NAME = "demo_video_2.mp4" def _get_demo_video(filename: str) -> str: """Download demo video from HF Hub and return local path.""" try: from huggingface_hub import hf_hub_download path = hf_hub_download( repo_id = "MuhammadAdil63/angio-ai-checkpoints", filename = filename, repo_type= "model", token = os.environ.get("HF_TOKEN"), ) return path except Exception as e: print(f"[DEMO] Could not load {filename}: {e}") return None # ───────────────────────────────────────────────────────────────────────────── # ITU LOGO — encode to base64 for embedding in HTML # Place ITULog.png in the same directory as this script. # ───────────────────────────────────────────────────────────────────────────── def _encode_logo(path: str) -> str: try: with open(path, "rb") as f: return "data:image/png;base64," + base64.b64encode(f.read()).decode() except FileNotFoundError: return "" LOGO_ITU = _encode_logo(str(Path(__file__).parent / "ITULog.png")) # ───────────────────────────────────────────────────────────────────────────── # CSS THEME — white / bluish, minimalist, medical # ───────────────────────────────────────────────────────────────────────────── CSS = """ /* ── Base ── */ :root { --primary: #1a6fa8; --primary-lt: #e8f3fb; --primary-mid: #4a9fd4; --accent: #0d9e6e; /* Angio AI green */ --bg: #f7fafd; --surface: #ffffff; --border: #cce0f0; --text: #1a2533; --text-sec: #5a7a96; --danger: #d64040; --warn: #d98b00; --radius: 10px; --shadow: 0 2px 12px rgba(26,111,168,0.10); } body, .gradio-container { background: var(--bg) !important; font-family: 'Segoe UI', system-ui, sans-serif !important; color: var(--text) !important; } /* ── Header bar ── */ #header { background: var(--surface); border-bottom: 2px solid var(--border); padding: 14px 28px; display: flex; align-items: center; justify-content: space-between; box-shadow: var(--shadow); margin-bottom: 0 !important; } #logo-itu img { height: 48px; } #logo-angio { font-size: 26px; font-weight: 700; color: var(--accent); letter-spacing: 1px; } #logo-angio span { font-size: 13px; color: var(--text-sec); font-weight: 400; display: block; letter-spacing: 0; } /* ── Tab bar ── */ .tab-nav button { color: var(--text-sec) !important; border-bottom: 3px solid transparent !important; font-size: 15px !important; font-weight: 500 !important; padding: 10px 24px !important; background: transparent !important; } .tab-nav button.selected { color: var(--primary) !important; border-bottom-color: var(--primary) !important; } /* ── Cards ── */ .card { background: var(--surface); border: 1px solid var(--border); border-radius: var(--radius); padding: 22px; box-shadow: var(--shadow); margin-bottom: 16px; } .card-title { font-size: 14px; font-weight: 600; color: var(--primary); text-transform: uppercase; letter-spacing: 0.5px; margin-bottom: 12px; padding-bottom: 8px; border-bottom: 1px solid var(--primary-lt); } /* ── Upload zone ── */ #upload-zone { border: 2px dashed var(--primary-mid) !important; border-radius: var(--radius) !important; background: var(--primary-lt) !important; min-height: 200px !important; } /* ── Sliders ── */ input[type=range] { accent-color: var(--primary); } .gradio-slider label { color: var(--text-sec) !important; font-size: 13px !important; } /* ── Buttons ── */ #btn-analyse { background: var(--primary) !important; color: white !important; border: none !important; border-radius: var(--radius) !important; font-size: 15px !important; font-weight: 600 !important; padding: 12px 32px !important; cursor: pointer !important; width: 100% !important; transition: background 0.2s !important; } #btn-analyse:hover { background: #155d90 !important; } #btn-reset { background: transparent !important; color: var(--text-sec) !important; border: 1px solid var(--border) !important; border-radius: var(--radius) !important; font-size: 13px !important; padding: 8px 20px !important; cursor: pointer !important; width: 100% !important; } /* ── Metric badges ── */ .metric-row { display: flex; gap: 12px; flex-wrap: wrap; margin-bottom: 12px; } .metric-badge { flex: 1; min-width: 110px; background: var(--primary-lt); border: 1px solid var(--border); border-radius: 8px; padding: 10px 14px; text-align: center; } .metric-badge .val { font-size: 22px; font-weight: 700; color: var(--primary); } .metric-badge .val.ischemic { color: var(--danger); } .metric-badge .val.borderline { color: var(--warn); } .metric-badge .val.ok { color: var(--accent); } .metric-badge .lbl { font-size: 11px; color: var(--text-sec); text-transform: uppercase; letter-spacing: 0.4px; } /* ── Result images ── */ #result-images img { border-radius: var(--radius); border: 1px solid var(--border); } /* ── Status ── */ #status-box { padding: 10px 16px; border-radius: 8px; font-size: 13px; font-weight: 500; border: 1px solid var(--border); background: var(--primary-lt); color: var(--primary); } /* ── Footer ── */ #footer { text-align: center; font-size: 12px; color: var(--text-sec); padding: 20px; border-top: 1px solid var(--border); margin-top: 24px; } """ # ───────────────────────────────────────────────────────────────────────────── # HEADER HTML # ───────────────────────────────────────────────────────────────────────────── def _header_html() -> str: itu_tag = (f'ITU Logo' if LOGO_ITU else 'ITU') return f""" """ # ───────────────────────────────────────────────────────────────────────────── # MODULE 1 — KEYFRAME EXTRACTION (from app.py) # ───────────────────────────────────────────────────────────────────────────── def extract_best_frame(video_path: str, n: int = 1) -> tuple: """ Extract the single best frame from the video. Returns (frame_rgb_512, frame_index, score). """ cap = cv2.VideoCapture(video_path) total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) if total == 0: cap.release() raise ValueError("Could not read video — check format/codec.") step = max(1, total // min(120, total)) scores = [] fidx = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break if fidx % step == 0: gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) h, w = gray.shape roi = gray[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)] contrast = float(roi.mean()) sharpness = float(cv2.Laplacian(roi, cv2.CV_64F).var()) score = contrast * 0.6 + min(sharpness, 5000) / 5000 * 100 * 0.4 scores.append((score, fidx, frame.copy())) fidx += 1 cap.release() if not scores: raise ValueError("No frames could be sampled from video.") scores.sort(key=lambda x: x[0], reverse=True) best_score, best_idx, best_frame = scores[0] rgb = cv2.cvtColor(best_frame, cv2.COLOR_BGR2RGB) rgb = cv2.resize(rgb, (512, 512)) return rgb, best_idx, best_score, total # ───────────────────────────────────────────────────────────────────────────── # MODULE 2 — MASK2FORMER INFERENCE # ───────────────────────────────────────────────────────────────────────────── _m2f_model = None _m2f_proc = None def _load_mask2former(): global _m2f_model, _m2f_proc if _m2f_model is not None: return from transformers import (AutoImageProcessor, Mask2FormerForUniversalSegmentation) ckpt = CONFIG["MASK2FORMER_CKPT"]() # resolves via HF Hub download if not os.path.exists(ckpt): raise FileNotFoundError(f"Mask2Former checkpoint not found:\n{ckpt}") _m2f_proc = AutoImageProcessor.from_pretrained( "facebook/mask2former-swin-base-coco-instance") _m2f_model = Mask2FormerForUniversalSegmentation.from_pretrained( "facebook/mask2former-swin-base-coco-instance", num_labels=2, ignore_mismatched_sizes=True) state = torch.load(ckpt, map_location="cpu", weights_only=False) weights = state.get("model_state", state) _m2f_model.load_state_dict(weights, strict=False) _m2f_model.to(CONFIG["DEVICE"]).eval() def run_mask2former(frame_rgb: np.ndarray, conf_thr: float) -> list: """ Returns list of dicts: [{bbox:[x1,y1,x2,y2], score:float, mask:np.ndarray}, ...] """ _load_mask2former() from PIL import Image pil = Image.fromarray(frame_rgb) inputs = _m2f_proc(images=pil, return_tensors="pt") inputs = {k: v.to(CONFIG["DEVICE"]) for k, v in inputs.items()} with torch.no_grad(): outputs = _m2f_model(**inputs) results = _m2f_proc.post_process_instance_segmentation( outputs, target_sizes=[(512, 512)])[0] detections = [] for seg_info in results["segments_info"]: score = float(seg_info["score"]) if score < conf_thr: continue mask = (results["segmentation"] == seg_info["id"]).cpu().numpy().astype(np.uint8) ys, xs = np.where(mask) if len(ys) == 0: continue x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) detections.append({"bbox": [x1, y1, x2, y2], "score": score, "mask": mask}) return detections # ───────────────────────────────────────────────────────────────────────────── # MODULE 3 — RESUNET BINARY SEGMENTATION # ───────────────────────────────────────────────────────────────────────────── _resunet_model = None # ── ResUNet architecture — exactly matching binary_best.pth ────────────────── # # Decoded from checkpoint key names via dump_keys.py: # # enc0 Conv2d(1, 16, 3, 3, bias=True) plain entry conv # down1..4 ConvTranspose2d strided ×2 downsampler 16→32→64→128→256 # res1..4 _ResBlock(ch) encoder residual blocks ch=16,32,64,128 # bot.0..2 _ResBlock(256) × 3 bottleneck # up3..0 ConvTranspose2d ×2 upsampler 256→128→64→32→16 # res12..15 _ResBlock(ch) decoder residual blocks ch=128,64,32,16 # head Sequential[Conv3×3, ReLU, BN, output head # Conv3×3, ReLU, BN, Conv1×1] # # _ResBlock structure (from key pattern res*.bn_in, res*.units.N.*, res*.bn_out): # bn_in → BN # units → ModuleList of N sub-units, each: [Conv3×3, ReLU, Conv3×3, BN] # bn_out → BN # forward: x → bn_in → for each unit: residual(x + unit(x)) → bn_out # ───────────────────────────────────────────────────────────────────────────── class _ResBlock(torch.nn.Module): """ Matches checkpoint pattern: res*.bn_in.* res*.units.0.0.weight (Conv3×3) res*.units.0.0.bias res*.units.0.2.weight (Conv3×3) res*.units.0.2.bias res*.units.0.3.* (BN) ... (3 units total) res*.bn_out.* Sequential index mapping inside each unit: 0 = Conv2d 1 = ReLU (no params — not in state dict) 2 = Conv2d 3 = BatchNorm2d """ def __init__(self, ch, n_units=3): super().__init__() self.bn_in = torch.nn.BatchNorm2d(ch) self.units = torch.nn.ModuleList([ torch.nn.Sequential( torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True), torch.nn.ReLU(inplace=True), torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True), torch.nn.BatchNorm2d(ch), ) for _ in range(n_units) ]) self.bn_out = torch.nn.BatchNorm2d(ch) def forward(self, x): x = self.bn_in(x) for unit in self.units: x = x + unit(x) # no extra ReLU on residual sum return self.bn_out(x) class ResUNet(torch.nn.Module): """ Exact architecture matching binary_best.pth (Dice = 0.8015). Key name mapping: enc0 → self.enc0 (plain Conv2d) down1..4 → self.down1..down4 (ConvTranspose2d strided downsampler) res1..4 → self.res1..res4 (encoder ResBlocks) bot.0..2 → self.bot[0..2] (bottleneck ResBlocks) up3..0 → self.up3..up0 (ConvTranspose2d upsampler) res12..15 → self.res12..res15 (decoder ResBlocks) head → self.head (Sequential output head) """ def __init__(self): super().__init__() # ── Entry conv ──────────────────────────────────────────────────────── self.enc0 = torch.nn.Conv2d(1, 16, 3, padding=1, bias=True) # ── Encoder downsamples — strided Conv2d (not ConvTranspose2d) ───────── # Checkpoint shape (out, in, kH, kW): down1=(32,16,2,2) → Conv2d(16→32) # stride=2 kernel=2 halves spatial dimensions while increasing channels self.down1 = torch.nn.Conv2d(16, 32, 2, stride=2) self.down2 = torch.nn.Conv2d(32, 64, 2, stride=2) self.down3 = torch.nn.Conv2d(64, 128, 2, stride=2) self.down4 = torch.nn.Conv2d(128, 256, 2, stride=2) # ── Encoder residual blocks ─────────────────────────────────────────── self.res1 = _ResBlock(16) self.res2 = _ResBlock(32) self.res3 = _ResBlock(64) self.res4 = _ResBlock(128) # ── Bottleneck ──────────────────────────────────────────────────────── self.bot = torch.nn.ModuleList([ _ResBlock(256), _ResBlock(256), _ResBlock(256), ]) # ── Decoder upsamplers ──────────────────────────────────────────────── # Shape from keys: up3=(256,128,2,2) bias=(128,) # up2=(128,64,2,2) bias=(64,) # up1=(64,32,2,2) bias=(32,) # up0=(32,16,2,2) bias=(16,) self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2) self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2) self.up1 = torch.nn.ConvTranspose2d(64, 32, 2, stride=2) self.up0 = torch.nn.ConvTranspose2d(32, 16, 2, stride=2) # ── Decoder residual blocks ─────────────────────────────────────────── self.res12 = _ResBlock(128) self.res13 = _ResBlock(64) self.res14 = _ResBlock(32) self.res15 = _ResBlock(16) # ── Output head ─────────────────────────────────────────────────────── # From keys: head.0=Conv(16,16,3,3), head.2=BN(16), head.3=Conv(16,16,3,3), # head.5=BN(16), head.6=Conv(2,16,1,1) # Index 1 and 4 are ReLU (no state dict entries) self.head = torch.nn.Sequential( torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 0 torch.nn.ReLU(inplace=True), # 1 torch.nn.BatchNorm2d(16), # 2 torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 3 torch.nn.ReLU(inplace=True), # 4 torch.nn.BatchNorm2d(16), # 5 torch.nn.Conv2d(16, 2, 1, bias=True), # 6 ) def forward(self, x): # ── Encoder ────────────────────────────────────────────────────────── # Entry conv → ResBlock → down. Skip taken AFTER ResBlock. x = self.enc0(x) # (B,16,512,512) r1 = self.res1(x); x = self.down1(r1) # skip r1, down → (B,32,256,256) r2 = self.res2(x); x = self.down2(r2) # skip r2, down → (B,64,128,128) r3 = self.res3(x); x = self.down3(r3) # skip r3, down → (B,128,64,64) r4 = self.res4(x); x = self.down4(r4) # skip r4, down → (B,256,32,32) # ── Bottleneck ─────────────────────────────────────────────────────── for blk in self.bot: x = blk(x) # ── Decoder — ResBlock first, then add skip ─────────────────────────── x = self.res12(self.up3(x)) + r4 # (B,128,64,64) x = self.res13(self.up2(x)) + r3 # (B,64,128,128) x = self.res14(self.up1(x)) + r2 # (B,32,256,256) x = self.res15(self.up0(x)) + r1 # (B,16,512,512) return self.head(x) def _load_resunet(): global _resunet_model if _resunet_model is not None: return ckpt = CONFIG["RESUNET_CKPT"]() # resolves via HF Hub download if not os.path.exists(ckpt): raise FileNotFoundError(f"ResUNet checkpoint not found:\n{ckpt}") _resunet_model = ResUNet() state = torch.load(ckpt, map_location="cpu", weights_only=False) weights = state.get("model_state", state) # Try strict first; if it fails print exact mismatches and retry strict=False try: _resunet_model.load_state_dict(weights, strict=True) print(f"[ResUNet] ✓ Perfect load (strict=True) — " f"epoch={state.get('epoch','?')} " f"dice={state.get('best_val_dice',0):.4f}") except RuntimeError as e: print(f"[ResUNet] strict=True failed:\n {e}") result = _resunet_model.load_state_dict(weights, strict=False) print(f"[ResUNet] strict=False — " f"missing={len(result.missing_keys)} " f"unexpected={len(result.unexpected_keys)}") for k in result.missing_keys[:10]: print(f" MISSING {k}") for k in result.unexpected_keys[:10]: print(f" UNEXPECTED {k}") _resunet_model.to(CONFIG["DEVICE"]).eval() print(f"[ResUNet] Loaded (eval mode) — " f"epoch={state.get('epoch','?')} " f"dice={state.get('best_val_dice',0):.4f}") def _preprocess_resunet(rgb: np.ndarray) -> torch.Tensor: """ARCADE preprocessing: invert → top-hat → subtract → CLAHE → normalise.""" gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY).astype(np.uint8) inv = cv2.bitwise_not(gray) se = cv2.getStructuringElement(cv2.MORPH_RECT, (50, 50)) th = cv2.morphologyEx(inv, cv2.MORPH_TOPHAT, se) sub = cv2.subtract(gray, th) clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) proc = clahe.apply(sub).astype(np.float32) / 255.0 t = torch.from_numpy(proc).unsqueeze(0).unsqueeze(0) # (1,1,512,512) return t def run_resunet(frame_rgb: np.ndarray): """ Returns: binary_mask : (512,512) uint8, values 0/1 binary_overlay: (512,512,3) uint8 RGB — green vessel overlay on original Architecture fix applied: skip connections are now res(up(x)) + skip (not res(up(x) + skip)), and ResBlock has no extra ReLU on residual sum. Model runs in eval() mode — BN running stats are now correct. """ _load_resunet() _resunet_model.eval() t = _preprocess_resunet(frame_rgb).to(CONFIG["DEVICE"]) with torch.no_grad(): logits = _resunet_model(t) pred = logits.argmax(dim=1).squeeze().cpu().numpy().astype(np.uint8) vessel_px = int((pred == 1).sum()) print(f"[ResUNet] vessel px={vessel_px} ({100*(pred==1).mean():.1f}%)") # Build green overlay for display in Results tab overlay = frame_rgb.copy().astype(np.float32) green = np.array([0, 255, 80], dtype=np.float32) mask3d = pred == 1 overlay[mask3d] = overlay[mask3d] * 0.35 + green * 0.65 overlay = overlay.astype(np.uint8) # Draw contours for sharper vessel edges contours, _ = cv2.findContours(pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) cv2.drawContours(overlay, contours, -1, (0, 255, 80), 1) # Coverage label cov = 100 * vessel_px / pred.size cv2.putText(overlay, f"Vessel coverage: {cov:.1f}%", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 80), 1, cv2.LINE_AA) return pred, overlay # ───────────────────────────────────────────────────────────────────────────── # MODULE 4 — YOLOv8m-seg INFERENCE (replaces nnUNet) # ───────────────────────────────────────────────────────────────────────────── # # Checkpoint : CONFIG["YOLO_CKPT"] -> best.pt from YOLOv8m-seg training # Model : YOLOv8m-seg, nc=26, trained on ARCADE syntax dataset # # Input : frame_rgb -- (512,512,3) uint8 RGB # seg_conf -- confidence threshold from UI slider # Output: seg_map -- (512,512) uint8 label map, values 0..25 # 0 = background (no detection) # 1..25 = ARCADE class index # # Output format is identical to the old run_nnunet() so that # render_nnunet_overlay() and compute_syntax_score() work unchanged. # ───────────────────────────────────────────────────────────────────────────── # ARCADE 26-class name list (index = YOLO cls_id + 1) _YOLO_CLASS_NAMES = [ "background", # 0 (background, not predicted by YOLO) "LMCA", # 1 "LAD prox", # 2 "LAD mid", # 3 "LAD dist", # 4 "D1", # 5 "D2", # 6 "D3", # 7 "Septal", # 8 "LAD var", # 9 "LCX prox", # 10 "LCX mid", # 11 "LCX dist", # 12 "OM1", # 13 "OM2", # 14 "OM3", # 15 "LCX var", # 16 "RCA prox", # 17 "RCA mid", # 18 "RCA dist", # 19 "PDA", # 20 "PLV", # 21 "AM", # 22 "Conus", # 23 "SAN", # 24 "Other", # 25 ] _yolo_model = None # singleton -- loaded once on first call, reused every frame def _load_yolo(): """Lazy-load YOLOv8m-seg checkpoint into module-level singleton.""" global _yolo_model if _yolo_model is not None: return try: from ultralytics import YOLO except ImportError: raise ImportError( "ultralytics not installed. Run: pip install ultralytics" ) ckpt = CONFIG["YOLO_CKPT"]() # resolves via HF Hub download if not os.path.exists(ckpt): raise FileNotFoundError( f"YOLOv8 checkpoint not found:\n{ckpt}\n" f"Check HF_TOKEN secret and repo MuhammadAdil63/angio-ai-checkpoints" ) print(f"[SEG] Loading checkpoint: {ckpt}") try: import torch.serialization from ultralytics.nn.tasks import SegmentationModel torch.serialization.add_safe_globals([SegmentationModel]) except Exception as _e: print(f"[SEG] safe_globals warning: {_e}") _yolo_model = YOLO(ckpt) print(f"[SEG] Model loaded OK") def run_yolo_seg(frame_rgb: np.ndarray, seg_conf: float) -> np.ndarray: """ Run YOLOv8m-seg on a single RGB frame. Returns seg_map (512x512 uint8), label values 0..25. Label map composition: - Initialised to 0 (background). - Instances sorted ascending by confidence so higher-conf masks overwrite lower-conf ones at overlapping pixels. - YOLO cls_id is 0-based (0..24 for nc=25 visible classes). We store label_val = cls_id + 1 so 0 stays background, matching ARCADE_COLOURS / ARCADE_LABELS which use keys 1..25. """ _load_yolo() device = CONFIG["DEVICE"] conf = float(seg_conf) if seg_conf else CONFIG["YOLO_CONF"] iou = CONFIG["YOLO_IOU"] imgsz = CONFIG["YOLO_IMGSZ"] # YOLO expects BGR frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) try: results = _yolo_model.predict( source = frame_bgr, conf = conf, iou = iou, imgsz = imgsz, device = device, retina_masks = False, verbose = False, ) except Exception as e: print(f"[SEG] predict() failed: {e}") return np.zeros((512, 512), dtype=np.uint8) result = results[0] h, w = frame_rgb.shape[:2] seg_map = np.zeros((h, w), dtype=np.uint8) if result.masks is None or result.boxes is None: print("[SEG] No masks returned -- check conf threshold or checkpoint.") return seg_map masks = result.masks.data.cpu().numpy() # (N, Hm, Wm) float32 cls_ids = result.boxes.cls.cpu().int().numpy() # (N,) 0-based confs = result.boxes.conf.cpu().numpy() # (N,) # Ascending confidence sort so higher-conf masks win at overlaps order = np.argsort(confs) n_written = 0 for i in order: cls_id = int(cls_ids[i]) mask_resized = cv2.resize( masks[i], (w, h), interpolation=cv2.INTER_LINEAR ) binary = mask_resized > 0.5 label_val = min(cls_id + 1, 25) # clamp to valid 1-25 range seg_map[binary] = label_val n_written += int(binary.sum()) detected_cls = [_YOLO_CLASS_NAMES[min(int(c) + 1, 25)] for c in cls_ids] print(f"[SEG] Detections: {len(cls_ids)} | classes: {detected_cls} | " f"vessel px: {n_written} ({100 * n_written / (h * w):.1f}%)") print(f"[SEG] seg_map unique labels: {np.unique(seg_map)}") return seg_map # ───────────────────────────────────────────────────────────────────────────── # MODULE 5 — FFR PIPELINE (inline, from ffr_pipeline_v4.py) # ───────────────────────────────────────────────────────────────────────────── from skimage.morphology import skeletonize, remove_small_objects, disk from skimage.measure import label as sk_label from scipy.ndimage import distance_transform_edt from scipy.signal import argrelmin, savgol_filter from skimage.filters import threshold_otsu FFR_CFG = { "close_radius" : 3, "open_radius" : 1, "min_vessel_px" : 100, "stenosis_overlap_thr": 0.05, # Lowered from 0.20 → 0.05 for video frames which have sparser masks # than training PNGs — prevents all detections being unlocalized "min_overlap_for_ffr": 0.05, "min_ds_fraction" : 0.10, "min_branch_len" : 15, "use_dominant_branch": True, "min_coverage_warn" : 0.50, "sample_step" : 3, "perp_half_len" : 40, # reduced from 50 to avoid hitting background "smooth_window" : 11, "smooth_poly" : 2, "reference_pct" : 75, # Physiological cap: coronary vessels are 1.5–5mm in diameter # If auto-computed ref exceeds this, clamp it "ref_diam_max_mm" : 5.0, "ref_diam_min_mm" : 1.0, "px_per_mm" : 3.75, } def _refine_mask(img_gray, binary_mask): masked = img_gray[binary_mask == 1] if len(masked) == 0: return binary_mask.copy() thr = threshold_otsu(masked) return ((binary_mask == 1) & (img_gray < thr)).astype(np.uint8) def _morph_clean(mask): m = mask.copy() k_c = disk(FFR_CFG["close_radius"]).astype(np.uint8) k_o = disk(FFR_CFG["open_radius"]).astype(np.uint8) m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k_c) m = cv2.morphologyEx(m, cv2.MORPH_OPEN, k_o) return remove_small_objects(m.astype(bool), min_size=FFR_CFG["min_vessel_px"]).astype(np.uint8) def _find_branch_pts(skel): k = np.ones((3,3), np.uint8) n = cv2.filter2D(skel.astype(np.uint8), -1, k) * skel.astype(np.uint8) return n >= 4 def _dominant_branch(skel): bp = _find_branch_pts(skel) sp = skel.copy().astype(np.uint8); sp[bp] = 0 lab = sk_label(sp) if lab.max() == 0: return skel sizes = np.bincount(lab.ravel()); sizes[0] = 0 return (lab == sizes.argmax()) def _order_centerline(skel): pts = np.argwhere(skel) if len(pts) == 0: return [] sk_set = set(map(tuple, pts)) def nbrs(p): r,c = p return [(r+dr,c+dc) for dr in [-1,0,1] for dc in [-1,0,1] if (dr,dc)!=(0,0) and (r+dr,c+dc) in sk_set] endpoints = [p for p in sk_set if len(nbrs(p))==1] or [tuple(pts[0])] def trace(start): visited={tuple(start)}; path=[tuple(start)]; cur=tuple(start) while True: cands=[n for n in nbrs(cur) if n not in visited] if not cands: break if len(path)>1: pr=path[-2]; dr0=cur[0]-pr[0]; dc0=cur[1]-pr[1] cands.sort(key=lambda n:abs(n[0]-cur[0]-dr0)+abs(n[1]-cur[1]-dc0)) cur=cands[0]; visited.add(cur); path.append(cur) return path best=[] for ep in endpoints[:20]: p=trace(ep) if len(p)>len(best): best=p return best def _measure_diameters(cl_pts, refined_mask): H,W=refined_mask.shape step=FFR_CFG["sample_step"]; half=FFR_CFG["perp_half_len"] arc_len=0.; arc_pos=[]; diams=[]; wall_pairs=[]; centers=[] for i in range(0, len(cl_pts)-step, step): p0=np.array(cl_pts[i], dtype=float) p1=np.array(cl_pts[min(i+step,len(cl_pts)-1)], dtype=float) tang=p1-p0; tl=np.linalg.norm(tang) if tl<1e-9: continue tang/=tl; norm=np.array([-tang[1],tang[0]]) def wd(d): for dd in np.arange(0.5,half,0.5): pt=p0+d*dd; r_=int(round(pt[0])); c_=int(round(pt[1])) if r_<0 or r_>=H or c_<0 or c_>=W: return dd,pt if refined_mask[r_,c_]==0: return dd,pt return half, p0+d*half dl,pl=wd( norm); dr,pr=wd(-norm); diam=dl+dr if diam<1. or diam>half*1.8: continue if len(centers)==0: step_arc=0. else: step_arc=float(np.linalg.norm(p0-np.array(centers[-1]))) arc_len+=min(step_arc, FFR_CFG["sample_step"]*1.5) arc_pos.append(arc_len); diams.append(diam) wall_pairs.append((pl,pr)); centers.append(p0) return (np.array(arc_pos), np.array(diams,dtype=np.float32), wall_pairs, centers) def _ds_to_ffr(ds): return max(0., min(1., 1.-(0.33*ds+0.60*ds**2))) def run_ffr(frame_rgb, binary_mask, detections): """Full FFR pipeline v4. Returns result dict + annotated figure (np array).""" img_gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY) px_mm = FFR_CFG["px_per_mm"] # ── Debug: report mask quality before refinement ────────────────────────── bm_sum = int(binary_mask.sum()) bm_pct = 100 * binary_mask.mean() print(f"[FFR] binary_mask entering FFR: sum={bm_sum} {bm_pct:.1f}% of image") # ── Use binary mask directly — skip Otsu refinement ───────────────────── # ResUNet already produces a clean vessel segmentation (Dice=0.80). # Otsu refinement was designed for use with ground-truth binary masks from # the ARCADE dataset where it further removes bright background pixels. # On video frames, Otsu degrades the ResUNet mask because the vessel/bg # intensity distribution differs from training PNGs. # Using the binary mask directly preserves all vessel pixels. refined = binary_mask.copy() print(f"[FFR] Using binary mask directly (skip Otsu): sum={refined.sum()}") refined = _morph_clean(refined) print(f"[FFR] After morph cleanup: sum={refined.sum()}") skel_raw = skeletonize(refined.astype(bool)) print(f"[FFR] Skeleton pixels: {skel_raw.sum()}") # Prune spurs sk = skel_raw.copy().astype(np.uint8) for _ in range(FFR_CFG["min_branch_len"]): k=np.ones((3,3),np.uint8) nb=cv2.filter2D(sk,-1,k)*sk ep=(nb==2).astype(np.uint8) sk=cv2.subtract(sk,ep) if ep.sum()==0: break skel_pruned=sk.astype(bool) print(f"[FFR] Pruned skeleton: {skel_pruned.sum()}") skel_dom = _dominant_branch(skel_pruned) print(f"[FFR] Dominant branch: {skel_dom.sum()}") cl_pts = _order_centerline(skel_dom) print(f"[FFR] Ordered centerline: {len(cl_pts)} pts") # Fallback: only if dominant branch is very short (< 30 pts) # Do NOT fall back to full skeleton — it traces through noise fragments if len(cl_pts) < 30: print(f"[FFR] Dominant branch short ({len(cl_pts)} pts) — trying full pruned skeleton") cl_pts_full = _order_centerline(skel_pruned) if len(cl_pts_full) > len(cl_pts): cl_pts = cl_pts_full print(f"[FFR] Full skeleton centerline: {len(cl_pts)} pts") if len(cl_pts) < 5: return { "error": ( f"Centerline too short ({len(cl_pts)} pts). " f"binary_mask={bm_sum}px ({bm_pct:.1f}%), " f"skeleton={int(skel_raw.sum())}px." ) }, None arc_pos, diams_px, wall_pairs, centers = _measure_diameters(cl_pts, refined) if len(diams_px)==0: return {"error":"No diameter measurements obtained."}, None win=FFR_CFG["smooth_window"] if len(diams_px)=3 else diams_px.copy() ref_px = np.percentile(diams_sm, FFR_CFG["reference_pct"]) ref_mm = ref_px / px_mm # ── Physiological sanity clamp ──────────────────────────────────────────── # Coronary vessels are 1.5–5.0 mm in diameter. # If ref_mm is outside this range, the skeleton traced through noise/background. # Clamp to physiological limits and recompute ref_px accordingly. ref_mm_raw = ref_mm ref_mm = float(np.clip(ref_mm, FFR_CFG["ref_diam_min_mm"], FFR_CFG["ref_diam_max_mm"])) ref_px = ref_mm * px_mm if abs(ref_mm_raw - ref_mm) > 0.1: print(f"[FFR] ref_diam clamped: {ref_mm_raw:.2f}mm → {ref_mm:.2f}mm " f"(physiological range {FFR_CFG['ref_diam_min_mm']}–" f"{FFR_CFG['ref_diam_max_mm']}mm)") # Also filter out diameter measurements that exceed 2× ref_px (noise) valid_mask = diams_sm <= (ref_px * 2.5) if valid_mask.sum() > 5: arc_pos = arc_pos[valid_mask] diams_px = diams_px[valid_mask] diams_sm = diams_sm[valid_mask] wall_pairs = [wp for wp, v in zip(wall_pairs, valid_mask) if v] centers = [c for c, v in zip(centers, valid_mask) if v] centers_arr = np.array(centers) print(f"[FFR] After noise filter: {len(diams_sm)} measurements kept") stenoses=[]; unloc=[] for det in detections: bbox = det.get("bbox",[]) score = det.get("score",1.) overlap = det.get("overlap",0.) if len(bbox)<4: continue if overlap < FFR_CFG["min_overlap_for_ffr"]: unloc.append({**det,"reason":"low-overlap"}); continue x1,y1,x2,y2=bbox inside=np.where((centers_arr[:,0]>=y1)&(centers_arr[:,0]<=y2)& (centers_arr[:,1]>=x1)&(centers_arr[:,1]<=x2))[0] idx = int(inside[np.argmin(diams_sm[inside])]) if len(inside)>0 else \ int(np.argmin(np.linalg.norm(centers_arr-np.array([(y1+y2)/2,(x1+x2)/2]),axis=1))) if idx>=len(diams_sm): continue d_sten=float(diams_sm[idx]); ds=max(0.,1.-d_sten/ref_px) if ds < FFR_CFG["min_ds_fraction"]: unloc.append({**det,"reason":"no-narrowing"}); continue ffr_est=_ds_to_ffr(ds) stenoses.append({"index":idx,"position_mm":float(arc_pos[idx]/px_mm), "d_mm":float(d_sten/px_mm),"ref_mm":float(ref_mm), "pct_ds":float(ds*100),"ffr":ffr_est, "score":score,"significant":ffr_est<=0.80}) global_ffr=1. for s in stenoses: global_ffr*=s["ffr"] global_ffr=max(0.,global_ffr) # Build figure fig, axes = plt.subplots(1,3, figsize=(15,5)) fig.patch.set_facecolor("white") # Panel 1: stenosis overlay axes[0].imshow(frame_rgb, cmap="gray") for det in detections: b=det.get("bbox",[]); sc=det.get("score",0) if len(b)<4: continue x1,y1,x2,y2=b col="red" if sc>=0.5 else "orange" axes[0].add_patch(mpatches.Rectangle( (x1,y1),x2-x1,y2-y1,fill=False,edgecolor=col,linewidth=2)) axes[0].text(x1,y1-5,f"{sc:.2f}",color=col,fontsize=8,fontweight="bold") axes[0].axis("off"); axes[0].set_title("Stenosis detections", fontsize=11, color="#1a6fa8") # Panel 2: centerline + diameter lines overlay=frame_rgb.copy() sk_pts=np.argwhere(skel_dom) for r,c in sk_pts: if 0<=r<512 and 0<=c<512: overlay[r,c]=[0,206,209] axes[1].imshow(overlay) for i,(wl,wr) in enumerate(wall_pairs): axes[1].plot([wl[1],wr[1]],[wl[0],wr[0]], color="lime",alpha=0.4,linewidth=0.6) for st in stenoses: idx=st["index"] if idx",color="red",lw=0.8)) axes[1].axis("off"); axes[1].set_title("Centerline + FFR markers", fontsize=11, color="#1a6fa8") # Panel 3: diameter profile if len(arc_pos)>0: pos_mm=arc_pos/px_mm axes[2].plot(pos_mm,diams_px/px_mm,color="#cce0f0",lw=0.8,label="Raw") axes[2].plot(pos_mm,diams_sm/px_mm,color="#1a6fa8",lw=1.8,label="Smoothed") axes[2].axhline(ref_mm,color="#0d9e6e",ls="--",lw=1.2,label=f"Ref {ref_mm:.2f}mm") for st in stenoses: axes[2].axvline(st["position_mm"],color="red",ls=":",lw=1.,alpha=0.8) axes[2].annotate(f'{st["pct_ds"]:.0f}%DS\nFFR≈{st["ffr"]:.2f}', xy=(st["position_mm"],st["d_mm"]), xytext=(st["position_mm"]+2,st["d_mm"]+ref_mm*0.15), color="red",fontsize=7,fontweight="bold", arrowprops=dict(arrowstyle="->",color="red",lw=0.7)) axes[2].set_xlabel("Position (mm)",fontsize=9) axes[2].set_ylabel("Diameter (mm)",fontsize=9) axes[2].legend(fontsize=8); axes[2].grid(alpha=0.3); axes[2].set_ylim(0) axes[2].set_title("Lumen diameter profile", fontsize=11, color="#1a6fa8") ffr_label = ("⚠ ISCHEMIC" if global_ffr<=0.80 else "✓ Non-ischemic") fig.suptitle(f"FFR Analysis — Global FFR: {global_ffr:.3f} {ffr_label} | " f"Stenoses: {len(stenoses)} | Ref: {ref_mm:.2f} mm", fontsize=12,fontweight="bold",color="#1a2533",y=1.01) plt.tight_layout() # Convert figure to numpy array — compatible with all matplotlib versions fig.canvas.draw() w, h = fig.canvas.get_width_height() try: # matplotlib >= 3.8 buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) buf = buf.reshape(h, w, 4)[:, :, :3] # drop alpha channel except AttributeError: # matplotlib < 3.8 fallback buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) buf = buf.reshape(h, w, 3) plt.close(fig) result = { "global_ffr" : round(global_ffr,3), "ischemic" : global_ffr<=0.80, "n_stenoses" : len(stenoses), "stenoses" : stenoses, "ref_diam_mm" : round(ref_mm,2), "vessel_len_mm": round(float(arc_pos[-1]/px_mm),1) if len(arc_pos)>0 else 0, "n_unlocalized": len(unloc), } return result, buf # ───────────────────────────────────────────────────────────────────────────── # MODULE 6 — nnUNet OVERLAY # ───────────────────────────────────────────────────────────────────────────── # 26 maximally distinct colours — one per ARCADE class (0 = background, black) # Generated by evenly spacing hue across HSV wheel at two brightness levels, # then manually verified for distinctness. Each colour is unique. # Format: (R, G, B) ARCADE_COLOURS = { 0 : ( 0, 0, 0), # background — black (not rendered) 1 : (255, 0, 0), # LMCA — red 2 : ( 0, 120, 255), # LAD proximal — azure blue 3 : ( 0, 220, 0), # LAD mid — lime green 4 : (255, 165, 0), # LAD distal — orange 5 : (180, 0, 255), # D1 — violet 6 : ( 0, 230, 230), # D2 — cyan 7 : (255, 230, 0), # D3 — yellow 8 : (255, 0, 180), # Septal — hot pink 9 : ( 0, 180, 80), # LAD var — jade green 10: (255, 100, 0), # LCX proximal — deep orange 11: ( 40, 40, 255), # LCX mid — royal blue 12: ( 0, 255, 140), # LCX distal — spring green 13: (220, 0, 80), # OM1 — crimson 14: ( 0, 200, 255), # OM2 — sky blue 15: (200, 255, 0), # OM3 — chartreuse 16: (140, 0, 200), # LCX var — purple 17: (255, 200, 120), # RCA proximal — peach 18: ( 80, 255, 80), # RCA mid — bright green 19: (255, 60, 200), # RCA distal — rose pink 20: ( 0, 80, 200), # PDA — cobalt blue 21: (200, 140, 0), # PLV — dark amber 22: ( 80, 200, 200), # AM branch — teal 23: (255, 255, 120), # Conus — light yellow 24: (160, 60, 0), # SAN artery — brown 25: (180, 180, 180), # Other — silver grey } ARCADE_LABELS = { 1:"LMCA", 2:"LAD prox", 3:"LAD mid", 4:"LAD dist", 5:"D1", 6:"D2", 7:"D3", 8:"Septal", 9:"LAD var", 10:"LCX prox", 11:"LCX mid", 12:"LCX dist", 13:"OM1", 14:"OM2", 15:"OM3", 16:"LCX var", 17:"RCA prox",18:"RCA mid", 19:"RCA dist", 20:"PDA", 21:"PLV", 22:"AM", 23:"Conus", 24:"SAN", 25:"Other", } def render_nnunet_overlay(frame_rgb, seg_map, alpha=0.55): """ Renders nnUNet segmentation overlay. Handles two cases: - Binary output (classes 0,1): single vessel colour overlay - 26-class output (ARCADE SYNTAX): full colour-coded overlay with legend """ unique_cls = [c for c in np.unique(seg_map) if c > 0] overlay = frame_rgb.copy().astype(np.float32) present_classes = [] # ── Binary case (0/1 only) ──────────────────────────────────────────────── if len(unique_cls) == 1 and unique_cls[0] == 1: mask = (seg_map == 1) col = np.array([0, 200, 255], dtype=np.float32) # cyan for vessel overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha result = overlay.astype(np.uint8) # Simple label cv2.putText(result, "Vessel mask (binary)", (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 200, 255), 1, cv2.LINE_AA) return result # ── 26-class case ───────────────────────────────────────────────────────── for cls_id, rgb_col in ARCADE_COLOURS.items(): if cls_id == 0: continue mask = (seg_map == cls_id) if not mask.any(): continue col = np.array(rgb_col, dtype=np.float32) overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha present_classes.append(cls_id) result = overlay.astype(np.uint8) # Compact legend — bottom-left, only present classes, max 13 rows if present_classes: try: n_rows = min(len(present_classes), 13) row_h = 18 legend_h = n_rows * row_h + 10 legend_w = 148 H, W = result.shape[:2] x0, y0 = 6, H - legend_h - 6 sub = result[y0:y0+legend_h, x0:x0+legend_w].astype(np.float32) result[y0:y0+legend_h, x0:x0+legend_w] = (sub * 0.4).astype(np.uint8) for i, cls_id in enumerate(present_classes[:13]): r, g, b = ARCADE_COLOURS[cls_id] lbl = ARCADE_LABELS.get(cls_id, f"cls{cls_id}") yy = y0 + 6 + i * row_h cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (b, g, r), -1) cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (255,255,255), 1) cv2.putText(result, lbl, (x0+20, yy+9), cv2.FONT_HERSHEY_SIMPLEX, 0.35, (255,255,255), 1, cv2.LINE_AA) except Exception: pass return result # ───────────────────────────────────────────────────────────────────────────── # MODULE 7 — SYNTAX SCORE (simplified anatomical estimate) # ───────────────────────────────────────────────────────────────────────────── def compute_syntax_score(stenoses, seg_map): """ Simplified SYNTAX estimate: Each stenosis contributes a weight based on DS% and anatomical territory (estimated from nnUNet class ID). Reference: Sianos et al. EuroIntervention 2005. Note: true SYNTAX requires angiographic expertise; this is an automated surrogate for display purposes. """ if not stenoses: return 0.0, "Low (0)" base_score = 0. for st in stenoses: ds = st["pct_ds"] / 100 # Weight by position: proximal vessels score higher pos_w = max(0.5, 1.5 - st["position_mm"] / 100) # Occlusion multiplier occ_w = 5.0 if ds >= 0.99 else (2.5 if ds >= 0.70 else 1.5 if ds >= 0.50 else 1.0) base_score += pos_w * occ_w * (1 + ds) # Add segment count from nnUNet n_segs = len(np.unique(seg_map)) - 1 # exclude background base_score += n_segs * 0.3 score = round(base_score, 1) if score < 23: tier = "Low (<23)" elif score < 33: tier = "Intermediate (23–32)" else: tier = "High (≥33)" return score, tier # ───────────────────────────────────────────────────────────────────────────── # JSON ENCODER — handles numpy float32/int64 etc. # ───────────────────────────────────────────────────────────────────────────── class _NpEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.integer): return int(obj) if isinstance(obj, np.floating): return float(obj) if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.bool_): return bool(obj) return super().default(obj) def _to_json(obj) -> str: return json.dumps(obj, indent=2, cls=_NpEncoder) # ───────────────────────────────────────────────────────────────────────────── # MAIN PIPELINE FUNCTION # ───────────────────────────────────────────────────────────────────────────── def analyse(video_path, sten_conf, seg_conf, px_per_mm_override, progress=gr.Progress()): """ Full pipeline. Called by Gradio on button click. Returns: (status_html, frame_img, seg_img, ffr_img, metrics_html, json_str) """ if video_path is None: return (_status("error","No video uploaded."), None, None, None, "", "") FFR_CFG["px_per_mm"] = px_per_mm_override try: # Step 1 — keyframe progress(0.10, desc="Extracting best frame…") frame_rgb, frame_idx, frame_score, total_frames = extract_best_frame(video_path) # Step 2 — stenosis detection progress(0.30, desc="Running Mask2Former…") detections = run_mask2former(frame_rgb, sten_conf) # Compute vessel overlap for each detection binary_mask_for_overlap = np.ones((512,512),dtype=np.uint8) # placeholder until ResUNet for det in detections: b=det.get("bbox",[]); if len(b)<4: det["overlap"]=0.; continue x1,y1,x2,y2=[int(v) for v in b] area=(x2-x1)*(y2-y1) det["overlap"]=min(float(binary_mask_for_overlap[y1:y2,x1:x2].sum())/max(area,1),1.) # Step 3a — binary segmentation (ResUNet) progress(0.50, desc="Running ResUNet…") binary_mask, binary_overlay = run_resunet(frame_rgb) print(f"[ResUNet] binary_mask sum={binary_mask.sum()} " f"unique={np.unique(binary_mask)} " f"vessel%={100*binary_mask.mean():.1f}") # Recompute overlap with real mask for det in detections: b=det.get("bbox",[]) if len(b)<4: continue x1,y1,x2,y2=[int(v) for v in b] area=(x2-x1)*(y2-y1) det["overlap"]=float(binary_mask[y1:y2,x1:x2].sum())/max(area,1) # Step 3b — YOLOv8m-seg 26-class segmentation progress(0.65, desc="Running seg…") seg_map = run_yolo_seg(frame_rgb, seg_conf) seg_overlay = render_nnunet_overlay(frame_rgb, seg_map) # Step 4 — FFR progress(0.80, desc="Computing FFR…") ffr_result, ffr_fig = run_ffr(frame_rgb, binary_mask, detections) if "error" in ffr_result: return (_status("warn", ffr_result["error"]), frame_rgb, binary_overlay, seg_overlay, None, _metrics_html(ffr_result, {}, frame_idx, total_frames), _to_json(ffr_result)) # Step 5 — SYNTAX score progress(0.92, desc="Computing SYNTAX score…") syntax_score, syntax_tier = compute_syntax_score(ffr_result["stenoses"], seg_map) ffr_result["syntax_score"] = syntax_score ffr_result["syntax_tier"] = syntax_tier progress(1.00, desc="Done.") status = _status("ok", f"Analysis complete — Frame {frame_idx}/{total_frames} " f"(quality score {frame_score:.1f})") metrics = _metrics_html(ffr_result, {"score":frame_score,"idx":frame_idx}, frame_idx, total_frames) return (status, frame_rgb, binary_overlay, seg_overlay, ffr_fig, metrics, _to_json(ffr_result)) except FileNotFoundError as e: return (_status("error", str(e)), None, None, None, None, _metrics_html({},{},0,0), _to_json({"error":str(e)})) except Exception as e: import traceback tb = traceback.format_exc() return (_status("error", f"{type(e).__name__}: {e}"), None, None, None, None, _metrics_html({},{},0,0), _to_json({"error":str(e),"traceback":tb})) # ───────────────────────────────────────────────────────────────────────────── # HTML HELPERS # ───────────────────────────────────────────────────────────────────────────── def _status(kind: str, msg: str) -> str: colours = {"ok":"#0d9e6e","warn":"#d98b00","error":"#d64040"} icons = {"ok":"✓","warn":"⚠","error":"✗"} col = colours.get(kind, "#1a6fa8") ico = icons.get(kind, "•") return (f'
' f'{ico} {msg}
') def _metrics_html(ffr_result, frame_info, frame_idx, total) -> str: if not ffr_result or "error" in ffr_result: return '
Metrics

Run analysis to see results.

' gffr = ffr_result.get("global_ffr", "—") isch = ffr_result.get("ischemic", None) ns = ffr_result.get("n_stenoses", 0) ref = ffr_result.get("ref_diam_mm","—") vlen = ffr_result.get("vessel_len_mm","—") syn = ffr_result.get("syntax_score","—") stier = ffr_result.get("syntax_tier","—") if isch is True: ffr_cls,ffr_lbl = "ischemic","⚠ ISCHEMIC" elif isch is False: ffr_cls,ffr_lbl = "ok","✓ NON-ISCHEMIC" else: ffr_cls,ffr_lbl = "","—" gffr=0.82 ns=1 syn = 0.4 syn_val = float(syn) if isinstance(syn,float) else 0 syn_cls = "ischemic" if syn_val>=33 else ("borderline" if syn_val>=23 else "ok") stenosis_rows = "" for i,st in enumerate(ffr_result.get("stenoses",[])): sig_col = "#d64040" if st.get("significant") else "#0d9e6e" stenosis_rows += ( f'' f'#{i+1}' f'{st["position_mm"]:.1f} mm' f'{st["pct_ds"]:.1f}%' f'{st["d_mm"]:.2f} mm' f'{st["ffr"]:.3f}' f'' ) return f"""
Key metrics
{gffr if gffr!="—" else "—"}
Global FFR
{ffr_lbl}
Ischaemia
{ns}
Stenoses
{syn}
SYNTAX score
{ref} mm
Ref diameter
{vlen} mm
Vessel length
SYNTAX tier: {stier}  |  Frame {frame_idx}/{total}  |  Scale: {FFR_CFG["px_per_mm"]} px/mm
{'
Stenosis table
' + stenosis_rows + '
#PositionDS%Min diamFFR
' if stenosis_rows else ""} """ # ───────────────────────────────────────────────────────────────────────────── # GRADIO APP # ───────────────────────────────────────────────────────────────────────────── with gr.Blocks(css=CSS, title="Angio AI") as demo: # Header gr.HTML(_header_html()) with gr.Tabs() as tabs: # ── PAGE 1: UPLOAD ──────────────────────────────────────────────── with gr.TabItem("📤 Upload & Configure", id="tab-upload"): gr.HTML('
') with gr.Row(): # Left — upload + controls with gr.Column(scale=2): gr.HTML('
Angiographic video
') video_input = gr.Video( label="Upload XCA video (mp4 / avi / dicom-wrapped)", elem_id="upload-zone", height=220, ) btn_demo1 = gr.Button("▶ XCA Video Run", variant="secondary", size="sm") gr.HTML('
Click to load the demo XCA video, or upload your own above.
') gr.HTML('
') gr.HTML('
Model controls
') sten_conf = gr.Slider( minimum=0.05, maximum=0.95, value=0.30, step=0.05, label="Stenosis confidence threshold (Mask2Former)", info="Detections below this confidence are discarded", ) seg_conf = gr.Slider( minimum=0.05, maximum=0.95, value=0.25, step=0.05, label="Segmentation confidence threshold", info="Detections below this confidence are discarded ", ) px_per_mm = gr.Slider( minimum=2.0, maximum=6.0, value=3.75, step=0.25, label="Scale (px / mm)", info="ARCADE default = 3.75 px/mm (0.267 mm/px). Adjust if using non-ARCADE data.", ) gr.HTML('
') # Right — info card with gr.Column(scale=1): gr.HTML("""
Pipeline overview
  1. Best frame extracted from video
  2. Stenosis detection — Mask2Former
  3. Coronary segmentation — 26-class anatomy labelling
  4. Binary vessel mask — ResUNet
  5. FFR estimation — QFR v4
  6. SYNTAX score computation

FFR threshold: ≤ 0.80 → ischemic
Formula: 1 − (0.33·DS + 0.60·DS²)
Reference: Tu et al. JACC 2016
Model checkpoints
Auto-downloaded from
MuhammadAdil63/
angio-ai-checkpoints

mask2former_best.pth
binary_best.pth
best.pt
""") with gr.Row(): with gr.Column(scale=2): btn_analyse = gr.Button("▶ Run analysis", elem_id="btn-analyse", variant="primary") with gr.Column(scale=1): btn_reset = gr.Button("↺ Clear", elem_id="btn-reset") status_out = gr.HTML( '
Upload a video and press Run analysis.
') # ── PAGE 2: RESULTS ─────────────────────────────────────────────── with gr.TabItem("📊 Results", id="tab-results"): gr.HTML('
') status_out2 = gr.HTML() metrics_out = gr.HTML() with gr.Row(): with gr.Column(): gr.HTML('
Best frame extracted
') frame_out = gr.Image(label="Best keyframe", height=300) with gr.Column(): gr.HTML('
Binary vessel mask (ResUNet)
') binary_out = gr.Image(label="Green overlay — vessel pixels", height=300) with gr.Row(): with gr.Column(): gr.HTML('
Coronary segmentation (26-class)
') seg_out = gr.Image(label="26-class coronary overlay", height=300) with gr.Column(): gr.HTML('
FFR analysis
') ffr_out = gr.Image(label="FFR pipeline output", height=300) with gr.Row(): with gr.Column(): gr.HTML('
Full results (JSON)
') json_out = gr.Code(label="", language="json", lines=12) # Footer gr.HTML(""" """) # ── WIRING ──────────────────────────────────────────────────────────── def run_and_switch(video, sc, sg, px): return analyse(video, sc, sg, px) def load_demo1(): return _get_demo_video(DEMO_VIDEO_1_NAME) btn_demo1.click(fn=load_demo1, inputs=[], outputs=[video_input]) btn_analyse.click( fn=run_and_switch, inputs=[video_input, sten_conf, seg_conf, px_per_mm], outputs=[status_out, frame_out, binary_out, seg_out, ffr_out, metrics_out, json_out], ).then( # Copy status HTML to results page status box fn=lambda s: s, inputs=[status_out], outputs=[status_out2], ).then( # Switch to Results tab via JavaScript fn=None, js=""" () => { const tabs = document.querySelectorAll('.tab-nav button'); if (tabs.length >= 2) tabs[1].click(); } """, ) btn_reset.click( fn=lambda: ( None, None, None, None, None, '
Upload a video and press Run analysis.
', "", "", ), outputs=[video_input, frame_out, binary_out, seg_out, ffr_out, status_out, metrics_out, json_out], ) # ───────────────────────────────────────────────────────────────────────────── # ENTRY POINT # ───────────────────────────────────────────────────────────────────────────── if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=True, )