Safetensors
tapct
custom_code
File size: 5,607 Bytes
7fb44d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from transformers.image_processing_utils import BaseImageProcessor


class TAPCTProcessor(BaseImageProcessor):
    """
    Image processor for TAP-CT 3D volumes.

    Processes CT volumes with the following pipeline:

    1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims
    2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size
    3. Intensity Clipping: Clip to HU range
    4. Normalization: Z-score normalization

    Parameters
    ----------
    resize_dims : tuple[int, int], default=(224, 224)
        Target spatial dimensions (H, W) for resizing.
    divisible_pad_z : int, default=4
        Pad the z-axis to be divisible by this value.
    clip_range : tuple[float, float], default=(-1008.0, 822.0)
        HU intensity clipping range (min, max).
    norm_mean : float, default=-86.80862426757812
        Mean for z-score normalization.
    norm_std : float, default=322.63470458984375
        Standard deviation for z-score normalization.
    **kwargs
        Additional arguments passed to BaseImageProcessor.
    """

    model_input_names = ["pixel_values"]

    def __init__(
        self,
        resize_dims: tuple[int, int] = (224, 224),
        divisible_pad_z: int = 4,
        clip_range: tuple[float, float] = (-1008.0, 822.0),
        norm_mean: float = -86.80862426757812,
        norm_std: float = 322.63470458984375,
        **kwargs
    ) -> None:
        super().__init__(**kwargs)
        self.resize_dims = resize_dims
        self.divisible_pad_z = divisible_pad_z
        self.clip_range = clip_range
        self.norm_mean = norm_mean
        self.norm_std = norm_std

    def preprocess(
        self,
        images: Union[torch.Tensor, np.ndarray],
        return_tensors: str = "pt",
        **kwargs
    ) -> dict[str, torch.Tensor]:
        """
        Preprocess CT volumes.

        Parameters
        ----------
        images : torch.Tensor or np.ndarray
            Input tensor or numpy array of shape (B, C, D, H, W) where
            B=batch, C=channels, D=depth/slices, H=height, W=width.
        return_tensors : str, default="pt"
            Return format. Only "pt" (PyTorch) is supported.
        **kwargs
            Additional keyword arguments (unused).

        Returns
        -------
        dict[str, torch.Tensor]
            Dictionary with "pixel_values" containing processed tensor of shape
            (B, C, D', H', W') where D' may be padded for divisibility.

        Raises
        ------
        ValueError
            If return_tensors is not "pt" or input is not 5D.
        """
        if return_tensors != "pt":
            raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}")

        # Convert numpy to tensor if needed
        if isinstance(images, np.ndarray):
            images = torch.from_numpy(images)

        # Ensure float32 dtype for processing
        images = images.float()

        # Validate input shape
        if images.ndim != 5:
            raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}")

        B, C, D, H, W = images.shape

        # Step 1: Spatial Resizing - resize H, W dimensions to resize_dims
        target_h, target_w = self.resize_dims
        if H != target_h or W != target_w:
            images = self._resize_spatial(images, target_h, target_w)

        # Step 2: Axial Padding - pad z-axis with -1024 for divisibility
        images = self._pad_axial(images)

        # Step 3: Intensity Clipping - clip to HU range
        images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1])

        # Step 4: Z-score Normalization
        images = (images - self.norm_mean) / self.norm_std

        return {"pixel_values": images}

    def _resize_spatial(
        self,
        images: torch.Tensor,
        target_h: int,
        target_w: int
    ) -> torch.Tensor:
        """
        Resize spatial dimensions (H, W) using trilinear interpolation.

        Parameters
        ----------
        images : torch.Tensor
            Tensor of shape (B, C, D, H, W).
        target_h : int
            Target height.
        target_w : int
            Target width.

        Returns
        -------
        torch.Tensor
            Resized tensor of shape (B, C, D, target_h, target_w).
        """
        D = images.shape[2]

        # Apply trilinear interpolation, keeping depth unchanged
        images = F.interpolate(
            images,
            size=(D, target_h, target_w),
            mode='trilinear',
            align_corners=False
        )

        return images

    def _pad_axial(self, images: torch.Tensor) -> torch.Tensor:
        """
        Pad the axial (z/depth) dimension with -1024 HU for divisibility.

        Parameters
        ----------
        images : torch.Tensor
            Tensor of shape (B, C, D, H, W).

        Returns
        -------
        torch.Tensor
            Padded tensor of shape (B, C, D', H, W) where D' is divisible
            by divisible_pad_z.
        """
        D = images.shape[2]
        remainder = D % self.divisible_pad_z

        if remainder == 0:
            return images

        pad_z = self.divisible_pad_z - remainder

        # F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...)
        # To pad depth at the end: (0, 0, 0, 0, 0, pad_z)
        padding = (0, 0, 0, 0, 0, pad_z)
        images = F.pad(images, padding, mode='constant', value=-1024.0)

        return images