Safetensors
tapct
custom_code
Files changed (4) hide show
  1. README.md +5 -91
  2. modeling_tapct.py +5 -10
  3. preprocessor_config.json +0 -13
  4. tapct_processor.py +0 -179
README.md CHANGED
@@ -3,41 +3,23 @@ license: cc-by-nc-4.0
3
  ---
4
 
5
  # TAP-CT: 3D Task-Agnostic Pretraining of CT Foundation Models
6
- [![arXiv](https://img.shields.io/badge/arXiv-TAP--CT-b31b1b.svg)](https://arxiv.org/abs/2512.00872)
7
 
8
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
9
 
10
- This repository provides TAP-CT-B-3D, a Vision Transformer (ViT-Base) architecture pretrained on volumetric inputs with a spatial resolution of (12, 224, 224) and a patch size of (4, 8, 8). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan.
11
 
12
  ## Preprocessing
13
 
14
- ### Using dedicated image processor
15
-
16
- Each TAP-CT model repository provides its own dedicated image processor and configuration file. To ensure proper preprocessing, it is recommended to instantiate the corresponding image processor using the `AutoImageProcessor` class from Hugging Face Transformers. This can be accomplished as follows:
17
-
18
- ```python
19
- from transformers import AutoImageProcessor
20
-
21
- preprocessor = AutoImageProcessor.from_pretrained(
22
- 'fomofo/tap-ct-b-3d',
23
- trust_remote_code=True
24
- )
25
- ```
26
-
27
- This approach automatically loads the appropriate processor and configuration for the selected TAP-CT model.
28
-
29
- ### Preprocessing without pipeline
30
 
31
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
32
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
33
- 3. **Axial Padding**: Apply -1024 padding along the \(z\)-axis to ensure divisibility by 4, accommodating the model's patch size of (4, 8, 8).
34
  4. **Intensity Clipping**: Clip voxel intensities to the range \([-1008, 822]\) HU (Hounsfield Units).
35
  5. **Normalization**: Apply z-score normalization using \(mean = -86.8086\) and \(std = 322.6347\).
36
 
37
  ## Usage
38
 
39
- ### Default Usage
40
-
41
  ```python
42
  import torch
43
  from transformers import AutoModel
@@ -49,62 +31,7 @@ model = AutoModel.from_pretrained('fomofo/tap-ct-b-3d', trust_remote_code=True)
49
  x = torch.randn((16, 1, 12, 224, 224))
50
 
51
  # Forward pass
52
- with torch.no_grad():
53
- output = model.forward(x)
54
- ```
55
-
56
- ### Usage with Preprocessor, loading CT volumes & sliding window inference
57
-
58
- **Recommended environment:**
59
- - Python >= 3.11
60
- - torch >= 2.8
61
- - numpy >= 2.35
62
- - SimpleITK >= 2.52
63
- - monai >= 1.4.0
64
- - xformers >= 0.0.32 (optional, recommended for CUDA)
65
-
66
- ```python
67
- import numpy as np
68
- import SimpleITK as sitk
69
- import torch
70
- from transformers import AutoModel, AutoImageProcessor
71
-
72
- # Load the model
73
- model = AutoModel.from_pretrained('fomofo/tap-ct-b-3d', trust_remote_code=True)
74
- preprocessor = AutoImageProcessor.from_pretrained('fomofo/tap-ct-b-3d', trust_remote_code=True)
75
-
76
- # Load image & set orientation to LPS
77
- volume = sitk.ReadImage('/path/to/ct-scan.nii.gz')
78
- volume = sitk.DICOMOrient(volume, 'LPS')
79
-
80
- # Get array, expand to (B, C, D, H, W) and preprocess
81
- array = sitk.GetArrayFromImage(volume)
82
- array = np.expand_dims(array, axis=(0, 1))
83
- x = preprocessor(array)['pixel_values']
84
-
85
- # Forward pass
86
- with torch.no_grad():
87
- output = model.forward(x)
88
-
89
- # OR
90
-
91
- # Forward pass with sliding window
92
- from monai.inferers import SlidingWindowInferer
93
-
94
- def predictor_fn(x):
95
- # Reshape the patch tokens to resemble a 3D feature map
96
- out = model(x, reshape=True)
97
- return out.last_hidden_state
98
-
99
- inferer = SlidingWindowInferer(
100
- roi_size=[12, 224, 224],
101
- sw_batch_size=1,
102
- overlap=0.75,
103
- mode='gaussian'
104
- )
105
-
106
- with torch.no_grad():
107
- output = inferer(x, predictor_fn)
108
  ```
109
 
110
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
@@ -114,17 +41,4 @@ The model returns a `BaseModelOutputWithPooling` object from the transformers li
114
  - **Model Type**: 3D CT Vision Foundation Model
115
  - **Input Shape**: `(batch_size, 1, depth, height, width)`
116
  - **Example Input**: `(16, 1, 12, 224, 224)` - batch of 16 CT crops with 12 slices at 224×224 resolution
117
- - **License**: CC-BY-NC-4.0
118
-
119
- ## Citation
120
-
121
- If you find this work useful, please cite:
122
-
123
- ```bibtex
124
- @article{veenboer2025tapct,
125
- title={TAP-CT: 3D Task-Agnostic Pretraining of Computed Tomography Foundation Models},
126
- author={Veenboer, Tim and Yiasemis, George and Marcus, Eric and Van Veldhuizen, Vivien and Snoek, Cees G. M. and Teuwen, Jonas and Groot Lipman, Kevin B. W.},
127
- journal={arXiv preprint arXiv:2512.00872},
128
- year={2025}
129
- }
130
- ```
 
3
  ---
4
 
5
  # TAP-CT: 3D Task-Agnostic Pretraining of CT Foundation Models
 
6
 
7
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
8
 
9
+ This repository provides TAP-CT-B-3D, a Vision Transformer (ViT-Base) architecture pretrained on volumetric inputs with a spatial resolution of (12, 224, 224) and a patch size of (4, 8, 8). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan. Additional TAP-CT model variants, as well as the image processor, will be released in future updates.
10
 
11
  ## Preprocessing
12
 
13
+ While a dedicated image processor will be released in future updates, optimal feature extraction requires the following preprocessing pipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
16
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
17
+ 3. **Axial Padding**: Apply zero-padding along the \(z\)-axis to ensure divisibility by 4, accommodating the model's patch size of (4, 8, 8).
18
  4. **Intensity Clipping**: Clip voxel intensities to the range \([-1008, 822]\) HU (Hounsfield Units).
19
  5. **Normalization**: Apply z-score normalization using \(mean = -86.8086\) and \(std = 322.6347\).
20
 
21
  ## Usage
22
 
 
 
23
  ```python
24
  import torch
25
  from transformers import AutoModel
 
31
  x = torch.randn((16, 1, 12, 224, 224))
32
 
33
  # Forward pass
34
+ output = model.forward(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ```
36
 
37
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
 
41
  - **Model Type**: 3D CT Vision Foundation Model
42
  - **Input Shape**: `(batch_size, 1, depth, height, width)`
43
  - **Example Input**: `(16, 1, 12, 224, 224)` - batch of 16 CT crops with 12 slices at 224×224 resolution
44
+ - **License**: CC-BY-NC-4.0
 
 
 
 
 
 
 
 
 
 
 
 
 
modeling_tapct.py CHANGED
@@ -94,7 +94,6 @@ class TAPCTModel(TAPCTPreTrainedModel):
94
  pixel_values: torch.Tensor,
95
  output_hidden_states: Optional[bool] = None,
96
  return_dict: Optional[bool] = None,
97
- reshape: bool = False
98
  ) -> BaseModelOutputWithPooling:
99
  """
100
  Forward pass of the TAP-CT model.
@@ -102,15 +101,11 @@ class TAPCTModel(TAPCTPreTrainedModel):
102
  Parameters
103
  ----------
104
  pixel_values : torch.Tensor
105
- Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D.
106
  output_hidden_states : Optional[bool], optional
107
- Whether to return hidden states from all layers.
108
  return_dict : Optional[bool], optional
109
- Whether to return a ModelOutput instead of a plain tuple.
110
- reshape : bool, default=False
111
- Whether to reshape output features to spatial dimensions. If True,
112
- returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead
113
- of flattened (B, N, C) where N is the number of patches.
114
 
115
  Returns
116
  -------
@@ -128,7 +123,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
128
  pixel_values,
129
  n=self.model.n_blocks,
130
  return_class_token=True,
131
- reshape=reshape
132
  )
133
  outputs = tuple(o[0] for o in outputs_tuple)
134
  class_tokens = tuple(o[1] for o in outputs_tuple)
@@ -141,7 +136,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
141
  pixel_values,
142
  n=1,
143
  return_class_token=True,
144
- reshape=reshape
145
  )
146
  last_hidden_state = outputs_tuple[0][0]
147
  pooler_output = outputs_tuple[0][1]
 
94
  pixel_values: torch.Tensor,
95
  output_hidden_states: Optional[bool] = None,
96
  return_dict: Optional[bool] = None,
 
97
  ) -> BaseModelOutputWithPooling:
98
  """
99
  Forward pass of the TAP-CT model.
 
101
  Parameters
102
  ----------
103
  pixel_values : torch.Tensor
104
+ Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D
105
  output_hidden_states : Optional[bool], optional
106
+ Whether to return hidden states from all layers
107
  return_dict : Optional[bool], optional
108
+ Whether to return a ModelOutput instead of a plain tuple
 
 
 
 
109
 
110
  Returns
111
  -------
 
123
  pixel_values,
124
  n=self.model.n_blocks,
125
  return_class_token=True,
126
+ reshape=False
127
  )
