"""Image processors for FlexiCT Hugging Face model repos.""" from __future__ import annotations from pathlib import Path from typing import Any import numpy as np import torch import torch.nn.functional as F from transformers import BatchFeature from transformers.image_processing_utils import ImageProcessingMixin def _as_float_array(image: Any) -> tuple[np.ndarray, dict[str, Any]]: if isinstance(image, (str, Path)): return _load_medical_image_array(image) if isinstance(image, torch.Tensor): image = image.detach().cpu().numpy() array = np.asarray(image, dtype=np.float32) return array, {"source": "array"} def _load_medical_image_array(path: str | Path) -> tuple[np.ndarray, dict[str, Any]]: try: import SimpleITK as sitk except ImportError as exc: # pragma: no cover - runtime dependency branch. raise RuntimeError("SimpleITK is required to load CT paths with FlexiCTImageProcessor.") from exc image = sitk.ReadImage(str(Path(path).expanduser())) image = sitk.DICOMOrient(image, "LPS") array = sitk.GetArrayFromImage(image).astype(np.float32, copy=False) metadata = { "source": str(path), "spacing_xyz": [float(v) for v in image.GetSpacing()], "origin_xyz": [float(v) for v in image.GetOrigin()], "direction": [float(v) for v in image.GetDirection()], "loaded_shape_zyx": [int(v) for v in array.shape], } return array, metadata def _resample_array_zyx( array: np.ndarray, input_spacing_xyz: tuple[float, float, float] | None, target_spacing_xyz: tuple[float, float, float], ) -> np.ndarray: if input_spacing_xyz is None: return array spacing_zyx = tuple(float(v) for v in input_spacing_xyz[::-1]) target_zyx = tuple(float(v) for v in target_spacing_xyz[::-1]) out_shape = [ max(1, int(round(size * spacing / target))) for size, spacing, target in zip(array.shape, spacing_zyx, target_zyx) ] tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) resized = F.interpolate(tensor, size=out_shape, mode="trilinear", align_corners=False) return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) def _clip_zscore( array: np.ndarray, clip_range: tuple[float, float], eps: float, ) -> tuple[np.ndarray, dict[str, float]]: clipped = np.clip(array.astype(np.float32, copy=False), clip_range[0], clip_range[1]) mean = float(clipped.mean()) std = float(clipped.std()) if std < eps: std = 1.0 normalized = (clipped - mean) / std return normalized.astype(np.float32, copy=False), { "clip_min": float(clip_range[0]), "clip_max": float(clip_range[1]), "mean": mean, "std": std, } def _pad_to_shape( array: np.ndarray, target_shape: tuple[int, ...], fill_value: float, ) -> tuple[np.ndarray, list[int], list[int]]: pad_before: list[int] = [] pad_after: list[int] = [] pads = [] for size, target in zip(array.shape, target_shape): total = max(0, int(target) - int(size)) before = total // 2 after = total - before pad_before.append(before) pad_after.append(after) pads.append((before, after)) if any(before or after for before, after in pads): array = np.pad(array, pads, mode="constant", constant_values=float(fill_value)) return array.astype(np.float32, copy=False), pad_before, pad_after def _center_crop(array: np.ndarray, target_shape: tuple[int, ...]) -> tuple[np.ndarray, list[int]]: starts = [max(0, (int(size) - int(target)) // 2) for size, target in zip(array.shape, target_shape)] slices = tuple(slice(start, start + int(target)) for start, target in zip(starts, target_shape)) return array[slices].astype(np.float32, copy=False), starts def _resize_2d(array: np.ndarray, output_size: int) -> np.ndarray: tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) resized = F.interpolate(tensor, size=(output_size, output_size), mode="bilinear", align_corners=False) return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) def _resize_3d(array: np.ndarray, output_size: tuple[int, int, int]) -> np.ndarray: tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) resized = F.interpolate(tensor, size=output_size, mode="trilinear", align_corners=False) return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) def _listify_images(images: Any, spatial_dims: int) -> list[Any]: if isinstance(images, (str, Path)): return [images] if isinstance(images, torch.Tensor): ndim = images.dim() else: ndim = np.asarray(images).ndim if ndim == spatial_dims: return [images] if ndim == spatial_dims + 1: return [sample for sample in images] return list(images) class FlexiCTImageProcessor(ImageProcessingMixin): """Preprocess CT arrays or image paths for FlexiCT model variants.""" model_input_names = ["pixel_values"] def __init__( self, model_variant: str = "3d", preset: str = "default", image_size: int | list[int] | tuple[int, ...] | None = None, clip_range: list[float] | tuple[float, float] = (-1000.0, 1000.0), target_spacing: list[float] | tuple[float, float, float] = (2.0, 2.0, 2.0), do_resample: bool = True, do_orient_lps: bool = True, eps: float = 1e-6, **kwargs: Any, ): super().__init__(**kwargs) model_variant = model_variant.lower() if model_variant not in {"2d", "3d", "vlm"}: raise ValueError("model_variant must be one of '2d', '3d', or 'vlm'") if preset not in {"default", "local_path", "retrieval_roi"}: raise ValueError("preset must be 'default', 'local_path', or 'retrieval_roi'") self.model_variant = model_variant self.preset = preset if image_size is None: image_size = 512 if model_variant == "2d" else [160, 160, 160] self.image_size = list(image_size) if isinstance(image_size, (list, tuple)) else int(image_size) self.clip_range = [float(clip_range[0]), float(clip_range[1])] self.target_spacing = [float(v) for v in target_spacing] self.do_resample = bool(do_resample) self.do_orient_lps = bool(do_orient_lps) self.eps = float(eps) def __call__( self, images: Any, return_tensors: str | None = "pt", return_metadata: bool = False, **kwargs: Any, ) -> BatchFeature: spatial_dims = 2 if self.model_variant == "2d" and np.asarray(images).ndim == 2 else 3 samples = _listify_images(images, spatial_dims=spatial_dims) processed = [] metadata = [] for sample in samples: if self.model_variant == "2d": array, meta = self._process_2d(sample, **kwargs) else: array, meta = self._process_3d(sample, **kwargs) processed.append(array[None]) metadata.append(meta) batch_array = np.stack(processed, axis=0).astype(np.float32, copy=False) data: dict[str, Any] = {"pixel_values": batch_array} if return_tensors == "pt": data["pixel_values"] = torch.from_numpy(batch_array) elif return_tensors not in {None, "np"}: raise ValueError("return_tensors must be 'pt', 'np', or None") if return_metadata: data["metadata"] = metadata return BatchFeature(data=data) def _process_2d(self, image: Any, slice_index: int | None = None, slice_axis: int = 0, **_: Any): array, metadata = _as_float_array(image) metadata["original_shape"] = [int(v) for v in array.shape] if array.ndim == 3: if slice_index is None: slice_index = array.shape[slice_axis] // 2 array = np.take(array, int(slice_index), axis=int(slice_axis)) metadata["slice_index"] = int(slice_index) metadata["slice_axis"] = int(slice_axis) if array.ndim != 2: raise ValueError(f"FlexiCT-2D expects a 2D slice or 3D volume, got shape {array.shape}") array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps) side = max(array.shape) array, pad_before, pad_after = _pad_to_shape(array, (side, side), float(array.min())) output_size = int(self.image_size) array = _resize_2d(array, output_size) metadata.update(stats) metadata.update( { "pad_before_yx": pad_before, "pad_after_yx": pad_after, "processed_shape_yx": [output_size, output_size], } ) return array, metadata def _process_3d( self, image: Any, input_spacing: tuple[float, float, float] | None = None, roi_center: tuple[int, int, int] | None = None, roi_size: int | tuple[int, int, int] | None = None, bbox: tuple[int, int, int, int, int, int] | None = None, mask: Any | None = None, **_: Any, ): array, metadata = _as_float_array(image) if array.ndim != 3: raise ValueError(f"FlexiCT-3D expects a 3D volume, got shape {array.shape}") if input_spacing is None and "spacing_xyz" in metadata: input_spacing = tuple(metadata["spacing_xyz"]) if self.do_resample and input_spacing is not None: array = _resample_array_zyx(array, input_spacing, tuple(self.target_spacing)) metadata["resampled_shape_zyx"] = [int(v) for v in array.shape] metadata["original_shape_zyx"] = [int(v) for v in array.shape] array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps) metadata.update(stats) target_shape = tuple(int(v) for v in self.image_size) if self.preset == "default": return self._default_3d(array, target_shape, metadata) if self.preset == "local_path": return self._local_path_3d(array, target_shape, metadata) return self._retrieval_roi_3d(array, target_shape, metadata, roi_center, roi_size, bbox, mask) def _default_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]): array, pad_before, pad_after = _pad_to_shape(array, target_shape, float(array.min())) array, crop_start = _center_crop(array, target_shape) metadata.update( { "pad_before_zyx": pad_before, "pad_after_zyx": pad_after, "crop_start_zyx": crop_start, "processed_shape_zyx": [int(v) for v in array.shape], } ) return array, metadata def _local_path_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]): side = max(int(v) for v in array.shape) array, pad_before, pad_after = _pad_to_shape(array, (side, side, side), float(array.min())) metadata.update( { "cubic_pad_before_zyx": pad_before, "cubic_pad_after_zyx": pad_after, "cubic_padded_shape_zyx": [int(v) for v in array.shape], } ) array = _resize_3d(array, target_shape) metadata.update({"processed_shape_zyx": [int(v) for v in array.shape], "resize_mode": "trilinear"}) return array, metadata def _retrieval_roi_3d( self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any], roi_center: tuple[int, int, int] | None, roi_size: int | tuple[int, int, int] | None, bbox: tuple[int, int, int, int, int, int] | None, mask: Any | None, ): if roi_size is None: roi_size = target_shape roi_shape = tuple([int(roi_size)] * 3) if isinstance(roi_size, int) else tuple(int(v) for v in roi_size) if bbox is not None: z0, y0, x0, z1, y1, x1 = [int(v) for v in bbox] roi_center = ((z0 + z1) // 2, (y0 + y1) // 2, (x0 + x1) // 2) elif mask is not None: mask_array = np.asarray(mask) coords = np.argwhere(mask_array > 0) if coords.size == 0: raise ValueError("mask does not contain any foreground voxels") roi_center = tuple(int(v) for v in coords.mean(axis=0).round()) elif roi_center is None: roi_center = tuple(int(v // 2) for v in array.shape) starts = [int(center) - size // 2 for center, size in zip(roi_center, roi_shape)] ends = [start + size for start, size in zip(starts, roi_shape)] src_starts = [max(0, start) for start in starts] src_ends = [min(dim, end) for dim, end in zip(array.shape, ends)] crop = array[tuple(slice(start, end) for start, end in zip(src_starts, src_ends))] pad_before = [src - start for src, start in zip(src_starts, starts)] pad_after = [end - src for end, src in zip(ends, src_ends)] crop = np.pad( crop, tuple(zip(pad_before, pad_after)), mode="constant", constant_values=float(array.min()), ).astype(np.float32, copy=False) resized = _resize_3d(crop, target_shape) metadata.update( { "roi_center_zyx": [int(v) for v in roi_center], "roi_crop_start_zyx": src_starts, "roi_crop_end_zyx": src_ends, "roi_pad_before_zyx": pad_before, "roi_pad_after_zyx": pad_after, "roi_padded_shape_zyx": [int(v) for v in crop.shape], "processed_shape_zyx": [int(v) for v in resized.shape], "resize_mode": "trilinear", } ) return resized, metadata