Safetensors
tapct
custom_code
TimVeenboer commited on
Commit
9be891b
·
1 Parent(s): 8650a91

feat(tap-hf): image processor

Browse files
Files changed (2) hide show
  1. preprocessor_config.json +12 -0
  2. tapct_processor.py +179 -0
preprocessor_config.json ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "image_processor_type": "TAPCTProcessor",
3
+ "resize_dims": [224, 224],
4
+ "divisible_pad_z": 1,
5
+ "clip_range": [-1008.0, 822.0],
6
+ "norm_mean": -86.80862426757812,
7
+ "norm_std": 322.63470458984375,
8
+ "auto_map": {
9
+ "AutoImageProcessor": "tapct_processor.TAPCTProcessor"
10
+ }
11
+ }
12
+
tapct_processor.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from transformers.image_processing_utils import BaseImageProcessor
7
+
8
+
9
+ class TAPCTProcessor(BaseImageProcessor):
10
+ """
11
+ Image processor for TAP-CT 3D volumes.
12
+
13
+ Processes CT volumes with the following pipeline:
14
+
15
+ 1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims
16
+ 2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size
17
+ 3. Intensity Clipping: Clip to HU range
18
+ 4. Normalization: Z-score normalization
19
+
20
+ Parameters
21
+ ----------
22
+ resize_dims : tuple[int, int], default=(224, 224)
23
+ Target spatial dimensions (H, W) for resizing.
24
+ divisible_pad_z : int, default=4
25
+ Pad the z-axis to be divisible by this value.
26
+ clip_range : tuple[float, float], default=(-1008.0, 822.0)
27
+ HU intensity clipping range (min, max).
28
+ norm_mean : float, default=-86.80862426757812
29
+ Mean for z-score normalization.
30
+ norm_std : float, default=322.63470458984375
31
+ Standard deviation for z-score normalization.
32
+ **kwargs
33
+ Additional arguments passed to BaseImageProcessor.
34
+ """
35
+
36
+ model_input_names = ["pixel_values"]
37
+
38
+ def __init__(
39
+ self,
40
+ resize_dims: tuple[int, int] = (224, 224),
41
+ divisible_pad_z: int = 4,
42
+ clip_range: tuple[float, float] = (-1008.0, 822.0),
43
+ norm_mean: float = -86.80862426757812,
44
+ norm_std: float = 322.63470458984375,
45
+ **kwargs
46
+ ) -> None:
47
+ super().__init__(**kwargs)
48
+ self.resize_dims = resize_dims
49
+ self.divisible_pad_z = divisible_pad_z
50
+ self.clip_range = clip_range
51
+ self.norm_mean = norm_mean
52
+ self.norm_std = norm_std
53
+
54
+ def preprocess(
55
+ self,
56
+ images: Union[torch.Tensor, np.ndarray],
57
+ return_tensors: str = "pt",
58
+ **kwargs
59
+ ) -> dict[str, torch.Tensor]:
60
+ """
61
+ Preprocess CT volumes.
62
+
63
+ Parameters
64
+ ----------
65
+ images : torch.Tensor or np.ndarray
66
+ Input tensor or numpy array of shape (B, C, D, H, W) where
67
+ B=batch, C=channels, D=depth/slices, H=height, W=width.
68
+ return_tensors : str, default="pt"
69
+ Return format. Only "pt" (PyTorch) is supported.
70
+ **kwargs
71
+ Additional keyword arguments (unused).
72
+
73
+ Returns
74
+ -------
75
+ dict[str, torch.Tensor]
76
+ Dictionary with "pixel_values" containing processed tensor of shape
77
+ (B, C, D', H', W') where D' may be padded for divisibility.
78
+
79
+ Raises
80
+ ------
81
+ ValueError
82
+ If return_tensors is not "pt" or input is not 5D.
83
+ """
84
+ if return_tensors != "pt":
85
+ raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}")
86
+
87
+ # Convert numpy to tensor if needed
88
+ if isinstance(images, np.ndarray):
89
+ images = torch.from_numpy(images)
90
+
91
+ # Ensure float32 dtype for processing
92
+ images = images.float()
93
+
94
+ # Validate input shape
95
+ if images.ndim != 5:
96
+ raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}")
97
+
98
+ B, C, D, H, W = images.shape
99
+
100
+ # Step 1: Spatial Resizing - resize H, W dimensions to resize_dims
101
+ target_h, target_w = self.resize_dims
102
+ if H != target_h or W != target_w:
103
+ images = self._resize_spatial(images, target_h, target_w)
104
+
105
+ # Step 2: Axial Padding - pad z-axis with -1024 for divisibility
106
+ images = self._pad_axial(images)
107
+
108
+ # Step 3: Intensity Clipping - clip to HU range
109
+ images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1])
110
+
111
+ # Step 4: Z-score Normalization
112
+ images = (images - self.norm_mean) / self.norm_std
113
+
114
+ return {"pixel_values": images}
115
+
116
+ def _resize_spatial(
117
+ self,
118
+ images: torch.Tensor,
119
+ target_h: int,
120
+ target_w: int
121
+ ) -> torch.Tensor:
122
+ """
123
+ Resize spatial dimensions (H, W) using trilinear interpolation.
124
+
125
+ Parameters
126
+ ----------
127
+ images : torch.Tensor
128
+ Tensor of shape (B, C, D, H, W).
129
+ target_h : int
130
+ Target height.
131
+ target_w : int
132
+ Target width.
133
+
134
+ Returns
135
+ -------
136
+ torch.Tensor
137
+ Resized tensor of shape (B, C, D, target_h, target_w).
138
+ """
139
+ D = images.shape[2]
140
+
141
+ # Apply trilinear interpolation, keeping depth unchanged
142
+ images = F.interpolate(
143
+ images,
144
+ size=(D, target_h, target_w),
145
+ mode='trilinear',
146
+ align_corners=False
147
+ )
148
+
149
+ return images
150
+
151
+ def _pad_axial(self, images: torch.Tensor) -> torch.Tensor:
152
+ """
153
+ Pad the axial (z/depth) dimension with -1024 HU for divisibility.
154
+
155
+ Parameters
156
+ ----------
157
+ images : torch.Tensor
158
+ Tensor of shape (B, C, D, H, W).
159
+
160
+ Returns
161
+ -------
162
+ torch.Tensor
163
+ Padded tensor of shape (B, C, D', H, W) where D' is divisible
164
+ by divisible_pad_z.
165
+ """
166
+ D = images.shape[2]
167
+ remainder = D % self.divisible_pad_z
168
+
169
+ if remainder == 0:
170
+ return images
171
+
172
+ pad_z = self.divisible_pad_z - remainder
173
+
174
+ # F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...)
175
+ # To pad depth at the end: (0, 0, 0, 0, 0, pad_z)
176
+ padding = (0, 0, 0, 0, 0, pad_z)
177
+ images = F.pad(images, padding, mode='constant', value=-1024.0)
178
+
179
+ return images