128
  outputs = tuple(o[0] for o in outputs_tuple)
129
  class_tokens = tuple(o[1] for o in outputs_tuple)
 
136
  pixel_values,
137
  n=1,
138
  return_class_token=True,
139
+ reshape=False
140
  )
141
  last_hidden_state = outputs_tuple[0][0]
142
  pooler_output = outputs_tuple[0][1]
preprocessor_config.json DELETED
@@ -1,13 +0,0 @@
1
- {
2
- "image_processor_type": "TAPCTProcessor",
3
- "use_fast": false,
4
- "resize_dims": [224, 224],
5
- "divisible_pad_z": 4,
6
- "clip_range": [-1008.0, 822.0],
7
- "norm_mean": -86.80862426757812,
8
- "norm_std": 322.63470458984375,
9
- "auto_map": {
10
- "AutoImageProcessor": "tapct_processor.TAPCTProcessor"
11
- }
12
- }
13
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tapct_processor.py DELETED
@@ -1,179 +0,0 @@
1
- from typing import Union
2
-
3
- import numpy as np
4
- import torch
5
- import torch.nn.functional as F
6
- from transformers.image_processing_utils import BaseImageProcessor
7
-
8
-
9
- class TAPCTProcessor(BaseImageProcessor):
10
- """
11
- Image processor for TAP-CT 3D volumes.
12
-
13
- Processes CT volumes with the following pipeline:
14
-
15
- 1. Spatial Resizing: Resize to (z, H', W') where H', W' are resize_dims
16
- 2. Axial Padding: Pad z-axis with -1024 HU for divisibility by patch size
17
- 3. Intensity Clipping: Clip to HU range
18
- 4. Normalization: Z-score normalization
19
-
20
- Parameters
21
- ----------
22
- resize_dims : tuple[int, int], default=(224, 224)
23
- Target spatial dimensions (H, W) for resizing.
24
- divisible_pad_z : int, default=4
25
- Pad the z-axis to be divisible by this value.
26
- clip_range : tuple[float, float], default=(-1008.0, 822.0)
27
- HU intensity clipping range (min, max).
28
- norm_mean : float, default=-86.80862426757812
29
- Mean for z-score normalization.
30
- norm_std : float, default=322.63470458984375
31
- Standard deviation for z-score normalization.
32
- **kwargs
33
- Additional arguments passed to BaseImageProcessor.
34
- """
35
-
36
- model_input_names = ["pixel_values"]
37
-
38
- def __init__(
39
- self,
40
- resize_dims: tuple[int, int] = (224, 224),
41
- divisible_pad_z: int = 4,
42
- clip_range: tuple[float, float] = (-1008.0, 822.0),
43
- norm_mean: float = -86.80862426757812,
44
- norm_std: float = 322.63470458984375,
45
- **kwargs
46
- ) -> None:
47
- super().__init__(**kwargs)
48
- self.resize_dims = resize_dims
49
- self.divisible_pad_z = divisible_pad_z
50
- self.clip_range = clip_range
51
- self.norm_mean = norm_mean
52
- self.norm_std = norm_std
53
-
54
- def preprocess(
55
- self,
56
- images: Union[torch.Tensor, np.ndarray],
57
- return_tensors: str = "pt",
58
- **kwargs
59
- ) -> dict[str, torch.Tensor]:
60
- """
61
- Preprocess CT volumes.
62
-
63
- Parameters
64
- ----------
65
- images : torch.Tensor or np.ndarray
66
- Input tensor or numpy array of shape (B, C, D, H, W) where
67
- B=batch, C=channels, D=depth/slices, H=height, W=width.
68
- return_tensors : str, default="pt"
69
- Return format. Only "pt" (PyTorch) is supported.
70
- **kwargs
71
- Additional keyword arguments (unused).
72
-
73
- Returns
74
- -------
75
- dict[str, torch.Tensor]
76
- Dictionary with "pixel_values" containing processed tensor of shape
77
- (B, C, D', H', W') where D' may be padded for divisibility.
78
-
79
- Raises
80
- ------
81
- ValueError
82
- If return_tensors is not "pt" or input is not 5D.
83
- """
84
- if return_tensors != "pt":
85
- raise ValueError(f"Only 'pt' return_tensors is supported, got {return_tensors}")
86
-
87
- # Convert numpy to tensor if needed
88
- if isinstance(images, np.ndarray):
89
- images = torch.from_numpy(images)
90
-
91
- # Ensure float32 dtype for processing
92
- images = images.float()
93
-
94
- # Validate input shape
95
- if images.ndim != 5:
96
- raise ValueError(f"Expected 5D input (B, C, D, H, W), got shape {images.shape}")
97
-
98
- B, C, D, H, W = images.shape
99
-
100
- # Step 1: Spatial Resizing - resize H, W dimensions to resize_dims
101
- target_h, target_w = self.resize_dims
102
- if H != target_h or W != target_w:
103
- images = self._resize_spatial(images, target_h, target_w)
104
-
105
- # Step 2: Axial Padding - pad z-axis with -1024 for divisibility
106
- images = self._pad_axial(images)
107
-
108
- # Step 3: Intensity Clipping - clip to HU range
109
- images = torch.clamp(images, min=self.clip_range[0], max=self.clip_range[1])
110
-
111
- # Step 4: Z-score Normalization
112
- images = (images - self.norm_mean) / self.norm_std
113
-
114
- return {"pixel_values": images}
115
-
116
- def _resize_spatial(
117
- self,
118
- images: torch.Tensor,
119
- target_h: int,
120
- target_w: int
121
- ) -> torch.Tensor:
122
- """
123
- Resize spatial dimensions (H, W) using trilinear interpolation.
124
-
125
- Parameters
126
- ----------
127
- images : torch.Tensor
128
- Tensor of shape (B, C, D, H, W).
129
- target_h : int
130
- Target height.
131
- target_w : int
132
- Target width.
133
-
134
- Returns
135
- -------
136
- torch.Tensor
137
- Resized tensor of shape (B, C, D, target_h, target_w).
138
- """
139
- D = images.shape[2]
140
-
141
- # Apply trilinear interpolation, keeping depth unchanged
142
- images = F.interpolate(
143
- images,
144
- size=(D, target_h, target_w),
145
- mode='trilinear',
146
- align_corners=False
147
- )
148
-
149
- return images
150
-
151
- def _pad_axial(self, images: torch.Tensor) -> torch.Tensor:
152
- """
153
- Pad the axial (z/depth) dimension with -1024 HU for divisibility.
154
-
155
- Parameters
156
- ----------
157
- images : torch.Tensor
158
- Tensor of shape (B, C, D, H, W).
159
-
160
- Returns
161
- -------
162
- torch.Tensor
163
- Padded tensor of shape (B, C, D', H, W) where D' is divisible
164
- by divisible_pad_z.
165
- """
166
- D = images.shape[2]
167
- remainder = D % self.divisible_pad_z
168
-
169
- if remainder == 0:
170
- return images
171
-
172
- pad_z = self.divisible_pad_z - remainder
173
-
174
- # F.pad expects padding in reverse dimension order: (W_l, W_r, H_l, H_r, D_l, D_r, ...)
175
- # To pad depth at the end: (0, 0, 0, 0, 0, pad_z)
176
- padding = (0, 0, 0, 0, 0, pad_z)
177
- images = F.pad(images, padding, mode='constant', value=-1024.0)
178
-
179
- return images