#!/usr/bin/env python3 """Shared utilities for lumen class discovery and fine-tuning scripts.""" from __future__ import annotations import json import sys from dataclasses import dataclass from pathlib import Path from typing import Iterable import numpy as np from matplotlib.path import Path as MplPath # common.py lives at scripts/finetune/shared/common.py, so repo root is parents[3]. REPO_ROOT = Path(__file__).resolve().parents[3] if str(REPO_ROOT) not in sys.path: sys.path.insert(0, str(REPO_ROOT)) from deepivus.io.dicom import read_dicom from deepivus.processing.preprocessing import apply_center_black_circle @dataclass(frozen=True) class LumenAnnotation: """Single manually-annotated lumen polygon for one frame.""" bank_path: Path group: str dicom_path: Path frame_idx: int bifurcation: bool lumen_x: np.ndarray lumen_y: np.ndarray @property def sample_id(self) -> str: rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT) return f"{rel_bank.as_posix()}::{self.frame_idx}" @dataclass(frozen=True) class BifurcationAnnotation: """Single frame-level bifurcation annotation.""" bank_path: Path group: str dicom_path: Path frame_idx: int bifurcation: bool @property def sample_id(self) -> str: rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT) return f"{rel_bank.as_posix()}::{self.frame_idx}" def _read_jsonl(path: Path) -> list[dict]: rows: list[dict] = [] with path.open("r", encoding="utf-8") as fp: for raw in fp: line = raw.strip() if not line: continue rows.append(json.loads(line)) return rows def _resolve_dicom_path(bank_path: Path, meta: dict) -> Path | None: group = str(meta.get("group", bank_path.parent.name)) dicom_raw = meta.get("dicom_path") if dicom_raw: dicom_path = Path(str(dicom_raw)) if not dicom_path.is_absolute(): dicom_path = (REPO_ROOT / dicom_path).resolve() if dicom_path.exists(): return dicom_path fallback = REPO_ROOT / "data" / group / f"{bank_path.stem}.dcm" if fallback.exists(): return fallback.resolve() return None def load_lumen_annotations(frame_bank_root: Path) -> list[LumenAnnotation]: """Load all frame-bank lumen annotations with valid polygons.""" annotations: list[LumenAnnotation] = [] bank_files = sorted(frame_bank_root.glob("*/*.jsonl")) for bank_path in bank_files: rows = _read_jsonl(bank_path) if not rows: continue meta = rows[0] if meta.get("record_type") != "meta": continue group = str(meta.get("group", bank_path.parent.name)) dicom_path = _resolve_dicom_path(bank_path, meta) if dicom_path is None: continue for rec in rows: if rec.get("record_type") != "frame": continue lumen = rec.get("lumen", {}) xs = np.asarray(lumen.get("x", []), dtype=np.float32) ys = np.asarray(lumen.get("y", []), dtype=np.float32) if xs.size < 3 or ys.size < 3 or xs.size != ys.size: continue if not np.all(np.isfinite(xs)) or not np.all(np.isfinite(ys)): continue annotations.append( LumenAnnotation( bank_path=bank_path, group=group, dicom_path=dicom_path, frame_idx=int(rec["frame"]), bifurcation=bool(rec.get("bifurcation", False)), lumen_x=xs, lumen_y=ys, ) ) return annotations def load_bifurcation_annotations(frame_bank_root: Path) -> list[BifurcationAnnotation]: """Load all frame-bank bifurcation labels, independent of lumen polygon presence.""" annotations: list[BifurcationAnnotation] = [] bank_files = sorted(frame_bank_root.glob("*/*.jsonl")) for bank_path in bank_files: rows = _read_jsonl(bank_path) if not rows: continue meta = rows[0] if meta.get("record_type") != "meta": continue group = str(meta.get("group", bank_path.parent.name)) dicom_path = _resolve_dicom_path(bank_path, meta) if dicom_path is None: continue for rec in rows: if rec.get("record_type") != "frame": continue bif = rec.get("bifurcation") if bif is None: continue annotations.append( BifurcationAnnotation( bank_path=bank_path, group=group, dicom_path=dicom_path, frame_idx=int(rec["frame"]), bifurcation=bool(bif), ) ) return annotations def polygon_to_mask(x_coords: np.ndarray, y_coords: np.ndarray, image_shape: tuple[int, int]) -> np.ndarray: """Rasterize a polygon in image coordinates to a binary mask.""" if x_coords.size < 3 or y_coords.size < 3: return np.zeros(image_shape, dtype=bool) vertices = np.column_stack((x_coords.astype(np.float32), y_coords.astype(np.float32))) polygon = MplPath(vertices, closed=True) h, w = image_shape yy, xx = np.mgrid[0:h, 0:w] points = np.column_stack((xx.ravel(), yy.ravel())) mask = polygon.contains_points(points, radius=0.5).reshape((h, w)) return mask def group_by_dicom(annotations: Iterable[LumenAnnotation]) -> dict[Path, list[LumenAnnotation]]: grouped: dict[Path, list[LumenAnnotation]] = {} for ann in annotations: grouped.setdefault(ann.dicom_path, []).append(ann) return grouped def load_preprocessed_stack(dicom_path: Path, diameter: int) -> np.ndarray: """Load one DICOM stack and apply the same preprocessing as pipeline inference.""" _, images = read_dicom(str(dicom_path)) return apply_center_black_circle(images, diameter=diameter) def build_images_and_masks( annotations: list[LumenAnnotation], diameter: int, ) -> tuple[np.ndarray, np.ndarray, list[LumenAnnotation]]: """Materialize image and mask arrays for a list of annotations.""" images_out: list[np.ndarray] = [] masks_out: list[np.ndarray] = [] kept: list[LumenAnnotation] = [] grouped = group_by_dicom(annotations) for dicom_path, ann_list in grouped.items(): stack = load_preprocessed_stack(dicom_path, diameter=diameter) h, w = int(stack.shape[1]), int(stack.shape[2]) for ann in ann_list: if ann.frame_idx < 0 or ann.frame_idx >= int(stack.shape[0]): continue mask = polygon_to_mask(ann.lumen_x, ann.lumen_y, (h, w)) if not np.any(mask): continue images_out.append(stack[ann.frame_idx]) masks_out.append(mask.astype(np.float32)) kept.append(ann) if not images_out: raise RuntimeError("No usable annotations found after loading DICOM frames and rasterizing masks.") return np.stack(images_out, axis=0), np.stack(masks_out, axis=0), kept def stratified_frame_split( annotations: list[LumenAnnotation], train_fraction: float, val_fraction: float, test_fraction: float, seed: int, ) -> dict[str, list[int]]: """Frame-level stratified split by bifurcation label.""" total = train_fraction + val_fraction + test_fraction if total <= 0: raise ValueError("Split fractions must sum to a positive value.") train_fraction /= total val_fraction /= total test_fraction /= total labels = np.asarray([1 if ann.bifurcation else 0 for ann in annotations], dtype=np.int32) indices = np.arange(len(annotations), dtype=np.int64) rng = np.random.default_rng(seed) train_ids: list[int] = [] val_ids: list[int] = [] test_ids: list[int] = [] for label in (0, 1): cls_idx = indices[labels == label] rng.shuffle(cls_idx) n = len(cls_idx) if n == 0: continue n_train = int(round(n * train_fraction)) n_val = int(round(n * val_fraction)) n_test = n - n_train - n_val # Keep at least one sample per split when class has enough samples. if n >= 3: n_train = max(1, n_train) n_val = max(1, n_val) n_test = max(1, n_test) overflow = n_train + n_val + n_test - n while overflow > 0: if n_train >= n_val and n_train >= n_test and n_train > 1: n_train -= 1 elif n_val >= n_test and n_val > 1: n_val -= 1 elif n_test > 1: n_test -= 1 overflow -= 1 train_ids.extend(cls_idx[:n_train].tolist()) val_ids.extend(cls_idx[n_train : n_train + n_val].tolist()) test_ids.extend(cls_idx[n_train + n_val : n_train + n_val + n_test].tolist()) rng.shuffle(train_ids) rng.shuffle(val_ids) rng.shuffle(test_ids) return {"train": train_ids, "val": val_ids, "test": test_ids} def split_summary(annotations: list[LumenAnnotation], split: dict[str, list[int]]) -> dict[str, dict[str, int]]: out: dict[str, dict[str, int]] = {} for part, ids in split.items(): bif = sum(1 for i in ids if annotations[i].bifurcation) non_bif = len(ids) - bif out[part] = { "count": len(ids), "bifurcation_true": bif, "bifurcation_false": non_bif, } return out def save_split_json(path: Path, annotations: list[LumenAnnotation], split: dict[str, list[int]], seed: int) -> None: path.parent.mkdir(parents=True, exist_ok=True) payload = { "seed": int(seed), "summary": split_summary(annotations, split), "splits": { part: [annotations[i].sample_id for i in ids] for part, ids in split.items() }, } with path.open("w", encoding="utf-8") as fp: json.dump(payload, fp, indent=2) def load_split_ids(path: Path) -> dict[str, set[str]]: with path.open("r", encoding="utf-8") as fp: payload = json.load(fp) splits = payload.get("splits", {}) out = {} for part in ("train", "val", "test"): out[part] = set(str(v) for v in splits.get(part, [])) return out