Aditya2162's picture
Upload folder using huggingface_hub
1d197a4 verified
#!/usr/bin/env python3
"""Shared utilities for lumen class discovery and fine-tuning scripts."""
from __future__ import annotations
import json
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable
import numpy as np
from matplotlib.path import Path as MplPath
# common.py lives at scripts/finetune/shared/common.py, so repo root is parents[3].
REPO_ROOT = Path(__file__).resolve().parents[3]
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from deepivus.io.dicom import read_dicom
from deepivus.processing.preprocessing import apply_center_black_circle
@dataclass(frozen=True)
class LumenAnnotation:
"""Single manually-annotated lumen polygon for one frame."""
bank_path: Path
group: str
dicom_path: Path
frame_idx: int
bifurcation: bool
lumen_x: np.ndarray
lumen_y: np.ndarray
@property
def sample_id(self) -> str:
rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT)
return f"{rel_bank.as_posix()}::{self.frame_idx}"
@dataclass(frozen=True)
class BifurcationAnnotation:
"""Single frame-level bifurcation annotation."""
bank_path: Path
group: str
dicom_path: Path
frame_idx: int
bifurcation: bool
@property
def sample_id(self) -> str:
rel_bank = self.bank_path.resolve().relative_to(REPO_ROOT)
return f"{rel_bank.as_posix()}::{self.frame_idx}"
def _read_jsonl(path: Path) -> list[dict]:
rows: list[dict] = []
with path.open("r", encoding="utf-8") as fp:
for raw in fp:
line = raw.strip()
if not line:
continue
rows.append(json.loads(line))
return rows
def _resolve_dicom_path(bank_path: Path, meta: dict) -> Path | None:
group = str(meta.get("group", bank_path.parent.name))
dicom_raw = meta.get("dicom_path")
if dicom_raw:
dicom_path = Path(str(dicom_raw))
if not dicom_path.is_absolute():
dicom_path = (REPO_ROOT / dicom_path).resolve()
if dicom_path.exists():
return dicom_path
fallback = REPO_ROOT / "data" / group / f"{bank_path.stem}.dcm"
if fallback.exists():
return fallback.resolve()
return None
def load_lumen_annotations(frame_bank_root: Path) -> list[LumenAnnotation]:
"""Load all frame-bank lumen annotations with valid polygons."""
annotations: list[LumenAnnotation] = []
bank_files = sorted(frame_bank_root.glob("*/*.jsonl"))
for bank_path in bank_files:
rows = _read_jsonl(bank_path)
if not rows:
continue
meta = rows[0]
if meta.get("record_type") != "meta":
continue
group = str(meta.get("group", bank_path.parent.name))
dicom_path = _resolve_dicom_path(bank_path, meta)
if dicom_path is None:
continue
for rec in rows:
if rec.get("record_type") != "frame":
continue
lumen = rec.get("lumen", {})
xs = np.asarray(lumen.get("x", []), dtype=np.float32)
ys = np.asarray(lumen.get("y", []), dtype=np.float32)
if xs.size < 3 or ys.size < 3 or xs.size != ys.size:
continue
if not np.all(np.isfinite(xs)) or not np.all(np.isfinite(ys)):
continue
annotations.append(
LumenAnnotation(
bank_path=bank_path,
group=group,
dicom_path=dicom_path,
frame_idx=int(rec["frame"]),
bifurcation=bool(rec.get("bifurcation", False)),
lumen_x=xs,
lumen_y=ys,
)
)
return annotations
def load_bifurcation_annotations(frame_bank_root: Path) -> list[BifurcationAnnotation]:
"""Load all frame-bank bifurcation labels, independent of lumen polygon presence."""
annotations: list[BifurcationAnnotation] = []
bank_files = sorted(frame_bank_root.glob("*/*.jsonl"))
for bank_path in bank_files:
rows = _read_jsonl(bank_path)
if not rows:
continue
meta = rows[0]
if meta.get("record_type") != "meta":
continue
group = str(meta.get("group", bank_path.parent.name))
dicom_path = _resolve_dicom_path(bank_path, meta)
if dicom_path is None:
continue
for rec in rows:
if rec.get("record_type") != "frame":
continue
bif = rec.get("bifurcation")
if bif is None:
continue
annotations.append(
BifurcationAnnotation(
bank_path=bank_path,
group=group,
dicom_path=dicom_path,
frame_idx=int(rec["frame"]),
bifurcation=bool(bif),
)
)
return annotations
def polygon_to_mask(x_coords: np.ndarray, y_coords: np.ndarray, image_shape: tuple[int, int]) -> np.ndarray:
"""Rasterize a polygon in image coordinates to a binary mask."""
if x_coords.size < 3 or y_coords.size < 3:
return np.zeros(image_shape, dtype=bool)
vertices = np.column_stack((x_coords.astype(np.float32), y_coords.astype(np.float32)))
polygon = MplPath(vertices, closed=True)
h, w = image_shape
yy, xx = np.mgrid[0:h, 0:w]
points = np.column_stack((xx.ravel(), yy.ravel()))
mask = polygon.contains_points(points, radius=0.5).reshape((h, w))
return mask
def group_by_dicom(annotations: Iterable[LumenAnnotation]) -> dict[Path, list[LumenAnnotation]]:
grouped: dict[Path, list[LumenAnnotation]] = {}
for ann in annotations:
grouped.setdefault(ann.dicom_path, []).append(ann)
return grouped
def load_preprocessed_stack(dicom_path: Path, diameter: int) -> np.ndarray:
"""Load one DICOM stack and apply the same preprocessing as pipeline inference."""
_, images = read_dicom(str(dicom_path))
return apply_center_black_circle(images, diameter=diameter)
def build_images_and_masks(
annotations: list[LumenAnnotation],
diameter: int,
) -> tuple[np.ndarray, np.ndarray, list[LumenAnnotation]]:
"""Materialize image and mask arrays for a list of annotations."""
images_out: list[np.ndarray] = []
masks_out: list[np.ndarray] = []
kept: list[LumenAnnotation] = []
grouped = group_by_dicom(annotations)
for dicom_path, ann_list in grouped.items():
stack = load_preprocessed_stack(dicom_path, diameter=diameter)
h, w = int(stack.shape[1]), int(stack.shape[2])
for ann in ann_list:
if ann.frame_idx < 0 or ann.frame_idx >= int(stack.shape[0]):
continue
mask = polygon_to_mask(ann.lumen_x, ann.lumen_y, (h, w))
if not np.any(mask):
continue
images_out.append(stack[ann.frame_idx])
masks_out.append(mask.astype(np.float32))
kept.append(ann)
if not images_out:
raise RuntimeError("No usable annotations found after loading DICOM frames and rasterizing masks.")
return np.stack(images_out, axis=0), np.stack(masks_out, axis=0), kept
def stratified_frame_split(
annotations: list[LumenAnnotation],
train_fraction: float,
val_fraction: float,
test_fraction: float,
seed: int,
) -> dict[str, list[int]]:
"""Frame-level stratified split by bifurcation label."""
total = train_fraction + val_fraction + test_fraction
if total <= 0:
raise ValueError("Split fractions must sum to a positive value.")
train_fraction /= total
val_fraction /= total
test_fraction /= total
labels = np.asarray([1 if ann.bifurcation else 0 for ann in annotations], dtype=np.int32)
indices = np.arange(len(annotations), dtype=np.int64)
rng = np.random.default_rng(seed)
train_ids: list[int] = []
val_ids: list[int] = []
test_ids: list[int] = []
for label in (0, 1):
cls_idx = indices[labels == label]
rng.shuffle(cls_idx)
n = len(cls_idx)
if n == 0:
continue
n_train = int(round(n * train_fraction))
n_val = int(round(n * val_fraction))
n_test = n - n_train - n_val
# Keep at least one sample per split when class has enough samples.
if n >= 3:
n_train = max(1, n_train)
n_val = max(1, n_val)
n_test = max(1, n_test)
overflow = n_train + n_val + n_test - n
while overflow > 0:
if n_train >= n_val and n_train >= n_test and n_train > 1:
n_train -= 1
elif n_val >= n_test and n_val > 1:
n_val -= 1
elif n_test > 1:
n_test -= 1
overflow -= 1
train_ids.extend(cls_idx[:n_train].tolist())
val_ids.extend(cls_idx[n_train : n_train + n_val].tolist())
test_ids.extend(cls_idx[n_train + n_val : n_train + n_val + n_test].tolist())
rng.shuffle(train_ids)
rng.shuffle(val_ids)
rng.shuffle(test_ids)
return {"train": train_ids, "val": val_ids, "test": test_ids}
def split_summary(annotations: list[LumenAnnotation], split: dict[str, list[int]]) -> dict[str, dict[str, int]]:
out: dict[str, dict[str, int]] = {}
for part, ids in split.items():
bif = sum(1 for i in ids if annotations[i].bifurcation)
non_bif = len(ids) - bif
out[part] = {
"count": len(ids),
"bifurcation_true": bif,
"bifurcation_false": non_bif,
}
return out
def save_split_json(path: Path, annotations: list[LumenAnnotation], split: dict[str, list[int]], seed: int) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
payload = {
"seed": int(seed),
"summary": split_summary(annotations, split),
"splits": {
part: [annotations[i].sample_id for i in ids] for part, ids in split.items()
},
}
with path.open("w", encoding="utf-8") as fp:
json.dump(payload, fp, indent=2)
def load_split_ids(path: Path) -> dict[str, set[str]]:
with path.open("r", encoding="utf-8") as fp:
payload = json.load(fp)
splits = payload.get("splits", {})
out = {}
for part in ("train", "val", "test"):
out[part] = set(str(v) for v in splits.get(part, []))
return out