Safetensors
tapct
custom_code
TimVeenboer commited on
Commit
3eca297
·
1 Parent(s): 7fb44d5

docs(tap-hf): README & docs editing

Browse files
Files changed (2) hide show
  1. README.md +67 -3
  2. modeling_tapct.py +7 -3
README.md CHANGED
@@ -6,11 +6,26 @@ license: cc-by-nc-4.0
6
 
7
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
8
 
9
- This repository provides TAP-CT-S-3D, a Vision Transformer (ViT-Small) architecture pretrained on volumetric inputs with a spatial resolution of (12, 224, 224) and a patch size of (4, 8, 8). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan. Additional TAP-CT model variants, as well as the image processor, will be released in future updates.
10
 
11
  ## Preprocessing
12
 
13
- While a dedicated image processor will be released in future updates, optimal feature extraction requires the following preprocessing pipeline:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
16
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
@@ -20,6 +35,8 @@ While a dedicated image processor will be released in future updates, optimal fe
20
 
21
  ## Usage
22
 
 
 
23
  ```python
24
  import torch
25
  from transformers import AutoModel
@@ -31,7 +48,54 @@ model = AutoModel.from_pretrained('fomofo/tap-ct-s-3d', trust_remote_code=True)
31
  x = torch.randn((16, 1, 12, 224, 224))
32
 
33
  # Forward pass
34
- output = model.forward(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ```
36
 
37
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
 
6
 
7
  TAP-CT is a suite of foundation models for computed tomography (CT) imaging, pretrained in a task-agnostic manner through an adaptation of DINOv2 for volumetric data. These models learn robust 3D representations from CT scans without requiring task-specific annotations.
8
 
9
+ This repository provides TAP-CT-S-3D, a Vision Transformer (ViT-Small) architecture pretrained on volumetric inputs with a spatial resolution of (12, 224, 224) and a patch size of (4, 8, 8). For inference on full-resolution CT volumes, a sliding window approach can be employed to extract features across the entire scan.
10
 
11
  ## Preprocessing
12
 
13
+ ### Using dedicated image processor
14
+
15
+ Each TAP-CT model repository provides its own dedicated image processor and configuration file. To ensure proper preprocessing, it is recommended to instantiate the corresponding image processor using the `AutoImageProcessor` class from Hugging Face Transformers. This can be accomplished as follows:
16
+
17
+ ```python
18
+ from transformers import AutoImageProcessor
19
+
20
+ preprocessor = AutoImageProcessor.from_pretrained(
21
+ 'fomofo/tap-ct-s-3d',
22
+ trust_remote_code=True
23
+ )
24
+ ```
25
+
26
+ This approach automatically loads the appropriate processor and configuration for the selected TAP-CT model.
27
+
28
+ ### Preprocessing without pipeline
29
 
30
  1. **Orientation**: Convert the volume to LPS (Left-Posterior-Superior) orientation. While the model is likely orientation-invariant, all evaluations were conducted using LPS orientation.
31
  2. **Spatial Resizing**: Resize the volume to a spatial resolution of \(z, 224, 224\) or \(z, 512, 512\), where \(z\) represents the number of slices along the axial dimension.
 
35
 
36
  ## Usage
37
 
38
+ ### Default Usage
39
+
40
  ```python
41
  import torch
42
  from transformers import AutoModel
 
48
  x = torch.randn((16, 1, 12, 224, 224))
49
 
50
  # Forward pass
51
+ with torch.no_grad():
52
+ output = model.forward(x)
53
+ ```
54
+
55
+ ### Usage with Preprocessor, loading CT volumes & sliding window inference
56
+
57
+ ```python
58
+ import numpy as np
59
+ import SimpleITK as sitk
60
+ import torch
61
+ from transformers import AutoModel, AutoImageProcessor
62
+
63
+ # Load the model
64
+ model = AutoModel.from_pretrained('fomofo/tap-ct-s-3d', trust_remote_code=True)
65
+ preprocessor = AutoImageProcessor.from_pretrained('fomofo/tap-ct-s-3d', trust_remote_code=True)
66
+
67
+ # Load image & set orientation to LPS
68
+ volume = sitk.ReadImage('/path/to/ct-scan.nii.gz')
69
+ volume = sitk.DICOMOrient(volume, 'LPS')
70
+
71
+ # Get array, expand to (B, C, D, H, W) and preprocess
72
+ array = sitk.GetArrayFromImage(volume)
73
+ array = np.expand_dims(array, axis(0, 1))
74
+ x = preprocessor(array)['pixel_values']
75
+
76
+ # Forward pass
77
+ with torch.no_grad():
78
+ output = model.forward(x)
79
+
80
+ # OR
81
+
82
+ # Forward pass with sliding window
83
+ from monai.inferers import SlidingWindowInferer
84
+
85
+ def predictor_fn(x):
86
+ # Reshape the patch tokens to resemble a 3D feature map
87
+ out = model(x, reshape=True)
88
+ return out.last_hidden_state
89
+
90
+ inferer = SlidingWindowInferer(
91
+ roi_size=[12, 224, 224],
92
+ sw_batch_size=1,
93
+ overlap=0.75,
94
+ mode='gaussian'
95
+ )
96
+
97
+ with torch.no_grad():
98
+ output = inferer(x, predictor_fn)
99
  ```
100
 
101
  The model returns a `BaseModelOutputWithPooling` object from the transformers library. The `output.pooler_output` contains the pooled `[CLS]` token representation, while `output.last_hidden_state` contains the spatial patch token embeddings. To extract features from all intermediate transformer layers, pass `output_hidden_states=True` to the forward method.
modeling_tapct.py CHANGED
@@ -102,11 +102,15 @@ class TAPCTModel(TAPCTPreTrainedModel):
102
  Parameters
103
  ----------
104
  pixel_values : torch.Tensor
105
- Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D
106
  output_hidden_states : Optional[bool], optional
107
- Whether to return hidden states from all layers
108
  return_dict : Optional[bool], optional
109
- Whether to return a ModelOutput instead of a plain tuple
 
 
 
 
110
 
111
  Returns
112
  -------
 
102
  Parameters
103
  ----------
104
  pixel_values : torch.Tensor
105
+ Input images. Shape (B, C, H, W) for 2D or (B, C, D, H, W) for 3D.
106
  output_hidden_states : Optional[bool], optional
107
+ Whether to return hidden states from all layers.
108
  return_dict : Optional[bool], optional
109
+ Whether to return a ModelOutput instead of a plain tuple.
110
+ reshape : bool, default=False
111
+ Whether to reshape output features to spatial dimensions. If True,
112
+ returns shape (B, H, W, C) for 2D or (B, D, H, W, C) for 3D instead
113
+ of flattened (B, N, C) where N is the number of patches.
114
 
115
  Returns
116
  -------