FlexiCT-2D / image_processing_flexict.py
ricklisz123's picture
Upload folder using huggingface_hub
fcefac1 verified
"""Image processors for FlexiCT Hugging Face model repos."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import numpy as np
import torch
import torch.nn.functional as F
from transformers import BatchFeature
from transformers.image_processing_utils import ImageProcessingMixin
def _as_float_array(image: Any) -> tuple[np.ndarray, dict[str, Any]]:
if isinstance(image, (str, Path)):
return _load_medical_image_array(image)
if isinstance(image, torch.Tensor):
image = image.detach().cpu().numpy()
array = np.asarray(image, dtype=np.float32)
return array, {"source": "array"}
def _load_medical_image_array(path: str | Path) -> tuple[np.ndarray, dict[str, Any]]:
try:
import SimpleITK as sitk
except ImportError as exc: # pragma: no cover - runtime dependency branch.
raise RuntimeError("SimpleITK is required to load CT paths with FlexiCTImageProcessor.") from exc
image = sitk.ReadImage(str(Path(path).expanduser()))
image = sitk.DICOMOrient(image, "LPS")
array = sitk.GetArrayFromImage(image).astype(np.float32, copy=False)
metadata = {
"source": str(path),
"spacing_xyz": [float(v) for v in image.GetSpacing()],
"origin_xyz": [float(v) for v in image.GetOrigin()],
"direction": [float(v) for v in image.GetDirection()],
"loaded_shape_zyx": [int(v) for v in array.shape],
}
return array, metadata
def _resample_array_zyx(
array: np.ndarray,
input_spacing_xyz: tuple[float, float, float] | None,
target_spacing_xyz: tuple[float, float, float],
) -> np.ndarray:
if input_spacing_xyz is None:
return array
spacing_zyx = tuple(float(v) for v in input_spacing_xyz[::-1])
target_zyx = tuple(float(v) for v in target_spacing_xyz[::-1])
out_shape = [
max(1, int(round(size * spacing / target)))
for size, spacing, target in zip(array.shape, spacing_zyx, target_zyx)
]
tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False))
resized = F.interpolate(tensor, size=out_shape, mode="trilinear", align_corners=False)
return resized[0, 0].cpu().numpy().astype(np.float32, copy=False)
def _clip_zscore(
array: np.ndarray,
clip_range: tuple[float, float],
eps: float,
) -> tuple[np.ndarray, dict[str, float]]:
clipped = np.clip(array.astype(np.float32, copy=False), clip_range[0], clip_range[1])
mean = float(clipped.mean())
std = float(clipped.std())
if std < eps:
std = 1.0
normalized = (clipped - mean) / std
return normalized.astype(np.float32, copy=False), {
"clip_min": float(clip_range[0]),
"clip_max": float(clip_range[1]),
"mean": mean,
"std": std,
}
def _pad_to_shape(
array: np.ndarray,
target_shape: tuple[int, ...],
fill_value: float,
) -> tuple[np.ndarray, list[int], list[int]]:
pad_before: list[int] = []
pad_after: list[int] = []
pads = []
for size, target in zip(array.shape, target_shape):
total = max(0, int(target) - int(size))
before = total // 2
after = total - before
pad_before.append(before)
pad_after.append(after)
pads.append((before, after))
if any(before or after for before, after in pads):
array = np.pad(array, pads, mode="constant", constant_values=float(fill_value))
return array.astype(np.float32, copy=False), pad_before, pad_after
def _center_crop(array: np.ndarray, target_shape: tuple[int, ...]) -> tuple[np.ndarray, list[int]]:
starts = [max(0, (int(size) - int(target)) // 2) for size, target in zip(array.shape, target_shape)]
slices = tuple(slice(start, start + int(target)) for start, target in zip(starts, target_shape))
return array[slices].astype(np.float32, copy=False), starts
def _resize_2d(array: np.ndarray, output_size: int) -> np.ndarray:
tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False))
resized = F.interpolate(tensor, size=(output_size, output_size), mode="bilinear", align_corners=False)
return resized[0, 0].cpu().numpy().astype(np.float32, copy=False)
def _resize_3d(array: np.ndarray, output_size: tuple[int, int, int]) -> np.ndarray:
tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False))
resized = F.interpolate(tensor, size=output_size, mode="trilinear", align_corners=False)
return resized[0, 0].cpu().numpy().astype(np.float32, copy=False)
def _listify_images(images: Any, spatial_dims: int) -> list[Any]:
if isinstance(images, (str, Path)):
return [images]
if isinstance(images, torch.Tensor):
ndim = images.dim()
else:
ndim = np.asarray(images).ndim
if ndim == spatial_dims:
return [images]
if ndim == spatial_dims + 1:
return [sample for sample in images]
return list(images)
class FlexiCTImageProcessor(ImageProcessingMixin):
"""Preprocess CT arrays or image paths for FlexiCT model variants."""
model_input_names = ["pixel_values"]
def __init__(
self,
model_variant: str = "3d",
preset: str = "default",
image_size: int | list[int] | tuple[int, ...] | None = None,
clip_range: list[float] | tuple[float, float] = (-1000.0, 1000.0),
target_spacing: list[float] | tuple[float, float, float] = (2.0, 2.0, 2.0),
do_resample: bool = True,
do_orient_lps: bool = True,
eps: float = 1e-6,
**kwargs: Any,
):
super().__init__(**kwargs)
model_variant = model_variant.lower()
if model_variant not in {"2d", "3d", "vlm"}:
raise ValueError("model_variant must be one of '2d', '3d', or 'vlm'")
if preset not in {"default", "local_path", "retrieval_roi"}:
raise ValueError("preset must be 'default', 'local_path', or 'retrieval_roi'")
self.model_variant = model_variant
self.preset = preset
if image_size is None:
image_size = 512 if model_variant == "2d" else [160, 160, 160]
self.image_size = list(image_size) if isinstance(image_size, (list, tuple)) else int(image_size)
self.clip_range = [float(clip_range[0]), float(clip_range[1])]
self.target_spacing = [float(v) for v in target_spacing]
self.do_resample = bool(do_resample)
self.do_orient_lps = bool(do_orient_lps)
self.eps = float(eps)
def __call__(
self,
images: Any,
return_tensors: str | None = "pt",
return_metadata: bool = False,
**kwargs: Any,
) -> BatchFeature:
spatial_dims = 2 if self.model_variant == "2d" and np.asarray(images).ndim == 2 else 3
samples = _listify_images(images, spatial_dims=spatial_dims)
processed = []
metadata = []
for sample in samples:
if self.model_variant == "2d":
array, meta = self._process_2d(sample, **kwargs)
else:
array, meta = self._process_3d(sample, **kwargs)
processed.append(array[None])
metadata.append(meta)
batch_array = np.stack(processed, axis=0).astype(np.float32, copy=False)
data: dict[str, Any] = {"pixel_values": batch_array}
if return_tensors == "pt":
data["pixel_values"] = torch.from_numpy(batch_array)
elif return_tensors not in {None, "np"}:
raise ValueError("return_tensors must be 'pt', 'np', or None")
if return_metadata:
data["metadata"] = metadata
return BatchFeature(data=data)
def _process_2d(self, image: Any, slice_index: int | None = None, slice_axis: int = 0, **_: Any):
array, metadata = _as_float_array(image)
metadata["original_shape"] = [int(v) for v in array.shape]
if array.ndim == 3:
if slice_index is None:
slice_index = array.shape[slice_axis] // 2
array = np.take(array, int(slice_index), axis=int(slice_axis))
metadata["slice_index"] = int(slice_index)
metadata["slice_axis"] = int(slice_axis)
if array.ndim != 2:
raise ValueError(f"FlexiCT-2D expects a 2D slice or 3D volume, got shape {array.shape}")
array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps)
side = max(array.shape)
array, pad_before, pad_after = _pad_to_shape(array, (side, side), float(array.min()))
output_size = int(self.image_size)
array = _resize_2d(array, output_size)
metadata.update(stats)
metadata.update(
{
"pad_before_yx": pad_before,
"pad_after_yx": pad_after,
"processed_shape_yx": [output_size, output_size],
}
)
return array, metadata
def _process_3d(
self,
image: Any,
input_spacing: tuple[float, float, float] | None = None,
roi_center: tuple[int, int, int] | None = None,
roi_size: int | tuple[int, int, int] | None = None,
bbox: tuple[int, int, int, int, int, int] | None = None,
mask: Any | None = None,
**_: Any,
):
array, metadata = _as_float_array(image)
if array.ndim != 3:
raise ValueError(f"FlexiCT-3D expects a 3D volume, got shape {array.shape}")
if input_spacing is None and "spacing_xyz" in metadata:
input_spacing = tuple(metadata["spacing_xyz"])
if self.do_resample and input_spacing is not None:
array = _resample_array_zyx(array, input_spacing, tuple(self.target_spacing))
metadata["resampled_shape_zyx"] = [int(v) for v in array.shape]
metadata["original_shape_zyx"] = [int(v) for v in array.shape]
array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps)
metadata.update(stats)
target_shape = tuple(int(v) for v in self.image_size)
if self.preset == "default":
return self._default_3d(array, target_shape, metadata)
if self.preset == "local_path":
return self._local_path_3d(array, target_shape, metadata)
return self._retrieval_roi_3d(array, target_shape, metadata, roi_center, roi_size, bbox, mask)
def _default_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]):
array, pad_before, pad_after = _pad_to_shape(array, target_shape, float(array.min()))
array, crop_start = _center_crop(array, target_shape)
metadata.update(
{
"pad_before_zyx": pad_before,
"pad_after_zyx": pad_after,
"crop_start_zyx": crop_start,
"processed_shape_zyx": [int(v) for v in array.shape],
}
)
return array, metadata
def _local_path_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]):
side = max(int(v) for v in array.shape)
array, pad_before, pad_after = _pad_to_shape(array, (side, side, side), float(array.min()))
metadata.update(
{
"cubic_pad_before_zyx": pad_before,
"cubic_pad_after_zyx": pad_after,
"cubic_padded_shape_zyx": [int(v) for v in array.shape],
}
)
array = _resize_3d(array, target_shape)
metadata.update({"processed_shape_zyx": [int(v) for v in array.shape], "resize_mode": "trilinear"})
return array, metadata
def _retrieval_roi_3d(
self,
array: np.ndarray,
target_shape: tuple[int, int, int],
metadata: dict[str, Any],
roi_center: tuple[int, int, int] | None,
roi_size: int | tuple[int, int, int] | None,
bbox: tuple[int, int, int, int, int, int] | None,
mask: Any | None,
):
if roi_size is None:
roi_size = target_shape
roi_shape = tuple([int(roi_size)] * 3) if isinstance(roi_size, int) else tuple(int(v) for v in roi_size)
if bbox is not None:
z0, y0, x0, z1, y1, x1 = [int(v) for v in bbox]
roi_center = ((z0 + z1) // 2, (y0 + y1) // 2, (x0 + x1) // 2)
elif mask is not None:
mask_array = np.asarray(mask)
coords = np.argwhere(mask_array > 0)
if coords.size == 0:
raise ValueError("mask does not contain any foreground voxels")
roi_center = tuple(int(v) for v in coords.mean(axis=0).round())
elif roi_center is None:
roi_center = tuple(int(v // 2) for v in array.shape)
starts = [int(center) - size // 2 for center, size in zip(roi_center, roi_shape)]
ends = [start + size for start, size in zip(starts, roi_shape)]
src_starts = [max(0, start) for start in starts]
src_ends = [min(dim, end) for dim, end in zip(array.shape, ends)]
crop = array[tuple(slice(start, end) for start, end in zip(src_starts, src_ends))]
pad_before = [src - start for src, start in zip(src_starts, starts)]
pad_after = [end - src for end, src in zip(ends, src_ends)]
crop = np.pad(
crop,
tuple(zip(pad_before, pad_after)),
mode="constant",
constant_values=float(array.min()),
).astype(np.float32, copy=False)
resized = _resize_3d(crop, target_shape)
metadata.update(
{
"roi_center_zyx": [int(v) for v in roi_center],
"roi_crop_start_zyx": src_starts,
"roi_crop_end_zyx": src_ends,
"roi_pad_before_zyx": pad_before,
"roi_pad_after_zyx": pad_after,
"roi_padded_shape_zyx": [int(v) for v in crop.shape],
"processed_shape_zyx": [int(v) for v in resized.shape],
"resize_mode": "trilinear",
}
)
return resized, metadata