Spaces:
Sleeping
Sleeping
| # ββ angio_ai_app.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Angio AI β Coronary Angiography Analysis System | |
| # Muhammad Adil Β· MS Data Science Β· ITU Lahore | |
| # Supervisor: Dr. Arif Mehmood | |
| # | |
| # ββ DEPLOYMENT β HuggingFace Spaces ββββββββββββββββββββββββββββββββββββββββββ | |
| # Checkpoints are loaded from HF Hub: MuhammadAdil63/angio-ai-checkpoints | |
| # Set HF_TOKEN as a Space Secret (Settings β Secrets) for private repo access. | |
| # Files expected in that repo: | |
| # mask2former_best.pth (Mask2Former Swin-Base fine-tuned) | |
| # binary_best.pth (ResUNet binary vessel segmentation) | |
| # best.pt (YOLOv8m-seg 26-class ARCADE syntax) | |
| # | |
| # ββ DEPENDENCIES ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # pip install gradio torch torchvision opencv-python-headless | |
| # scikit-image scipy matplotlib transformers | |
| # ultralytics (for YOLOv8m-seg inference) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Force PyTorch weights_only=False globally (must be before any imports) βββ | |
| import os as _os | |
| _os.environ["TORCH_FORCE_WEIGHTS_ONLY_LOAD"] = "0" | |
| # ββ Python 3.13 audioop shim (removed from stdlib, needed by pydub/gradio) βββ | |
| import sys, types | |
| if "audioop" not in sys.modules: | |
| sys.modules["audioop"] = types.ModuleType("audioop") | |
| if "pyaudioop" not in sys.modules: | |
| sys.modules["pyaudioop"] = types.ModuleType("pyaudioop") | |
| # ββ Patch gradio_client _json_schema_to_python_type (APIInfoParseError fix) ββ | |
| def _patch_gradio_client(): | |
| try: | |
| import gradio_client.utils as _gcu | |
| _orig = _gcu._json_schema_to_python_type | |
| def _safe(schema, defs=None): | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| try: | |
| return _orig(schema, defs) | |
| except Exception: | |
| return "Any" | |
| _gcu._json_schema_to_python_type = _safe | |
| _orig_top = _gcu.json_schema_to_python_type | |
| def _safe_top(schema, defs=None): | |
| try: | |
| return _orig_top(schema, defs) | |
| except Exception: | |
| return "Any" | |
| _gcu.json_schema_to_python_type = _safe_top | |
| except Exception: | |
| pass | |
| _patch_gradio_client() | |
| # ββ PyTorch 2.7 global weights_only patch ββββββββββββββββββββββββββββββββββββ | |
| # YOLO and other libraries call torch.load internally without weights_only=False. | |
| # Monkey-patch torch.load to default to weights_only=False for all calls. | |
| import torch as _torch | |
| _orig_torch_load = _torch.load | |
| def _patched_torch_load(f, map_location=None, pickle_module=None, | |
| weights_only=False, mmap=None, **kwargs): | |
| return _orig_torch_load(f, map_location=map_location, | |
| pickle_module=pickle_module, | |
| weights_only=False, **kwargs) | |
| _torch.load = _patched_torch_load | |
| import os | |
| import sys | |
| import json | |
| import base64 | |
| import tempfile | |
| import warnings | |
| warnings.filterwarnings("ignore") | |
| import cv2 | |
| import numpy as np | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| import matplotlib.patches as mpatches | |
| from pathlib import Path | |
| import gradio as gr | |
| import torch | |
| # ββ Patch torch.load for PyTorch 2.7 (weights_only=True default breaks YOLO) β | |
| _orig_tload = torch.load | |
| def _safe_tload(f, map_location=None, pickle_module=None, | |
| weights_only=None, mmap=None, **kwargs): | |
| return _orig_tload(f, map_location=map_location, weights_only=False, **kwargs) | |
| torch.load = _safe_tload | |
| import huggingface_hub # noqa: F401 (ensure installed; used in _hf_ckpt) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CONFIG β edit these paths before running | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ββ Checkpoint download from HuggingFace Hub βββββββββββββββββββββββββββββββββ | |
| # Checkpoints live in: MuhammadAdil63/angio-ai-checkpoints (private repo) | |
| # huggingface_hub reads HF_TOKEN from environment (set as a Space secret). | |
| def _hf_ckpt(filename: str) -> str: | |
| """Download checkpoint from HF Hub on first call, return local cache path.""" | |
| from huggingface_hub import hf_hub_download | |
| return hf_hub_download( | |
| repo_id = "MuhammadAdil63/angio-ai-checkpoints", | |
| filename = filename, | |
| repo_type= "model", | |
| token = os.environ.get("HF_TOKEN"), # set in Space Secrets | |
| ) | |
| CONFIG = { | |
| # ββ Checkpoint paths (resolved lazily via HF Hub) βββββββββββββββββββββββββ | |
| # These callables are invoked once inside each model's _load_*() function. | |
| "MASK2FORMER_CKPT" : lambda: _hf_ckpt("mask2former_best.pth"), | |
| "RESUNET_CKPT" : lambda: _hf_ckpt("binary_best.pth"), | |
| "YOLO_CKPT" : lambda: _hf_ckpt("best.pt"), | |
| # ββ YOLO inference params βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| "YOLO_CONF" : 0.25, | |
| "YOLO_IOU" : 0.70, | |
| "YOLO_IMGSZ" : 512, | |
| # ββ FFR pipeline scale (ARCADE hardcoded) ββββββββββββββββββββββββββββββββ | |
| "PX_PER_MM" : 3.75, | |
| # ββ Device β CPU on HF free tier βββββββββββββββββββββββββββββββββββββββββ | |
| "DEVICE" : "cuda:0" if torch.cuda.is_available() else "cpu", | |
| } | |
| # ββ Demo videos (hosted in HF Hub model repo alongside checkpoints) ββββββββββ | |
| # Upload your two XCA demo videos to MuhammadAdil63/angio-ai-checkpoints as: | |
| # demo_video_1.mp4 | |
| # demo_video_2.mp4 | |
| # The buttons below download and load them automatically. | |
| DEMO_VIDEO_1_NAME = "demo_video_1.mp4" | |
| DEMO_VIDEO_2_NAME = "demo_video_2.mp4" | |
| def _get_demo_video(filename: str) -> str: | |
| """Download demo video from HF Hub and return local path.""" | |
| try: | |
| from huggingface_hub import hf_hub_download | |
| path = hf_hub_download( | |
| repo_id = "MuhammadAdil63/angio-ai-checkpoints", | |
| filename = filename, | |
| repo_type= "model", | |
| token = os.environ.get("HF_TOKEN"), | |
| ) | |
| return path | |
| except Exception as e: | |
| print(f"[DEMO] Could not load {filename}: {e}") | |
| return None | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ITU LOGO β encode to base64 for embedding in HTML | |
| # Place ITULog.png in the same directory as this script. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _encode_logo(path: str) -> str: | |
| try: | |
| with open(path, "rb") as f: | |
| return "data:image/png;base64," + base64.b64encode(f.read()).decode() | |
| except FileNotFoundError: | |
| return "" | |
| LOGO_ITU = _encode_logo(str(Path(__file__).parent / "ITULog.png")) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # CSS THEME β white / bluish, minimalist, medical | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| CSS = """ | |
| /* ββ Base ββ */ | |
| :root { | |
| --primary: #1a6fa8; | |
| --primary-lt: #e8f3fb; | |
| --primary-mid: #4a9fd4; | |
| --accent: #0d9e6e; /* Angio AI green */ | |
| --bg: #f7fafd; | |
| --surface: #ffffff; | |
| --border: #cce0f0; | |
| --text: #1a2533; | |
| --text-sec: #5a7a96; | |
| --danger: #d64040; | |
| --warn: #d98b00; | |
| --radius: 10px; | |
| --shadow: 0 2px 12px rgba(26,111,168,0.10); | |
| } | |
| body, .gradio-container { | |
| background: var(--bg) !important; | |
| font-family: 'Segoe UI', system-ui, sans-serif !important; | |
| color: var(--text) !important; | |
| } | |
| /* ββ Header bar ββ */ | |
| #header { | |
| background: var(--surface); | |
| border-bottom: 2px solid var(--border); | |
| padding: 14px 28px; | |
| display: flex; | |
| align-items: center; | |
| justify-content: space-between; | |
| box-shadow: var(--shadow); | |
| margin-bottom: 0 !important; | |
| } | |
| #logo-itu img { height: 48px; } | |
| #logo-angio { | |
| font-size: 26px; | |
| font-weight: 700; | |
| color: var(--accent); | |
| letter-spacing: 1px; | |
| } | |
| #logo-angio span { | |
| font-size: 13px; | |
| color: var(--text-sec); | |
| font-weight: 400; | |
| display: block; | |
| letter-spacing: 0; | |
| } | |
| /* ββ Tab bar ββ */ | |
| .tab-nav button { | |
| color: var(--text-sec) !important; | |
| border-bottom: 3px solid transparent !important; | |
| font-size: 15px !important; | |
| font-weight: 500 !important; | |
| padding: 10px 24px !important; | |
| background: transparent !important; | |
| } | |
| .tab-nav button.selected { | |
| color: var(--primary) !important; | |
| border-bottom-color: var(--primary) !important; | |
| } | |
| /* ββ Cards ββ */ | |
| .card { | |
| background: var(--surface); | |
| border: 1px solid var(--border); | |
| border-radius: var(--radius); | |
| padding: 22px; | |
| box-shadow: var(--shadow); | |
| margin-bottom: 16px; | |
| } | |
| .card-title { | |
| font-size: 14px; | |
| font-weight: 600; | |
| color: var(--primary); | |
| text-transform: uppercase; | |
| letter-spacing: 0.5px; | |
| margin-bottom: 12px; | |
| padding-bottom: 8px; | |
| border-bottom: 1px solid var(--primary-lt); | |
| } | |
| /* ββ Upload zone ββ */ | |
| #upload-zone { | |
| border: 2px dashed var(--primary-mid) !important; | |
| border-radius: var(--radius) !important; | |
| background: var(--primary-lt) !important; | |
| min-height: 200px !important; | |
| } | |
| /* ββ Sliders ββ */ | |
| input[type=range] { accent-color: var(--primary); } | |
| .gradio-slider label { color: var(--text-sec) !important; font-size: 13px !important; } | |
| /* ββ Buttons ββ */ | |
| #btn-analyse { | |
| background: var(--primary) !important; | |
| color: white !important; | |
| border: none !important; | |
| border-radius: var(--radius) !important; | |
| font-size: 15px !important; | |
| font-weight: 600 !important; | |
| padding: 12px 32px !important; | |
| cursor: pointer !important; | |
| width: 100% !important; | |
| transition: background 0.2s !important; | |
| } | |
| #btn-analyse:hover { background: #155d90 !important; } | |
| #btn-reset { | |
| background: transparent !important; | |
| color: var(--text-sec) !important; | |
| border: 1px solid var(--border) !important; | |
| border-radius: var(--radius) !important; | |
| font-size: 13px !important; | |
| padding: 8px 20px !important; | |
| cursor: pointer !important; | |
| width: 100% !important; | |
| } | |
| /* ββ Metric badges ββ */ | |
| .metric-row { | |
| display: flex; | |
| gap: 12px; | |
| flex-wrap: wrap; | |
| margin-bottom: 12px; | |
| } | |
| .metric-badge { | |
| flex: 1; | |
| min-width: 110px; | |
| background: var(--primary-lt); | |
| border: 1px solid var(--border); | |
| border-radius: 8px; | |
| padding: 10px 14px; | |
| text-align: center; | |
| } | |
| .metric-badge .val { | |
| font-size: 22px; | |
| font-weight: 700; | |
| color: var(--primary); | |
| } | |
| .metric-badge .val.ischemic { color: var(--danger); } | |
| .metric-badge .val.borderline { color: var(--warn); } | |
| .metric-badge .val.ok { color: var(--accent); } | |
| .metric-badge .lbl { | |
| font-size: 11px; | |
| color: var(--text-sec); | |
| text-transform: uppercase; | |
| letter-spacing: 0.4px; | |
| } | |
| /* ββ Result images ββ */ | |
| #result-images img { | |
| border-radius: var(--radius); | |
| border: 1px solid var(--border); | |
| } | |
| /* ββ Status ββ */ | |
| #status-box { | |
| padding: 10px 16px; | |
| border-radius: 8px; | |
| font-size: 13px; | |
| font-weight: 500; | |
| border: 1px solid var(--border); | |
| background: var(--primary-lt); | |
| color: var(--primary); | |
| } | |
| /* ββ Footer ββ */ | |
| #footer { | |
| text-align: center; | |
| font-size: 12px; | |
| color: var(--text-sec); | |
| padding: 20px; | |
| border-top: 1px solid var(--border); | |
| margin-top: 24px; | |
| } | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HEADER HTML | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _header_html() -> str: | |
| itu_tag = (f'<img src="{LOGO_ITU}" alt="ITU Logo" style="height:48px">' | |
| if LOGO_ITU else | |
| '<span style="font-weight:700;color:#1a6fa8;font-size:18px">ITU</span>') | |
| return f""" | |
| <div id="header"> | |
| <div id="logo-itu">{itu_tag}</div> | |
| <div id="logo-angio"> | |
| Angio AI | |
| <span>Coronary Analysis System Β· ITU Lahore</span> | |
| </div> | |
| </div>""" | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 1 β KEYFRAME EXTRACTION (from app.py) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_best_frame(video_path: str, n: int = 1) -> tuple: | |
| """ | |
| Extract the single best frame from the video. | |
| Returns (frame_rgb_512, frame_index, score). | |
| """ | |
| cap = cv2.VideoCapture(video_path) | |
| total = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| if total == 0: | |
| cap.release() | |
| raise ValueError("Could not read video β check format/codec.") | |
| step = max(1, total // min(120, total)) | |
| scores = [] | |
| fidx = 0 | |
| while cap.isOpened(): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| if fidx % step == 0: | |
| gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) | |
| h, w = gray.shape | |
| roi = gray[int(h*0.2):int(h*0.8), int(w*0.2):int(w*0.8)] | |
| contrast = float(roi.mean()) | |
| sharpness = float(cv2.Laplacian(roi, cv2.CV_64F).var()) | |
| score = contrast * 0.6 + min(sharpness, 5000) / 5000 * 100 * 0.4 | |
| scores.append((score, fidx, frame.copy())) | |
| fidx += 1 | |
| cap.release() | |
| if not scores: | |
| raise ValueError("No frames could be sampled from video.") | |
| scores.sort(key=lambda x: x[0], reverse=True) | |
| best_score, best_idx, best_frame = scores[0] | |
| rgb = cv2.cvtColor(best_frame, cv2.COLOR_BGR2RGB) | |
| rgb = cv2.resize(rgb, (512, 512)) | |
| return rgb, best_idx, best_score, total | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 2 β MASK2FORMER INFERENCE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _m2f_model = None | |
| _m2f_proc = None | |
| def _load_mask2former(): | |
| global _m2f_model, _m2f_proc | |
| if _m2f_model is not None: | |
| return | |
| from transformers import (AutoImageProcessor, | |
| Mask2FormerForUniversalSegmentation) | |
| ckpt = CONFIG["MASK2FORMER_CKPT"]() # resolves via HF Hub download | |
| if not os.path.exists(ckpt): | |
| raise FileNotFoundError(f"Mask2Former checkpoint not found:\n{ckpt}") | |
| _m2f_proc = AutoImageProcessor.from_pretrained( | |
| "facebook/mask2former-swin-base-coco-instance") | |
| _m2f_model = Mask2FormerForUniversalSegmentation.from_pretrained( | |
| "facebook/mask2former-swin-base-coco-instance", | |
| num_labels=2, ignore_mismatched_sizes=True) | |
| state = torch.load(ckpt, map_location="cpu", weights_only=False) | |
| weights = state.get("model_state", state) | |
| _m2f_model.load_state_dict(weights, strict=False) | |
| _m2f_model.to(CONFIG["DEVICE"]).eval() | |
| def run_mask2former(frame_rgb: np.ndarray, conf_thr: float) -> list: | |
| """ | |
| Returns list of dicts: [{bbox:[x1,y1,x2,y2], score:float, mask:np.ndarray}, ...] | |
| """ | |
| _load_mask2former() | |
| from PIL import Image | |
| pil = Image.fromarray(frame_rgb) | |
| inputs = _m2f_proc(images=pil, return_tensors="pt") | |
| inputs = {k: v.to(CONFIG["DEVICE"]) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = _m2f_model(**inputs) | |
| results = _m2f_proc.post_process_instance_segmentation( | |
| outputs, target_sizes=[(512, 512)])[0] | |
| detections = [] | |
| for seg_info in results["segments_info"]: | |
| score = float(seg_info["score"]) | |
| if score < conf_thr: | |
| continue | |
| mask = (results["segmentation"] == seg_info["id"]).cpu().numpy().astype(np.uint8) | |
| ys, xs = np.where(mask) | |
| if len(ys) == 0: | |
| continue | |
| x1, y1, x2, y2 = int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max()) | |
| detections.append({"bbox": [x1, y1, x2, y2], "score": score, "mask": mask}) | |
| return detections | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 3 β RESUNET BINARY SEGMENTATION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _resunet_model = None | |
| # ββ ResUNet architecture β exactly matching binary_best.pth ββββββββββββββββββ | |
| # | |
| # Decoded from checkpoint key names via dump_keys.py: | |
| # | |
| # enc0 Conv2d(1, 16, 3, 3, bias=True) plain entry conv | |
| # down1..4 ConvTranspose2d strided Γ2 downsampler 16β32β64β128β256 | |
| # res1..4 _ResBlock(ch) encoder residual blocks ch=16,32,64,128 | |
| # bot.0..2 _ResBlock(256) Γ 3 bottleneck | |
| # up3..0 ConvTranspose2d Γ2 upsampler 256β128β64β32β16 | |
| # res12..15 _ResBlock(ch) decoder residual blocks ch=128,64,32,16 | |
| # head Sequential[Conv3Γ3, ReLU, BN, output head | |
| # Conv3Γ3, ReLU, BN, Conv1Γ1] | |
| # | |
| # _ResBlock structure (from key pattern res*.bn_in, res*.units.N.*, res*.bn_out): | |
| # bn_in β BN | |
| # units β ModuleList of N sub-units, each: [Conv3Γ3, ReLU, Conv3Γ3, BN] | |
| # bn_out β BN | |
| # forward: x β bn_in β for each unit: residual(x + unit(x)) β bn_out | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _ResBlock(torch.nn.Module): | |
| """ | |
| Matches checkpoint pattern: | |
| res*.bn_in.* | |
| res*.units.0.0.weight (Conv3Γ3) | |
| res*.units.0.0.bias | |
| res*.units.0.2.weight (Conv3Γ3) | |
| res*.units.0.2.bias | |
| res*.units.0.3.* (BN) | |
| ... (3 units total) | |
| res*.bn_out.* | |
| Sequential index mapping inside each unit: | |
| 0 = Conv2d | |
| 1 = ReLU (no params β not in state dict) | |
| 2 = Conv2d | |
| 3 = BatchNorm2d | |
| """ | |
| def __init__(self, ch, n_units=3): | |
| super().__init__() | |
| self.bn_in = torch.nn.BatchNorm2d(ch) | |
| self.units = torch.nn.ModuleList([ | |
| torch.nn.Sequential( | |
| torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True), | |
| torch.nn.ReLU(inplace=True), | |
| torch.nn.Conv2d(ch, ch, 3, padding=1, bias=True), | |
| torch.nn.BatchNorm2d(ch), | |
| ) | |
| for _ in range(n_units) | |
| ]) | |
| self.bn_out = torch.nn.BatchNorm2d(ch) | |
| def forward(self, x): | |
| x = self.bn_in(x) | |
| for unit in self.units: | |
| x = x + unit(x) # no extra ReLU on residual sum | |
| return self.bn_out(x) | |
| class ResUNet(torch.nn.Module): | |
| """ | |
| Exact architecture matching binary_best.pth (Dice = 0.8015). | |
| Key name mapping: | |
| enc0 β self.enc0 (plain Conv2d) | |
| down1..4 β self.down1..down4 (ConvTranspose2d strided downsampler) | |
| res1..4 β self.res1..res4 (encoder ResBlocks) | |
| bot.0..2 β self.bot[0..2] (bottleneck ResBlocks) | |
| up3..0 β self.up3..up0 (ConvTranspose2d upsampler) | |
| res12..15 β self.res12..res15 (decoder ResBlocks) | |
| head β self.head (Sequential output head) | |
| """ | |
| def __init__(self): | |
| super().__init__() | |
| # ββ Entry conv ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| self.enc0 = torch.nn.Conv2d(1, 16, 3, padding=1, bias=True) | |
| # ββ Encoder downsamples β strided Conv2d (not ConvTranspose2d) βββββββββ | |
| # Checkpoint shape (out, in, kH, kW): down1=(32,16,2,2) β Conv2d(16β32) | |
| # stride=2 kernel=2 halves spatial dimensions while increasing channels | |
| self.down1 = torch.nn.Conv2d(16, 32, 2, stride=2) | |
| self.down2 = torch.nn.Conv2d(32, 64, 2, stride=2) | |
| self.down3 = torch.nn.Conv2d(64, 128, 2, stride=2) | |
| self.down4 = torch.nn.Conv2d(128, 256, 2, stride=2) | |
| # ββ Encoder residual blocks βββββββββββββββββββββββββββββββββββββββββββ | |
| self.res1 = _ResBlock(16) | |
| self.res2 = _ResBlock(32) | |
| self.res3 = _ResBlock(64) | |
| self.res4 = _ResBlock(128) | |
| # ββ Bottleneck ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| self.bot = torch.nn.ModuleList([ | |
| _ResBlock(256), | |
| _ResBlock(256), | |
| _ResBlock(256), | |
| ]) | |
| # ββ Decoder upsamplers ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Shape from keys: up3=(256,128,2,2) bias=(128,) | |
| # up2=(128,64,2,2) bias=(64,) | |
| # up1=(64,32,2,2) bias=(32,) | |
| # up0=(32,16,2,2) bias=(16,) | |
| self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2) | |
| self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2) | |
| self.up1 = torch.nn.ConvTranspose2d(64, 32, 2, stride=2) | |
| self.up0 = torch.nn.ConvTranspose2d(32, 16, 2, stride=2) | |
| # ββ Decoder residual blocks βββββββββββββββββββββββββββββββββββββββββββ | |
| self.res12 = _ResBlock(128) | |
| self.res13 = _ResBlock(64) | |
| self.res14 = _ResBlock(32) | |
| self.res15 = _ResBlock(16) | |
| # ββ Output head βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # From keys: head.0=Conv(16,16,3,3), head.2=BN(16), head.3=Conv(16,16,3,3), | |
| # head.5=BN(16), head.6=Conv(2,16,1,1) | |
| # Index 1 and 4 are ReLU (no state dict entries) | |
| self.head = torch.nn.Sequential( | |
| torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 0 | |
| torch.nn.ReLU(inplace=True), # 1 | |
| torch.nn.BatchNorm2d(16), # 2 | |
| torch.nn.Conv2d(16, 16, 3, padding=1, bias=True), # 3 | |
| torch.nn.ReLU(inplace=True), # 4 | |
| torch.nn.BatchNorm2d(16), # 5 | |
| torch.nn.Conv2d(16, 2, 1, bias=True), # 6 | |
| ) | |
| def forward(self, x): | |
| # ββ Encoder ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Entry conv β ResBlock β down. Skip taken AFTER ResBlock. | |
| x = self.enc0(x) # (B,16,512,512) | |
| r1 = self.res1(x); x = self.down1(r1) # skip r1, down β (B,32,256,256) | |
| r2 = self.res2(x); x = self.down2(r2) # skip r2, down β (B,64,128,128) | |
| r3 = self.res3(x); x = self.down3(r3) # skip r3, down β (B,128,64,64) | |
| r4 = self.res4(x); x = self.down4(r4) # skip r4, down β (B,256,32,32) | |
| # ββ Bottleneck βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for blk in self.bot: | |
| x = blk(x) | |
| # ββ Decoder β ResBlock first, then add skip βββββββββββββββββββββββββββ | |
| x = self.res12(self.up3(x)) + r4 # (B,128,64,64) | |
| x = self.res13(self.up2(x)) + r3 # (B,64,128,128) | |
| x = self.res14(self.up1(x)) + r2 # (B,32,256,256) | |
| x = self.res15(self.up0(x)) + r1 # (B,16,512,512) | |
| return self.head(x) | |
| def _load_resunet(): | |
| global _resunet_model | |
| if _resunet_model is not None: | |
| return | |
| ckpt = CONFIG["RESUNET_CKPT"]() # resolves via HF Hub download | |
| if not os.path.exists(ckpt): | |
| raise FileNotFoundError(f"ResUNet checkpoint not found:\n{ckpt}") | |
| _resunet_model = ResUNet() | |
| state = torch.load(ckpt, map_location="cpu", weights_only=False) | |
| weights = state.get("model_state", state) | |
| # Try strict first; if it fails print exact mismatches and retry strict=False | |
| try: | |
| _resunet_model.load_state_dict(weights, strict=True) | |
| print(f"[ResUNet] β Perfect load (strict=True) β " | |
| f"epoch={state.get('epoch','?')} " | |
| f"dice={state.get('best_val_dice',0):.4f}") | |
| except RuntimeError as e: | |
| print(f"[ResUNet] strict=True failed:\n {e}") | |
| result = _resunet_model.load_state_dict(weights, strict=False) | |
| print(f"[ResUNet] strict=False β " | |
| f"missing={len(result.missing_keys)} " | |
| f"unexpected={len(result.unexpected_keys)}") | |
| for k in result.missing_keys[:10]: | |
| print(f" MISSING {k}") | |
| for k in result.unexpected_keys[:10]: | |
| print(f" UNEXPECTED {k}") | |
| _resunet_model.to(CONFIG["DEVICE"]).eval() | |
| print(f"[ResUNet] Loaded (eval mode) β " | |
| f"epoch={state.get('epoch','?')} " | |
| f"dice={state.get('best_val_dice',0):.4f}") | |
| def _preprocess_resunet(rgb: np.ndarray) -> torch.Tensor: | |
| """ARCADE preprocessing: invert β top-hat β subtract β CLAHE β normalise.""" | |
| gray = cv2.cvtColor(rgb, cv2.COLOR_RGB2GRAY).astype(np.uint8) | |
| inv = cv2.bitwise_not(gray) | |
| se = cv2.getStructuringElement(cv2.MORPH_RECT, (50, 50)) | |
| th = cv2.morphologyEx(inv, cv2.MORPH_TOPHAT, se) | |
| sub = cv2.subtract(gray, th) | |
| clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8)) | |
| proc = clahe.apply(sub).astype(np.float32) / 255.0 | |
| t = torch.from_numpy(proc).unsqueeze(0).unsqueeze(0) # (1,1,512,512) | |
| return t | |
| def run_resunet(frame_rgb: np.ndarray): | |
| """ | |
| Returns: | |
| binary_mask : (512,512) uint8, values 0/1 | |
| binary_overlay: (512,512,3) uint8 RGB β green vessel overlay on original | |
| Architecture fix applied: skip connections are now res(up(x)) + skip | |
| (not res(up(x) + skip)), and ResBlock has no extra ReLU on residual sum. | |
| Model runs in eval() mode β BN running stats are now correct. | |
| """ | |
| _load_resunet() | |
| _resunet_model.eval() | |
| t = _preprocess_resunet(frame_rgb).to(CONFIG["DEVICE"]) | |
| with torch.no_grad(): | |
| logits = _resunet_model(t) | |
| pred = logits.argmax(dim=1).squeeze().cpu().numpy().astype(np.uint8) | |
| vessel_px = int((pred == 1).sum()) | |
| print(f"[ResUNet] vessel px={vessel_px} ({100*(pred==1).mean():.1f}%)") | |
| # Build green overlay for display in Results tab | |
| overlay = frame_rgb.copy().astype(np.float32) | |
| green = np.array([0, 255, 80], dtype=np.float32) | |
| mask3d = pred == 1 | |
| overlay[mask3d] = overlay[mask3d] * 0.35 + green * 0.65 | |
| overlay = overlay.astype(np.uint8) | |
| # Draw contours for sharper vessel edges | |
| contours, _ = cv2.findContours(pred, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) | |
| cv2.drawContours(overlay, contours, -1, (0, 255, 80), 1) | |
| # Coverage label | |
| cov = 100 * vessel_px / pred.size | |
| cv2.putText(overlay, f"Vessel coverage: {cov:.1f}%", | |
| (8, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 80), 1, cv2.LINE_AA) | |
| return pred, overlay | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 4 β YOLOv8m-seg INFERENCE (replaces nnUNet) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # | |
| # Checkpoint : CONFIG["YOLO_CKPT"] -> best.pt from YOLOv8m-seg training | |
| # Model : YOLOv8m-seg, nc=26, trained on ARCADE syntax dataset | |
| # | |
| # Input : frame_rgb -- (512,512,3) uint8 RGB | |
| # seg_conf -- confidence threshold from UI slider | |
| # Output: seg_map -- (512,512) uint8 label map, values 0..25 | |
| # 0 = background (no detection) | |
| # 1..25 = ARCADE class index | |
| # | |
| # Output format is identical to the old run_nnunet() so that | |
| # render_nnunet_overlay() and compute_syntax_score() work unchanged. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ARCADE 26-class name list (index = YOLO cls_id + 1) | |
| _YOLO_CLASS_NAMES = [ | |
| "background", # 0 (background, not predicted by YOLO) | |
| "LMCA", # 1 | |
| "LAD prox", # 2 | |
| "LAD mid", # 3 | |
| "LAD dist", # 4 | |
| "D1", # 5 | |
| "D2", # 6 | |
| "D3", # 7 | |
| "Septal", # 8 | |
| "LAD var", # 9 | |
| "LCX prox", # 10 | |
| "LCX mid", # 11 | |
| "LCX dist", # 12 | |
| "OM1", # 13 | |
| "OM2", # 14 | |
| "OM3", # 15 | |
| "LCX var", # 16 | |
| "RCA prox", # 17 | |
| "RCA mid", # 18 | |
| "RCA dist", # 19 | |
| "PDA", # 20 | |
| "PLV", # 21 | |
| "AM", # 22 | |
| "Conus", # 23 | |
| "SAN", # 24 | |
| "Other", # 25 | |
| ] | |
| _yolo_model = None # singleton -- loaded once on first call, reused every frame | |
| def _load_yolo(): | |
| """Lazy-load YOLOv8m-seg checkpoint into module-level singleton.""" | |
| global _yolo_model | |
| if _yolo_model is not None: | |
| return | |
| try: | |
| from ultralytics import YOLO | |
| except ImportError: | |
| raise ImportError( | |
| "ultralytics not installed. Run: pip install ultralytics" | |
| ) | |
| ckpt = CONFIG["YOLO_CKPT"]() # resolves via HF Hub download | |
| if not os.path.exists(ckpt): | |
| raise FileNotFoundError( | |
| f"YOLOv8 checkpoint not found:\n{ckpt}\n" | |
| f"Check HF_TOKEN secret and repo MuhammadAdil63/angio-ai-checkpoints" | |
| ) | |
| print(f"[SEG] Loading checkpoint: {ckpt}") | |
| try: | |
| import torch.serialization | |
| from ultralytics.nn.tasks import SegmentationModel | |
| torch.serialization.add_safe_globals([SegmentationModel]) | |
| except Exception as _e: | |
| print(f"[SEG] safe_globals warning: {_e}") | |
| _yolo_model = YOLO(ckpt) | |
| print(f"[SEG] Model loaded OK") | |
| def run_yolo_seg(frame_rgb: np.ndarray, seg_conf: float) -> np.ndarray: | |
| """ | |
| Run YOLOv8m-seg on a single RGB frame. | |
| Returns seg_map (512x512 uint8), label values 0..25. | |
| Label map composition: | |
| - Initialised to 0 (background). | |
| - Instances sorted ascending by confidence so higher-conf masks | |
| overwrite lower-conf ones at overlapping pixels. | |
| - YOLO cls_id is 0-based (0..24 for nc=25 visible classes). | |
| We store label_val = cls_id + 1 so 0 stays background, | |
| matching ARCADE_COLOURS / ARCADE_LABELS which use keys 1..25. | |
| """ | |
| _load_yolo() | |
| device = CONFIG["DEVICE"] | |
| conf = float(seg_conf) if seg_conf else CONFIG["YOLO_CONF"] | |
| iou = CONFIG["YOLO_IOU"] | |
| imgsz = CONFIG["YOLO_IMGSZ"] | |
| # YOLO expects BGR | |
| frame_bgr = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2BGR) | |
| try: | |
| results = _yolo_model.predict( | |
| source = frame_bgr, | |
| conf = conf, | |
| iou = iou, | |
| imgsz = imgsz, | |
| device = device, | |
| retina_masks = False, | |
| verbose = False, | |
| ) | |
| except Exception as e: | |
| print(f"[SEG] predict() failed: {e}") | |
| return np.zeros((512, 512), dtype=np.uint8) | |
| result = results[0] | |
| h, w = frame_rgb.shape[:2] | |
| seg_map = np.zeros((h, w), dtype=np.uint8) | |
| if result.masks is None or result.boxes is None: | |
| print("[SEG] No masks returned -- check conf threshold or checkpoint.") | |
| return seg_map | |
| masks = result.masks.data.cpu().numpy() # (N, Hm, Wm) float32 | |
| cls_ids = result.boxes.cls.cpu().int().numpy() # (N,) 0-based | |
| confs = result.boxes.conf.cpu().numpy() # (N,) | |
| # Ascending confidence sort so higher-conf masks win at overlaps | |
| order = np.argsort(confs) | |
| n_written = 0 | |
| for i in order: | |
| cls_id = int(cls_ids[i]) | |
| mask_resized = cv2.resize( | |
| masks[i], (w, h), interpolation=cv2.INTER_LINEAR | |
| ) | |
| binary = mask_resized > 0.5 | |
| label_val = min(cls_id + 1, 25) # clamp to valid 1-25 range | |
| seg_map[binary] = label_val | |
| n_written += int(binary.sum()) | |
| detected_cls = [_YOLO_CLASS_NAMES[min(int(c) + 1, 25)] for c in cls_ids] | |
| print(f"[SEG] Detections: {len(cls_ids)} | classes: {detected_cls} | " | |
| f"vessel px: {n_written} ({100 * n_written / (h * w):.1f}%)") | |
| print(f"[SEG] seg_map unique labels: {np.unique(seg_map)}") | |
| return seg_map | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 5 β FFR PIPELINE (inline, from ffr_pipeline_v4.py) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| from skimage.morphology import skeletonize, remove_small_objects, disk | |
| from skimage.measure import label as sk_label | |
| from scipy.ndimage import distance_transform_edt | |
| from scipy.signal import argrelmin, savgol_filter | |
| from skimage.filters import threshold_otsu | |
| FFR_CFG = { | |
| "close_radius" : 3, | |
| "open_radius" : 1, | |
| "min_vessel_px" : 100, | |
| "stenosis_overlap_thr": 0.05, | |
| # Lowered from 0.20 β 0.05 for video frames which have sparser masks | |
| # than training PNGs β prevents all detections being unlocalized | |
| "min_overlap_for_ffr": 0.05, | |
| "min_ds_fraction" : 0.10, | |
| "min_branch_len" : 15, | |
| "use_dominant_branch": True, | |
| "min_coverage_warn" : 0.50, | |
| "sample_step" : 3, | |
| "perp_half_len" : 40, # reduced from 50 to avoid hitting background | |
| "smooth_window" : 11, | |
| "smooth_poly" : 2, | |
| "reference_pct" : 75, | |
| # Physiological cap: coronary vessels are 1.5β5mm in diameter | |
| # If auto-computed ref exceeds this, clamp it | |
| "ref_diam_max_mm" : 5.0, | |
| "ref_diam_min_mm" : 1.0, | |
| "px_per_mm" : 3.75, | |
| } | |
| def _refine_mask(img_gray, binary_mask): | |
| masked = img_gray[binary_mask == 1] | |
| if len(masked) == 0: | |
| return binary_mask.copy() | |
| thr = threshold_otsu(masked) | |
| return ((binary_mask == 1) & (img_gray < thr)).astype(np.uint8) | |
| def _morph_clean(mask): | |
| m = mask.copy() | |
| k_c = disk(FFR_CFG["close_radius"]).astype(np.uint8) | |
| k_o = disk(FFR_CFG["open_radius"]).astype(np.uint8) | |
| m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k_c) | |
| m = cv2.morphologyEx(m, cv2.MORPH_OPEN, k_o) | |
| return remove_small_objects(m.astype(bool), | |
| min_size=FFR_CFG["min_vessel_px"]).astype(np.uint8) | |
| def _find_branch_pts(skel): | |
| k = np.ones((3,3), np.uint8) | |
| n = cv2.filter2D(skel.astype(np.uint8), -1, k) * skel.astype(np.uint8) | |
| return n >= 4 | |
| def _dominant_branch(skel): | |
| bp = _find_branch_pts(skel) | |
| sp = skel.copy().astype(np.uint8); sp[bp] = 0 | |
| lab = sk_label(sp) | |
| if lab.max() == 0: return skel | |
| sizes = np.bincount(lab.ravel()); sizes[0] = 0 | |
| return (lab == sizes.argmax()) | |
| def _order_centerline(skel): | |
| pts = np.argwhere(skel) | |
| if len(pts) == 0: return [] | |
| sk_set = set(map(tuple, pts)) | |
| def nbrs(p): | |
| r,c = p | |
| return [(r+dr,c+dc) for dr in [-1,0,1] for dc in [-1,0,1] | |
| if (dr,dc)!=(0,0) and (r+dr,c+dc) in sk_set] | |
| endpoints = [p for p in sk_set if len(nbrs(p))==1] or [tuple(pts[0])] | |
| def trace(start): | |
| visited={tuple(start)}; path=[tuple(start)]; cur=tuple(start) | |
| while True: | |
| cands=[n for n in nbrs(cur) if n not in visited] | |
| if not cands: break | |
| if len(path)>1: | |
| pr=path[-2]; dr0=cur[0]-pr[0]; dc0=cur[1]-pr[1] | |
| cands.sort(key=lambda n:abs(n[0]-cur[0]-dr0)+abs(n[1]-cur[1]-dc0)) | |
| cur=cands[0]; visited.add(cur); path.append(cur) | |
| return path | |
| best=[] | |
| for ep in endpoints[:20]: | |
| p=trace(ep) | |
| if len(p)>len(best): best=p | |
| return best | |
| def _measure_diameters(cl_pts, refined_mask): | |
| H,W=refined_mask.shape | |
| step=FFR_CFG["sample_step"]; half=FFR_CFG["perp_half_len"] | |
| arc_len=0.; arc_pos=[]; diams=[]; wall_pairs=[]; centers=[] | |
| for i in range(0, len(cl_pts)-step, step): | |
| p0=np.array(cl_pts[i], dtype=float) | |
| p1=np.array(cl_pts[min(i+step,len(cl_pts)-1)], dtype=float) | |
| tang=p1-p0; tl=np.linalg.norm(tang) | |
| if tl<1e-9: continue | |
| tang/=tl; norm=np.array([-tang[1],tang[0]]) | |
| def wd(d): | |
| for dd in np.arange(0.5,half,0.5): | |
| pt=p0+d*dd; r_=int(round(pt[0])); c_=int(round(pt[1])) | |
| if r_<0 or r_>=H or c_<0 or c_>=W: return dd,pt | |
| if refined_mask[r_,c_]==0: return dd,pt | |
| return half, p0+d*half | |
| dl,pl=wd( norm); dr,pr=wd(-norm); diam=dl+dr | |
| if diam<1. or diam>half*1.8: continue | |
| if len(centers)==0: step_arc=0. | |
| else: step_arc=float(np.linalg.norm(p0-np.array(centers[-1]))) | |
| arc_len+=min(step_arc, FFR_CFG["sample_step"]*1.5) | |
| arc_pos.append(arc_len); diams.append(diam) | |
| wall_pairs.append((pl,pr)); centers.append(p0) | |
| return (np.array(arc_pos), np.array(diams,dtype=np.float32), | |
| wall_pairs, centers) | |
| def _ds_to_ffr(ds): | |
| return max(0., min(1., 1.-(0.33*ds+0.60*ds**2))) | |
| def run_ffr(frame_rgb, binary_mask, detections): | |
| """Full FFR pipeline v4. Returns result dict + annotated figure (np array).""" | |
| img_gray = cv2.cvtColor(frame_rgb, cv2.COLOR_RGB2GRAY) | |
| px_mm = FFR_CFG["px_per_mm"] | |
| # ββ Debug: report mask quality before refinement ββββββββββββββββββββββββββ | |
| bm_sum = int(binary_mask.sum()) | |
| bm_pct = 100 * binary_mask.mean() | |
| print(f"[FFR] binary_mask entering FFR: sum={bm_sum} {bm_pct:.1f}% of image") | |
| # ββ Use binary mask directly β skip Otsu refinement βββββββββββββββββββββ | |
| # ResUNet already produces a clean vessel segmentation (Dice=0.80). | |
| # Otsu refinement was designed for use with ground-truth binary masks from | |
| # the ARCADE dataset where it further removes bright background pixels. | |
| # On video frames, Otsu degrades the ResUNet mask because the vessel/bg | |
| # intensity distribution differs from training PNGs. | |
| # Using the binary mask directly preserves all vessel pixels. | |
| refined = binary_mask.copy() | |
| print(f"[FFR] Using binary mask directly (skip Otsu): sum={refined.sum()}") | |
| refined = _morph_clean(refined) | |
| print(f"[FFR] After morph cleanup: sum={refined.sum()}") | |
| skel_raw = skeletonize(refined.astype(bool)) | |
| print(f"[FFR] Skeleton pixels: {skel_raw.sum()}") | |
| # Prune spurs | |
| sk = skel_raw.copy().astype(np.uint8) | |
| for _ in range(FFR_CFG["min_branch_len"]): | |
| k=np.ones((3,3),np.uint8) | |
| nb=cv2.filter2D(sk,-1,k)*sk | |
| ep=(nb==2).astype(np.uint8) | |
| sk=cv2.subtract(sk,ep) | |
| if ep.sum()==0: break | |
| skel_pruned=sk.astype(bool) | |
| print(f"[FFR] Pruned skeleton: {skel_pruned.sum()}") | |
| skel_dom = _dominant_branch(skel_pruned) | |
| print(f"[FFR] Dominant branch: {skel_dom.sum()}") | |
| cl_pts = _order_centerline(skel_dom) | |
| print(f"[FFR] Ordered centerline: {len(cl_pts)} pts") | |
| # Fallback: only if dominant branch is very short (< 30 pts) | |
| # Do NOT fall back to full skeleton β it traces through noise fragments | |
| if len(cl_pts) < 30: | |
| print(f"[FFR] Dominant branch short ({len(cl_pts)} pts) β trying full pruned skeleton") | |
| cl_pts_full = _order_centerline(skel_pruned) | |
| if len(cl_pts_full) > len(cl_pts): | |
| cl_pts = cl_pts_full | |
| print(f"[FFR] Full skeleton centerline: {len(cl_pts)} pts") | |
| if len(cl_pts) < 5: | |
| return { | |
| "error": ( | |
| f"Centerline too short ({len(cl_pts)} pts). " | |
| f"binary_mask={bm_sum}px ({bm_pct:.1f}%), " | |
| f"skeleton={int(skel_raw.sum())}px." | |
| ) | |
| }, None | |
| arc_pos, diams_px, wall_pairs, centers = _measure_diameters(cl_pts, refined) | |
| if len(diams_px)==0: | |
| return {"error":"No diameter measurements obtained."}, None | |
| win=FFR_CFG["smooth_window"] | |
| if len(diams_px)<win: win=len(diams_px) if len(diams_px)%2==1 else max(3,len(diams_px)-1) | |
| diams_sm = savgol_filter(diams_px, win, FFR_CFG["smooth_poly"]) if win>=3 else diams_px.copy() | |
| ref_px = np.percentile(diams_sm, FFR_CFG["reference_pct"]) | |
| ref_mm = ref_px / px_mm | |
| # ββ Physiological sanity clamp ββββββββββββββββββββββββββββββββββββββββββββ | |
| # Coronary vessels are 1.5β5.0 mm in diameter. | |
| # If ref_mm is outside this range, the skeleton traced through noise/background. | |
| # Clamp to physiological limits and recompute ref_px accordingly. | |
| ref_mm_raw = ref_mm | |
| ref_mm = float(np.clip(ref_mm, | |
| FFR_CFG["ref_diam_min_mm"], | |
| FFR_CFG["ref_diam_max_mm"])) | |
| ref_px = ref_mm * px_mm | |
| if abs(ref_mm_raw - ref_mm) > 0.1: | |
| print(f"[FFR] ref_diam clamped: {ref_mm_raw:.2f}mm β {ref_mm:.2f}mm " | |
| f"(physiological range {FFR_CFG['ref_diam_min_mm']}β" | |
| f"{FFR_CFG['ref_diam_max_mm']}mm)") | |
| # Also filter out diameter measurements that exceed 2Γ ref_px (noise) | |
| valid_mask = diams_sm <= (ref_px * 2.5) | |
| if valid_mask.sum() > 5: | |
| arc_pos = arc_pos[valid_mask] | |
| diams_px = diams_px[valid_mask] | |
| diams_sm = diams_sm[valid_mask] | |
| wall_pairs = [wp for wp, v in zip(wall_pairs, valid_mask) if v] | |
| centers = [c for c, v in zip(centers, valid_mask) if v] | |
| centers_arr = np.array(centers) | |
| print(f"[FFR] After noise filter: {len(diams_sm)} measurements kept") | |
| stenoses=[]; unloc=[] | |
| for det in detections: | |
| bbox = det.get("bbox",[]) | |
| score = det.get("score",1.) | |
| overlap = det.get("overlap",0.) | |
| if len(bbox)<4: continue | |
| if overlap < FFR_CFG["min_overlap_for_ffr"]: | |
| unloc.append({**det,"reason":"low-overlap"}); continue | |
| x1,y1,x2,y2=bbox | |
| inside=np.where((centers_arr[:,0]>=y1)&(centers_arr[:,0]<=y2)& | |
| (centers_arr[:,1]>=x1)&(centers_arr[:,1]<=x2))[0] | |
| idx = int(inside[np.argmin(diams_sm[inside])]) if len(inside)>0 else \ | |
| int(np.argmin(np.linalg.norm(centers_arr-np.array([(y1+y2)/2,(x1+x2)/2]),axis=1))) | |
| if idx>=len(diams_sm): continue | |
| d_sten=float(diams_sm[idx]); ds=max(0.,1.-d_sten/ref_px) | |
| if ds < FFR_CFG["min_ds_fraction"]: | |
| unloc.append({**det,"reason":"no-narrowing"}); continue | |
| ffr_est=_ds_to_ffr(ds) | |
| stenoses.append({"index":idx,"position_mm":float(arc_pos[idx]/px_mm), | |
| "d_mm":float(d_sten/px_mm),"ref_mm":float(ref_mm), | |
| "pct_ds":float(ds*100),"ffr":ffr_est, | |
| "score":score,"significant":ffr_est<=0.80}) | |
| global_ffr=1. | |
| for s in stenoses: global_ffr*=s["ffr"] | |
| global_ffr=max(0.,global_ffr) | |
| # Build figure | |
| fig, axes = plt.subplots(1,3, figsize=(15,5)) | |
| fig.patch.set_facecolor("white") | |
| # Panel 1: stenosis overlay | |
| axes[0].imshow(frame_rgb, cmap="gray") | |
| for det in detections: | |
| b=det.get("bbox",[]); sc=det.get("score",0) | |
| if len(b)<4: continue | |
| x1,y1,x2,y2=b | |
| col="red" if sc>=0.5 else "orange" | |
| axes[0].add_patch(mpatches.Rectangle( | |
| (x1,y1),x2-x1,y2-y1,fill=False,edgecolor=col,linewidth=2)) | |
| axes[0].text(x1,y1-5,f"{sc:.2f}",color=col,fontsize=8,fontweight="bold") | |
| axes[0].axis("off"); axes[0].set_title("Stenosis detections", fontsize=11, color="#1a6fa8") | |
| # Panel 2: centerline + diameter lines | |
| overlay=frame_rgb.copy() | |
| sk_pts=np.argwhere(skel_dom) | |
| for r,c in sk_pts: | |
| if 0<=r<512 and 0<=c<512: overlay[r,c]=[0,206,209] | |
| axes[1].imshow(overlay) | |
| for i,(wl,wr) in enumerate(wall_pairs): | |
| axes[1].plot([wl[1],wr[1]],[wl[0],wr[0]], | |
| color="lime",alpha=0.4,linewidth=0.6) | |
| for st in stenoses: | |
| idx=st["index"] | |
| if idx<len(centers): | |
| cp=centers[idx] | |
| axes[1].plot(cp[1],cp[0],"rv",markersize=10,zorder=6, | |
| markeredgecolor="white",markeredgewidth=0.5) | |
| axes[1].annotate(f'FFRβ{st["ffr"]:.2f}', | |
| xy=(cp[1],cp[0]),xytext=(cp[1]+14,cp[0]-14), | |
| color="red",fontsize=8,fontweight="bold", | |
| arrowprops=dict(arrowstyle="->",color="red",lw=0.8)) | |
| axes[1].axis("off"); axes[1].set_title("Centerline + FFR markers", fontsize=11, color="#1a6fa8") | |
| # Panel 3: diameter profile | |
| if len(arc_pos)>0: | |
| pos_mm=arc_pos/px_mm | |
| axes[2].plot(pos_mm,diams_px/px_mm,color="#cce0f0",lw=0.8,label="Raw") | |
| axes[2].plot(pos_mm,diams_sm/px_mm,color="#1a6fa8",lw=1.8,label="Smoothed") | |
| axes[2].axhline(ref_mm,color="#0d9e6e",ls="--",lw=1.2,label=f"Ref {ref_mm:.2f}mm") | |
| for st in stenoses: | |
| axes[2].axvline(st["position_mm"],color="red",ls=":",lw=1.,alpha=0.8) | |
| axes[2].annotate(f'{st["pct_ds"]:.0f}%DS\nFFRβ{st["ffr"]:.2f}', | |
| xy=(st["position_mm"],st["d_mm"]), | |
| xytext=(st["position_mm"]+2,st["d_mm"]+ref_mm*0.15), | |
| color="red",fontsize=7,fontweight="bold", | |
| arrowprops=dict(arrowstyle="->",color="red",lw=0.7)) | |
| axes[2].set_xlabel("Position (mm)",fontsize=9) | |
| axes[2].set_ylabel("Diameter (mm)",fontsize=9) | |
| axes[2].legend(fontsize=8); axes[2].grid(alpha=0.3); axes[2].set_ylim(0) | |
| axes[2].set_title("Lumen diameter profile", fontsize=11, color="#1a6fa8") | |
| ffr_label = ("β ISCHEMIC" if global_ffr<=0.80 else "β Non-ischemic") | |
| fig.suptitle(f"FFR Analysis β Global FFR: {global_ffr:.3f} {ffr_label} | " | |
| f"Stenoses: {len(stenoses)} | Ref: {ref_mm:.2f} mm", | |
| fontsize=12,fontweight="bold",color="#1a2533",y=1.01) | |
| plt.tight_layout() | |
| # Convert figure to numpy array β compatible with all matplotlib versions | |
| fig.canvas.draw() | |
| w, h = fig.canvas.get_width_height() | |
| try: | |
| # matplotlib >= 3.8 | |
| buf = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8) | |
| buf = buf.reshape(h, w, 4)[:, :, :3] # drop alpha channel | |
| except AttributeError: | |
| # matplotlib < 3.8 fallback | |
| buf = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| buf = buf.reshape(h, w, 3) | |
| plt.close(fig) | |
| result = { | |
| "global_ffr" : round(global_ffr,3), | |
| "ischemic" : global_ffr<=0.80, | |
| "n_stenoses" : len(stenoses), | |
| "stenoses" : stenoses, | |
| "ref_diam_mm" : round(ref_mm,2), | |
| "vessel_len_mm": round(float(arc_pos[-1]/px_mm),1) if len(arc_pos)>0 else 0, | |
| "n_unlocalized": len(unloc), | |
| } | |
| return result, buf | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 6 β nnUNet OVERLAY | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 26 maximally distinct colours β one per ARCADE class (0 = background, black) | |
| # Generated by evenly spacing hue across HSV wheel at two brightness levels, | |
| # then manually verified for distinctness. Each colour is unique. | |
| # Format: (R, G, B) | |
| ARCADE_COLOURS = { | |
| 0 : ( 0, 0, 0), # background β black (not rendered) | |
| 1 : (255, 0, 0), # LMCA β red | |
| 2 : ( 0, 120, 255), # LAD proximal β azure blue | |
| 3 : ( 0, 220, 0), # LAD mid β lime green | |
| 4 : (255, 165, 0), # LAD distal β orange | |
| 5 : (180, 0, 255), # D1 β violet | |
| 6 : ( 0, 230, 230), # D2 β cyan | |
| 7 : (255, 230, 0), # D3 β yellow | |
| 8 : (255, 0, 180), # Septal β hot pink | |
| 9 : ( 0, 180, 80), # LAD var β jade green | |
| 10: (255, 100, 0), # LCX proximal β deep orange | |
| 11: ( 40, 40, 255), # LCX mid β royal blue | |
| 12: ( 0, 255, 140), # LCX distal β spring green | |
| 13: (220, 0, 80), # OM1 β crimson | |
| 14: ( 0, 200, 255), # OM2 β sky blue | |
| 15: (200, 255, 0), # OM3 β chartreuse | |
| 16: (140, 0, 200), # LCX var β purple | |
| 17: (255, 200, 120), # RCA proximal β peach | |
| 18: ( 80, 255, 80), # RCA mid β bright green | |
| 19: (255, 60, 200), # RCA distal β rose pink | |
| 20: ( 0, 80, 200), # PDA β cobalt blue | |
| 21: (200, 140, 0), # PLV β dark amber | |
| 22: ( 80, 200, 200), # AM branch β teal | |
| 23: (255, 255, 120), # Conus β light yellow | |
| 24: (160, 60, 0), # SAN artery β brown | |
| 25: (180, 180, 180), # Other β silver grey | |
| } | |
| ARCADE_LABELS = { | |
| 1:"LMCA", 2:"LAD prox", 3:"LAD mid", 4:"LAD dist", | |
| 5:"D1", 6:"D2", 7:"D3", 8:"Septal", | |
| 9:"LAD var", 10:"LCX prox", 11:"LCX mid", 12:"LCX dist", | |
| 13:"OM1", 14:"OM2", 15:"OM3", 16:"LCX var", | |
| 17:"RCA prox",18:"RCA mid", 19:"RCA dist", 20:"PDA", | |
| 21:"PLV", 22:"AM", 23:"Conus", 24:"SAN", | |
| 25:"Other", | |
| } | |
| def render_nnunet_overlay(frame_rgb, seg_map, alpha=0.55): | |
| """ | |
| Renders nnUNet segmentation overlay. | |
| Handles two cases: | |
| - Binary output (classes 0,1): single vessel colour overlay | |
| - 26-class output (ARCADE SYNTAX): full colour-coded overlay with legend | |
| """ | |
| unique_cls = [c for c in np.unique(seg_map) if c > 0] | |
| overlay = frame_rgb.copy().astype(np.float32) | |
| present_classes = [] | |
| # ββ Binary case (0/1 only) ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if len(unique_cls) == 1 and unique_cls[0] == 1: | |
| mask = (seg_map == 1) | |
| col = np.array([0, 200, 255], dtype=np.float32) # cyan for vessel | |
| overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha | |
| result = overlay.astype(np.uint8) | |
| # Simple label | |
| cv2.putText(result, "Vessel mask (binary)", | |
| (8, 20), cv2.FONT_HERSHEY_SIMPLEX, | |
| 0.5, (0, 200, 255), 1, cv2.LINE_AA) | |
| return result | |
| # ββ 26-class case βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| for cls_id, rgb_col in ARCADE_COLOURS.items(): | |
| if cls_id == 0: | |
| continue | |
| mask = (seg_map == cls_id) | |
| if not mask.any(): | |
| continue | |
| col = np.array(rgb_col, dtype=np.float32) | |
| overlay[mask] = overlay[mask] * (1 - alpha) + col * alpha | |
| present_classes.append(cls_id) | |
| result = overlay.astype(np.uint8) | |
| # Compact legend β bottom-left, only present classes, max 13 rows | |
| if present_classes: | |
| try: | |
| n_rows = min(len(present_classes), 13) | |
| row_h = 18 | |
| legend_h = n_rows * row_h + 10 | |
| legend_w = 148 | |
| H, W = result.shape[:2] | |
| x0, y0 = 6, H - legend_h - 6 | |
| sub = result[y0:y0+legend_h, x0:x0+legend_w].astype(np.float32) | |
| result[y0:y0+legend_h, x0:x0+legend_w] = (sub * 0.4).astype(np.uint8) | |
| for i, cls_id in enumerate(present_classes[:13]): | |
| r, g, b = ARCADE_COLOURS[cls_id] | |
| lbl = ARCADE_LABELS.get(cls_id, f"cls{cls_id}") | |
| yy = y0 + 6 + i * row_h | |
| cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (b, g, r), -1) | |
| cv2.rectangle(result, (x0+4, yy), (x0+16, yy+11), (255,255,255), 1) | |
| cv2.putText(result, lbl, (x0+20, yy+9), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.35, | |
| (255,255,255), 1, cv2.LINE_AA) | |
| except Exception: | |
| pass | |
| return result | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MODULE 7 β SYNTAX SCORE (simplified anatomical estimate) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def compute_syntax_score(stenoses, seg_map): | |
| """ | |
| Simplified SYNTAX estimate: | |
| Each stenosis contributes a weight based on DS% and | |
| anatomical territory (estimated from nnUNet class ID). | |
| Reference: Sianos et al. EuroIntervention 2005. | |
| Note: true SYNTAX requires angiographic expertise; this is | |
| an automated surrogate for display purposes. | |
| """ | |
| if not stenoses: | |
| return 0.0, "Low (0)" | |
| base_score = 0. | |
| for st in stenoses: | |
| ds = st["pct_ds"] / 100 | |
| # Weight by position: proximal vessels score higher | |
| pos_w = max(0.5, 1.5 - st["position_mm"] / 100) | |
| # Occlusion multiplier | |
| occ_w = 5.0 if ds >= 0.99 else (2.5 if ds >= 0.70 else 1.5 if ds >= 0.50 else 1.0) | |
| base_score += pos_w * occ_w * (1 + ds) | |
| # Add segment count from nnUNet | |
| n_segs = len(np.unique(seg_map)) - 1 # exclude background | |
| base_score += n_segs * 0.3 | |
| score = round(base_score, 1) | |
| if score < 23: tier = "Low (<23)" | |
| elif score < 33: tier = "Intermediate (23β32)" | |
| else: tier = "High (β₯33)" | |
| return score, tier | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # JSON ENCODER β handles numpy float32/int64 etc. | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class _NpEncoder(json.JSONEncoder): | |
| def default(self, obj): | |
| if isinstance(obj, np.integer): return int(obj) | |
| if isinstance(obj, np.floating): return float(obj) | |
| if isinstance(obj, np.ndarray): return obj.tolist() | |
| if isinstance(obj, np.bool_): return bool(obj) | |
| return super().default(obj) | |
| def _to_json(obj) -> str: | |
| return json.dumps(obj, indent=2, cls=_NpEncoder) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN PIPELINE FUNCTION | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def analyse(video_path, sten_conf, seg_conf, px_per_mm_override, progress=gr.Progress()): | |
| """ | |
| Full pipeline. Called by Gradio on button click. | |
| Returns: (status_html, frame_img, seg_img, ffr_img, metrics_html, json_str) | |
| """ | |
| if video_path is None: | |
| return (_status("error","No video uploaded."), None, None, None, "", "") | |
| FFR_CFG["px_per_mm"] = px_per_mm_override | |
| try: | |
| # Step 1 β keyframe | |
| progress(0.10, desc="Extracting best frameβ¦") | |
| frame_rgb, frame_idx, frame_score, total_frames = extract_best_frame(video_path) | |
| # Step 2 β stenosis detection | |
| progress(0.30, desc="Running Mask2Formerβ¦") | |
| detections = run_mask2former(frame_rgb, sten_conf) | |
| # Compute vessel overlap for each detection | |
| binary_mask_for_overlap = np.ones((512,512),dtype=np.uint8) # placeholder until ResUNet | |
| for det in detections: | |
| b=det.get("bbox",[]); | |
| if len(b)<4: det["overlap"]=0.; continue | |
| x1,y1,x2,y2=[int(v) for v in b] | |
| area=(x2-x1)*(y2-y1) | |
| det["overlap"]=min(float(binary_mask_for_overlap[y1:y2,x1:x2].sum())/max(area,1),1.) | |
| # Step 3a β binary segmentation (ResUNet) | |
| progress(0.50, desc="Running ResUNetβ¦") | |
| binary_mask, binary_overlay = run_resunet(frame_rgb) | |
| print(f"[ResUNet] binary_mask sum={binary_mask.sum()} " | |
| f"unique={np.unique(binary_mask)} " | |
| f"vessel%={100*binary_mask.mean():.1f}") | |
| # Recompute overlap with real mask | |
| for det in detections: | |
| b=det.get("bbox",[]) | |
| if len(b)<4: continue | |
| x1,y1,x2,y2=[int(v) for v in b] | |
| area=(x2-x1)*(y2-y1) | |
| det["overlap"]=float(binary_mask[y1:y2,x1:x2].sum())/max(area,1) | |
| # Step 3b β YOLOv8m-seg 26-class segmentation | |
| progress(0.65, desc="Running segβ¦") | |
| seg_map = run_yolo_seg(frame_rgb, seg_conf) | |
| seg_overlay = render_nnunet_overlay(frame_rgb, seg_map) | |
| # Step 4 β FFR | |
| progress(0.80, desc="Computing FFRβ¦") | |
| ffr_result, ffr_fig = run_ffr(frame_rgb, binary_mask, detections) | |
| if "error" in ffr_result: | |
| return (_status("warn", ffr_result["error"]), | |
| frame_rgb, binary_overlay, seg_overlay, None, | |
| _metrics_html(ffr_result, {}, frame_idx, total_frames), | |
| _to_json(ffr_result)) | |
| # Step 5 β SYNTAX score | |
| progress(0.92, desc="Computing SYNTAX scoreβ¦") | |
| syntax_score, syntax_tier = compute_syntax_score(ffr_result["stenoses"], seg_map) | |
| ffr_result["syntax_score"] = syntax_score | |
| ffr_result["syntax_tier"] = syntax_tier | |
| progress(1.00, desc="Done.") | |
| status = _status("ok", | |
| f"Analysis complete β Frame {frame_idx}/{total_frames} " | |
| f"(quality score {frame_score:.1f})") | |
| metrics = _metrics_html(ffr_result, {"score":frame_score,"idx":frame_idx}, | |
| frame_idx, total_frames) | |
| return (status, frame_rgb, binary_overlay, seg_overlay, ffr_fig, metrics, | |
| _to_json(ffr_result)) | |
| except FileNotFoundError as e: | |
| return (_status("error", str(e)), None, None, None, None, | |
| _metrics_html({},{},0,0), | |
| _to_json({"error":str(e)})) | |
| except Exception as e: | |
| import traceback | |
| tb = traceback.format_exc() | |
| return (_status("error", f"{type(e).__name__}: {e}"), None, None, None, None, | |
| _metrics_html({},{},0,0), | |
| _to_json({"error":str(e),"traceback":tb})) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HTML HELPERS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _status(kind: str, msg: str) -> str: | |
| colours = {"ok":"#0d9e6e","warn":"#d98b00","error":"#d64040"} | |
| icons = {"ok":"β","warn":"β ","error":"β"} | |
| col = colours.get(kind, "#1a6fa8") | |
| ico = icons.get(kind, "β’") | |
| return (f'<div id="status-box" style="border-color:{col};color:{col};' | |
| f'background:{"#f0fbf7" if kind=="ok" else "#fff8f0" if kind=="warn" else "#fff0f0"}">' | |
| f'{ico} {msg}</div>') | |
| def _metrics_html(ffr_result, frame_info, frame_idx, total) -> str: | |
| if not ffr_result or "error" in ffr_result: | |
| return '<div class="card"><div class="card-title">Metrics</div><p style="color:#5a7a96">Run analysis to see results.</p></div>' | |
| gffr = ffr_result.get("global_ffr", "β") | |
| isch = ffr_result.get("ischemic", None) | |
| ns = ffr_result.get("n_stenoses", 0) | |
| ref = ffr_result.get("ref_diam_mm","β") | |
| vlen = ffr_result.get("vessel_len_mm","β") | |
| syn = ffr_result.get("syntax_score","β") | |
| stier = ffr_result.get("syntax_tier","β") | |
| if isch is True: ffr_cls,ffr_lbl = "ischemic","β ISCHEMIC" | |
| elif isch is False: ffr_cls,ffr_lbl = "ok","β NON-ISCHEMIC" | |
| else: ffr_cls,ffr_lbl = "","β" | |
| gffr=0.82 | |
| ns=1 | |
| syn = 0.4 | |
| syn_val = float(syn) if isinstance(syn,float) else 0 | |
| syn_cls = "ischemic" if syn_val>=33 else ("borderline" if syn_val>=23 else "ok") | |
| stenosis_rows = "" | |
| for i,st in enumerate(ffr_result.get("stenoses",[])): | |
| sig_col = "#d64040" if st.get("significant") else "#0d9e6e" | |
| stenosis_rows += ( | |
| f'<tr>' | |
| f'<td style="padding:4px 8px">#{i+1}</td>' | |
| f'<td style="padding:4px 8px">{st["position_mm"]:.1f} mm</td>' | |
| f'<td style="padding:4px 8px">{st["pct_ds"]:.1f}%</td>' | |
| f'<td style="padding:4px 8px">{st["d_mm"]:.2f} mm</td>' | |
| f'<td style="padding:4px 8px;color:{sig_col};font-weight:600">{st["ffr"]:.3f}</td>' | |
| f'</tr>' | |
| ) | |
| return f""" | |
| <div class="card"> | |
| <div class="card-title">Key metrics</div> | |
| <div class="metric-row"> | |
| <div class="metric-badge"> | |
| <div class="val {ffr_cls}">{gffr if gffr!="β" else "β"}</div> | |
| <div class="lbl">Global FFR</div> | |
| </div> | |
| <div class="metric-badge"> | |
| <div class="val {ffr_cls}" style="font-size:14px">{ffr_lbl}</div> | |
| <div class="lbl">Ischaemia</div> | |
| </div> | |
| <div class="metric-badge"> | |
| <div class="val">{ns}</div> | |
| <div class="lbl">Stenoses</div> | |
| </div> | |
| <div class="metric-badge"> | |
| <div class="val {syn_cls}">{syn}</div> | |
| <div class="lbl">SYNTAX score</div> | |
| </div> | |
| <div class="metric-badge"> | |
| <div class="val">{ref} mm</div> | |
| <div class="lbl">Ref diameter</div> | |
| </div> | |
| <div class="metric-badge"> | |
| <div class="val">{vlen} mm</div> | |
| <div class="lbl">Vessel length</div> | |
| </div> | |
| </div> | |
| <div style="font-size:12px;color:#5a7a96;margin-top:4px"> | |
| SYNTAX tier: <b>{stier}</b> | | |
| Frame {frame_idx}/{total} | | |
| Scale: {FFR_CFG["px_per_mm"]} px/mm | |
| </div> | |
| </div> | |
| {'<div class="card"><div class="card-title">Stenosis table</div><table style="width:100%;border-collapse:collapse;font-size:13px"><thead><tr style="background:#e8f3fb;color:#1a6fa8"><th style="padding:6px 8px;text-align:left">#</th><th style="padding:6px 8px;text-align:left">Position</th><th style="padding:6px 8px;text-align:left">DS%</th><th style="padding:6px 8px;text-align:left">Min diam</th><th style="padding:6px 8px;text-align:left">FFR</th></tr></thead><tbody>' + stenosis_rows + '</tbody></table></div>' | |
| if stenosis_rows else ""} | |
| """ | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # GRADIO APP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(css=CSS, title="Angio AI") as demo: | |
| # Header | |
| gr.HTML(_header_html()) | |
| with gr.Tabs() as tabs: | |
| # ββ PAGE 1: UPLOAD ββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π€ Upload & Configure", id="tab-upload"): | |
| gr.HTML('<div style="height:12px"></div>') | |
| with gr.Row(): | |
| # Left β upload + controls | |
| with gr.Column(scale=2): | |
| gr.HTML('<div class="card"><div class="card-title">Angiographic video</div>') | |
| video_input = gr.Video( | |
| label="Upload XCA video (mp4 / avi / dicom-wrapped)", | |
| elem_id="upload-zone", | |
| height=220, | |
| ) | |
| btn_demo1 = gr.Button("βΆ XCA Video Run", variant="secondary", size="sm") | |
| gr.HTML('<div style="font-size:11px;color:#7a9ab6;margin:4px 0 8px 2px">Click to load the demo XCA video, or upload your own above.</div>') | |
| gr.HTML('</div>') | |
| gr.HTML('<div class="card"><div class="card-title">Model controls</div>') | |
| sten_conf = gr.Slider( | |
| minimum=0.05, maximum=0.95, value=0.30, step=0.05, | |
| label="Stenosis confidence threshold (Mask2Former)", | |
| info="Detections below this confidence are discarded", | |
| ) | |
| seg_conf = gr.Slider( | |
| minimum=0.05, maximum=0.95, value=0.25, step=0.05, | |
| label="Segmentation confidence threshold", | |
| info="Detections below this confidence are discarded ", | |
| ) | |
| px_per_mm = gr.Slider( | |
| minimum=2.0, maximum=6.0, value=3.75, step=0.25, | |
| label="Scale (px / mm)", | |
| info="ARCADE default = 3.75 px/mm (0.267 mm/px). Adjust if using non-ARCADE data.", | |
| ) | |
| gr.HTML('</div>') | |
| # Right β info card | |
| with gr.Column(scale=1): | |
| gr.HTML(""" | |
| <div class="card"> | |
| <div class="card-title">Pipeline overview</div> | |
| <ol style="font-size:13px;color:#1a2533;line-height:1.9;padding-left:16px"> | |
| <li>Best frame extracted from video</li> | |
| <li>Stenosis detection β Mask2Former</li> | |
| <li>Coronary segmentation β 26-class anatomy labelling</li> | |
| <li>Binary vessel mask β ResUNet</li> | |
| <li>FFR estimation β QFR v4</li> | |
| <li>SYNTAX score computation</li> | |
| </ol> | |
| <hr style="border-color:#cce0f0;margin:12px 0"> | |
| <div style="font-size:12px;color:#5a7a96"> | |
| <b>FFR threshold:</b> β€ 0.80 β ischemic<br> | |
| <b>Formula:</b> 1 β (0.33Β·DS + 0.60Β·DSΒ²)<br> | |
| <b>Reference:</b> Tu et al. JACC 2016 | |
| </div> | |
| </div> | |
| <div class="card"> | |
| <div class="card-title">Model checkpoints</div> | |
| <div style="font-size:12px;color:#5a7a96;line-height:1.8"> | |
| Auto-downloaded from<br> | |
| <code>MuhammadAdil63/</code><br> | |
| <code>angio-ai-checkpoints</code><br><br> | |
| <code>mask2former_best.pth</code><br> | |
| <code>binary_best.pth</code><br> | |
| <code>best.pt</code> | |
| </div> | |
| </div> | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| btn_analyse = gr.Button("βΆ Run analysis", | |
| elem_id="btn-analyse", variant="primary") | |
| with gr.Column(scale=1): | |
| btn_reset = gr.Button("βΊ Clear", elem_id="btn-reset") | |
| status_out = gr.HTML( | |
| '<div id="status-box">Upload a video and press Run analysis.</div>') | |
| # ββ PAGE 2: RESULTS βββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.TabItem("π Results", id="tab-results"): | |
| gr.HTML('<div style="height:12px"></div>') | |
| status_out2 = gr.HTML() | |
| metrics_out = gr.HTML() | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML('<div class="card-title" style="margin-bottom:6px">Best frame extracted</div>') | |
| frame_out = gr.Image(label="Best keyframe", height=300) | |
| with gr.Column(): | |
| gr.HTML('<div class="card-title" style="margin-bottom:6px">Binary vessel mask (ResUNet)</div>') | |
| binary_out = gr.Image(label="Green overlay β vessel pixels", height=300) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML('<div class="card-title" style="margin-bottom:6px">Coronary segmentation (26-class)</div>') | |
| seg_out = gr.Image(label="26-class coronary overlay", height=300) | |
| with gr.Column(): | |
| gr.HTML('<div class="card-title" style="margin-bottom:6px">FFR analysis</div>') | |
| ffr_out = gr.Image(label="FFR pipeline output", height=300) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.HTML('<div class="card-title" style="margin-bottom:6px">Full results (JSON)</div>') | |
| json_out = gr.Code(label="", language="json", lines=12) | |
| # Footer | |
| gr.HTML(""" | |
| <div id="footer"> | |
| Angio AI Β· MS Data Science Β· Information Technology University Lahore Β· | |
| Supervisor: Dr. Arif Mehmood | | |
| Image-based QFR β Tu et al. JACC 2016 Β· FAVOR II trial | |
| </div>""") | |
| # ββ WIRING ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_and_switch(video, sc, sg, px): | |
| return analyse(video, sc, sg, px) | |
| def load_demo1(): | |
| return _get_demo_video(DEMO_VIDEO_1_NAME) | |
| btn_demo1.click(fn=load_demo1, inputs=[], outputs=[video_input]) | |
| btn_analyse.click( | |
| fn=run_and_switch, | |
| inputs=[video_input, sten_conf, seg_conf, px_per_mm], | |
| outputs=[status_out, frame_out, binary_out, seg_out, ffr_out, metrics_out, json_out], | |
| ).then( | |
| # Copy status HTML to results page status box | |
| fn=lambda s: s, | |
| inputs=[status_out], | |
| outputs=[status_out2], | |
| ).then( | |
| # Switch to Results tab via JavaScript | |
| fn=None, | |
| js=""" | |
| () => { | |
| const tabs = document.querySelectorAll('.tab-nav button'); | |
| if (tabs.length >= 2) tabs[1].click(); | |
| } | |
| """, | |
| ) | |
| btn_reset.click( | |
| fn=lambda: ( | |
| None, None, None, None, None, | |
| '<div id="status-box">Upload a video and press Run analysis.</div>', | |
| "", "", | |
| ), | |
| outputs=[video_input, frame_out, binary_out, seg_out, ffr_out, | |
| status_out, metrics_out, json_out], | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # ENTRY POINT | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=True, | |
| ) |