Instructions to use Aditya2162/ivus-segmentation with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use Aditya2162/ivus-segmentation with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://Aditya2162/ivus-segmentation") - Notebooks
- Google Colab
- Kaggle
| #!/usr/bin/env python3 | |
| """Shared utilities for lumen class discovery and fine-tuning scripts.""" | |
| from __future__ import annotations | |
| import json | |
| import sys | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Iterable | |
| import numpy as np | |
| from matplotlib.path import Path as MplPath | |
| # common.py lives at scripts/finetune/shared/common.py, so repo root is parents[3]. | |
| REPO_ROOT = Path(__file__).resolve().parents[3] | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from deepivus.io.dicom import read_dicom | |
| from deepivus.processing.preprocessing import apply_center_black_circle | |
| class LumenAnnotation: | |
| """Single manually-annotated lumen polygon for one frame.""" | |
| bank_path: Path | |
| group: str | |
| dicom_path: Path | |
| frame_idx: int | |
| bifurcation: bool | |
| lumen_x: np.ndarray | |
| lumen_y: np.ndarray | |
| def sample_id(self) -> str: | |
| rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT) | |
| return f"{rel_bank.as_posix()}::{self.frame_idx}" | |
| class BifurcationAnnotation: | |
| """Single frame-level bifurcation annotation.""" | |
| bank_path: Path | |
| group: str | |
| dicom_path: Path | |
| frame_idx: int | |
| bifurcation: bool | |
| def sample_id(self) -> str: | |
| rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT) | |
| return f"{rel_bank.as_posix()}::{self.frame_idx}" | |
| def _read_jsonl(path: Path) -> list[dict]: | |
| rows: list[dict] = [] | |
| with path.open("r", encoding="utf-8") as fp: | |
| for raw in fp: | |
| line = raw.strip() | |
| if not line: | |
| continue | |
| rows.append(json.loads(line)) | |
| return rows | |
| def _resolve_dicom_path(bank_path: Path, meta: dict) -> Path | None: | |
| group = str(meta.get("group", bank_path.parent.name)) | |
| dicom_raw = meta.get("dicom_path") | |
| if dicom_raw: | |
| dicom_path = Path(str(dicom_raw)) | |
| if not dicom_path.is_absolute(): | |
| dicom_path = (REPO_ROOT / dicom_path).resolve() | |
| if dicom_path.exists(): | |
| return dicom_path | |
| fallback = REPO_ROOT / "data" / group / f"{bank_path.stem}.dcm" | |
| if fallback.exists(): | |
| return fallback.resolve() | |
| return None | |
| def load_lumen_annotations(frame_bank_root: Path) -> list[LumenAnnotation]: | |
| """Load all frame-bank lumen annotations with valid polygons.""" | |
| annotations: list[LumenAnnotation] = [] | |
| bank_files = sorted(frame_bank_root.glob("*/*.jsonl")) | |
| for bank_path in bank_files: | |
| rows = _read_jsonl(bank_path) | |
| if not rows: | |
| continue | |
| meta = rows[0] | |
| if meta.get("record_type") != "meta": | |
| continue | |
| group = str(meta.get("group", bank_path.parent.name)) | |
| dicom_path = _resolve_dicom_path(bank_path, meta) | |
| if dicom_path is None: | |
| continue | |
| for rec in rows: | |
| if rec.get("record_type") != "frame": | |
| continue | |
| lumen = rec.get("lumen", {}) | |
| xs = np.asarray(lumen.get("x", []), dtype=np.float32) | |
| ys = np.asarray(lumen.get("y", []), dtype=np.float32) | |
| if xs.size < 3 or ys.size < 3 or xs.size != ys.size: | |
| continue | |
| if not np.all(np.isfinite(xs)) or not np.all(np.isfinite(ys)): | |
| continue | |
| annotations.append( | |
| LumenAnnotation( | |
| bank_path=bank_path, | |
| group=group, | |
| dicom_path=dicom_path, | |
| frame_idx=int(rec["frame"]), | |
| bifurcation=bool(rec.get("bifurcation", False)), | |
| lumen_x=xs, | |
| lumen_y=ys, | |
| ) | |
| ) | |
| return annotations | |
| def load_bifurcation_annotations(frame_bank_root: Path) -> list[BifurcationAnnotation]: | |
| """Load all frame-bank bifurcation labels, independent of lumen polygon presence.""" | |
| annotations: list[BifurcationAnnotation] = [] | |
| bank_files = sorted(frame_bank_root.glob("*/*.jsonl")) | |
| for bank_path in bank_files: | |
| rows = _read_jsonl(bank_path) | |
| if not rows: | |
| continue | |
| meta = rows[0] | |
| if meta.get("record_type") != "meta": | |
| continue | |
| group = str(meta.get("group", bank_path.parent.name)) | |
| dicom_path = _resolve_dicom_path(bank_path, meta) | |
| if dicom_path is None: | |
| continue | |
| for rec in rows: | |
| if rec.get("record_type") != "frame": | |
| continue | |
| bif = rec.get("bifurcation") | |
| if bif is None: | |
| continue | |
| annotations.append( | |
| BifurcationAnnotation( | |
| bank_path=bank_path, | |
| group=group, | |
| dicom_path=dicom_path, | |
| frame_idx=int(rec["frame"]), | |
| bifurcation=bool(bif), | |
| ) | |
| ) | |
| return annotations | |
| def polygon_to_mask(x_coords: np.ndarray, y_coords: np.ndarray, image_shape: tuple[int, int]) -> np.ndarray: | |
| """Rasterize a polygon in image coordinates to a binary mask.""" | |
| if x_coords.size < 3 or y_coords.size < 3: | |
| return np.zeros(image_shape, dtype=bool) | |
| vertices = np.column_stack((x_coords.astype(np.float32), y_coords.astype(np.float32))) | |
| polygon = MplPath(vertices, closed=True) | |
| h, w = image_shape | |
| yy, xx = np.mgrid[0:h, 0:w] | |
| points = np.column_stack((xx.ravel(), yy.ravel())) | |
| mask = polygon.contains_points(points, radius=0.5).reshape((h, w)) | |
| return mask | |
| def group_by_dicom(annotations: Iterable[LumenAnnotation]) -> dict[Path, list[LumenAnnotation]]: | |
| grouped: dict[Path, list[LumenAnnotation]] = {} | |
| for ann in annotations: | |
| grouped.setdefault(ann.dicom_path, []).append(ann) | |
| return grouped | |
| def load_preprocessed_stack(dicom_path: Path, diameter: int) -> np.ndarray: | |
| """Load one DICOM stack and apply the same preprocessing as pipeline inference.""" | |
| _, images = read_dicom(str(dicom_path)) | |
| return apply_center_black_circle(images, diameter=diameter) | |
| def build_images_and_masks( | |
| annotations: list[LumenAnnotation], | |
| diameter: int, | |
| ) -> tuple[np.ndarray, np.ndarray, list[LumenAnnotation]]: | |
| """Materialize image and mask arrays for a list of annotations.""" | |
| images_out: list[np.ndarray] = [] | |
| masks_out: list[np.ndarray] = [] | |
| kept: list[LumenAnnotation] = [] | |
| grouped = group_by_dicom(annotations) | |
| for dicom_path, ann_list in grouped.items(): | |
| stack = load_preprocessed_stack(dicom_path, diameter=diameter) | |
| h, w = int(stack.shape[1]), int(stack.shape[2]) | |
| for ann in ann_list: | |
| if ann.frame_idx < 0 or ann.frame_idx >= int(stack.shape[0]): | |
| continue | |
| mask = polygon_to_mask(ann.lumen_x, ann.lumen_y, (h, w)) | |
| if not np.any(mask): | |
| continue | |
| images_out.append(stack[ann.frame_idx]) | |
| masks_out.append(mask.astype(np.float32)) | |
| kept.append(ann) | |
| if not images_out: | |
| raise RuntimeError("No usable annotations found after loading DICOM frames and rasterizing masks.") | |
| return np.stack(images_out, axis=0), np.stack(masks_out, axis=0), kept | |
| def stratified_frame_split( | |
| annotations: list[LumenAnnotation], | |
| train_fraction: float, | |
| val_fraction: float, | |
| test_fraction: float, | |
| seed: int, | |
| ) -> dict[str, list[int]]: | |
| """Frame-level stratified split by bifurcation label.""" | |
| total = train_fraction + val_fraction + test_fraction | |
| if total <= 0: | |
| raise ValueError("Split fractions must sum to a positive value.") | |
| train_fraction /= total | |
| val_fraction /= total | |
| test_fraction /= total | |
| labels = np.asarray([1 if ann.bifurcation else 0 for ann in annotations], dtype=np.int32) | |
| indices = np.arange(len(annotations), dtype=np.int64) | |
| rng = np.random.default_rng(seed) | |
| train_ids: list[int] = [] | |
| val_ids: list[int] = [] | |
| test_ids: list[int] = [] | |
| for label in (0, 1): | |
| cls_idx = indices[labels == label] | |
| rng.shuffle(cls_idx) | |
| n = len(cls_idx) | |
| if n == 0: | |
| continue | |
| n_train = int(round(n * train_fraction)) | |
| n_val = int(round(n * val_fraction)) | |
| n_test = n - n_train - n_val | |
| # Keep at least one sample per split when class has enough samples. | |
| if n >= 3: | |
| n_train = max(1, n_train) | |
| n_val = max(1, n_val) | |
| n_test = max(1, n_test) | |
| overflow = n_train + n_val + n_test - n | |
| while overflow > 0: | |
| if n_train >= n_val and n_train >= n_test and n_train > 1: | |
| n_train -= 1 | |
| elif n_val >= n_test and n_val > 1: | |
| n_val -= 1 | |
| elif n_test > 1: | |
| n_test -= 1 | |
| overflow -= 1 | |
| train_ids.extend(cls_idx[:n_train].tolist()) | |
| val_ids.extend(cls_idx[n_train : n_train + n_val].tolist()) | |
| test_ids.extend(cls_idx[n_train + n_val : n_train + n_val + n_test].tolist()) | |
| rng.shuffle(train_ids) | |
| rng.shuffle(val_ids) | |
| rng.shuffle(test_ids) | |
| return {"train": train_ids, "val": val_ids, "test": test_ids} | |
| def split_summary(annotations: list[LumenAnnotation], split: dict[str, list[int]]) -> dict[str, dict[str, int]]: | |
| out: dict[str, dict[str, int]] = {} | |
| for part, ids in split.items(): | |
| bif = sum(1 for i in ids if annotations[i].bifurcation) | |
| non_bif = len(ids) - bif | |
| out[part] = { | |
| "count": len(ids), | |
| "bifurcation_true": bif, | |
| "bifurcation_false": non_bif, | |
| } | |
| return out | |
| def save_split_json(path: Path, annotations: list[LumenAnnotation], split: dict[str, list[int]], seed: int) -> None: | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "seed": int(seed), | |
| "summary": split_summary(annotations, split), | |
| "splits": { | |
| part: [annotations[i].sample_id for i in ids] for part, ids in split.items() | |
| }, | |
| } | |
| with path.open("w", encoding="utf-8") as fp: | |
| json.dump(payload, fp, indent=2) | |
| def load_split_ids(path: Path) -> dict[str, set[str]]: | |
| with path.open("r", encoding="utf-8") as fp: | |
| payload = json.load(fp) | |
| splits = payload.get("splits", {}) | |
| out = {} | |
| for part in ("train", "val", "test"): | |
| out[part] = set(str(v) for v in splits.get(part, [])) | |
| return out | |