Safetensors
tapct
custom_code
tap-ct-s-3d / tapct_processor.py
TimVeenboer
feat(tap-hf): Add preprocessor
7fb44d5
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers.image_processing_utils import BaseImageProcessor
class TAPCTProcessor(BaseImageProcessor):
"""
Image processor for TAP-CT 3D volumes.
Processes CT volumes with the following pipeline:
1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims
2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size
3. Intensity Clipping: Clip to HU range
4. Normalization: Z-score normalization
Parameters
----------
resize_dims : tuple[int, int], default=(224, 224)
Target spatial dimensions (H, W) for resizing.
divisible_pad_z : int, default=4
Pad the z-axis to be divisible by this value.
clip_range : tuple[float, float], default=(-1008.0, 822.0)
HU intensity clipping range (min, max).
norm_mean : float, default=-86.80862426757812
Mean for z-score normalization.
norm_std : float, default=322.63470458984375
Standard deviation for z-score normalization.
**kwargs
Additional arguments passed to BaseImageProcessor.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
resize_dims: tuple[int, int] = (224, 224),
divisible_pad_z: int = 4,
clip_range: tuple[float, float] = (-1008.0, 822.0),
norm_mean: float = -86.80862426757812,
norm_std: float = 322.63470458984375,
**kwargs
) -> None:
super().__init__(**kwargs)
self.resize_dims = resize_dims
self.divisible_pad_z = divisible_pad_z
self.clip_range = clip_range
self.norm_mean = norm_mean
self.norm_std = norm_std
def preprocess(
self,
images: Union[torch.Tensor, np.ndarray],
return_tensors: str = "pt",
**kwargs
) -> dict[str, torch.Tensor]:
"""
Preprocess CT volumes.
Parameters
----------
images : torch.Tensor or np.ndarray
Input tensor or numpy array of shape (B, C, D, H, W) where
B=batch, C=channels, D=depth/slices, H=height, W=width.
return_tensors : str, default="pt"
Return format. Only "pt" (PyTorch) is supported.
**kwargs
Additional keyword arguments (unused).
Returns
-------
dict[str, torch.Tensor]
Dictionary with "pixel_values" containing processed tensor of shape
(B, C, D', H', W') where D' may be padded for divisibility.
Raises
------
ValueError
If return_tensors is not "pt" or input is not 5D.
"""
if return_tensors != "pt":
raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}")
# Convert numpy to tensor if needed
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
# Ensure float32 dtype for processing
images = images.float()
# Validate input shape
if images.ndim != 5:
raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}")
B, C, D, H, W = images.shape
# Step 1: Spatial Resizing - resize H, W dimensions to resize_dims
target_h, target_w = self.resize_dims
if H != target_h or W != target_w:
images = self._resize_spatial(images, target_h, target_w)
# Step 2: Axial Padding - pad z-axis with -1024 for divisibility
images = self._pad_axial(images)
# Step 3: Intensity Clipping - clip to HU range
images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1])
# Step 4: Z-score Normalization
images = (images - self.norm_mean) / self.norm_std
return {"pixel_values": images}
def _resize_spatial(
self,
images: torch.Tensor,
target_h: int,
target_w: int
) -> torch.Tensor:
"""
Resize spatial dimensions (H, W) using trilinear interpolation.
Parameters
----------
images : torch.Tensor
Tensor of shape (B, C, D, H, W).
target_h : int
Target height.
target_w : int
Target width.
Returns
-------
torch.Tensor
Resized tensor of shape (B, C, D, target_h, target_w).
"""
D = images.shape[2]
# Apply trilinear interpolation, keeping depth unchanged
images = F.interpolate(
images,
size=(D, target_h, target_w),
mode='trilinear',
align_corners=False
)
return images
def _pad_axial(self, images: torch.Tensor) -> torch.Tensor:
"""
Pad the axial (z/depth) dimension with -1024 HU for divisibility.
Parameters
----------
images : torch.Tensor
Tensor of shape (B, C, D, H, W).
Returns
-------
torch.Tensor
Padded tensor of shape (B, C, D', H, W) where D' is divisible
by divisible_pad_z.
"""
D = images.shape[2]
remainder = D % self.divisible_pad_z
if remainder == 0:
return images
pad_z = self.divisible_pad_z - remainder
# F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...)
# To pad depth at the end: (0, 0, 0, 0, 0, pad_z)
padding = (0, 0, 0, 0, 0, pad_z)
images = F.pad(images, padding, mode='constant', value=-1024.0)
return images