from typing import Union import numpy as np import torch import torch.nn.functional as F from transformers.image_processing_utils import BaseImageProcessor class TAPCTProcessor(BaseImageProcessor): """ Image processor for TAP-CT 3D volumes. Processes CT volumes with the following pipeline: 1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims 2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size 3. Intensity Clipping: Clip to HU range 4. Normalization: Z-score normalization Parameters ---------- resize_dims : tuple[int, int], default=(224, 224) Target spatial dimensions (H, W) for resizing. divisible_pad_z : int, default=4 Pad the z-axis to be divisible by this value. clip_range : tuple[float, float], default=(-1008.0, 822.0) HU intensity clipping range (min, max). norm_mean : float, default=-86.80862426757812 Mean for z-score normalization. norm_std : float, default=322.63470458984375 Standard deviation for z-score normalization. **kwargs Additional arguments passed to BaseImageProcessor. """ model_input_names = ["pixel_values"] def __init__( self, resize_dims: tuple[int, int] = (224, 224), divisible_pad_z: int = 4, clip_range: tuple[float, float] = (-1008.0, 822.0), norm_mean: float = -86.80862426757812, norm_std: float = 322.63470458984375, **kwargs ) -> None: super().__init__(**kwargs) self.resize_dims = resize_dims self.divisible_pad_z = divisible_pad_z self.clip_range = clip_range self.norm_mean = norm_mean self.norm_std = norm_std def preprocess( self, images: Union[torch.Tensor, np.ndarray], return_tensors: str = "pt", **kwargs ) -> dict[str, torch.Tensor]: """ Preprocess CT volumes. Parameters ---------- images : torch.Tensor or np.ndarray Input tensor or numpy array of shape (B, C, D, H, W) where B=batch, C=channels, D=depth/slices, H=height, W=width. return_tensors : str, default="pt" Return format. Only "pt" (PyTorch) is supported. **kwargs Additional keyword arguments (unused). Returns ------- dict[str, torch.Tensor] Dictionary with "pixel_values" containing processed tensor of shape (B, C, D', H', W') where D' may be padded for divisibility. Raises ------ ValueError If return_tensors is not "pt" or input is not 5D. """ if return_tensors != "pt": raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}") # Convert numpy to tensor if needed if isinstance(images, np.ndarray): images = torch.from_numpy(images) # Ensure float32 dtype for processing images = images.float() # Validate input shape if images.ndim != 5: raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}") B, C, D, H, W = images.shape # Step 1: Spatial Resizing - resize H, W dimensions to resize_dims target_h, target_w = self.resize_dims if H != target_h or W != target_w: images = self._resize_spatial(images, target_h, target_w) # Step 2: Axial Padding - pad z-axis with -1024 for divisibility images = self._pad_axial(images) # Step 3: Intensity Clipping - clip to HU range images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1]) # Step 4: Z-score Normalization images = (images - self.norm_mean) / self.norm_std return {"pixel_values": images} def _resize_spatial( self, images: torch.Tensor, target_h: int, target_w: int ) -> torch.Tensor: """ Resize spatial dimensions (H, W) using trilinear interpolation. Parameters ---------- images : torch.Tensor Tensor of shape (B, C, D, H, W). target_h : int Target height. target_w : int Target width. Returns ------- torch.Tensor Resized tensor of shape (B, C, D, target_h, target_w). """ D = images.shape[2] # Apply trilinear interpolation, keeping depth unchanged images = F.interpolate( images, size=(D, target_h, target_w), mode='trilinear', align_corners=False ) return images def _pad_axial(self, images: torch.Tensor) -> torch.Tensor: """ Pad the axial (z/depth) dimension with -1024 HU for divisibility. Parameters ---------- images : torch.Tensor Tensor of shape (B, C, D, H, W). Returns ------- torch.Tensor Padded tensor of shape (B, C, D', H, W) where D' is divisible by divisible_pad_z. """ D = images.shape[2] remainder = D % self.divisible_pad_z if remainder == 0: return images pad_z = self.divisible_pad_z - remainder # F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...) # To pad depth at the end: (0, 0, 0, 0, 0, pad_z) padding = (0, 0, 0, 0, 0, pad_z) images = F.pad(images, padding, mode='constant', value=-1024.0) return images