Add package code and model weights
Browse files- .gitattributes +1 -0
- .gitignore +12 -0
- .python-version +1 -0
- COLIPRI_demo.ipynb +0 -0
- LICENSE +21 -0
- README.md +176 -1
- assets/input.png +3 -0
- justfile +2 -0
- model.safetensors +3 -0
- model_card.md +115 -0
- pyproject.toml +70 -0
- src/colipri/__init__.py +11 -0
- src/colipri/checkpoint.py +54 -0
- src/colipri/configs/config.yaml +10 -0
- src/colipri/configs/model/default.yaml +5 -0
- src/colipri/configs/model/image_encoder/backbone/default.yaml +20 -0
- src/colipri/configs/model/image_encoder/default.yaml +6 -0
- src/colipri/configs/model/image_encoder/pooler/default.yaml +3 -0
- src/colipri/configs/model/image_encoder/projector/default.yaml +4 -0
- src/colipri/configs/model/text_encoder/backbone/default.yaml +4 -0
- src/colipri/configs/model/text_encoder/default.yaml +5 -0
- src/colipri/configs/model/text_encoder/pooler/default.yaml +3 -0
- src/colipri/configs/processor/default.yaml +5 -0
- src/colipri/configs/processor/image_transform/default.yaml +49 -0
- src/colipri/configs/processor/tokenizer/default.yaml +5 -0
- src/colipri/defaults.py +21 -0
- src/colipri/model/__init__.py +0 -0
- src/colipri/model/base.py +8 -0
- src/colipri/model/image.py +109 -0
- src/colipri/model/multimodal.py +184 -0
- src/colipri/model/text.py +42 -0
- src/colipri/pipelines.py +65 -0
- src/colipri/pooling.py +74 -0
- src/colipri/processor.py +90 -0
- src/colipri/sample_data.py +24 -0
- src/colipri/types.py +42 -0
- uv.lock +0 -0
.gitattributes
CHANGED
|
@@ -19,6 +19,7 @@
|
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
|
|
|
| 19 |
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 23 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 24 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 25 |
*.rar filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
|
| 12 |
+
.vscode/
|
.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.13
|
COLIPRI_demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) Microsoft Corporation.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE
|
README.md
CHANGED
|
@@ -1,5 +1,180 @@
|
|
| 1 |
---
|
| 2 |
license: mit
|
|
|
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
|
| 5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
pipeline_tag: zero-shot-image-classification
|
| 6 |
---
|
| 7 |
|
| 8 |
+
# Model card for COLIPRI
|
| 9 |
+
|
| 10 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 11 |
+
|
| 12 |
+
COLIPRI is a 3D vision–language transformer model trained to encode chest CT scans and reports.
|
| 13 |
+
|
| 14 |
+
## Model description
|
| 15 |
+
|
| 16 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 17 |
+
|
| 18 |
+
COLIPRI was trained using tens of thousands of chest CT scans and reports, without any annotations, using multiple objectives to learn strong joint representations of 3D images and text.
|
| 19 |
+
The procedure is described in detail in our manuscript, [_Comprehensive language-image pre-training for 3D medical image understanding_](https://arxiv.org/abs/2510.15042) (Wald et al. 2026).
|
| 20 |
+
|
| 21 |
+
The weights shared here correspond to our best-performing model, COLIPRI-CRM.
|
| 22 |
+
|
| 23 |
+
- **Developed by:** Microsoft Health Futures
|
| 24 |
+
- **Model type:** 3D vision–language encoder
|
| 25 |
+
- **License:** [MIT](./LICENSE)
|
| 26 |
+
|
| 27 |
+
## Uses
|
| 28 |
+
|
| 29 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 30 |
+
|
| 31 |
+
COLIPRI is shared for research purposes only.
|
| 32 |
+
It is **not meant to be used for clinical practice**.
|
| 33 |
+
|
| 34 |
+
<!-- ### Downstream use -->
|
| 35 |
+
|
| 36 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 37 |
+
|
| 38 |
+
The encoders be plugged to other models, or used independently or jointly for many downstream tasks, such as:
|
| 39 |
+
|
| 40 |
+
- Image classification with text prompts
|
| 41 |
+
- Image clustering
|
| 42 |
+
- Text clustering
|
| 43 |
+
- Text-to-image retrieval
|
| 44 |
+
- Image-to-image retrieval
|
| 45 |
+
- Image-to-text retrieval
|
| 46 |
+
- Text-to-text retrieval
|
| 47 |
+
- Image classification with a classifier
|
| 48 |
+
- Text classification with a classifier
|
| 49 |
+
- Image segmentation with a decoder
|
| 50 |
+
- Report generation with a language decoder
|
| 51 |
+
|
| 52 |
+
Fine-tuning COLIPRI is typically not necessary to obtain good performance in downstream tasks.
|
| 53 |
+
|
| 54 |
+
<!-- ### Out-of-scope use -->
|
| 55 |
+
|
| 56 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 57 |
+
|
| 58 |
+
## Biases, risks, and limitations
|
| 59 |
+
|
| 60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 61 |
+
|
| 62 |
+
COLIPRI was trained with data from Turkey and the USA only, therefore it might be biased towards population in the training data.
|
| 63 |
+
Underlying biases of the training datasets may not be well characterized.
|
| 64 |
+
|
| 65 |
+
## Installation
|
| 66 |
+
|
| 67 |
+
```shell
|
| 68 |
+
pip install git+https://huggingface.co/microsoft/colipri.git
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Usage examples
|
| 72 |
+
|
| 73 |
+
First, let's get a 3D chest CT we can use for demonstration.
|
| 74 |
+
The plotted slices intersect a lung nodule near the heart.
|
| 75 |
+
|
| 76 |
+
```python
|
| 77 |
+
>>> from colipri import load_sample_ct
|
| 78 |
+
>>> image = load_sample_ct()
|
| 79 |
+
>>> image
|
| 80 |
+
ScalarImage(shape: (1, 512, 512, 139); spacing: (0.76, 0.76, 2.50); orientation: LPS+; dtype: torch.IntTensor; memory: 139.0 MiB)
|
| 81 |
+
>>> nodule_indices_ct = 180, 341, 76
|
| 82 |
+
>>> image.plot(indices=nodule_indices_ct)
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+

|
| 86 |
+
|
| 87 |
+
Now, let's instantiate the model and processor.
|
| 88 |
+
|
| 89 |
+
```python
|
| 90 |
+
>>> from colipri import get_model
|
| 91 |
+
>>> from colipri import get_processor
|
| 92 |
+
>>> model = get_model().cuda()
|
| 93 |
+
>>> processor = get_processor()
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
### Feature extraction
|
| 97 |
+
|
| 98 |
+
```python
|
| 99 |
+
>>> import torch
|
| 100 |
+
>>> preprocessed_images = processor.process_images(image)
|
| 101 |
+
>>> preprocessed_images[0]
|
| 102 |
+
ScalarImage(shape: (1, 192, 192, 192); spacing: (2.00, 2.00, 2.00); orientation: SAR+; dtype: torch.FloatTensor; memory: 27.0 MiB)
|
| 103 |
+
>>> images_batch = processor.to_images_batch(preprocessed_images)
|
| 104 |
+
images_batch.shape
|
| 105 |
+
torch.Size([1, 1, 192, 192, 192])
|
| 106 |
+
>>> with torch.no_grad():
|
| 107 |
+
... patch_embeddings = model.encode_image(images_batch)
|
| 108 |
+
>>> patch_embeddings.shape
|
| 109 |
+
torch.Size([1, 768, 24, 24, 24])
|
| 110 |
+
>>> with torch.no_grad():
|
| 111 |
+
... pooled_embeddings = model.encode_image(images_batch, pool=True)
|
| 112 |
+
>>> pooled_embeddings.shape
|
| 113 |
+
torch.Size([1, 768])
|
| 114 |
+
```
|
| 115 |
+
|
| 116 |
+
### Zero-shot classification
|
| 117 |
+
|
| 118 |
+
```python
|
| 119 |
+
>>> from colipri import ZeroShotImageClassificationPipeline
|
| 120 |
+
>>> pipeline = ZeroShotImageClassificationPipeline(model, processor)
|
| 121 |
+
>>> pipeline(image, ["No lung nodules", "Lung nodules"])
|
| 122 |
+
[
|
| 123 |
+
{'score': 0.005, 'label': 'No lung nodules'},
|
| 124 |
+
{'score': 0.995, 'label': 'Lung nodules'}
|
| 125 |
+
]
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
## Environmental impact
|
| 129 |
+
|
| 130 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 131 |
+
|
| 132 |
+
<!-- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). -->
|
| 133 |
+
|
| 134 |
+
- **Hardware type:** NVIDIA A100 GPUs
|
| 135 |
+
- **Hours used:** 72 hours × 4 GPUs = 288 GPU-hours
|
| 136 |
+
- **Cloud provider:** Azure
|
| 137 |
+
- **Compute region:** West US 2
|
| 138 |
+
- **Carbon emitted:** 21.6 kg CO₂ eq.
|
| 139 |
+
|
| 140 |
+
### Compute infrastructure
|
| 141 |
+
|
| 142 |
+
COLIPRI was trained on [Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning).
|
| 143 |
+
|
| 144 |
+
#### Hardware
|
| 145 |
+
|
| 146 |
+
| Stage | Node type | Num. nodes | GPU type | GPUs per node |
|
| 147 |
+
| --- | --- | --- | --- | --- |
|
| 148 |
+
| Pre-training | [`Standard_NC96ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 4 |
|
| 149 |
+
| Evaluation | [`Standard_NC24ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 1 |
|
| 150 |
+
|
| 151 |
+
#### Software
|
| 152 |
+
|
| 153 |
+
The main software libraries used in this work were [nnSSL](https://github.com/MIC-DKFZ/nnssl) for training, [TorchIO](https://torchio.org/) for preprocessing and augmentation, [`nifti-zarr-py`](https://github.com/neuroscales/nifti-zarr-py) for data loading, and [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) for segmentation evaluation.
|
| 154 |
+
|
| 155 |
+
## Citation
|
| 156 |
+
|
| 157 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 158 |
+
|
| 159 |
+
### BibTeX
|
| 160 |
+
|
| 161 |
+
```bibtex
|
| 162 |
+
@misc{
|
| 163 |
+
wald2026_colipri,
|
| 164 |
+
title={Comprehensive language-image pre-training for 3D medical image understanding},
|
| 165 |
+
author={Tassilo Wald and Ibrahim Ethem Hamamci and Yuan Gao and Sam Bond-Taylor and Harshita Sharma and Maximilian Ilse and Cynthia Lo and Olesya Melnichenko and Anton Schwaighofer and Noel C. F. Codella and Maria Teodora Wetscherek and Klaus H. Maier-Hein and Panagiotis Korfiatis and Valentina Salvatelli and Javier Alvarez-Valle and P{\'e}rez-Garc{\'i}a},
|
| 166 |
+
year={2026},
|
| 167 |
+
eprint={2510.15042},
|
| 168 |
+
archivePrefix={arXiv},
|
| 169 |
+
primaryClass={cs.CV},
|
| 170 |
+
url={https://arxiv.org/abs/2510.15042},
|
| 171 |
+
}
|
| 172 |
+
```
|
| 173 |
+
|
| 174 |
+
### APA
|
| 175 |
+
|
| 176 |
+
> Wald, T., Hamamci, I. E., Gao, Y., Bond-Taylor, S., Sharma, H., Ilse, M., Lo, C., Melnichenko, O., Schwaighofer, A., Codella, N. C. F., Wetscherek, M. T., Maier-Hein, K. H., Korfiatis, P., Salvatelli, V., Alvarez-Valle, J., & Pérez-García, F. (2026). Comprehensive language-image pre-training for 3D medical image understanding. arXiv. <https://doi.org/10.48550/ARXIV.2510.15042>
|
| 177 |
+
|
| 178 |
+
## Model card contact
|
| 179 |
+
|
| 180 |
+
Fernando Pérez-García ([`fperezgarcia@microsoft.com`](mailto:fperezgarcia@microsoft.com)).
|
assets/input.png
ADDED
|
Git LFS Details
|
justfile
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
requirements:
|
| 2 |
+
uv export --quiet --no-hashes --no-emit-project --output-file requirements_demo.txt
|
model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97d66586507d1f48577b5149777d02e12d25787435a2a295cac5563df9291c14
|
| 3 |
+
size 1033541380
|
model_card.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Model card
|
| 3 |
+
|
| 4 |
+
Date effective: September 2025
|
| 5 |
+
Owner: ORA, Internal Policy & Practice
|
| 6 |
+
|
| 7 |
+
## Model summary
|
| 8 |
+
|
| 9 |
+
| | |
|
| 10 |
+
| --- | --- |
|
| 11 |
+
| Developer | Microsoft Research |
|
| 12 |
+
| Description | COLIPRI is composed of a vision encoder for 3D CT scans and a text encoder for clinical reports. The encoders may be leveraged for downstream tasks such as classification or segmentation. The models were trained using CT-RATE and NLST. COLIPRI is not a generative model. |
|
| 13 |
+
| Model architecture | The vision encoder is Primus-M, a 3D vision transformer. The text encoder architecture is BERT. |
|
| 14 |
+
| Parameters | 1-500M |
|
| 15 |
+
| Inputs | 3D chest CT scans and/or radiology reports |
|
| 16 |
+
| Context length | N/A (not an LLM) |
|
| 17 |
+
| Outputs | Image and/or text embeddings, not directly human-interpretable. |
|
| 18 |
+
| GPUs | 4 x A100 GPUs |
|
| 19 |
+
| Training time | 1.75 Days |
|
| 20 |
+
| Public data summary (or summaries) | Training Data: National Lung Screening Trial (NLST) dataset: Chest CT images from approximately 26k patients CT-RATE dataset: Approximately 25.7k CT acquisitions with associate reports |
|
| 21 |
+
| Training dates | Dates of training: August 2024. Intended model weight release date: Dec 5, 2025 |
|
| 22 |
+
| Status | Static |
|
| 23 |
+
| Release date | Intended release date: Dec 5, 2025 |
|
| 24 |
+
| Release date in the EU (if different) | Intended release date: Dec 5, 2025 |
|
| 25 |
+
| License | Free/open source |
|
| 26 |
+
| Model dependencies | COLIPRI’s text decoder is fine-tuned from microsoft/BiomedVLP-CXR-BERT-specialized · Hugging Face. |
|
| 27 |
+
| List and link to any additional related assets | N/A |
|
| 28 |
+
| Acceptable use policy | N/A |
|
| 29 |
+
|
| 30 |
+
## Model overview
|
| 31 |
+
|
| 32 |
+
COLIPRI is an encoder designed for 3D CT scans. Its architecture combines a 3D vision encoder (PRIMUS-M transformer) with a biomedical text encoder (BiomedVLP-CXR-BERT) and enhances existing models by adding additional training objectives and optimizations. These include the inclusion of new training objectives and optimizations to the CLIP paradigm to account for limited data availability. This multi-objective training approach optimizes both global semantic alignment and dense feature learning. After pre-training, the model is evaluated comprehensively on downstream medical imaging tasks such as classification, retrieval, and semantic segmentation.
|
| 33 |
+
COLIPRI is not a generative model. It cannot generate any images or text.
|
| 34 |
+
|
| 35 |
+
### Alignment approach
|
| 36 |
+
|
| 37 |
+
N/A
|
| 38 |
+
|
| 39 |
+
## Usage
|
| 40 |
+
|
| 41 |
+
### Primary use cases
|
| 42 |
+
|
| 43 |
+
COLIPRI is primarily intended for 3D medical imaging tasks such as segmentation and classification. It is intended for research only.
|
| 44 |
+
|
| 45 |
+
### Out-of-scope use cases
|
| 46 |
+
|
| 47 |
+
The COLIPRI model family was trained on 3D CT volumes and is therefore best suited for inputs of this type. It is intended for research only. The model is for research use only; clinical or deployed use cases are out of scope.
|
| 48 |
+
|
| 49 |
+
### Distribution channels
|
| 50 |
+
|
| 51 |
+
The model will be released with open access via Hugging Face.
|
| 52 |
+
|
| 53 |
+
### Input formats
|
| 54 |
+
|
| 55 |
+
The vision encoder takes 3D chest CT scans, and the text encoder takes radiology reports.
|
| 56 |
+
|
| 57 |
+
### Technical requirements and integration guidance
|
| 58 |
+
|
| 59 |
+
The model requires a Python-based environment with PyTorch. The only hardware requirement is one GPU.
|
| 60 |
+
|
| 61 |
+
### Responsible AI considerations
|
| 62 |
+
|
| 63 |
+
The model may underperform or produce unreliable outputs for non-English languages, out-of-distribution or adversarial inputs, or clinical scenarios not well represented in the training data. The model should be used in an AI-assisted research setup with human oversight, rather than as a sole diagnostic tool.
|
| 64 |
+
|
| 65 |
+
## Data overview
|
| 66 |
+
|
| 67 |
+
### Training, testing, and validation datasets
|
| 68 |
+
|
| 69 |
+
The COLIPRI model family was pre-trained on the CT-RATE dataset, consisting of approximately 25.7k CT acquisitions with associated reports, and the National Lung Screening Trial (NLST) dataset, consisting of chest CT images from approximately 26k patients.
|
| 70 |
+
Classification performance was evaluated on a withheld test set of CT-RATE and the publicly available subset of RAD-ChestCT, which comprises 3.6k chest CT volumes with 16 multi-abnormality labels. Semantic segmentation was evaluated by training a five-fold cross validation on four datasets:
|
| 71 |
+
|
| 72 |
+
- LiTS: Task 3 of the Medical Segmentation Decathlon (MSD), containing segmentations for liver and liver tumors
|
| 73 |
+
- Lung: Task 6 of MSD, containing segmentations of primary lung cancers
|
| 74 |
+
- HVS: Task 8 of the MSD, containing segmentations of hepatic vessels and tumors next to such vessels
|
| 75 |
+
- KiTS23: containing segmentations of tumors, cysts, and the kidney
|
| 76 |
+
|
| 77 |
+
## Quality and performance evaluation
|
| 78 |
+
|
| 79 |
+
### Results summary
|
| 80 |
+
|
| 81 |
+
#### Classification Probes
|
| 82 |
+
|
| 83 |
+
The COLIPRI family exceeded all third-party models on both the CT-RATE and RAD-ChestCT datasets across all metrics.
|
| 84 |
+
|
| 85 |
+
#### Zero-Shot Classification
|
| 86 |
+
|
| 87 |
+
COLIPRI encoders exceed the state of the art using short prompts on the CT-RATE and RAD-ChestCT datasets.
|
| 88 |
+
|
| 89 |
+
#### Report-to-Image Retrieval
|
| 90 |
+
|
| 91 |
+
COLIPRI encoders are substantially better in retrieving the associated image given a report, evaluated on the CT-RATE test set.
|
| 92 |
+
|
| 93 |
+
#### Segmentation
|
| 94 |
+
|
| 95 |
+
Across four datasets (LiTs, Lung, HVS, KiTS23), COLIPRI-CM and COLIPRI-CRM perform on-par or better than the state of the art.
|
| 96 |
+
|
| 97 |
+
#### Qualitative Analysis
|
| 98 |
+
|
| 99 |
+
COLIPRI embeddings generates sharper and more coherent features than third-party models.
|
| 100 |
+
|
| 101 |
+
### Long context
|
| 102 |
+
|
| 103 |
+
N/A – not an LLM.
|
| 104 |
+
|
| 105 |
+
### Safety evaluation and red-teaming
|
| 106 |
+
|
| 107 |
+
N/A - model is limited to the medical imaging domain. Model outputs were compared to ground-truth images. The standard quality metrics were computed.
|
| 108 |
+
|
| 109 |
+
## Tracked capability evaluations
|
| 110 |
+
|
| 111 |
+
N/A This model is not a frontier model.
|
| 112 |
+
|
| 113 |
+
## Additional information
|
| 114 |
+
|
| 115 |
+
Requests for additional information may be directed to [MSFTAIActRequest@microsoft.com](mailto:MSFTAIActRequest@microsoft.com).
|
pyproject.toml
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "colipri"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Vision-language encoder for chest CT scans and reports."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
authors = [
|
| 7 |
+
{ name = "Microsoft Health Futures", email = "innereyedev@microsoft.com" },
|
| 8 |
+
]
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"accelerate",
|
| 12 |
+
"dynamic-network-architectures@git+https://github.com/fepegar/dynamic-network-architectures.git@88ad3c9",
|
| 13 |
+
"huggingface-hub",
|
| 14 |
+
"hydra-core",
|
| 15 |
+
"jaxtyping",
|
| 16 |
+
"safetensors",
|
| 17 |
+
"timm<1.0.23", # RuntimeError related to position embeddings in 1.0.23
|
| 18 |
+
"torchio>=0.21.1",
|
| 19 |
+
"tqdm",
|
| 20 |
+
"transformers",
|
| 21 |
+
"typer",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.urls]
|
| 25 |
+
Homepage = "https://huggingface.co/microsoft/colipri"
|
| 26 |
+
Source = "https://huggingface.co/microsoft/colipri"
|
| 27 |
+
"Issue tracker" = "https://huggingface.co/microsoft/colipri/discussions/new"
|
| 28 |
+
Documentation = "https://huggingface.co/microsoft/colipri/blob/main/README.md"
|
| 29 |
+
|
| 30 |
+
[project.optional-dependencies]
|
| 31 |
+
demo = [
|
| 32 |
+
"lovely-tensors",
|
| 33 |
+
"matplotlib",
|
| 34 |
+
"scikit-learn",
|
| 35 |
+
"torchinfo",
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
[dependency-groups]
|
| 39 |
+
dev = [
|
| 40 |
+
"ipykernel",
|
| 41 |
+
"ipywidgets",
|
| 42 |
+
]
|
| 43 |
+
types = ["ty"]
|
| 44 |
+
|
| 45 |
+
[build-system]
|
| 46 |
+
requires = ["uv_build"]
|
| 47 |
+
build-backend = "uv_build"
|
| 48 |
+
|
| 49 |
+
[tool.ruff.lint]
|
| 50 |
+
# Defaults from https://docs.astral.sh/ruff/linter/#rule-selection
|
| 51 |
+
select = [
|
| 52 |
+
# pycodestyle
|
| 53 |
+
"E",
|
| 54 |
+
# Pyflakes
|
| 55 |
+
"F",
|
| 56 |
+
# pyupgrade
|
| 57 |
+
"UP",
|
| 58 |
+
# flake8-bugbear
|
| 59 |
+
"B",
|
| 60 |
+
# flake8-simplify
|
| 61 |
+
"SIM",
|
| 62 |
+
# isort
|
| 63 |
+
"I",
|
| 64 |
+
]
|
| 65 |
+
ignore = [
|
| 66 |
+
"F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
[tool.ruff.lint.isort]
|
| 70 |
+
force-single-line = true
|
src/colipri/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .model.multimodal import get_model
|
| 2 |
+
from .pipelines import ZeroShotImageClassificationPipeline
|
| 3 |
+
from .processor import get_processor
|
| 4 |
+
from .sample_data import load_sample_ct
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"get_model",
|
| 8 |
+
"get_processor",
|
| 9 |
+
"load_sample_ct",
|
| 10 |
+
"ZeroShotImageClassificationPipeline",
|
| 11 |
+
]
|
src/colipri/checkpoint.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from hydra import compose
|
| 5 |
+
from hydra import initialize_config_dir
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
from .defaults import REPO_NAME
|
| 10 |
+
from .defaults import REPO_USER
|
| 11 |
+
from .defaults import REVISION
|
| 12 |
+
from .defaults import ROOT_CONFIG_FILENAME
|
| 13 |
+
from .defaults import WEIGHTS_FILENAME
|
| 14 |
+
from .defaults import get_configs_dir
|
| 15 |
+
|
| 16 |
+
TypeStateDict = dict[str, Tensor]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def download_weights(**kwargs) -> Path:
|
| 20 |
+
weights_path = _download_from_hugging_face(**kwargs)
|
| 21 |
+
return weights_path
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _download_from_hugging_face(
|
| 25 |
+
repo_user: str = REPO_USER,
|
| 26 |
+
repo_name: str = REPO_NAME,
|
| 27 |
+
revision: str | None = REVISION,
|
| 28 |
+
filename: str = WEIGHTS_FILENAME,
|
| 29 |
+
) -> Path:
|
| 30 |
+
repo_id = f"{repo_user}/{repo_name}"
|
| 31 |
+
try:
|
| 32 |
+
weights_path = hf_hub_download(
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
filename=filename,
|
| 35 |
+
revision=revision,
|
| 36 |
+
)
|
| 37 |
+
except Exception as e:
|
| 38 |
+
msg = f'Failed to download "{filename}" from Hugging Face Hub repo "{repo_id}".'
|
| 39 |
+
raise RuntimeError(msg) from e
|
| 40 |
+
return Path(weights_path)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _load_config(**kwargs) -> DictConfig:
|
| 44 |
+
with initialize_config_dir(str(get_configs_dir()), version_base=None):
|
| 45 |
+
config = compose(ROOT_CONFIG_FILENAME, **kwargs)
|
| 46 |
+
return config
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_model_config(**kwargs) -> DictConfig:
|
| 50 |
+
return _load_config(**kwargs)["model"]
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_processor_config(**kwargs) -> DictConfig:
|
| 54 |
+
return _load_config(**kwargs)["processor"]
|
src/colipri/configs/config.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- model: default
|
| 3 |
+
- processor: default
|
| 4 |
+
- _self_
|
| 5 |
+
|
| 6 |
+
input_size: 192
|
| 7 |
+
spacing: 2 # mm
|
| 8 |
+
padding_mode: minimum
|
| 9 |
+
image_embed_dim: 864
|
| 10 |
+
projection_dim: 768
|
src/colipri/configs/model/default.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_encoder: default
|
| 3 |
+
- text_encoder: default
|
| 4 |
+
|
| 5 |
+
_target_: colipri.model.multimodal.Model
|
src/colipri/configs/model/image_encoder/backbone/default.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: dynamic_network_architectures.architectures.primus.Primus
|
| 2 |
+
input_channels: 1
|
| 3 |
+
num_classes: 1
|
| 4 |
+
eva_depth: 16
|
| 5 |
+
eva_numheads: 12
|
| 6 |
+
embed_dim: ${image_embed_dim}
|
| 7 |
+
patch_embed_size:
|
| 8 |
+
- 8
|
| 9 |
+
- 8
|
| 10 |
+
- 8
|
| 11 |
+
input_shape:
|
| 12 |
+
- ${input_size}
|
| 13 |
+
- ${input_size}
|
| 14 |
+
- ${input_size}
|
| 15 |
+
use_rot_pos_emb: True
|
| 16 |
+
use_abs_pos_embed: False
|
| 17 |
+
drop_path_rate: 0.2
|
| 18 |
+
init_values: 0.1
|
| 19 |
+
scale_attn_inner: True
|
| 20 |
+
num_register_tokens: 0
|
src/colipri/configs/model/image_encoder/default.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- backbone: default
|
| 3 |
+
- pooler: default
|
| 4 |
+
- projector: default
|
| 5 |
+
|
| 6 |
+
_target_: colipri.model.image.ImageEncoder
|
src/colipri/configs/model/image_encoder/pooler/default.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: colipri.pooling.AttentionPool1D
|
| 2 |
+
embed_dim: ${projection_dim}
|
| 3 |
+
num_heads: 12
|
src/colipri/configs/model/image_encoder/projector/default.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: torch.nn.Conv3d
|
| 2 |
+
in_channels: ${image_embed_dim}
|
| 3 |
+
out_channels: ${projection_dim}
|
| 4 |
+
kernel_size: 1
|
src/colipri/configs/model/text_encoder/backbone/default.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.AutoModel.from_pretrained
|
| 2 |
+
pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
|
| 3 |
+
trust_remote_code: yes
|
| 4 |
+
revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
|
src/colipri/configs/model/text_encoder/default.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- backbone: default
|
| 3 |
+
- pooler: default
|
| 4 |
+
|
| 5 |
+
_target_: colipri.model.text.TextEncoder
|
src/colipri/configs/model/text_encoder/pooler/default.yaml
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: colipri.pooling.AttentionPool1D
|
| 2 |
+
embed_dim: ${projection_dim}
|
| 3 |
+
num_heads: 12
|
src/colipri/configs/processor/default.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
defaults:
|
| 2 |
+
- image_transform: default
|
| 3 |
+
- tokenizer: default
|
| 4 |
+
|
| 5 |
+
_target_: colipri.processor.Processor
|
src/colipri/configs/processor/image_transform/default.yaml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Exported automatically from nnssl: https://github.com/microsoft/nnssl/pull/41, then
|
| 2 |
+
# adapted for Colipri needs.
|
| 3 |
+
_target_: torchio.transforms.augmentation.composition.Compose
|
| 4 |
+
transforms:
|
| 5 |
+
- _target_: torchio.transforms.preprocessing.spatial.to_orientation.ToOrientation
|
| 6 |
+
copy: true
|
| 7 |
+
exclude: null
|
| 8 |
+
include: null
|
| 9 |
+
orientation: SAR
|
| 10 |
+
- _target_: torchio.transforms.preprocessing.spatial.resample.Resample
|
| 11 |
+
antialias: false
|
| 12 |
+
copy: true
|
| 13 |
+
exclude: null
|
| 14 |
+
image_interpolation: linear
|
| 15 |
+
include: null
|
| 16 |
+
label_interpolation: nearest
|
| 17 |
+
pre_affine_name: null
|
| 18 |
+
scalars_only: false
|
| 19 |
+
target: ${spacing}
|
| 20 |
+
- _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp
|
| 21 |
+
out_min: -1000
|
| 22 |
+
out_max: 1000
|
| 23 |
+
- _target_: torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity
|
| 24 |
+
copy: true
|
| 25 |
+
exclude: null
|
| 26 |
+
in_min_max:
|
| 27 |
+
- -1000
|
| 28 |
+
- 1000
|
| 29 |
+
include: null
|
| 30 |
+
masking_method: null
|
| 31 |
+
out_min_max:
|
| 32 |
+
- -1
|
| 33 |
+
- 1
|
| 34 |
+
percentiles:
|
| 35 |
+
- 0
|
| 36 |
+
- 100
|
| 37 |
+
- _target_: torchio.transforms.preprocessing.spatial.crop_or_pad.CropOrPad
|
| 38 |
+
copy: true
|
| 39 |
+
exclude: null
|
| 40 |
+
include: null
|
| 41 |
+
labels: null
|
| 42 |
+
mask_name: null
|
| 43 |
+
only_crop: false
|
| 44 |
+
only_pad: false
|
| 45 |
+
padding_mode: ${padding_mode}
|
| 46 |
+
target_shape:
|
| 47 |
+
- ${input_size}
|
| 48 |
+
- ${input_size}
|
| 49 |
+
- ${input_size}
|
src/colipri/configs/processor/tokenizer/default.yaml
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: transformers.AutoTokenizer.from_pretrained
|
| 2 |
+
pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
|
| 3 |
+
trust_remote_code: yes
|
| 4 |
+
revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
|
| 5 |
+
model_max_length: 512
|
src/colipri/defaults.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.resources
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
|
| 4 |
+
package_name = __package__
|
| 5 |
+
assert isinstance(package_name, str)
|
| 6 |
+
|
| 7 |
+
REPO_USER = "microsoft"
|
| 8 |
+
REPO_NAME = package_name
|
| 9 |
+
REVISION = None
|
| 10 |
+
WEIGHTS_FILENAME = "model.safetensors"
|
| 11 |
+
ROOT_CONFIG_FILENAME = "config.yaml"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_configs_dir(pkg: str | None = None) -> Path:
|
| 15 |
+
package_dir = importlib.resources.files(package_name)
|
| 16 |
+
return Path(package_dir / "configs")
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"get_configs_dir",
|
| 21 |
+
]
|
src/colipri/model/__init__.py
ADDED
|
File without changes
|
src/colipri/model/base.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ModelMixin:
|
| 5 |
+
@property
|
| 6 |
+
def device(self) -> torch.device:
|
| 7 |
+
one_parameter = next(self.parameters())
|
| 8 |
+
return one_parameter.device
|
src/colipri/model/image.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchio as tio
|
| 6 |
+
from dynamic_network_architectures.architectures.primus import Primus
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
from torch import nn
|
| 9 |
+
from tqdm.auto import tqdm
|
| 10 |
+
|
| 11 |
+
from ..pooling import AttentionPool1D
|
| 12 |
+
from ..types import TypeImageEmbeddings
|
| 13 |
+
from ..types import TypeImagesTensor
|
| 14 |
+
from ..types import TypeIntOrTripletInt
|
| 15 |
+
from .base import ModelMixin
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ImageEncoder(nn.Module, ModelMixin):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
backbone: Primus,
|
| 22 |
+
projector: nn.Conv3d,
|
| 23 |
+
pooler: AttentionPool1D,
|
| 24 |
+
):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.backbone = self._remove_decoder(backbone)
|
| 27 |
+
self.projector = projector
|
| 28 |
+
self.pooler = pooler
|
| 29 |
+
|
| 30 |
+
def _remove_decoder(self, backbone: Primus) -> Primus:
|
| 31 |
+
if hasattr(backbone, "up_projection"):
|
| 32 |
+
backbone.up_projection = nn.Identity() # type: ignore
|
| 33 |
+
return backbone
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def patch_size(self) -> tuple[int, int, int]:
|
| 37 |
+
patch_embedder: nn.Conv3d = self.backbone.down_projection.proj # type: ignore[reportAssignmentType]
|
| 38 |
+
patch_size: tuple[int, int, int] = patch_embedder.stride # type: ignore[reportAssignmentType]
|
| 39 |
+
return patch_size
|
| 40 |
+
|
| 41 |
+
def encode(
|
| 42 |
+
self,
|
| 43 |
+
images: TypeImagesTensor,
|
| 44 |
+
) -> TypeImageEmbeddings:
|
| 45 |
+
images = images.to(self.device)
|
| 46 |
+
embeddings: TypeImageEmbeddings = self.backbone(images)
|
| 47 |
+
return embeddings
|
| 48 |
+
|
| 49 |
+
def encode_sliding_window(
|
| 50 |
+
self,
|
| 51 |
+
images: TypeImagesTensor,
|
| 52 |
+
window_size: TypeIntOrTripletInt,
|
| 53 |
+
overlap: int = 0,
|
| 54 |
+
) -> TypeImageEmbeddings:
|
| 55 |
+
if len(set(self.patch_size)) > 1:
|
| 56 |
+
msg = (
|
| 57 |
+
"Sliding window encoding is only supported for models with cubic"
|
| 58 |
+
" patch sizes for now."
|
| 59 |
+
)
|
| 60 |
+
raise NotImplementedError(msg)
|
| 61 |
+
else:
|
| 62 |
+
patch_size = self.patch_size[0]
|
| 63 |
+
image_key = "image" # could be anything
|
| 64 |
+
embeddings = []
|
| 65 |
+
for image in images:
|
| 66 |
+
grid_sampler = tio.inference.GridSampler(
|
| 67 |
+
tio.Subject(**{image_key: tio.ScalarImage(tensor=image)}),
|
| 68 |
+
window_size,
|
| 69 |
+
overlap,
|
| 70 |
+
)
|
| 71 |
+
patch_loader = tio.SubjectsLoader(grid_sampler) # type: ignore[reportArgumentType]
|
| 72 |
+
aggregator = tio.data.GridAggregator(
|
| 73 |
+
grid_sampler,
|
| 74 |
+
downsampling_factor=patch_size,
|
| 75 |
+
)
|
| 76 |
+
for patches_batch in tqdm(patch_loader):
|
| 77 |
+
input_tensor = patches_batch[image_key][tio.DATA].to(self.device)
|
| 78 |
+
locations = patches_batch[tio.LOCATION]
|
| 79 |
+
outputs = self.backbone(input_tensor)
|
| 80 |
+
aggregator.add_batch(outputs, locations)
|
| 81 |
+
embeddings.append(aggregator.get_output_tensor())
|
| 82 |
+
return torch.stack(embeddings).to(self.device)
|
| 83 |
+
|
| 84 |
+
def forward(
|
| 85 |
+
self,
|
| 86 |
+
images: TypeImagesTensor,
|
| 87 |
+
*,
|
| 88 |
+
project: bool,
|
| 89 |
+
pool: bool,
|
| 90 |
+
normalize: bool,
|
| 91 |
+
window_size: TypeIntOrTripletInt | None = None,
|
| 92 |
+
) -> TypeImageEmbeddings:
|
| 93 |
+
if pool and not project:
|
| 94 |
+
msg = "Pooling requires projection to be enabled. Set project=True."
|
| 95 |
+
raise NotImplementedError(msg)
|
| 96 |
+
if window_size is None:
|
| 97 |
+
embeddings = self.encode(images)
|
| 98 |
+
else:
|
| 99 |
+
embeddings = self.encode_sliding_window(images, window_size)
|
| 100 |
+
if project:
|
| 101 |
+
embeddings = self.projector(embeddings)
|
| 102 |
+
if pool:
|
| 103 |
+
sequence = rearrange(embeddings, "b c x y z -> b (x y z) c")
|
| 104 |
+
embeddings = self.pooler(sequence)
|
| 105 |
+
else:
|
| 106 |
+
embeddings = self.pooler.to_dense()(embeddings)
|
| 107 |
+
if normalize:
|
| 108 |
+
embeddings = F.normalize(embeddings, dim=1)
|
| 109 |
+
return embeddings
|
src/colipri/model/multimodal.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from accelerate import init_empty_weights
|
| 8 |
+
from accelerate import load_checkpoint_and_dispatch
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from hydra.utils import instantiate
|
| 11 |
+
from safetensors.torch import load_model
|
| 12 |
+
from safetensors.torch import save_model
|
| 13 |
+
from torch import nn
|
| 14 |
+
from transformers.utils.logging import get_verbosity
|
| 15 |
+
from transformers.utils.logging import set_verbosity
|
| 16 |
+
from transformers.utils.logging import set_verbosity_error
|
| 17 |
+
|
| 18 |
+
from ..checkpoint import download_weights
|
| 19 |
+
from ..checkpoint import load_model_config
|
| 20 |
+
from ..types import TypeImageEmbeddings
|
| 21 |
+
from ..types import TypeImagesTensor
|
| 22 |
+
from ..types import TypeIntOrTripletInt
|
| 23 |
+
from ..types import TypePatchEmbeddings
|
| 24 |
+
from ..types import TypePatchLogits
|
| 25 |
+
from ..types import TypePatchProbabilities
|
| 26 |
+
from ..types import TypePath
|
| 27 |
+
from ..types import TypePooledEmbeddings
|
| 28 |
+
from ..types import TypePooledLogits
|
| 29 |
+
from ..types import TypePooledProbabilities
|
| 30 |
+
from ..types import TypeTextAttentionMask
|
| 31 |
+
from ..types import TypeTextEmbeddings
|
| 32 |
+
from ..types import TypeTextTokenIds
|
| 33 |
+
from .base import ModelMixin
|
| 34 |
+
from .image import ImageEncoder
|
| 35 |
+
from .text import TextEncoder
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_model(
|
| 39 |
+
checkpoint_path: TypePath | None = None,
|
| 40 |
+
*,
|
| 41 |
+
pretrained: bool = True,
|
| 42 |
+
image_only: bool = False,
|
| 43 |
+
**kwargs,
|
| 44 |
+
) -> Model:
|
| 45 |
+
if pretrained and checkpoint_path is None:
|
| 46 |
+
checkpoint_path = download_weights()
|
| 47 |
+
|
| 48 |
+
overrides = []
|
| 49 |
+
for key, value in kwargs.items():
|
| 50 |
+
overrides.append(f"{key}={value}")
|
| 51 |
+
config = load_model_config(overrides=overrides)
|
| 52 |
+
|
| 53 |
+
if image_only:
|
| 54 |
+
config.text_encoder = None
|
| 55 |
+
|
| 56 |
+
transformers_verbosity = get_verbosity()
|
| 57 |
+
set_verbosity_error()
|
| 58 |
+
if checkpoint_path is None:
|
| 59 |
+
model = instantiate(config)
|
| 60 |
+
else:
|
| 61 |
+
with init_empty_weights():
|
| 62 |
+
model = instantiate(config)
|
| 63 |
+
accelerate_logger = logging.getLogger("accelerate.utils.modeling")
|
| 64 |
+
old_level = accelerate_logger.getEffectiveLevel()
|
| 65 |
+
accelerate_logger.setLevel(logging.ERROR)
|
| 66 |
+
model = load_checkpoint_and_dispatch(model, checkpoint_path)
|
| 67 |
+
accelerate_logger.setLevel(old_level)
|
| 68 |
+
set_verbosity(transformers_verbosity)
|
| 69 |
+
|
| 70 |
+
assert isinstance(model, Model)
|
| 71 |
+
return model.eval()
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Model(nn.Module, ModelMixin):
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
image_encoder: ImageEncoder,
|
| 78 |
+
text_encoder: TextEncoder,
|
| 79 |
+
temperature: float = 1,
|
| 80 |
+
):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.image_encoder = image_encoder
|
| 83 |
+
self.text_encoder = text_encoder
|
| 84 |
+
self.register_buffer("softmax_temperature", torch.tensor(temperature))
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def patch_size(self) -> tuple[int, int, int]:
|
| 88 |
+
return self.image_encoder.patch_size
|
| 89 |
+
|
| 90 |
+
def encode_image(
|
| 91 |
+
self,
|
| 92 |
+
images: TypeImagesTensor,
|
| 93 |
+
*,
|
| 94 |
+
pool: bool = False,
|
| 95 |
+
project: bool = False,
|
| 96 |
+
normalize: bool = False,
|
| 97 |
+
window_size: TypeIntOrTripletInt | None = None,
|
| 98 |
+
) -> TypeImageEmbeddings:
|
| 99 |
+
return self.image_encoder(
|
| 100 |
+
images,
|
| 101 |
+
project=project,
|
| 102 |
+
pool=pool,
|
| 103 |
+
normalize=normalize,
|
| 104 |
+
window_size=window_size,
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
def encode_text(
|
| 108 |
+
self,
|
| 109 |
+
token_ids: TypeTextTokenIds,
|
| 110 |
+
attention_mask: TypeTextAttentionMask | None = None,
|
| 111 |
+
*,
|
| 112 |
+
pool: bool = True,
|
| 113 |
+
normalize: bool = True,
|
| 114 |
+
) -> TypeTextEmbeddings:
|
| 115 |
+
return self.text_encoder(
|
| 116 |
+
token_ids,
|
| 117 |
+
attention_mask,
|
| 118 |
+
pool=pool,
|
| 119 |
+
normalize=normalize,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
def compute_similarities(
|
| 123 |
+
self,
|
| 124 |
+
image_embeddings: TypePooledEmbeddings | TypePatchEmbeddings,
|
| 125 |
+
text_embeddings: TypePooledEmbeddings,
|
| 126 |
+
) -> TypePatchLogits | TypePooledLogits:
|
| 127 |
+
text_embeddings = rearrange(text_embeddings, "num_prompts c -> c num_prompts")
|
| 128 |
+
is_grid = image_embeddings.ndim == 5
|
| 129 |
+
if is_grid:
|
| 130 |
+
num_images, _, x, y, z = image_embeddings.shape
|
| 131 |
+
image_embeddings = rearrange(
|
| 132 |
+
image_embeddings,
|
| 133 |
+
"num_images c x y z -> (num_images x y z) c",
|
| 134 |
+
)
|
| 135 |
+
similarities_flat = image_embeddings @ text_embeddings
|
| 136 |
+
similarities = rearrange(
|
| 137 |
+
similarities_flat,
|
| 138 |
+
"(num_images x y z) num_prompts -> num_images num_prompts x y z",
|
| 139 |
+
num_images=num_images,
|
| 140 |
+
x=x,
|
| 141 |
+
y=y,
|
| 142 |
+
z=z,
|
| 143 |
+
)
|
| 144 |
+
else:
|
| 145 |
+
similarities = image_embeddings @ text_embeddings
|
| 146 |
+
return similarities
|
| 147 |
+
|
| 148 |
+
def classify(
|
| 149 |
+
self,
|
| 150 |
+
images: TypePooledEmbeddings | TypePatchEmbeddings,
|
| 151 |
+
text: TypePooledEmbeddings,
|
| 152 |
+
) -> TypePooledProbabilities | TypePatchProbabilities:
|
| 153 |
+
logits = self.compute_similarities(images, text)
|
| 154 |
+
assert isinstance(self.softmax_temperature, torch.Tensor)
|
| 155 |
+
probabilities = (logits / self.softmax_temperature).softmax(dim=-1)
|
| 156 |
+
return probabilities
|
| 157 |
+
|
| 158 |
+
def save_weights(self, path: TypePath) -> None:
|
| 159 |
+
path = Path(path)
|
| 160 |
+
if path.suffix == ".safetensors":
|
| 161 |
+
save_model(self, str(path))
|
| 162 |
+
else:
|
| 163 |
+
weights = self.state_dict()
|
| 164 |
+
torch.save(weights, path)
|
| 165 |
+
|
| 166 |
+
def load_weights(
|
| 167 |
+
self,
|
| 168 |
+
path: TypePath,
|
| 169 |
+
device: torch.device | str | None = None,
|
| 170 |
+
) -> None:
|
| 171 |
+
path = Path(path)
|
| 172 |
+
if not path.exists():
|
| 173 |
+
raise FileNotFoundError(f"Checkpoint file {path} does not exist.")
|
| 174 |
+
if device is None:
|
| 175 |
+
device = torch.device("cpu")
|
| 176 |
+
if path.suffix == ".safetensors":
|
| 177 |
+
if device is not None:
|
| 178 |
+
device = str(device)
|
| 179 |
+
missing, unexpected = load_model(self, path, device=device)
|
| 180 |
+
if missing or unexpected:
|
| 181 |
+
raise RuntimeError("TODO")
|
| 182 |
+
else:
|
| 183 |
+
weights = torch.load(path, map_location=device)
|
| 184 |
+
self.load_state_dict(weights, strict=False)
|
src/colipri/model/text.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch import nn
|
| 5 |
+
from transformers import BertModel
|
| 6 |
+
|
| 7 |
+
from ..types import TypeTextAttentionMask
|
| 8 |
+
from ..types import TypeTextEmbeddings
|
| 9 |
+
from ..types import TypeTextTokenIds
|
| 10 |
+
from .base import ModelMixin
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class TextEncoder(nn.Module, ModelMixin):
|
| 14 |
+
def __init__(self, backbone: nn.Module, pooler: nn.Module):
|
| 15 |
+
super().__init__()
|
| 16 |
+
if hasattr(backbone, "bert"):
|
| 17 |
+
bert = backbone.bert
|
| 18 |
+
assert isinstance(bert, BertModel)
|
| 19 |
+
bert.pooler = None # we use our own pooler
|
| 20 |
+
backbone = bert
|
| 21 |
+
self.backbone = backbone
|
| 22 |
+
self.pooler = pooler
|
| 23 |
+
|
| 24 |
+
def forward(
|
| 25 |
+
self,
|
| 26 |
+
token_ids: TypeTextTokenIds,
|
| 27 |
+
attention_mask: TypeTextAttentionMask | None = None,
|
| 28 |
+
*,
|
| 29 |
+
pool: bool = True,
|
| 30 |
+
normalize: bool = True,
|
| 31 |
+
) -> TypeTextEmbeddings:
|
| 32 |
+
token_ids = token_ids.to(self.device)
|
| 33 |
+
if attention_mask is not None:
|
| 34 |
+
attention_mask = attention_mask.to(self.device)
|
| 35 |
+
# We didn't use a projector during pre-training
|
| 36 |
+
output = self.backbone(token_ids, attention_mask=attention_mask)
|
| 37 |
+
embeddings = output["last_hidden_state"]
|
| 38 |
+
if pool:
|
| 39 |
+
embeddings = self.pooler(embeddings)
|
| 40 |
+
if normalize:
|
| 41 |
+
embeddings = F.normalize(embeddings, dim=1)
|
| 42 |
+
return embeddings
|
src/colipri/pipelines.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from .model.multimodal import Model
|
| 4 |
+
from .processor import Processor
|
| 5 |
+
from .types import TypeImage
|
| 6 |
+
from .types import TypePath
|
| 7 |
+
from .types import TypeScores
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ZeroShotImageClassificationPipeline:
|
| 11 |
+
def __init__(
|
| 12 |
+
self,
|
| 13 |
+
model: Model,
|
| 14 |
+
processor: Processor,
|
| 15 |
+
batch_size: int = 1, # TODO
|
| 16 |
+
num_workers: int = 0, # TODO
|
| 17 |
+
):
|
| 18 |
+
self._model = model.eval()
|
| 19 |
+
self._processor = processor
|
| 20 |
+
|
| 21 |
+
@property
|
| 22 |
+
def device(self) -> torch.device:
|
| 23 |
+
return self._model.device
|
| 24 |
+
|
| 25 |
+
# TODO: add support for multiple prompts per class
|
| 26 |
+
@torch.no_grad()
|
| 27 |
+
def __call__(
|
| 28 |
+
self,
|
| 29 |
+
images: TypePath | TypeImage | list[TypePath | TypeImage],
|
| 30 |
+
prompts: str | list[str],
|
| 31 |
+
) -> TypeScores | list[TypeScores]:
|
| 32 |
+
preprocessed_images = self._processor.process_images(images)
|
| 33 |
+
images_batch = self._processor.to_images_batch(preprocessed_images)
|
| 34 |
+
|
| 35 |
+
text_token_ids, text_attention_mask = self._processor.process_text(prompts)
|
| 36 |
+
|
| 37 |
+
image_embeddings_batch = self._model.encode_image(
|
| 38 |
+
images_batch,
|
| 39 |
+
project=True,
|
| 40 |
+
pool=True,
|
| 41 |
+
normalize=True,
|
| 42 |
+
)
|
| 43 |
+
text_embeddings_batch = self._model.encode_text(
|
| 44 |
+
text_token_ids,
|
| 45 |
+
text_attention_mask,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
probabilities = self._model.classify(
|
| 49 |
+
image_embeddings_batch,
|
| 50 |
+
text_embeddings_batch,
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if not isinstance(prompts, list):
|
| 54 |
+
prompts = [prompts]
|
| 55 |
+
|
| 56 |
+
all_results = []
|
| 57 |
+
for image_probabilities in probabilities:
|
| 58 |
+
image_results = []
|
| 59 |
+
for prompt, score in zip(prompts, image_probabilities, strict=True):
|
| 60 |
+
image_results.append({"score": score.item(), "label": prompt})
|
| 61 |
+
all_results.append(image_results)
|
| 62 |
+
|
| 63 |
+
if len(images_batch) == 1 and not isinstance(images, list):
|
| 64 |
+
all_results = all_results[0]
|
| 65 |
+
return all_results
|
src/colipri/pooling.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
from .types import TypePooledEmbeddings
|
| 6 |
+
from .types import TypeSequenceEmbeddings
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class AttentionPool1D(nn.Module):
|
| 10 |
+
def __init__(self, embed_dim: int, num_heads: int):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
|
| 13 |
+
|
| 14 |
+
def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
|
| 15 |
+
query = x.mean(dim=1, keepdim=True)
|
| 16 |
+
key = value = x
|
| 17 |
+
pooled, _ = self.attn(query, key, value)
|
| 18 |
+
return rearrange(pooled, "batch 1 embed_dim -> batch embed_dim")
|
| 19 |
+
|
| 20 |
+
def to_dense(self):
|
| 21 |
+
v_proj_in_weight_qkv = self.get_parameter("attn.in_proj_weight")
|
| 22 |
+
v_proj_in_bias_qkv = self.get_parameter("attn.in_proj_bias")
|
| 23 |
+
v_proj_out_weight = self.get_parameter("attn.out_proj.weight")
|
| 24 |
+
v_proj_out_bias = self.get_parameter("attn.out_proj.bias")
|
| 25 |
+
dim = v_proj_in_weight_qkv.shape[0] // 3
|
| 26 |
+
v_proj_in_weight_v = v_proj_in_weight_qkv[2 * dim :]
|
| 27 |
+
v_proj_in_bias_v = v_proj_in_bias_qkv[2 * dim :]
|
| 28 |
+
|
| 29 |
+
value_projection = nn.Conv3d(
|
| 30 |
+
in_channels=dim,
|
| 31 |
+
out_channels=dim,
|
| 32 |
+
kernel_size=1,
|
| 33 |
+
)
|
| 34 |
+
value_projection.weight.data = rearrange(
|
| 35 |
+
v_proj_in_weight_v,
|
| 36 |
+
"c_out c_in -> c_out c_in 1 1 1",
|
| 37 |
+
)
|
| 38 |
+
assert value_projection.bias is not None
|
| 39 |
+
value_projection.bias.data = v_proj_in_bias_v
|
| 40 |
+
|
| 41 |
+
out_projection = nn.Conv3d(
|
| 42 |
+
in_channels=dim,
|
| 43 |
+
out_channels=dim,
|
| 44 |
+
kernel_size=1,
|
| 45 |
+
)
|
| 46 |
+
out_projection.weight.data = rearrange(
|
| 47 |
+
v_proj_out_weight,
|
| 48 |
+
"c_out c_in -> c_out c_in 1 1 1",
|
| 49 |
+
)
|
| 50 |
+
assert out_projection.bias is not None
|
| 51 |
+
out_projection.bias.data = v_proj_out_bias
|
| 52 |
+
|
| 53 |
+
return nn.Sequential(
|
| 54 |
+
value_projection,
|
| 55 |
+
out_projection,
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class MultiLearnedQueryAttentionPool1D(AttentionPool1D):
|
| 60 |
+
def __init__(self, embed_dim: int, num_heads: int):
|
| 61 |
+
super().__init__(embed_dim, num_heads)
|
| 62 |
+
# 4 Queries instead of 1
|
| 63 |
+
self.query = nn.Parameter(torch.randn(1, 4, embed_dim) / embed_dim**0.5)
|
| 64 |
+
|
| 65 |
+
def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
|
| 66 |
+
"""
|
| 67 |
+
x: [B, T, D] — sequence of token embeddings
|
| 68 |
+
returns: [B, D] — pooled representation
|
| 69 |
+
"""
|
| 70 |
+
B, T, D = x.shape
|
| 71 |
+
query = self.query.expand(B, -1, -1) # [B, 4, D]
|
| 72 |
+
pooled, _ = self.attn(query, x, x) # [B, 4, D]
|
| 73 |
+
# pooled: [4, B, D_out], want [B, D_out] by pooling over queries (mean)
|
| 74 |
+
return pooled.mean(dim=1) # [B, D]
|
src/colipri/processor.py
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torchio as tio
|
| 5 |
+
from hydra.utils import instantiate
|
| 6 |
+
from transformers import BertTokenizer
|
| 7 |
+
from transformers.tokenization_utils_base import BatchEncoding
|
| 8 |
+
|
| 9 |
+
from .checkpoint import load_processor_config
|
| 10 |
+
from .types import TypeImage
|
| 11 |
+
from .types import TypeImages
|
| 12 |
+
from .types import TypeImagesTensor
|
| 13 |
+
from .types import TypePath
|
| 14 |
+
from .types import TypeStringOrStrings
|
| 15 |
+
from .types import TypeTextAttentionMask
|
| 16 |
+
from .types import TypeTextTokenIds
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_processor(*, image_only: bool = False, **kwargs) -> Processor:
|
| 20 |
+
overrides = []
|
| 21 |
+
for key, value in kwargs.items():
|
| 22 |
+
overrides.append(f"{key}={value}")
|
| 23 |
+
config = load_processor_config(overrides=overrides)
|
| 24 |
+
if image_only:
|
| 25 |
+
config.tokenizer = None
|
| 26 |
+
return instantiate(config)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Processor:
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
image_transform: tio.Transform,
|
| 33 |
+
tokenizer: BertTokenizer,
|
| 34 |
+
):
|
| 35 |
+
self._image_transform = image_transform
|
| 36 |
+
self._text_tokenizer = tokenizer
|
| 37 |
+
|
| 38 |
+
def __repr__(self) -> str:
|
| 39 |
+
lines = [
|
| 40 |
+
f"{self.__class__.__name__}(",
|
| 41 |
+
f" image_transform={self._image_transform},",
|
| 42 |
+
f" text_tokenizer={self._text_tokenizer},",
|
| 43 |
+
")",
|
| 44 |
+
]
|
| 45 |
+
return "\n".join(lines)
|
| 46 |
+
|
| 47 |
+
def process_images(
|
| 48 |
+
self,
|
| 49 |
+
inputs: TypePath | TypeImage | list[TypePath | TypeImage],
|
| 50 |
+
) -> TypeImages:
|
| 51 |
+
if not isinstance(inputs, list):
|
| 52 |
+
inputs = [inputs]
|
| 53 |
+
images = []
|
| 54 |
+
for image_or_path in inputs:
|
| 55 |
+
is_image = isinstance(image_or_path, tio.ScalarImage)
|
| 56 |
+
if is_image:
|
| 57 |
+
image = image_or_path
|
| 58 |
+
else:
|
| 59 |
+
path = image_or_path
|
| 60 |
+
image = tio.ScalarImage(path)
|
| 61 |
+
images.append(image)
|
| 62 |
+
images = [self._image_transform(image) for image in images]
|
| 63 |
+
return images
|
| 64 |
+
|
| 65 |
+
def to_images_batch(
|
| 66 |
+
self,
|
| 67 |
+
images: TypeImages,
|
| 68 |
+
) -> TypeImagesTensor:
|
| 69 |
+
if not isinstance(images, list):
|
| 70 |
+
msg = f"Expected images to be a list, got {type(images)}"
|
| 71 |
+
raise TypeError(msg)
|
| 72 |
+
images_tensor = torch.stack([image.data for image in images])
|
| 73 |
+
return images_tensor
|
| 74 |
+
|
| 75 |
+
def process_text(
|
| 76 |
+
self,
|
| 77 |
+
text: TypeStringOrStrings,
|
| 78 |
+
**kwargs,
|
| 79 |
+
) -> tuple[TypeTextTokenIds, TypeTextAttentionMask]:
|
| 80 |
+
encodings: BatchEncoding = self._text_tokenizer(
|
| 81 |
+
text,
|
| 82 |
+
max_length=512, # TODO
|
| 83 |
+
padding="max_length",
|
| 84 |
+
truncation=True,
|
| 85 |
+
return_tensors="pt",
|
| 86 |
+
**kwargs,
|
| 87 |
+
)
|
| 88 |
+
token_ids: TypeTextTokenIds = encodings["input_ids"] # type: ignore[reportAssignmentType]
|
| 89 |
+
attention_mask: TypeTextAttentionMask = encodings["attention_mask"] # type: ignore[reportAssignmentType]
|
| 90 |
+
return token_ids, attention_mask
|
src/colipri/sample_data.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
|
| 3 |
+
import torchio as tio
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def load_sample_ct(clip: bool = True) -> tio.ScalarImage:
|
| 7 |
+
"""Get a sample CT image."""
|
| 8 |
+
ct = get_sample_ct()
|
| 9 |
+
if clip:
|
| 10 |
+
clamp = tio.Clamp(-1000, 1000)
|
| 11 |
+
ct = clamp(ct)
|
| 12 |
+
else:
|
| 13 |
+
ct.load()
|
| 14 |
+
return ct
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_sample_ct() -> tio.ScalarImage:
|
| 18 |
+
"""Download a sample CT image if not already present."""
|
| 19 |
+
return tio.datasets.Slicer("CTChest").CT_chest # type: ignore[reportAttributeAccessIssue]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def get_sample_ct_path() -> Path:
|
| 23 |
+
"""Get the path to a sample CT image."""
|
| 24 |
+
return get_sample_ct().path # type: ignore[reportReturnType]
|
src/colipri/types.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import TypeAlias
|
| 3 |
+
|
| 4 |
+
import torchio as tio
|
| 5 |
+
from jaxtyping import Float
|
| 6 |
+
from jaxtyping import Int64
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
TypePath: TypeAlias = str | os.PathLike[str]
|
| 10 |
+
TypePaths: TypeAlias = list[TypePath]
|
| 11 |
+
|
| 12 |
+
# Raw inputs
|
| 13 |
+
TypeImage: TypeAlias = tio.ScalarImage
|
| 14 |
+
TypeImages: TypeAlias = list[tio.ScalarImage]
|
| 15 |
+
TypeImageOrImages: TypeAlias = TypeImage | TypeImages
|
| 16 |
+
|
| 17 |
+
TypeString: TypeAlias = str
|
| 18 |
+
TypeStrings: TypeAlias = list[str]
|
| 19 |
+
TypeStringOrStrings: TypeAlias = TypeString | TypeStrings
|
| 20 |
+
|
| 21 |
+
TypeIntOrTripletInt: TypeAlias = int | tuple[int, int, int]
|
| 22 |
+
|
| 23 |
+
# Processed inputs
|
| 24 |
+
TypeImagesTensor: TypeAlias = Float[Tensor, "batch 1 x_in y_in z_in"]
|
| 25 |
+
|
| 26 |
+
TypeTextTokenIds: TypeAlias = Int64[Tensor, "batch text_length"]
|
| 27 |
+
TypeTextAttentionMask: TypeAlias = Int64[Tensor, "batch text_length"]
|
| 28 |
+
|
| 29 |
+
# Outputs
|
| 30 |
+
TypeSequenceEmbeddings: TypeAlias = Float[Tensor, "batch sequence_length embed_dim"]
|
| 31 |
+
TypePooledEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim"]
|
| 32 |
+
|
| 33 |
+
TypePatchEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim x_out y_out z_out"]
|
| 34 |
+
TypeImageEmbeddings: TypeAlias = TypePatchEmbeddings | TypePooledEmbeddings
|
| 35 |
+
|
| 36 |
+
TypeTextEmbeddings: TypeAlias = TypeSequenceEmbeddings | TypePooledEmbeddings
|
| 37 |
+
|
| 38 |
+
TypePatchLogits: TypeAlias = Float[Tensor, "num_images num_prompts x_out y_out z_out"]
|
| 39 |
+
TypePooledLogits: TypeAlias = Float[Tensor, "num_images num_prompts"]
|
| 40 |
+
TypePatchProbabilities: TypeAlias = TypePatchLogits
|
| 41 |
+
TypePooledProbabilities: TypeAlias = TypePooledLogits
|
| 42 |
+
TypeScores: TypeAlias = list[dict[str, float | str]]
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|