Safetensors
tapct
custom_code
TimVeenboer commited on
Commit
eb26bbb
·
1 Parent(s): 9be891b

docs(tap-hf): readme and modeling docs

Browse files
Files changed (2) hide show
  1. README.md +68 -4
  2. modeling_tapct.py +10 -5
README.md CHANGED
@@ -6,11 +6,26 @@ license: cc-by-nc-4.0
6
 
7
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
8
 
9
- This repository provides TAP-CT-B-2.5D, a Vision Transformer (ViT-Small) architecture pretrained on volumetric inputs with a spatial resolution of (6, 224, 224) and a patch size of (1, 16, 16). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan. Additional TAP-CT model variants, as well as the image processor, will be released in future updates.
10
 
11
  ## Preprocessing
12
 
13
- While a dedicated image processor will be released in future updates, optimal feature extraction requires the following preprocessing pipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
16
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
@@ -19,6 +34,8 @@ While a dedicated image processor will be released in future updates, optimal fe
19
 
20
  ## Usage
21
 
 
 
22
  ```python
23
  import torch
24
  from transformers import AutoModel
@@ -30,7 +47,54 @@ model = AutoModel.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True
30
  x = torch.randn((16, 1, 6, 224, 224))
31
 
32
  # Forward pass
33
- output = model.forward(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ```
35
 
36
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
@@ -40,4 +104,4 @@ The model returns a `BaseModelOutputWithPooling` object from the transformers li
40
  - **Model Type**: 3D CT Vision Foundation Model
41
  - **Input Shape**: `(batch_size, 1, depth, height, width)`
42
  - **Example Input**: `(16, 1, 6, 224, 224)` - batch of 16 CT crops with 6 slices at 224×224 resolution
43
- - **License**: CC-BY-NC-4.0
 
6
 
7
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
8
 
9
+ This repository provides TAP-CT-B-2.5D, a Vision Transformer (ViT-Base) architecture pretrained on volumetric inputs with a spatial resolution of (6, 224, 224) and a patch size of (1, 16, 16). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan.
10
 
11
  ## Preprocessing
12
 
13
+ ### Using dedicated image processor
14
+
15
+ Each TAP-CT model repository provides its own dedicated image processor and configuration file. To ensure proper preprocessing, it is recommended to instantiate the corresponding image processor using the `AutoImageProcessor` class from Hugging Face Transformers. This can be accomplished as follows:
16
+
17
+ ```python
18
+ from transformers import AutoImageProcessor
19
+
20
+ preprocessor = AutoImageProcessor.from_pretrained(
21
+ 'fomofo/tap-ct-b-2-5d',
22
+ trust_remote_code=True
23
+ )
24
+ ```
25
+
26
+ This approach automatically loads the appropriate processor and configuration for the selected TAP-CT model.
27
+
28
+ ### Preprocessing without pipeline
29
 
30
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
31
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
 
34
 
35
  ## Usage
36
 
37
+ ### Default Usage
38
+
39
  ```python
40
  import torch
41
  from transformers import AutoModel
 
47
  x = torch.randn((16, 1, 6, 224, 224))
48
 
49
  # Forward pass
50
+ with torch.no_grad():
51
+ output = model.forward(x)
52
+ ```
53
+
54
+ ### Usage with Preprocessor, loading CT volumes & sliding window inference
55
+
56
+ ```python
57
+ import numpy as np
58
+ import SimpleITK as sitk
59
+ import torch
60
+ from transformers import AutoModel, AutoImageProcessor
61
+
62
+ # Load the model
63
+ model = AutoModel.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True)
64
+ preprocessor = AutoImageProcessor.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True)
65
+
66
+ # Load image & set orientation to LPS
67
+ volume = sitk.ReadImage('/path/to/ct-scan.nii.gz')
68
+ volume = sitk.DICOMOrient(volume, 'LPS')
69
+
70
+ # Get array, expand to (B, C, D, H, W) and preprocess
71
+ array = sitk.GetArrayFromImage(volume)
72
+ array = np.expand_dims(array, axis(0, 1))
73
+ x = preprocessor(array)['pixel_values']
74
+
75
+ # Forward pass
76
+ with torch.no_grad():
77
+ output = model.forward(x)
78
+
79
+ # OR
80
+
81
+ # Forward pass with sliding window
82
+ from monai.inferers import SlidingWindowInferer
83
+
84
+ def predictor_fn(x):
85
+ # Reshape the patch tokens to resemble a 3D feature map
86
+ out = model(x, reshape=True)
87
+ return out.last_hidden_state
88
+
89
+ inferer = SlidingWindowInferer(
90
+ roi_size=[6, 224, 224],
91
+ sw_batch_size=1,
92
+ overlap=0.75,
93
+ mode='gaussian'
94
+ )
95
+
96
+ with torch.no_grad():
97
+ output = inferer(x, predictor_fn)
98
  ```
