File size: 5,607 Bytes
9be891b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
from typing import Union
import numpy as np
import torch
import torch.nn.functional as F
from transformers.image_processing_utils import BaseImageProcessor
class TAPCTProcessor(BaseImageProcessor):
"""
Image processor for TAP-CT 3D volumes.
Processes CT volumes with the following pipeline:
1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims
2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size
3. Intensity Clipping: Clip to HU range
4. Normalization: Z-score normalization
Parameters
----------
resize_dims : tuple[int, int], default=(224, 224)
Target spatial dimensions (H, W) for resizing.
divisible_pad_z : int, default=4
Pad the z-axis to be divisible by this value.
clip_range : tuple[float, float], default=(-1008.0, 822.0)
HU intensity clipping range (min, max).
norm_mean : float, default=-86.80862426757812
Mean for z-score normalization.
norm_std : float, default=322.63470458984375
Standard deviation for z-score normalization.
**kwargs
Additional arguments passed to BaseImageProcessor.
"""
model_input_names = ["pixel_values"]
def __init__(
self,
resize_dims: tuple[int, int] = (224, 224),
divisible_pad_z: int = 4,
clip_range: tuple[float, float] = (-1008.0, 822.0),
norm_mean: float = -86.80862426757812,
norm_std: float = 322.63470458984375,
**kwargs
) -> None:
super().__init__(**kwargs)
self.resize_dims = resize_dims
self.divisible_pad_z = divisible_pad_z
self.clip_range = clip_range
self.norm_mean = norm_mean
self.norm_std = norm_std
def preprocess(
self,
images: Union[torch.Tensor, np.ndarray],
return_tensors: str = "pt",
**kwargs
) -> dict[str, torch.Tensor]:
"""
Preprocess CT volumes.
Parameters
----------
images : torch.Tensor or np.ndarray
Input tensor or numpy array of shape (B, C, D, H, W) where
B=batch, C=channels, D=depth/slices, H=height, W=width.
return_tensors : str, default="pt"
Return format. Only "pt" (PyTorch) is supported.
**kwargs
Additional keyword arguments (unused).
Returns
-------
dict[str, torch.Tensor]
Dictionary with "pixel_values" containing processed tensor of shape
(B, C, D', H', W') where D' may be padded for divisibility.
Raises
------
ValueError
If return_tensors is not "pt" or input is not 5D.
"""
if return_tensors != "pt":
raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}")
# Convert numpy to tensor if needed
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
# Ensure float32 dtype for processing
images = images.float()
# Validate input shape
if images.ndim != 5:
raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}")
B, C, D, H, W = images.shape
# Step 1: Spatial Resizing - resize H, W dimensions to resize_dims
target_h, target_w = self.resize_dims
if H != target_h or W != target_w:
images = self._resize_spatial(images, target_h, target_w)
# Step 2: Axial Padding - pad z-axis with -1024 for divisibility
images = self._pad_axial(images)
# Step 3: Intensity Clipping - clip to HU range
images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1])
# Step 4: Z-score Normalization
images = (images - self.norm_mean) / self.norm_std
return {"pixel_values": images}
def _resize_spatial(
self,
images: torch.Tensor,
target_h: int,
target_w: int
) -> torch.Tensor:
"""
Resize spatial dimensions (H, W) using trilinear interpolation.
Parameters
----------
images : torch.Tensor
Tensor of shape (B, C, D, H, W).
target_h : int
Target height.
target_w : int
Target width.
Returns
-------
torch.Tensor
Resized tensor of shape (B, C, D, target_h, target_w).
"""
D = images.shape[2]
# Apply trilinear interpolation, keeping depth unchanged
images = F.interpolate(
images,
size=(D, target_h, target_w),
mode='trilinear',
align_corners=False
)
return images
def _pad_axial(self, images: torch.Tensor) -> torch.Tensor:
"""
Pad the axial (z/depth) dimension with -1024 HU for divisibility.
Parameters
----------
images : torch.Tensor
Tensor of shape (B, C, D, H, W).
Returns
-------
torch.Tensor
Padded tensor of shape (B, C, D', H, W) where D' is divisible
by divisible_pad_z.
"""
D = images.shape[2]
remainder = D % self.divisible_pad_z
if remainder == 0:
return images
pad_z = self.divisible_pad_z - remainder
# F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...)
# To pad depth at the end: (0, 0, 0, 0, 0, pad_z)
padding = (0, 0, 0, 0, 0, pad_z)
images = F.pad(images, padding, mode='constant', value=-1024.0)
return images
|