| """Custom MONAI transforms for binary coronary artery segmentation.""" |
|
|
| import json |
|
|
| import numpy as np |
| from pathlib import Path |
| from typing import Dict, Hashable, Mapping, Optional, Any |
|
|
| import torch |
| from monai import transforms |
| from monai.config.type_definitions import KeysCollection, NdarrayOrTensor |
| from monai.utils.enums import TransformBackends |
| from scipy import ndimage |
|
|
|
|
| class ApplyWindowing(transforms.Transform): |
| """ |
| Apply window presets to DICOM images. |
| |
| Windowing adapts the greyscale component of a CT image to highlight particular structures |
| by reducing the range of Hounsfield units (HU) to be displayed. |
| |
| Args: |
| window: a string for preset windows (brain, subdural, stroke, temporal bone, |
| lungs, abdomen, liver, bone). |
| upper: upper threshold for windowing |
| lower: lower threshold for windowing |
| width: window width |
| level: window level (or window center) |
| |
| Raises: |
| ValueError: if none or multiple of window/lower+upper/width+level are specified. |
| """ |
|
|
| backend = [TransformBackends.TORCH, TransformBackends.NUMPY] |
|
|
| def __init__( |
| self, |
| window: Optional[str] = None, |
| upper: Optional[int] = None, |
| lower: Optional[int] = None, |
| width: Optional[int] = None, |
| level: Optional[int] = None, |
| ): |
| error_message = "Please specifiy either window or upper/lower or width/level." |
| if window: |
| if upper or lower: |
| raise ValueError(error_message) |
| if width or level: |
| raise ValueError(error_message) |
| elif upper and lower: |
| if window: |
| raise ValueError(error_message) |
| if width or level: |
| raise ValueError(error_message) |
| elif width and level: |
| if upper or lower: |
| raise ValueError(error_message) |
| if window: |
| raise ValueError(error_message) |
| else: |
| raise ValueError(error_message) |
|
|
| if window: |
| if window == "brain": |
| width, level = 80, 40 |
| elif window == "subdural": |
| width, level = 130, 50 |
| elif window == "stroke": |
| width, level = 8, 40 |
| elif window == "temporal bone": |
| width, level = 2800, 700 |
| elif window == "lungs": |
| width, level = 150, -600 |
| elif window == "abdomen": |
| width, level = 400, 50 |
| elif window == "liver": |
| width, level = 150, 30 |
| elif window == "bone": |
| width, level = 1800, 400 |
|
|
| if width and level: |
| upper = level + width // 2 |
| lower = level - width // 2 |
|
|
| self.upper = upper |
| self.lower = lower |
|
|
| def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: |
| return img.clip(self.lower, self.upper) |
|
|
|
|
| class ApplyWindowingd(transforms.MapTransform): |
| "Dictionary-based wrapper of :py:class:`ApplyWindowing`." |
|
|
| def __init__( |
| self, |
| keys: KeysCollection, |
| window: Optional[str] = None, |
| upper: Optional[int] = None, |
| lower: Optional[int] = None, |
| width: Optional[int] = None, |
| level: Optional[int] = None, |
| allow_missing_keys: bool = False, |
| ): |
| super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) |
| self.windowing = ApplyWindowing( |
| window=window, upper=upper, lower=lower, width=width, level=level |
| ) |
|
|
| def __call__( |
| self, data: Mapping[Hashable, NdarrayOrTensor] |
| ) -> Dict[Hashable, NdarrayOrTensor]: |
| d = dict(data) |
| for key in self.key_iterator(d): |
| d[key] = self.windowing(d[key]) |
| return d |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _to_numpy(img: NdarrayOrTensor) -> np.ndarray: |
| """Convert tensor to numpy for percentile/statistics computation.""" |
| if isinstance(img, torch.Tensor): |
| return img.cpu().numpy() |
| return np.asarray(img) |
|
|
|
|
| def _from_numpy(arr: np.ndarray, reference: NdarrayOrTensor) -> NdarrayOrTensor: |
| """Convert numpy back to the same type as reference, preserving MetaTensor metadata.""" |
| if isinstance(reference, torch.Tensor): |
| result = torch.from_numpy(arr).to(reference.device) |
| if hasattr(reference, 'meta'): |
| from monai.data import MetaTensor |
| result = MetaTensor(result, meta=reference.meta) |
| return result |
| return arr |
|
|
|
|
| class ZScoreForegroundNormalize(transforms.Transform): |
| """ |
| Z-score normalization using only non-background voxels. |
| |
| Applied AFTER windowing. Computes mean and std only from voxels above |
| a threshold (excluding background/air), then normalizes the entire image. |
| |
| Args: |
| background_threshold: Voxels below this value are considered background. |
| After windowing to [-100, 900], -50 excludes low-intensity regions. |
| """ |
|
|
| backend = [TransformBackends.TORCH, TransformBackends.NUMPY] |
|
|
| def __init__(self, background_threshold: float = -50) -> None: |
| self.background_threshold = background_threshold |
|
|
| def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: |
| arr = _to_numpy(img) |
| mask = arr > self.background_threshold |
| if mask.sum() > 0: |
| mean = arr[mask].mean() |
| std = arr[mask].std() |
| arr = (arr - mean) / (std + 1e-8) |
| else: |
| arr = (arr - arr.mean()) / (arr.std() + 1e-8) |
| return _from_numpy(arr.astype(np.float32), img) |
|
|
|
|
| class ZScoreForegroundNormalized(transforms.MapTransform): |
| """Dictionary-based wrapper of :py:class:`ZScoreForegroundNormalize`.""" |
|
|
| def __init__( |
| self, |
| keys: KeysCollection, |
| background_threshold: float = -50, |
| allow_missing_keys: bool = False, |
| ) -> None: |
| super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) |
| self.normalizer = ZScoreForegroundNormalize( |
| background_threshold=background_threshold |
| ) |
|
|
| def __call__( |
| self, data: Mapping[Hashable, NdarrayOrTensor] |
| ) -> Dict[Hashable, NdarrayOrTensor]: |
| d = dict(data) |
| for key in self.key_iterator(d): |
| d[key] = self.normalizer(d[key]) |
| return d |
|
|
|
|
| |
| |
| |
|
|
|
|
| def _get_neighbors(point, skel_arr): |
| """Get 26-connected skeleton neighbors of a point.""" |
| neighbors = [] |
| for dx in (-1, 0, 1): |
| for dy in (-1, 0, 1): |
| for dz in (-1, 0, 1): |
| if dx == 0 and dy == 0 and dz == 0: |
| continue |
| nb = (point[0] + dx, point[1] + dy, point[2] + dz) |
| if (0 <= nb[0] < skel_arr.shape[0] |
| and 0 <= nb[1] < skel_arr.shape[1] |
| and 0 <= nb[2] < skel_arr.shape[2] |
| and skel_arr[nb]): |
| neighbors.append(nb) |
| return neighbors |
|
|
|
|
| def _trace_branch(start, skel_arr, visited, branch_points): |
| """Trace a single branch from start until an endpoint or branch point. |
| |
| Follows the skeleton greedily through unvisited voxels. Stops when |
| hitting a dead end, a branch point, or a previously visited voxel. |
| Returns the ordered list of voxel coordinates along the branch. |
| """ |
| path = [start] |
| visited.add(start) |
| current = start |
| while True: |
| nbs = [n for n in _get_neighbors(current, skel_arr) if n not in visited] |
| if not nbs: |
| break |
| if len(nbs) == 1: |
| current = nbs[0] |
| visited.add(current) |
| path.append(current) |
| if current in branch_points: |
| break |
| else: |
| |
| if len(path) >= 2: |
| direction = np.array(path[-1]) - np.array(path[-2]) |
| dists = [np.dot(np.array(n) - np.array(current), direction) for n in nbs] |
| best = nbs[int(np.argmax(dists))] |
| else: |
| best = nbs[0] |
| current = best |
| visited.add(current) |
| path.append(current) |
| if current in branch_points: |
| break |
| return path |
|
|
|
|
| def _smooth_branch(points, affine, smoothing_factor=2.0): |
| """Fit a B-spline to branch points and resample at ~1mm intervals. |
| |
| Args: |
| points: List of (x, y, z) voxel coordinates. |
| affine: 4x4 affine matrix mapping voxel to physical (mm). |
| smoothing_factor: Spline smoothing (higher = smoother). |
| |
| Returns: |
| List of [x, y, z] physical coordinates (mm), rounded to 2 decimals. |
| """ |
| from scipy.interpolate import splprep, splev |
|
|
| pts = np.array(points, dtype=float) |
|
|
| |
| ones = np.ones((len(pts), 1)) |
| homogeneous = np.hstack([pts, ones]) |
| physical = (affine @ homogeneous.T).T[:, :3] |
|
|
| if len(physical) < 4: |
| return [[round(float(c), 2) for c in p] for p in physical] |
|
|
| try: |
| k = min(3, len(physical) - 1) |
| tck, u = splprep( |
| [physical[:, 0], physical[:, 1], physical[:, 2]], |
| s=len(physical) * smoothing_factor, |
| k=k, |
| ) |
| |
| diffs = np.diff(physical, axis=0) |
| total_length = float(np.sum(np.sqrt(np.sum(diffs ** 2, axis=1)))) |
| n_out = max(int(total_length), 4) |
| u_new = np.linspace(0, 1, n_out) |
| smooth = np.array(splev(u_new, tck)).T |
| return [[round(float(c), 2) for c in p] for p in smooth] |
| except Exception: |
| return [[round(float(c), 2) for c in p] for p in physical] |
|
|
|
|
| def extract_centerlines(binary_mask, affine, min_branch_points=3, |
| min_length_mm=5.0, smoothing_factor=2.0): |
| """Extract vessel centerlines from a binary mask. |
| |
| Args: |
| binary_mask: 3D numpy array (bool or int). |
| affine: 4x4 affine matrix (voxel to mm). |
| min_branch_points: Discard branches with fewer raw skeleton points. |
| min_length_mm: Discard branches shorter than this (mm) after smoothing. |
| smoothing_factor: Spline smoothing parameter. |
| |
| Returns: |
| Dict with 'branches' list, each containing 'id', 'points_mm', |
| 'length_mm', and 'n_points'. |
| """ |
| from skimage.morphology import skeletonize |
|
|
| arr = np.asarray(binary_mask).squeeze().astype(bool) |
| if not arr.any(): |
| return {"branches": []} |
|
|
| skel = skeletonize(arr) |
|
|
| |
| struct = ndimage.generate_binary_structure(3, 3) |
| neighbor_count = ndimage.convolve( |
| skel.astype(np.int32), struct.astype(np.int32), mode="constant" |
| ) - skel.astype(np.int32) |
|
|
| endpoints = set(map(tuple, np.argwhere(skel & (neighbor_count == 1)))) |
| branch_points = set(map(tuple, np.argwhere(skel & (neighbor_count >= 3)))) |
|
|
| |
| visited = set() |
| raw_branches = [] |
| for start in list(endpoints) + list(branch_points): |
| if start in visited: |
| continue |
| path = _trace_branch(start, skel, visited, branch_points) |
| if len(path) >= min_branch_points: |
| raw_branches.append(path) |
| |
| if start in branch_points: |
| for nb in _get_neighbors(start, skel): |
| if nb not in visited: |
| path2 = _trace_branch(nb, skel, visited, branch_points) |
| if len(path2) >= min_branch_points: |
| raw_branches.append([start] + path2) |
|
|
| |
| affine_np = np.array(affine, dtype=float) |
| branches = [] |
| branch_id = 0 |
| for raw in raw_branches: |
| pts_mm = _smooth_branch(raw, affine_np, smoothing_factor) |
| if len(pts_mm) < 2: |
| continue |
| diffs = np.diff(pts_mm, axis=0) |
| length = float(np.sum(np.sqrt(np.sum(np.array(diffs) ** 2, axis=1)))) |
| if length < min_length_mm: |
| continue |
| branches.append({ |
| "id": branch_id, |
| "points_mm": pts_mm, |
| "length_mm": round(length, 2), |
| "n_points": len(pts_mm), |
| }) |
| branch_id += 1 |
|
|
| return {"branches": branches} |
|
|
|
|
| class ExtractCenterlinesd(transforms.MapTransform): |
| """Extract vessel centerlines from binary mask and save as JSON. |
| |
| Post-processing transform for inference. Takes the predicted binary mask, |
| extracts a spline-smoothed centerline, and writes a JSON file with |
| ordered branch points in physical (mm) coordinates. |
| |
| Output file: ``{output_dir}/{patient_name}_centerline.json`` |
| |
| Args: |
| keys: Key of the binary mask prediction (typically "pred"). |
| image_key: Key of the input image (for filename extraction). |
| output_dir: Directory to write JSON files. |
| min_branch_points: Minimum raw skeleton points per branch. |
| min_length_mm: Discard branches shorter than this (mm). |
| smoothing_factor: B-spline smoothing (higher = smoother). |
| """ |
|
|
| def __init__( |
| self, |
| keys: KeysCollection, |
| image_key: str = "image", |
| output_dir: str = "./output", |
| min_branch_points: int = 3, |
| min_length_mm: float = 5.0, |
| smoothing_factor: float = 2.0, |
| allow_missing_keys: bool = False, |
| ) -> None: |
| super().__init__(keys=keys, allow_missing_keys=allow_missing_keys) |
| self.image_key = image_key |
| self.output_dir = output_dir |
| self.min_branch_points = min_branch_points |
| self.min_length_mm = min_length_mm |
| self.smoothing_factor = smoothing_factor |
|
|
| def __call__(self, data: Mapping[Hashable, Any]) -> Dict[Hashable, Any]: |
| d = dict(data) |
| for key in self.key_iterator(d): |
| pred = d[key] |
| mask_np = _to_numpy(pred) |
|
|
| |
| affine = np.eye(4) |
| if hasattr(pred, "meta") and "affine" in pred.meta: |
| affine = np.array(pred.meta["affine"], dtype=float) |
|
|
| centerlines = extract_centerlines( |
| mask_np, affine, |
| min_branch_points=self.min_branch_points, |
| min_length_mm=self.min_length_mm, |
| smoothing_factor=self.smoothing_factor, |
| ) |
|
|
| |
| filename = "unknown" |
| img = d.get(self.image_key) |
| if img is not None and hasattr(img, "meta"): |
| raw = img.meta.get("filename_or_obj", "unknown") |
| filename = Path(str(raw)).stem |
| for suffix in (".nii", ".nrrd", ".dcm"): |
| filename = filename.replace(suffix, "") |
|
|
| out_path = Path(self.output_dir) / f"{filename}_centerline.json" |
| out_path.parent.mkdir(parents=True, exist_ok=True) |
| with open(out_path, "w") as f: |
| json.dump(centerlines, f) |
|
|
| return d |
|
|