"""Shared offline-aligned preprocessing helpers for 3D brain MRI volumes.""" from __future__ import annotations import math from pathlib import Path from typing import Any, Mapping import nibabel as nib import numpy as np import torch import torch.nn.functional as F try: from scipy import ndimage as scipy_ndimage except Exception: # pragma: no cover - optional import surface scipy_ndimage = None TARGET_SHAPE = (128, 192, 192) TARGET_SPACING = (1.25, 1.0, 1.0) CROP_MARGIN_MM = 5.0 FOREGROUND_THRESHOLD = 1e-3 BACKGROUND_VALUE = -1.0 FOREGROUND_STRATEGY = "largest_component_nonzero" GENERIC_RECIPE_ID = "generic_foreground_128x192x192_fp16_v1" GENERIC_CACHE_VERSION = 1 def load_canonical_nifti(path: str | Path): return nib.as_closest_canonical(nib.load(str(path))) def load_image_spacing(image) -> tuple[float, float, float]: zooms = image.header.get_zooms()[:3] if len(zooms) != 3: raise ValueError(f"Expected a 3D image spacing tuple, got {zooms}.") return tuple(float(value) for value in zooms) def coerce_volume_to_3d(volume: np.ndarray) -> np.ndarray: if volume.ndim == 3: return volume.astype(np.float32, copy=False) if volume.ndim != 4: raise ValueError(f"Expected a 3D or 4D volume, got shape {volume.shape}.") if volume.shape[0] <= 4 and volume.shape[-1] > 4: selected = volume[0] else: selected = volume[..., 0] return np.asarray(selected, dtype=np.float32) def largest_connected_component(mask: np.ndarray) -> np.ndarray: if not mask.any() or scipy_ndimage is None: return mask structure = scipy_ndimage.generate_binary_structure(mask.ndim, 1) labels, num_labels = scipy_ndimage.label(mask, structure=structure) if num_labels <= 1: return mask counts = np.bincount(labels.reshape(-1)) if counts.size <= 1: return mask counts[0] = 0 winning_label = int(counts.argmax()) if winning_label <= 0 or counts[winning_label] <= 0: return mask return labels == winning_label def build_foreground_mask(volume: np.ndarray, threshold: float = FOREGROUND_THRESHOLD) -> np.ndarray: sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0) raw_mask = np.abs(sanitized) > float(threshold) if not raw_mask.any(): return np.ones_like(sanitized, dtype=bool) component_mask = largest_connected_component(raw_mask) component_count = int(component_mask.sum()) raw_count = int(raw_mask.sum()) if component_count <= 0: return raw_mask if component_count < 512 and raw_count > component_count: return raw_mask return component_mask def compute_crop_bbox( mask: np.ndarray, spacing: tuple[float, float, float], margin_mm: float = CROP_MARGIN_MM, ) -> tuple[tuple[int, int], ...]: coords = np.where(mask) if coords[0].size == 0: raise ValueError("Foreground mask contains no positive voxels after selection.") bbox = [] for axis, values in enumerate(coords): margin_voxels = int(math.ceil(float(margin_mm) / float(spacing[axis]))) start = max(0, int(values.min()) - margin_voxels) stop = min(mask.shape[axis], int(values.max()) + margin_voxels + 1) bbox.append((start, stop)) return tuple(bbox) def crop_volume_and_mask( volume: np.ndarray, mask: np.ndarray, spacing: tuple[float, float, float], margin_mm: float = CROP_MARGIN_MM, ) -> tuple[np.ndarray, np.ndarray, tuple[tuple[int, int], ...]]: bbox = compute_crop_bbox(mask, spacing, margin_mm=margin_mm) slices = tuple(slice(start, stop) for start, stop in bbox) return volume[slices], mask[slices], bbox def normalize_foreground_only(volume: np.ndarray, mask: np.ndarray) -> np.ndarray: sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) foreground_values = sanitized[mask] if foreground_values.size == 0: raise ValueError("Cannot normalize volume because the foreground mask is empty.") if foreground_values.size > 1_000_000: step = max(1, foreground_values.size // 1_000_000) foreground_values = foreground_values[::step] low, high = np.percentile(foreground_values, [0.5, 99.5]) if not np.isfinite(low) or not np.isfinite(high) or high <= low: normalized = np.zeros_like(sanitized, dtype=np.float32) else: normalized = np.clip(sanitized, float(low), float(high)) normalized = np.clip((normalized - float(low)) / float(high - low), 0.0, 1.0) normalized = normalized * 2.0 - 1.0 return normalized.astype(np.float32, copy=False) def resize_volume(volume: np.ndarray, size: tuple[int, int, int], mode: str) -> np.ndarray: tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0) kwargs = {} if mode in {"linear", "bilinear", "bicubic", "trilinear"}: kwargs["align_corners"] = False tensor = F.interpolate(tensor, size=size, mode=mode, **kwargs) return tensor.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32, copy=False) def resize_mask(mask: np.ndarray, size: tuple[int, int, int]) -> np.ndarray: tensor = torch.from_numpy(mask.astype(np.float32, copy=False)).unsqueeze(0).unsqueeze(0) tensor = F.interpolate(tensor, size=size, mode="nearest") return tensor.squeeze(0).squeeze(0).cpu().numpy() > 0.5 def resample_to_target_spacing( volume: np.ndarray, mask: np.ndarray, source_spacing: tuple[float, float, float], target_spacing: tuple[float, float, float] = TARGET_SPACING, ) -> tuple[np.ndarray, np.ndarray]: target_shape = [] for current_size, src, dst in zip(volume.shape, source_spacing, target_spacing): target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst))))) target_shape_tuple = tuple(target_shape) if target_shape_tuple == tuple(int(v) for v in volume.shape): return volume.astype(np.float32, copy=False), mask return ( resize_volume(volume, target_shape_tuple, mode="trilinear"), resize_mask(mask, target_shape_tuple), ) def downscale_to_fit( volume: np.ndarray, mask: np.ndarray, target_shape: tuple[int, int, int] = TARGET_SHAPE, ) -> tuple[np.ndarray, np.ndarray]: current_shape = tuple(int(v) for v in volume.shape) if all(current <= target for current, target in zip(current_shape, target_shape)): return volume, mask scale = min(float(target) / float(current) for current, target in zip(current_shape, target_shape)) if scale >= 1.0: return volume, mask new_shape = tuple( min(target, max(1, int(math.floor(float(current) * scale)))) for current, target in zip(current_shape, target_shape) ) return ( resize_volume(volume, new_shape, mode="trilinear"), resize_mask(mask, new_shape), ) def center_pad( array: np.ndarray, target_shape: tuple[int, int, int] = TARGET_SHAPE, fill_value: float = BACKGROUND_VALUE, ) -> np.ndarray: if any(current > target for current, target in zip(array.shape, target_shape)): raise ValueError(f"Cannot center-pad shape {array.shape} into smaller target {target_shape}.") pad_width = [] for current, target in zip(array.shape, target_shape): delta = target - current before = delta // 2 after = delta - before pad_width.append((before, after)) return np.pad(array, pad_width=tuple(pad_width), mode="constant", constant_values=fill_value) def preprocess_image_with_foreground_mask( image_path: str | Path, *, target_shape: tuple[int, int, int] = TARGET_SHAPE, target_spacing: tuple[float, float, float] = TARGET_SPACING, crop_margin_mm: float = CROP_MARGIN_MM, foreground_threshold: float = FOREGROUND_THRESHOLD, background_value: float = BACKGROUND_VALUE, foreground_strategy: str = FOREGROUND_STRATEGY, recipe_id: str = GENERIC_RECIPE_ID, cache_version: int = GENERIC_CACHE_VERSION, ) -> dict[str, object]: image_path = Path(image_path) image = load_canonical_nifti(image_path) source_shape = tuple(int(value) for value in image.shape) source_spacing = load_image_spacing(image) volume = np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32) volume = coerce_volume_to_3d(volume) foreground_mask = build_foreground_mask(volume, threshold=foreground_threshold) cropped_volume, cropped_mask, crop_bbox = crop_volume_and_mask( volume, foreground_mask, source_spacing, margin_mm=crop_margin_mm, ) normalized_volume = normalize_foreground_only(cropped_volume, cropped_mask) resampled_volume, resampled_mask = resample_to_target_spacing( normalized_volume, cropped_mask, source_spacing=source_spacing, target_spacing=target_spacing, ) fitted_volume, fitted_mask = downscale_to_fit( resampled_volume, resampled_mask, target_shape=target_shape, ) fitted_volume = np.clip(fitted_volume, -1.0, 1.0).astype(np.float32, copy=False) fitted_volume[~fitted_mask] = float(background_value) padded_volume = center_pad( fitted_volume, target_shape=target_shape, fill_value=float(background_value), ).astype(np.float32, copy=False) pixel_values = torch.from_numpy(padded_volume).unsqueeze(0).to(dtype=torch.float16).contiguous() return { "pixel_values": pixel_values, "source_image": str(image_path), "source_shape": list(source_shape), "source_spacing": list(source_spacing), "crop_bbox": [[int(start), int(stop)] for start, stop in crop_bbox], "foreground_strategy": foreground_strategy, "recipe_id": recipe_id, "cache_version": int(cache_version), } def validate_fixed_payload( payload: Mapping[str, Any], *, target_shape: tuple[int, int, int] = TARGET_SHAPE, ) -> None: pixel_values = payload.get("pixel_values") if not isinstance(pixel_values, torch.Tensor): raise TypeError("`pixel_values` must be a torch.Tensor.") expected_shape = (1,) + tuple(target_shape) if tuple(pixel_values.shape) != expected_shape: raise ValueError(f"Expected tensor shape {expected_shape}, got {tuple(pixel_values.shape)}.") if pixel_values.dtype != torch.float16: raise ValueError(f"Expected tensor dtype torch.float16, got {pixel_values.dtype}.") if not torch.isfinite(pixel_values).all(): raise ValueError("Tensor contains non-finite values.") min_value = float(pixel_values.min().item()) max_value = float(pixel_values.max().item()) if min_value < -1.01 or max_value > 1.01: raise ValueError(f"Expected tensor values in [-1, 1]. Got min={min_value}, max={max_value}.")