TimVeenboer
commited on
Commit
·
eb26bbb
1
Parent(s):
9be891b
docs(tap-hf): readme and modeling docs
Browse files- README.md +68 -4
- modeling_tapct.py +10 -5
README.md
CHANGED
|
@@ -6,11 +6,26 @@ license: cc-by-nc-4.0
|
|
| 6 |
|
| 7 |
TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
|
| 8 |
|
| 9 |
-
This repository provides TAP-CT-B-2.5D, a Vision Transformer (ViT-
|
| 10 |
|
| 11 |
## Preprocessing
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
|
| 16 |
2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
|
|
@@ -19,6 +34,8 @@ While a dedicated image processor will be released in future updates, optimal fe
|
|
| 19 |
|
| 20 |
## Usage
|
| 21 |
|
|
|
|
|
|
|
| 22 |
```python
|
| 23 |
import torch
|
| 24 |
from transformers import AutoModel
|
|
@@ -30,7 +47,54 @@ model = AutoModel.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True
|
|
| 30 |
x = torch.randn((16, 1, 6, 224, 224))
|
| 31 |
|
| 32 |
# Forward pass
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
```
|
| 35 |
|
| 36 |
The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
|
|
@@ -40,4 +104,4 @@ The model returns a `BaseModelOutputWithPooling` object from the transformers li
|
|
| 40 |
- **Model Type**: 3D CT Vision Foundation Model
|
| 41 |
- **Input Shape**: `(batch_size, 1, depth, height, width)`
|
| 42 |
- **Example Input**: `(16, 1, 6, 224, 224)` - batch of 16 CT crops with 6 slices at 224×224 resolution
|
| 43 |
-
- **License**: CC-BY-NC-4.0
|
|
|
|
| 6 |
|
| 7 |
TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
|
| 8 |
|
| 9 |
+
This repository provides TAP-CT-B-2.5D, a Vision Transformer (ViT-Base) architecture pretrained on volumetric inputs with a spatial resolution of (6, 224, 224) and a patch size of (1, 16, 16). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan.
|
| 10 |
|
| 11 |
## Preprocessing
|
| 12 |
|
| 13 |
+
### Using dedicated image processor
|
| 14 |
+
|
| 15 |
+
Each TAP-CT model repository provides its own dedicated image processor and configuration file. To ensure proper preprocessing, it is recommended to instantiate the corresponding image processor using the `AutoImageProcessor` class from Hugging Face Transformers. This can be accomplished as follows:
|
| 16 |
+
|
| 17 |
+
```python
|
| 18 |
+
from transformers import AutoImageProcessor
|
| 19 |
+
|
| 20 |
+
preprocessor = AutoImageProcessor.from_pretrained(
|
| 21 |
+
'fomofo/tap-ct-b-2-5d',
|
| 22 |
+
trust_remote_code=True
|
| 23 |
+
)
|
| 24 |
+
```
|
| 25 |
+
|
| 26 |
+
This approach automatically loads the appropriate processor and configuration for the selected TAP-CT model.
|
| 27 |
+
|
| 28 |
+
### Preprocessing without pipeline
|
| 29 |
|
| 30 |
1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
|
| 31 |
2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
|
|
|
|
| 34 |
|
| 35 |
## Usage
|
| 36 |
|
| 37 |
+
### Default Usage
|
| 38 |
+
|
| 39 |
```python
|
| 40 |
import torch
|
| 41 |
from transformers import AutoModel
|
|
|
|
| 47 |
x = torch.randn((16, 1, 6, 224, 224))
|
| 48 |
|
| 49 |
# Forward pass
|
| 50 |
+
with torch.no_grad():
|
| 51 |
+
output = model.forward(x)
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
### Usage with Preprocessor, loading CT volumes & sliding window inference
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
import numpy as np
|
| 58 |
+
import SimpleITK as sitk
|
| 59 |
+
import torch
|
| 60 |
+
from transformers import AutoModel, AutoImageProcessor
|
| 61 |
+
|
| 62 |
+
# Load the model
|
| 63 |
+
model = AutoModel.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True)
|
| 64 |
+
preprocessor = AutoImageProcessor.from_pretrained('fomofo/tap-ct-b-2-5d', trust_remote_code=True)
|
| 65 |
+
|
| 66 |
+
# Load image & set orientation to LPS
|
| 67 |
+
volume = sitk.ReadImage('/path/to/ct-scan.nii.gz')
|
| 68 |
+
volume = sitk.DICOMOrient(volume, 'LPS')
|
| 69 |
+
|
| 70 |
+
# Get array, expand to (B, C, D, H, W) and preprocess
|
| 71 |
+
array = sitk.GetArrayFromImage(volume)
|
| 72 |
+
array = np.expand_dims(array, axis(0, 1))
|
| 73 |
+
x = preprocessor(array)['pixel_values']
|
| 74 |
+
|
| 75 |
+
# Forward pass
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
output = model.forward(x)
|
| 78 |
+
|
| 79 |
+
# OR
|
| 80 |
+
|
| 81 |
+
# Forward pass with sliding window
|
| 82 |
+
from monai.inferers import SlidingWindowInferer
|
| 83 |
+
|
| 84 |
+
def predictor_fn(x):
|
| 85 |
+
# Reshape the patch tokens to resemble a 3D feature map
|
| 86 |
+
out = model(x, reshape=True)
|
| 87 |
+
return out.last_hidden_state
|
| 88 |
+
|
| 89 |
+
inferer = SlidingWindowInferer(
|
| 90 |
+
roi_size=[6, 224, 224],
|
| 91 |
+
sw_batch_size=1,
|
| 92 |
+
overlap=0.75,
|
| 93 |
+
mode='gaussian'
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
with torch.no_grad():
|
| 97 |
+
output = inferer(x, predictor_fn)
|
| 98 |
```
|
| 99 |
|
| 100 |
The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
|
|
|
|
| 104 |
- **Model Type**: 3D CT Vision Foundation Model
|
| 105 |
- **Input Shape**: `(batch_size, 1, depth, height, width)`
|
| 106 |
- **Example Input**: `(16, 1, 6, 224, 224)` - batch of 16 CT crops with 6 slices at 224×224 resolution
|
| 107 |
+
- **License**: CC-BY-NC-4.0
|
modeling_tapct.py
CHANGED
|
@@ -94,6 +94,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
|
|
| 94 |
pixel_values: torch.Tensor,
|
| 95 |
output_hidden_states: Optional[bool] = None,
|
| 96 |
return_dict: Optional[bool] = None,
|
|
|
|
| 97 |
) -> BaseModelOutputWithPooling:
|
| 98 |
"""
|
| 99 |
Forward pass of the TAP-CT model.
|
|
@@ -101,11 +102,15 @@ class TAPCTModel(TAPCTPreTrainedModel):
|
|
| 101 |
Parameters
|
| 102 |
----------
|
| 103 |
pixel_values : torch.Tensor
|
| 104 |
-
Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D
|
| 105 |
output_hidden_states : Optional[bool], optional
|
| 106 |
-
Whether to return hidden states from all layers
|
| 107 |
return_dict : Optional[bool], optional
|
| 108 |
-
Whether to return a ModelOutput instead of a plain tuple
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
|
| 110 |
Returns
|
| 111 |
-------
|
|
@@ -123,7 +128,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
|
|
| 123 |
pixel_values,
|
| 124 |
n=self.model.n_blocks,
|
| 125 |
return_class_token=True,
|
| 126 |
-
reshape=
|
| 127 |
)
|
| 128 |
outputs = tuple(o[0] for o in outputs_tuple)
|
| 129 |
class_tokens = tuple(o[1] for o in outputs_tuple)
|
|
@@ -136,7 +141,7 @@ class TAPCTModel(TAPCTPreTrainedModel):
|
|
| 136 |
pixel_values,
|
| 137 |
n=1,
|
| 138 |
return_class_token=True,
|
| 139 |
-
reshape=
|
| 140 |
)
|
| 141 |
last_hidden_state = outputs_tuple[0][0]
|
| 142 |
pooler_output = outputs_tuple[0][1]
|
|
|
|
| 94 |
pixel_values: torch.Tensor,
|
| 95 |
output_hidden_states: Optional[bool] = None,
|
| 96 |
return_dict: Optional[bool] = None,
|
| 97 |
+
reshape: bool = False
|
| 98 |
) -> BaseModelOutputWithPooling:
|
| 99 |
"""
|
| 100 |
Forward pass of the TAP-CT model.
|
|
|
|
| 102 |
Parameters
|
| 103 |
----------
|
| 104 |
pixel_values : torch.Tensor
|
| 105 |
+
Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D.
|
| 106 |
output_hidden_states : Optional[bool], optional
|
| 107 |
+
Whether to return hidden states from all layers.
|
| 108 |
return_dict : Optional[bool], optional
|
| 109 |
+
Whether to return a ModelOutput instead of a plain tuple.
|
| 110 |
+
reshape : bool, default=False
|
| 111 |
+
Whether to reshape output features to spatial dimensions. If True,
|
| 112 |
+
returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead
|
| 113 |
+
of flattened (B, N, C) where N is the number of patches.
|
| 114 |
|
| 115 |
Returns
|
| 116 |
-------
|
|
|
|
| 128 |
pixel_values,
|
| 129 |
n=self.model.n_blocks,
|
| 130 |
return_class_token=True,
|
| 131 |
+
reshape=reshape
|
| 132 |
)
|
| 133 |
outputs = tuple(o[0] for o in outputs_tuple)
|
| 134 |
class_tokens = tuple(o[1] for o in outputs_tuple)
|
|
|
|
| 141 |
pixel_values,
|
| 142 |
n=1,
|
| 143 |
return_class_token=True,
|
| 144 |
+
reshape=reshape
|
| 145 |
)
|
| 146 |
last_hidden_state = outputs_tuple[0][0]
|
| 147 |
pooler_output = outputs_tuple[0][1]
|