Image Feature Extraction
Transformers
Safetensors
flexict
feature-extraction
medical-imaging
ct
vision
custom_code
Instructions to use ricklisz123/FlexiCT-2D with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use ricklisz123/FlexiCT-2D with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("image-feature-extraction", model="ricklisz123/FlexiCT-2D", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("ricklisz123/FlexiCT-2D", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Image processors for FlexiCT Hugging Face model repos.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from typing import Any | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| from transformers import BatchFeature | |
| from transformers.image_processing_utils import ImageProcessingMixin | |
| def _as_float_array(image: Any) -> tuple[np.ndarray, dict[str, Any]]: | |
| if isinstance(image, (str, Path)): | |
| return _load_medical_image_array(image) | |
| if isinstance(image, torch.Tensor): | |
| image = image.detach().cpu().numpy() | |
| array = np.asarray(image, dtype=np.float32) | |
| return array, {"source": "array"} | |
| def _load_medical_image_array(path: str | Path) -> tuple[np.ndarray, dict[str, Any]]: | |
| try: | |
| import SimpleITK as sitk | |
| except ImportError as exc: # pragma: no cover - runtime dependency branch. | |
| raise RuntimeError("SimpleITK is required to load CT paths with FlexiCTImageProcessor.") from exc | |
| image = sitk.ReadImage(str(Path(path).expanduser())) | |
| image = sitk.DICOMOrient(image, "LPS") | |
| array = sitk.GetArrayFromImage(image).astype(np.float32, copy=False) | |
| metadata = { | |
| "source": str(path), | |
| "spacing_xyz": [float(v) for v in image.GetSpacing()], | |
| "origin_xyz": [float(v) for v in image.GetOrigin()], | |
| "direction": [float(v) for v in image.GetDirection()], | |
| "loaded_shape_zyx": [int(v) for v in array.shape], | |
| } | |
| return array, metadata | |
| def _resample_array_zyx( | |
| array: np.ndarray, | |
| input_spacing_xyz: tuple[float, float, float] | None, | |
| target_spacing_xyz: tuple[float, float, float], | |
| ) -> np.ndarray: | |
| if input_spacing_xyz is None: | |
| return array | |
| spacing_zyx = tuple(float(v) for v in input_spacing_xyz[::-1]) | |
| target_zyx = tuple(float(v) for v in target_spacing_xyz[::-1]) | |
| out_shape = [ | |
| max(1, int(round(size * spacing / target))) | |
| for size, spacing, target in zip(array.shape, spacing_zyx, target_zyx) | |
| ] | |
| tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) | |
| resized = F.interpolate(tensor, size=out_shape, mode="trilinear", align_corners=False) | |
| return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) | |
| def _clip_zscore( | |
| array: np.ndarray, | |
| clip_range: tuple[float, float], | |
| eps: float, | |
| ) -> tuple[np.ndarray, dict[str, float]]: | |
| clipped = np.clip(array.astype(np.float32, copy=False), clip_range[0], clip_range[1]) | |
| mean = float(clipped.mean()) | |
| std = float(clipped.std()) | |
| if std < eps: | |
| std = 1.0 | |
| normalized = (clipped - mean) / std | |
| return normalized.astype(np.float32, copy=False), { | |
| "clip_min": float(clip_range[0]), | |
| "clip_max": float(clip_range[1]), | |
| "mean": mean, | |
| "std": std, | |
| } | |
| def _pad_to_shape( | |
| array: np.ndarray, | |
| target_shape: tuple[int, ...], | |
| fill_value: float, | |
| ) -> tuple[np.ndarray, list[int], list[int]]: | |
| pad_before: list[int] = [] | |
| pad_after: list[int] = [] | |
| pads = [] | |
| for size, target in zip(array.shape, target_shape): | |
| total = max(0, int(target) - int(size)) | |
| before = total // 2 | |
| after = total - before | |
| pad_before.append(before) | |
| pad_after.append(after) | |
| pads.append((before, after)) | |
| if any(before or after for before, after in pads): | |
| array = np.pad(array, pads, mode="constant", constant_values=float(fill_value)) | |
| return array.astype(np.float32, copy=False), pad_before, pad_after | |
| def _center_crop(array: np.ndarray, target_shape: tuple[int, ...]) -> tuple[np.ndarray, list[int]]: | |
| starts = [max(0, (int(size) - int(target)) // 2) for size, target in zip(array.shape, target_shape)] | |
| slices = tuple(slice(start, start + int(target)) for start, target in zip(starts, target_shape)) | |
| return array[slices].astype(np.float32, copy=False), starts | |
| def _resize_2d(array: np.ndarray, output_size: int) -> np.ndarray: | |
| tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) | |
| resized = F.interpolate(tensor, size=(output_size, output_size), mode="bilinear", align_corners=False) | |
| return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) | |
| def _resize_3d(array: np.ndarray, output_size: tuple[int, int, int]) -> np.ndarray: | |
| tensor = torch.from_numpy(array[None, None].astype(np.float32, copy=False)) | |
| resized = F.interpolate(tensor, size=output_size, mode="trilinear", align_corners=False) | |
| return resized[0, 0].cpu().numpy().astype(np.float32, copy=False) | |
| def _listify_images(images: Any, spatial_dims: int) -> list[Any]: | |
| if isinstance(images, (str, Path)): | |
| return [images] | |
| if isinstance(images, torch.Tensor): | |
| ndim = images.dim() | |
| else: | |
| ndim = np.asarray(images).ndim | |
| if ndim == spatial_dims: | |
| return [images] | |
| if ndim == spatial_dims + 1: | |
| return [sample for sample in images] | |
| return list(images) | |
| class FlexiCTImageProcessor(ImageProcessingMixin): | |
| """Preprocess CT arrays or image paths for FlexiCT model variants.""" | |
| model_input_names = ["pixel_values"] | |
| def __init__( | |
| self, | |
| model_variant: str = "3d", | |
| preset: str = "default", | |
| image_size: int | list[int] | tuple[int, ...] | None = None, | |
| clip_range: list[float] | tuple[float, float] = (-1000.0, 1000.0), | |
| target_spacing: list[float] | tuple[float, float, float] = (2.0, 2.0, 2.0), | |
| do_resample: bool = True, | |
| do_orient_lps: bool = True, | |
| eps: float = 1e-6, | |
| **kwargs: Any, | |
| ): | |
| super().__init__(**kwargs) | |
| model_variant = model_variant.lower() | |
| if model_variant not in {"2d", "3d", "vlm"}: | |
| raise ValueError("model_variant must be one of '2d', '3d', or 'vlm'") | |
| if preset not in {"default", "local_path", "retrieval_roi"}: | |
| raise ValueError("preset must be 'default', 'local_path', or 'retrieval_roi'") | |
| self.model_variant = model_variant | |
| self.preset = preset | |
| if image_size is None: | |
| image_size = 512 if model_variant == "2d" else [160, 160, 160] | |
| self.image_size = list(image_size) if isinstance(image_size, (list, tuple)) else int(image_size) | |
| self.clip_range = [float(clip_range[0]), float(clip_range[1])] | |
| self.target_spacing = [float(v) for v in target_spacing] | |
| self.do_resample = bool(do_resample) | |
| self.do_orient_lps = bool(do_orient_lps) | |
| self.eps = float(eps) | |
| def __call__( | |
| self, | |
| images: Any, | |
| return_tensors: str | None = "pt", | |
| return_metadata: bool = False, | |
| **kwargs: Any, | |
| ) -> BatchFeature: | |
| spatial_dims = 2 if self.model_variant == "2d" and np.asarray(images).ndim == 2 else 3 | |
| samples = _listify_images(images, spatial_dims=spatial_dims) | |
| processed = [] | |
| metadata = [] | |
| for sample in samples: | |
| if self.model_variant == "2d": | |
| array, meta = self._process_2d(sample, **kwargs) | |
| else: | |
| array, meta = self._process_3d(sample, **kwargs) | |
| processed.append(array[None]) | |
| metadata.append(meta) | |
| batch_array = np.stack(processed, axis=0).astype(np.float32, copy=False) | |
| data: dict[str, Any] = {"pixel_values": batch_array} | |
| if return_tensors == "pt": | |
| data["pixel_values"] = torch.from_numpy(batch_array) | |
| elif return_tensors not in {None, "np"}: | |
| raise ValueError("return_tensors must be 'pt', 'np', or None") | |
| if return_metadata: | |
| data["metadata"] = metadata | |
| return BatchFeature(data=data) | |
| def _process_2d(self, image: Any, slice_index: int | None = None, slice_axis: int = 0, **_: Any): | |
| array, metadata = _as_float_array(image) | |
| metadata["original_shape"] = [int(v) for v in array.shape] | |
| if array.ndim == 3: | |
| if slice_index is None: | |
| slice_index = array.shape[slice_axis] // 2 | |
| array = np.take(array, int(slice_index), axis=int(slice_axis)) | |
| metadata["slice_index"] = int(slice_index) | |
| metadata["slice_axis"] = int(slice_axis) | |
| if array.ndim != 2: | |
| raise ValueError(f"FlexiCT-2D expects a 2D slice or 3D volume, got shape {array.shape}") | |
| array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps) | |
| side = max(array.shape) | |
| array, pad_before, pad_after = _pad_to_shape(array, (side, side), float(array.min())) | |
| output_size = int(self.image_size) | |
| array = _resize_2d(array, output_size) | |
| metadata.update(stats) | |
| metadata.update( | |
| { | |
| "pad_before_yx": pad_before, | |
| "pad_after_yx": pad_after, | |
| "processed_shape_yx": [output_size, output_size], | |
| } | |
| ) | |
| return array, metadata | |
| def _process_3d( | |
| self, | |
| image: Any, | |
| input_spacing: tuple[float, float, float] | None = None, | |
| roi_center: tuple[int, int, int] | None = None, | |
| roi_size: int | tuple[int, int, int] | None = None, | |
| bbox: tuple[int, int, int, int, int, int] | None = None, | |
| mask: Any | None = None, | |
| **_: Any, | |
| ): | |
| array, metadata = _as_float_array(image) | |
| if array.ndim != 3: | |
| raise ValueError(f"FlexiCT-3D expects a 3D volume, got shape {array.shape}") | |
| if input_spacing is None and "spacing_xyz" in metadata: | |
| input_spacing = tuple(metadata["spacing_xyz"]) | |
| if self.do_resample and input_spacing is not None: | |
| array = _resample_array_zyx(array, input_spacing, tuple(self.target_spacing)) | |
| metadata["resampled_shape_zyx"] = [int(v) for v in array.shape] | |
| metadata["original_shape_zyx"] = [int(v) for v in array.shape] | |
| array, stats = _clip_zscore(array, tuple(self.clip_range), self.eps) | |
| metadata.update(stats) | |
| target_shape = tuple(int(v) for v in self.image_size) | |
| if self.preset == "default": | |
| return self._default_3d(array, target_shape, metadata) | |
| if self.preset == "local_path": | |
| return self._local_path_3d(array, target_shape, metadata) | |
| return self._retrieval_roi_3d(array, target_shape, metadata, roi_center, roi_size, bbox, mask) | |
| def _default_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]): | |
| array, pad_before, pad_after = _pad_to_shape(array, target_shape, float(array.min())) | |
| array, crop_start = _center_crop(array, target_shape) | |
| metadata.update( | |
| { | |
| "pad_before_zyx": pad_before, | |
| "pad_after_zyx": pad_after, | |
| "crop_start_zyx": crop_start, | |
| "processed_shape_zyx": [int(v) for v in array.shape], | |
| } | |
| ) | |
| return array, metadata | |
| def _local_path_3d(self, array: np.ndarray, target_shape: tuple[int, int, int], metadata: dict[str, Any]): | |
| side = max(int(v) for v in array.shape) | |
| array, pad_before, pad_after = _pad_to_shape(array, (side, side, side), float(array.min())) | |
| metadata.update( | |
| { | |
| "cubic_pad_before_zyx": pad_before, | |
| "cubic_pad_after_zyx": pad_after, | |
| "cubic_padded_shape_zyx": [int(v) for v in array.shape], | |
| } | |
| ) | |
| array = _resize_3d(array, target_shape) | |
| metadata.update({"processed_shape_zyx": [int(v) for v in array.shape], "resize_mode": "trilinear"}) | |
| return array, metadata | |
| def _retrieval_roi_3d( | |
| self, | |
| array: np.ndarray, | |
| target_shape: tuple[int, int, int], | |
| metadata: dict[str, Any], | |
| roi_center: tuple[int, int, int] | None, | |
| roi_size: int | tuple[int, int, int] | None, | |
| bbox: tuple[int, int, int, int, int, int] | None, | |
| mask: Any | None, | |
| ): | |
| if roi_size is None: | |
| roi_size = target_shape | |
| roi_shape = tuple([int(roi_size)] * 3) if isinstance(roi_size, int) else tuple(int(v) for v in roi_size) | |
| if bbox is not None: | |
| z0, y0, x0, z1, y1, x1 = [int(v) for v in bbox] | |
| roi_center = ((z0 + z1) // 2, (y0 + y1) // 2, (x0 + x1) // 2) | |
| elif mask is not None: | |
| mask_array = np.asarray(mask) | |
| coords = np.argwhere(mask_array > 0) | |
| if coords.size == 0: | |
| raise ValueError("mask does not contain any foreground voxels") | |
| roi_center = tuple(int(v) for v in coords.mean(axis=0).round()) | |
| elif roi_center is None: | |
| roi_center = tuple(int(v // 2) for v in array.shape) | |
| starts = [int(center) - size // 2 for center, size in zip(roi_center, roi_shape)] | |
| ends = [start + size for start, size in zip(starts, roi_shape)] | |
| src_starts = [max(0, start) for start in starts] | |
| src_ends = [min(dim, end) for dim, end in zip(array.shape, ends)] | |
| crop = array[tuple(slice(start, end) for start, end in zip(src_starts, src_ends))] | |
| pad_before = [src - start for src, start in zip(src_starts, starts)] | |
| pad_after = [end - src for end, src in zip(ends, src_ends)] | |
| crop = np.pad( | |
| crop, | |
| tuple(zip(pad_before, pad_after)), | |
| mode="constant", | |
| constant_values=float(array.min()), | |
| ).astype(np.float32, copy=False) | |
| resized = _resize_3d(crop, target_shape) | |
| metadata.update( | |
| { | |
| "roi_center_zyx": [int(v) for v in roi_center], | |
| "roi_crop_start_zyx": src_starts, | |
| "roi_crop_end_zyx": src_ends, | |
| "roi_pad_before_zyx": pad_before, | |
| "roi_pad_after_zyx": pad_after, | |
| "roi_padded_shape_zyx": [int(v) for v in crop.shape], | |
| "processed_shape_zyx": [int(v) for v in resized.shape], | |
| "resize_mode": "trilinear", | |
| } | |
| ) | |
| return resized, metadata | |