|
|
from typing import Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from transformers.image_processing_utils import BaseImageProcessor |
|
|
|
|
|
|
|
|
class TAPCTProcessor(BaseImageProcessor): |
|
|
""" |
|
|
Image processor for TAP-CT 3D volumes. |
|
|
|
|
|
Processes CT volumes with the following pipeline: |
|
|
|
|
|
1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims |
|
|
2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size |
|
|
3. Intensity Clipping: Clip to HU range |
|
|
4. Normalization: Z-score normalization |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
resize_dims : tuple[int, int], default=(224, 224) |
|
|
Target spatial dimensions (H, W) for resizing. |
|
|
divisible_pad_z : int, default=4 |
|
|
Pad the z-axis to be divisible by this value. |
|
|
clip_range : tuple[float, float], default=(-1008.0, 822.0) |
|
|
HU intensity clipping range (min, max). |
|
|
norm_mean : float, default=-86.80862426757812 |
|
|
Mean for z-score normalization. |
|
|
norm_std : float, default=322.63470458984375 |
|
|
Standard deviation for z-score normalization. |
|
|
**kwargs |
|
|
Additional arguments passed to BaseImageProcessor. |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
resize_dims: tuple[int, int] = (224, 224), |
|
|
divisible_pad_z: int = 4, |
|
|
clip_range: tuple[float, float] = (-1008.0, 822.0), |
|
|
norm_mean: float = -86.80862426757812, |
|
|
norm_std: float = 322.63470458984375, |
|
|
**kwargs |
|
|
) -> None: |
|
|
super().__init__(**kwargs) |
|
|
self.resize_dims = resize_dims |
|
|
self.divisible_pad_z = divisible_pad_z |
|
|
self.clip_range = clip_range |
|
|
self.norm_mean = norm_mean |
|
|
self.norm_std = norm_std |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: Union[torch.Tensor, np.ndarray], |
|
|
return_tensors: str = "pt", |
|
|
**kwargs |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Preprocess CT volumes. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
images : torch.Tensor or np.ndarray |
|
|
Input tensor or numpy array of shape (B, C, D, H, W) where |
|
|
B=batch, C=channels, D=depth/slices, H=height, W=width. |
|
|
return_tensors : str, default="pt" |
|
|
Return format. Only "pt" (PyTorch) is supported. |
|
|
**kwargs |
|
|
Additional keyword arguments (unused). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
dict[str, torch.Tensor] |
|
|
Dictionary with "pixel_values" containing processed tensor of shape |
|
|
(B, C, D', H', W') where D' may be padded for divisibility. |
|
|
|
|
|
Raises |
|
|
------ |
|
|
ValueError |
|
|
If return_tensors is not "pt" or input is not 5D. |
|
|
""" |
|
|
if return_tensors != "pt": |
|
|
raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}") |
|
|
|
|
|
|
|
|
if isinstance(images, np.ndarray): |
|
|
images = torch.from_numpy(images) |
|
|
|
|
|
|
|
|
images = images.float() |
|
|
|
|
|
|
|
|
if images.ndim != 5: |
|
|
raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}") |
|
|
|
|
|
B, C, D, H, W = images.shape |
|
|
|
|
|
|
|
|
target_h, target_w = self.resize_dims |
|
|
if H != target_h or W != target_w: |
|
|
images = self._resize_spatial(images, target_h, target_w) |
|
|
|
|
|
|
|
|
images = self._pad_axial(images) |
|
|
|
|
|
|
|
|
images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1]) |
|
|
|
|
|
|
|
|
images = (images - self.norm_mean) / self.norm_std |
|
|
|
|
|
return {"pixel_values": images} |
|
|
|
|
|
def _resize_spatial( |
|
|
self, |
|
|
images: torch.Tensor, |
|
|
target_h: int, |
|
|
target_w: int |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Resize spatial dimensions (H, W) using trilinear interpolation. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
images : torch.Tensor |
|
|
Tensor of shape (B, C, D, H, W). |
|
|
target_h : int |
|
|
Target height. |
|
|
target_w : int |
|
|
Target width. |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Resized tensor of shape (B, C, D, target_h, target_w). |
|
|
""" |
|
|
D = images.shape[2] |
|
|
|
|
|
|
|
|
images = F.interpolate( |
|
|
images, |
|
|
size=(D, target_h, target_w), |
|
|
mode='trilinear', |
|
|
align_corners=False |
|
|
) |
|
|
|
|
|
return images |
|
|
|
|
|
def _pad_axial(self, images: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Pad the axial (z/depth) dimension with -1024 HU for divisibility. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
images : torch.Tensor |
|
|
Tensor of shape (B, C, D, H, W). |
|
|
|
|
|
Returns |
|
|
------- |
|
|
torch.Tensor |
|
|
Padded tensor of shape (B, C, D', H, W) where D' is divisible |
|
|
by divisible_pad_z. |
|
|
""" |
|
|
D = images.shape[2] |
|
|
remainder = D % self.divisible_pad_z |
|
|
|
|
|
if remainder == 0: |
|
|
return images |
|
|
|
|
|
pad_z = self.divisible_pad_z - remainder |
|
|
|
|
|
|
|
|
|
|
|
padding = (0, 0, 0, 0, 0, pad_z) |
|
|
images = F.pad(images, padding, mode='constant', value=-1024.0) |
|
|
|
|
|
return images |
|
|
|