99
 
100
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
 
104
  - **Model Type**: 3D CT Vision Foundation Model
105
  - **Input Shape**: `(batch_size, 1, depth, height, width)`
106
  - **Example Input**: `(16, 1, 6, 224, 224)` - batch of 16 CT crops with 6 slices at 224×224 resolution
107
+ - **License**: CC-BY-NC-4.0
modeling_tapct.py CHANGED
@@ -94,6 +94,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
94
  pixel_values: torch.Tensor,
95
  output_hidden_states: Optional[bool] = None,
96
  return_dict: Optional[bool] = None,
 
97
  ) -> BaseModelOutputWithPooling:
98
  """
99
  Forward pass of the TAP-CT model.
@@ -101,11 +102,15 @@ class TAPCTModel(TAPCTPreTrainedModel):
101
  Parameters
102
  ----------
103
  pixel_values : torch.Tensor
104
- Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D
105
  output_hidden_states : Optional[bool], optional
106
- Whether to return hidden states from all layers
107
  return_dict : Optional[bool], optional
108
- Whether to return a ModelOutput instead of a plain tuple
 
 
 
 
109
 
110
  Returns
111
  -------
@@ -123,7 +128,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
123
  pixel_values,
124
  n=self.model.n_blocks,
125
  return_class_token=True,
126
- reshape=False
127
  )
128
  outputs = tuple(o[0] for o in outputs_tuple)
129
  class_tokens = tuple(o[1] for o in outputs_tuple)
@@ -136,7 +141,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
136
  pixel_values,
137
  n=1,
138
  return_class_token=True,
139
- reshape=False
140
  )
141
  last_hidden_state = outputs_tuple[0][0]
142
  pooler_output = outputs_tuple[0][1]
 
94
  pixel_values: torch.Tensor,
95
  output_hidden_states: Optional[bool] = None,
96
  return_dict: Optional[bool] = None,
97
+ reshape: bool = False
98
  ) -> BaseModelOutputWithPooling:
99
  """
100
  Forward pass of the TAP-CT model.
 
102
  Parameters
103
  ----------
104
  pixel_values : torch.Tensor
105
+ Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D.
106
  output_hidden_states : Optional[bool], optional
107
+ Whether to return hidden states from all layers.
108
  return_dict : Optional[bool], optional
109
+ Whether to return a ModelOutput instead of a plain tuple.
110
+ reshape : bool, default=False
111
+ Whether to reshape output features to spatial dimensions. If True,
112
+ returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead
113
+ of flattened (B, N, C) where N is the number of patches.
114
 
115
  Returns
116
  -------
 
128
  pixel_values,
129
  n=self.model.n_blocks,
130
  return_class_token=True,
131
+ reshape=reshape
132
  )
133
  outputs = tuple(o[0] for o in outputs_tuple)
134
  class_tokens = tuple(o[1] for o in outputs_tuple)
 
141
  pixel_values,
142
  n=1,
143
  return_class_token=True,
144
+ reshape=reshape
145
  )
146
  last_hidden_state = outputs_tuple[0][0]
147
  pooler_output = outputs_tuple[0][1]