mmrech fepegar commited on
Commit
4886d4e
·
0 Parent(s):

Duplicate from microsoft/colipri

Browse files

Co-authored-by: Fernando Pérez-García <fepegar@users.noreply.huggingface.co>

.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.png filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.wasm filter=lfs diff=lfs merge=lfs -text
33
+ *.xz filter=lfs diff=lfs merge=lfs -text
34
+ *.zip filter=lfs diff=lfs merge=lfs -text
35
+ *.zst filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .vscode/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
COLIPRI_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ library_name: colipri
6
+ pipeline_tag: zero-shot-image-classification
7
+ ---
8
+
9
+ # COLIPRI
10
+
11
+ <!-- Provide a quick summary of what the model is/does. -->
12
+
13
+ COLIPRI is a 3D vision&ndash;language transformer model trained to encode chest CT scans and reports.
14
+
15
+ ## Model description
16
+
17
+ <!-- Provide a longer summary of what this model is. -->
18
+
19
+ COLIPRI was trained using tens of thousands of chest CT scans and reports, without any annotations, using multiple objectives to learn strong joint representations of 3D images and text.
20
+ The procedure is described in detail in our manuscript, [_Comprehensive language-image pre-training for 3D medical image understanding_](https://arxiv.org/abs/2510.15042) (Wald et al. 2026).
21
+
22
+ The weights shared here correspond to our best-performing model, COLIPRI-CRM.
23
+
24
+ - **Developed by:** Microsoft Health Futures
25
+ - **Model type:** 3D vision&ndash;language encoder
26
+ - **License:** [MIT](./LICENSE)
27
+
28
+ ## Uses
29
+
30
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
31
+
32
+ COLIPRI is shared for research purposes only.
33
+ It is **not meant to be used for clinical practice**.
34
+
35
+ The encoders be plugged to other models, or used independently or jointly for many downstream tasks, such as:
36
+
37
+ - Image classification with text prompts
38
+ - Image clustering
39
+ - Text clustering
40
+ - Text-to-image retrieval
41
+ - Image-to-image retrieval
42
+ - Image-to-text retrieval
43
+ - Text-to-text retrieval
44
+ - Image classification with a classifier
45
+ - Text classification with a classifier
46
+ - Image segmentation with a decoder
47
+ - Report generation with a language decoder
48
+
49
+ Fine-tuning COLIPRI is typically not necessary to obtain good performance in downstream tasks.
50
+
51
+ ## Getting started
52
+
53
+ ### Installation
54
+
55
+ ```shell
56
+ pip install colipri
57
+ ```
58
+
59
+ ### Usage examples
60
+
61
+ Below we share some usage snippets to get started with COLIPRI.
62
+ A more complete [Jupyter notebook](./COLIPRI_demo.ipynb) is also available.
63
+
64
+ First, let's get a 3D chest CT we can use for demonstration.
65
+ The plotted slices intersect a lung nodule near the heart.
66
+
67
+ ```python
68
+ >>> from colipri import load_sample_ct
69
+ >>> image = load_sample_ct()
70
+ >>> image
71
+ ScalarImage(shape: (1, 512, 512, 139); spacing: (0.76, 0.76, 2.50); orientation: LPS+; dtype: torch.IntTensor; memory: 139.0 MiB)
72
+ ```
73
+
74
+ The image looks like this:
75
+
76
+ ![Input CT](assets/input.png)
77
+
78
+ Now, let's instantiate the model and processor.
79
+
80
+ ```python
81
+ >>> from colipri import get_model
82
+ >>> from colipri import get_processor
83
+ >>> model = get_model().cuda()
84
+ >>> processor = get_processor()
85
+ ```
86
+
87
+ #### Zero-shot classification
88
+
89
+ ```python
90
+ >>> from colipri import ZeroShotImageClassificationPipeline
91
+ >>> pipeline = ZeroShotImageClassificationPipeline(model, processor)
92
+ >>> pipeline(image, ["No lung nodules", "Lung nodules"])
93
+ [
94
+ {'score': 0.005, 'label': 'No lung nodules'},
95
+ {'score': 0.995, 'label': 'Lung nodules'}
96
+ ]
97
+ ```
98
+
99
+ #### Feature extraction
100
+
101
+ ```python
102
+ >>> import torch
103
+ >>> preprocessed_images = processor.process_images(image)
104
+ >>> preprocessed_images[0]
105
+ ScalarImage(shape: (1, 192, 192, 192); spacing: (2.00, 2.00, 2.00); orientation: SAR+; dtype: torch.FloatTensor; memory: 27.0 MiB)
106
+ >>> images_batch = processor.to_images_batch(preprocessed_images)
107
+ images_batch.shape
108
+ torch.Size([1, 1, 192, 192, 192])
109
+ >>> with torch.no_grad():
110
+ ... patch_embeddings = model.encode_image(images_batch)
111
+ >>> patch_embeddings.shape
112
+ torch.Size([1, 768, 24, 24, 24])
113
+ >>> with torch.no_grad():
114
+ ... pooled_embeddings = model.encode_image(images_batch, pool=True, project=True)
115
+ >>> pooled_embeddings.shape
116
+ torch.Size([1, 768])
117
+ ```
118
+
119
+ ## Biases, risks, and limitations
120
+
121
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
122
+
123
+ COLIPRI was trained with data from Turkey and the USA only, therefore it might be biased towards population in the training data.
124
+ Underlying biases of the training datasets may not be well characterized.
125
+
126
+ ## Environmental impact
127
+
128
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
129
+
130
+ - **Hardware type:** NVIDIA A100 GPUs
131
+ - **Hours used:** 72 hours × 4 GPUs = 288 GPU-hours
132
+ - **Cloud provider:** Azure
133
+ - **Compute region:** West US 2
134
+ - **Carbon emitted:** 21.6 kg CO₂ eq.
135
+
136
+ ### Compute infrastructure
137
+
138
+ COLIPRI was trained on [Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning).
139
+
140
+ #### Hardware
141
+
142
+ | Stage | Node type | Num. nodes | GPU type | GPUs per node |
143
+ | --- | --- | --- | --- | --- |
144
+ | Pre-training | [`Standard_NC96ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 4 |
145
+ | Evaluation | [`Standard_NC24ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 1 |
146
+
147
+ #### Software
148
+
149
+ The main software libraries used in this work were [nnSSL](https://github.com/MIC-DKFZ/nnssl) for training, [TorchIO](https://torchio.org/) for preprocessing and augmentation, [`nifti-zarr-py`](https://github.com/neuroscales/nifti-zarr-py) for data loading, and [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) for segmentation evaluation.
150
+
151
+ ## Citation
152
+
153
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
154
+
155
+ ### BibTeX
156
+
157
+ ```bibtex
158
+ @misc{
159
+ wald2026_colipri,
160
+ title={Comprehensive language-image pre-training for 3D medical image understanding},
161
+ author={Tassilo Wald and Ibrahim Ethem Hamamci and Yuan Gao and Sam Bond-Taylor and Harshita Sharma and Maximilian Ilse and Cynthia Lo and Olesya Melnichenko and Anton Schwaighofer and Noel C. F. Codella and Maria Teodora Wetscherek and Klaus H. Maier-Hein and Panagiotis Korfiatis and Valentina Salvatelli and Javier Alvarez-Valle and P{\'e}rez-Garc{\'i}a},
162
+ year={2026},
163
+ eprint={2510.15042},
164
+ archivePrefix={arXiv},
165
+ primaryClass={cs.CV},
166
+ url={https://arxiv.org/abs/2510.15042},
167
+ }
168
+ ```
169
+
170
+ ### APA
171
+
172
+ > Wald, T., Hamamci, I. E., Gao, Y., Bond-Taylor, S., Sharma, H., Ilse, M., Lo, C., Melnichenko, O., Schwaighofer, A., Codella, N. C. F., Wetscherek, M. T., Maier-Hein, K. H., Korfiatis, P., Salvatelli, V., Alvarez-Valle, J., & Pérez-García, F. (2026). Comprehensive language-image pre-training for 3D medical image understanding. arXiv. <https://doi.org/10.48550/ARXIV.2510.15042>
173
+
174
+ ## Model card contact
175
+
176
+ Fernando Pérez-García ([`fperezgarcia@microsoft.com`](mailto:fperezgarcia@microsoft.com)).
assets/input.png ADDED

Git LFS Details

  • SHA256: 9cb8413eaa5628f5499c6fc3a349561c60dfde81351afbb4a3c19b33cd54c547
  • Pointer size: 131 Bytes
  • Size of remote file: 928 kB
mise.toml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tasks.build]
2
+ description = "Build the package"
3
+ run = "rm -rf dist/ && uv build"
4
+
5
+ [tasks.publish]
6
+ description = "Publish the package"
7
+ run = "uv publish"
8
+
9
+ [tasks.bump]
10
+ description = "Bump the version (usage: mise run bump [patch|minor|major])"
11
+ run = "uv run --group maintain bump-my-version bump {{ arg(i=0, name='part', default='patch') }} --verbose"
12
+
13
+ [tasks."bump-dry"]
14
+ description = "Dry-run version bump"
15
+ run = "uv run --group maintain bump-my-version bump {{ arg(i=0, name='part', default='patch') }} --dry-run --verbose --allow-dirty"
16
+
17
+ [tasks.push]
18
+ description = "Push commits and tags"
19
+ run = "git push && git push --tags"
20
+
21
+ [tasks.release]
22
+ description = "Bump version and push (usage: mise run release [patch|minor|major])"
23
+ depends = ["bump"]
24
+ run = "git push && git push --tags"
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97d66586507d1f48577b5149777d02e12d25787435a2a295cac5563df9291c14
3
+ size 1033541380
model_card.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Model card
3
+
4
+ Date effective: September 2025
5
+ Owner: ORA, Internal Policy & Practice
6
+
7
+ ## Model summary
8
+
9
+ | | |
10
+ | --- | --- |
11
+ | Developer | Microsoft Research |
12
+ | Description | COLIPRI is composed of a vision encoder for 3D CT scans and a text encoder for clinical reports. The encoders may be leveraged for downstream tasks such as classification or segmentation. The models were trained using CT-RATE and NLST. COLIPRI is not a generative model. |
13
+ | Model architecture | The vision encoder is Primus-M, a 3D vision transformer. The text encoder architecture is BERT. |
14
+ | Parameters | 1-500M |
15
+ | Inputs | 3D chest CT scans and/or radiology reports |
16
+ | Context length | N/A (not an LLM) |
17
+ | Outputs | Image and/or text embeddings, not directly human-interpretable. |
18
+ | GPUs | 4 x A100 GPUs |
19
+ | Training time | 1.75 Days |
20
+ | Public data summary (or summaries) | Training Data: National Lung Screening Trial (NLST) dataset: Chest CT images from approximately 26k patients CT-RATE dataset: Approximately 25.7k CT acquisitions with associate reports |
21
+ | Training dates | Dates of training: August 2024. Intended model weight release date: Dec 5, 2025 |
22
+ | Status | Static |
23
+ | Release date | Intended release date: Dec 5, 2025 |
24
+ | Release date in the EU (if different) | Intended release date: Dec 5, 2025 |
25
+ | License | Free/open source |
26
+ | Model dependencies | COLIPRI’s text decoder is fine-tuned from microsoft/BiomedVLP-CXR-BERT-specialized · Hugging Face. |
27
+ | List and link to any additional related assets | N/A |
28
+ | Acceptable use policy | N/A |
29
+
30
+ ## Model overview
31
+
32
+ COLIPRI is an encoder designed for 3D CT scans. Its architecture combines a 3D vision encoder (PRIMUS-M transformer) with a biomedical text encoder (BiomedVLP-CXR-BERT) and enhances existing models by adding additional training objectives and optimizations. These include the inclusion of new training objectives and optimizations to the CLIP paradigm to account for limited data availability. This multi-objective training approach optimizes both global semantic alignment and dense feature learning. After pre-training, the model is evaluated comprehensively on downstream medical imaging tasks such as classification, retrieval, and semantic segmentation.
33
+ COLIPRI is not a generative model. It cannot generate any images or text.
34
+
35
+ ### Alignment approach
36
+
37
+ N/A
38
+
39
+ ## Usage
40
+
41
+ ### Primary use cases
42
+
43
+ COLIPRI is primarily intended for 3D medical imaging tasks such as segmentation and classification. It is intended for research only.
44
+
45
+ ### Out-of-scope use cases
46
+
47
+ The COLIPRI model family was trained on 3D CT volumes and is therefore best suited for inputs of this type. It is intended for research only. The model is for research use only; clinical or deployed use cases are out of scope.
48
+
49
+ ### Distribution channels
50
+
51
+ The model will be released with open access via Hugging Face.
52
+
53
+ ### Input formats
54
+
55
+ The vision encoder takes 3D chest CT scans, and the text encoder takes radiology reports.
56
+
57
+ ### Technical requirements and integration guidance
58
+
59
+ The model requires a Python-based environment with PyTorch. The only hardware requirement is one GPU.
60
+
61
+ ### Responsible AI considerations
62
+
63
+ The model may underperform or produce unreliable outputs for non-English languages, out-of-distribution or adversarial inputs, or clinical scenarios not well represented in the training data. The model should be used in an AI-assisted research setup with human oversight, rather than as a sole diagnostic tool.
64
+
65
+ ## Data overview
66
+
67
+ ### Training, testing, and validation datasets
68
+
69
+ The COLIPRI model family was pre-trained on the CT-RATE dataset, consisting of approximately 25.7k CT acquisitions with associated reports, and the National Lung Screening Trial (NLST) dataset, consisting of chest CT images from approximately 26k patients.
70
+ Classification performance was evaluated on a withheld test set of CT-RATE and the publicly available subset of RAD-ChestCT, which comprises 3.6k chest CT volumes with 16 multi-abnormality labels. Semantic segmentation was evaluated by training a five-fold cross validation on four datasets:
71
+
72
+ - LiTS: Task 3 of the Medical Segmentation Decathlon (MSD), containing segmentations for liver and liver tumors
73
+ - Lung: Task 6 of MSD, containing segmentations of primary lung cancers
74
+ - HVS: Task 8 of the MSD, containing segmentations of hepatic vessels and tumors next to such vessels
75
+ - KiTS23: containing segmentations of tumors, cysts, and the kidney
76
+
77
+ ## Quality and performance evaluation
78
+
79
+ ### Results summary
80
+
81
+ #### Classification Probes
82
+
83
+ The COLIPRI family exceeded all third-party models on both the CT-RATE and RAD-ChestCT datasets across all metrics.
84
+
85
+ #### Zero-Shot Classification
86
+
87
+ COLIPRI encoders exceed the state of the art using short prompts on the CT-RATE and RAD-ChestCT datasets.
88
+
89
+ #### Report-to-Image Retrieval
90
+
91
+ COLIPRI encoders are substantially better in retrieving the associated image given a report, evaluated on the CT-RATE test set.
92
+
93
+ #### Segmentation
94
+
95
+ Across four datasets (LiTs, Lung, HVS, KiTS23), COLIPRI-CM and COLIPRI-CRM perform on-par or better than the state of the art.
96
+
97
+ #### Qualitative Analysis
98
+
99
+ COLIPRI embeddings generates sharper and more coherent features than third-party models.
100
+
101
+ ### Long context
102
+
103
+ N/A – not an LLM.
104
+
105
+ ### Safety evaluation and red-teaming
106
+
107
+ N/A - model is limited to the medical imaging domain. Model outputs were compared to ground-truth images. The standard quality metrics were computed.
108
+
109
+ ## Tracked capability evaluations
110
+
111
+ N/A This model is not a frontier model.
112
+
113
+ ## Additional information
114
+
115
+ Requests for additional information may be directed to [MSFTAIActRequest@microsoft.com](mailto:MSFTAIActRequest@microsoft.com).
pyproject.toml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "colipri"
3
+ version = "0.1.1"
4
+ description = "Vision-language encoder for chest CT scans and reports."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Microsoft Health Futures", email = "innereyedev@microsoft.com" },
8
+ ]
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "accelerate",
12
+ "dynamic-network-architectures>=0.4.3",
13
+ "huggingface-hub",
14
+ "hydra-core",
15
+ "jaxtyping",
16
+ "safetensors",
17
+ "timm<1.0.23", # RuntimeError related to position embeddings in 1.0.23
18
+ "torchio>=0.21.1",
19
+ "tqdm",
20
+ "transformers",
21
+ "typer",
22
+ ]
23
+
24
+ [project.urls]
25
+ Homepage = "https://huggingface.co/microsoft/colipri"
26
+ Source = "https://huggingface.co/microsoft/colipri"
27
+ "Issue tracker" = "https://huggingface.co/microsoft/colipri/discussions/new"
28
+ Documentation = "https://huggingface.co/microsoft/colipri/blob/main/README.md"
29
+
30
+ [project.optional-dependencies]
31
+ demo = [
32
+ "lovely-tensors",
33
+ "matplotlib",
34
+ "scikit-learn",
35
+ "torchinfo",
36
+ ]
37
+
38
+ [dependency-groups]
39
+ dev = [
40
+ "ipykernel",
41
+ "ipywidgets",
42
+ "pytest-sugar>=1.1.1",
43
+ ]
44
+ maintain = [
45
+ "bump-my-version",
46
+ ]
47
+ types = ["ty"]
48
+
49
+ [build-system]
50
+ requires = ["uv_build"]
51
+ build-backend = "uv_build"
52
+
53
+ [tool.ruff.lint]
54
+ # Defaults from https://docs.astral.sh/ruff/linter/#rule-selection
55
+ select = [
56
+ # pycodestyle
57
+ "E",
58
+ # Pyflakes
59
+ "F",
60
+ # pyupgrade
61
+ "UP",
62
+ # flake8-bugbear
63
+ "B",
64
+ # flake8-simplify
65
+ "SIM",
66
+ # isort
67
+ "I",
68
+ ]
69
+ ignore = [
70
+ "F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
71
+ ]
72
+
73
+ [tool.ruff.lint.isort]
74
+ force-single-line = true
75
+
76
+ [tool.bumpversion]
77
+ current_version = "0.1.1"
78
+ commit = true
79
+ tag = true
80
+
81
+ [[tool.bumpversion.files]]
82
+ filename = "pyproject.toml"
83
+ search = 'version = "{current_version}"'
84
+ replace = 'version = "{new_version}"'
85
+
86
+ [[tool.bumpversion.files]]
87
+ filename = "src/colipri/__init__.py"
88
+ search = '__version__ = "{current_version}"'
89
+ replace = '__version__ = "{new_version}"'
src/colipri/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model.multimodal import get_model
2
+ from .pipelines import ZeroShotImageClassificationPipeline
3
+ from .processor import get_processor
4
+ from .sample_data import load_sample_ct
5
+
6
+ __all__ = [
7
+ "get_model",
8
+ "get_processor",
9
+ "load_sample_ct",
10
+ "ZeroShotImageClassificationPipeline",
11
+ ]
12
+
13
+ __version__ = "0.1.1"
14
+
src/colipri/checkpoint.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from huggingface_hub import hf_hub_download
4
+ from omegaconf import DictConfig
5
+ from omegaconf import OmegaConf
6
+ from torch import Tensor
7
+
8
+ from .defaults import REPO_NAME
9
+ from .defaults import REPO_USER
10
+ from .defaults import REVISION
11
+ from .defaults import ROOT_CONFIG_FILENAME
12
+ from .defaults import WEIGHTS_FILENAME
13
+ from .defaults import get_configs_dir
14
+
15
+ TypeStateDict = dict[str, Tensor]
16
+
17
+
18
+ def download_weights(**kwargs) -> Path:
19
+ weights_path = _download_from_hugging_face(**kwargs)
20
+ return weights_path
21
+
22
+
23
+ def _download_from_hugging_face(
24
+ repo_user: str = REPO_USER,
25
+ repo_name: str = REPO_NAME,
26
+ revision: str | None = REVISION,
27
+ filename: str = WEIGHTS_FILENAME,
28
+ ) -> Path:
29
+ repo_id = f"{repo_user}/{repo_name}"
30
+ try:
31
+ weights_path = hf_hub_download(
32
+ repo_id=repo_id,
33
+ filename=filename,
34
+ revision=revision,
35
+ )
36
+ except Exception as e:
37
+ msg = f'Failed to download "{filename}" from Hugging Face Hub repo "{repo_id}".'
38
+ raise RuntimeError(msg) from e
39
+ return Path(weights_path)
40
+
41
+
42
+ def _resolve_defaults(config_dir: Path, config_name: str) -> DictConfig:
43
+ """Load a YAML config and recursively resolve its Hydra-style defaults."""
44
+ config_path = config_dir / f"{config_name}.yaml"
45
+ raw = OmegaConf.load(config_path)
46
+ assert isinstance(raw, DictConfig)
47
+
48
+ defaults = OmegaConf.to_container(raw.pop("defaults", OmegaConf.create([])))
49
+ assert isinstance(defaults, list)
50
+
51
+ merged = OmegaConf.create()
52
+ for default in defaults:
53
+ if default == "_self_":
54
+ continue
55
+ if isinstance(default, dict):
56
+ for group, option in default.items():
57
+ sub = _resolve_defaults(config_dir / group, str(option))
58
+ merged = OmegaConf.merge(merged, OmegaConf.create({group: sub}))
59
+
60
+ # _self_ is always applied last (standard Hydra behavior)
61
+ merged = OmegaConf.merge(merged, raw)
62
+ assert isinstance(merged, DictConfig)
63
+ return merged
64
+
65
+
66
+ def _load_config(*, overrides: list[str] | None = None) -> DictConfig:
67
+ config_name = ROOT_CONFIG_FILENAME.removesuffix(".yaml")
68
+ config = _resolve_defaults(get_configs_dir(), config_name)
69
+ if overrides:
70
+ config = OmegaConf.merge(config, OmegaConf.from_dotlist(overrides))
71
+ assert isinstance(config, DictConfig)
72
+ return config
73
+
74
+
75
+ def load_model_config(**kwargs) -> DictConfig:
76
+ return _load_config(**kwargs)["model"]
77
+
78
+
79
+ def load_processor_config(**kwargs) -> DictConfig:
80
+ return _load_config(**kwargs)["processor"]
src/colipri/configs/config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - processor: default
4
+ - _self_
5
+
6
+ input_size: 192
7
+ spacing: 2 # mm
8
+ padding_mode: minimum
9
+ image_embed_dim: 864
10
+ projection_dim: 768
src/colipri/configs/model/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_encoder: default
3
+ - text_encoder: default
4
+
5
+ _target_: colipri.model.multimodal.Model
src/colipri/configs/model/image_encoder/backbone/default.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dynamic_network_architectures.architectures.primus.Primus
2
+ input_channels: 1
3
+ num_classes: 1
4
+ eva_depth: 16
5
+ eva_numheads: 12
6
+ embed_dim: ${image_embed_dim}
7
+ patch_embed_size:
8
+ - 8
9
+ - 8
10
+ - 8
11
+ input_shape:
12
+ - ${input_size}
13
+ - ${input_size}
14
+ - ${input_size}
15
+ use_rot_pos_emb: True
16
+ use_abs_pos_embed: False
17
+ drop_path_rate: 0.2
18
+ init_values: 0.1
19
+ scale_attn_inner: True
20
+ num_register_tokens: 0
src/colipri/configs/model/image_encoder/default.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - backbone: default
3
+ - pooler: default
4
+ - projector: default
5
+
6
+ _target_: colipri.model.image.ImageEncoder
src/colipri/configs/model/image_encoder/pooler/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: colipri.pooling.AttentionPool1D
2
+ embed_dim: ${projection_dim}
3
+ num_heads: 12
src/colipri/configs/model/image_encoder/projector/default.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: torch.nn.Conv3d
2
+ in_channels: ${image_embed_dim}
3
+ out_channels: ${projection_dim}
4
+ kernel_size: 1
src/colipri/configs/model/text_encoder/backbone/default.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: transformers.AutoModel.from_pretrained
2
+ pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
3
+ trust_remote_code: yes
4
+ revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
src/colipri/configs/model/text_encoder/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - backbone: default
3
+ - pooler: default
4
+
5
+ _target_: colipri.model.text.TextEncoder
src/colipri/configs/model/text_encoder/pooler/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: colipri.pooling.AttentionPool1D
2
+ embed_dim: ${projection_dim}
3
+ num_heads: 12
src/colipri/configs/processor/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_transform: default
3
+ - tokenizer: default
4
+
5
+ _target_: colipri.processor.Processor
src/colipri/configs/processor/image_transform/default.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exported automatically from nnssl: https://github.com/microsoft/nnssl/pull/41, then
2
+ # adapted for Colipri needs.
3
+ _target_: torchio.transforms.augmentation.composition.Compose
4
+ transforms:
5
+ - _target_: torchio.transforms.preprocessing.spatial.to_orientation.ToOrientation
6
+ copy: true
7
+ exclude: null
8
+ include: null
9
+ orientation: SAR
10
+ - _target_: torchio.transforms.preprocessing.spatial.resample.Resample
11
+ antialias: false
12
+ copy: true
13
+ exclude: null
14
+ image_interpolation: linear
15
+ include: null
16
+ label_interpolation: nearest
17
+ pre_affine_name: null
18
+ scalars_only: false
19
+ target: ${spacing}
20
+ - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp
21
+ out_min: -1000
22
+ out_max: 1000
23
+ - _target_: torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity
24
+ copy: true
25
+ exclude: null
26
+ in_min_max:
27
+ - -1000
28
+ - 1000
29
+ include: null
30
+ masking_method: null
31
+ out_min_max:
32
+ - -1
33
+ - 1
34
+ percentiles:
35
+ - 0
36
+ - 100
37
+ - _target_: torchio.transforms.preprocessing.spatial.crop_or_pad.CropOrPad
38
+ copy: true
39
+ exclude: null
40
+ include: null
41
+ labels: null
42
+ mask_name: null
43
+ only_crop: false
44
+ only_pad: false
45
+ padding_mode: ${padding_mode}
46
+ target_shape:
47
+ - ${input_size}
48
+ - ${input_size}
49
+ - ${input_size}
src/colipri/configs/processor/tokenizer/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: transformers.AutoTokenizer.from_pretrained
2
+ pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
3
+ trust_remote_code: yes
4
+ revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
5
+ model_max_length: 512
src/colipri/defaults.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.resources
2
+ from pathlib import Path
3
+
4
+ package_name = __package__
5
+ assert isinstance(package_name, str)
6
+
7
+ REPO_USER = "microsoft"
8
+ REPO_NAME = package_name
9
+ REVISION = None
10
+ WEIGHTS_FILENAME = "model.safetensors"
11
+ ROOT_CONFIG_FILENAME = "config.yaml"
12
+
13
+
14
+ def get_configs_dir(pkg: str | None = None) -> Path:
15
+ package_dir = importlib.resources.files(package_name)
16
+ return Path(package_dir / "configs")
17
+
18
+
19
+ __all__ = [
20
+ "get_configs_dir",
21
+ ]
src/colipri/model/__init__.py ADDED
File without changes
src/colipri/model/base.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ModelMixin:
5
+ @property
6
+ def device(self) -> torch.device:
7
+ one_parameter = next(self.parameters())
8
+ return one_parameter.device
src/colipri/model/image.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchio as tio
6
+ from dynamic_network_architectures.architectures.primus import Primus
7
+ from einops import rearrange
8
+ from torch import nn
9
+ from tqdm.auto import tqdm
10
+
11
+ from ..pooling import AttentionPool1D
12
+ from ..types import TypeImageEmbeddings
13
+ from ..types import TypeImagesTensor
14
+ from ..types import TypeIntOrTripletInt
15
+ from .base import ModelMixin
16
+
17
+
18
+ class ImageEncoder(nn.Module, ModelMixin):
19
+ def __init__(
20
+ self,
21
+ backbone: Primus,
22
+ projector: nn.Conv3d,
23
+ pooler: AttentionPool1D,
24
+ ):
25
+ super().__init__()
26
+ self.backbone = self._remove_decoder(backbone)
27
+ self.projector = projector
28
+ self.pooler = pooler
29
+
30
+ def _remove_decoder(self, backbone: Primus) -> Primus:
31
+ if hasattr(backbone, "up_projection"):
32
+ backbone.up_projection = nn.Identity() # type: ignore
33
+ return backbone
34
+
35
+ @property
36
+ def patch_size(self) -> tuple[int, int, int]:
37
+ patch_embedder: nn.Conv3d = self.backbone.down_projection.proj # type: ignore[reportAssignmentType]
38
+ patch_size: tuple[int, int, int] = patch_embedder.stride # type: ignore[reportAssignmentType]
39
+ return patch_size
40
+
41
+ def encode(
42
+ self,
43
+ images: TypeImagesTensor,
44
+ ) -> TypeImageEmbeddings:
45
+ images = images.to(self.device)
46
+ embeddings: TypeImageEmbeddings = self.backbone(images)
47
+ return embeddings
48
+
49
+ def encode_sliding_window(
50
+ self,
51
+ images: TypeImagesTensor,
52
+ window_size: TypeIntOrTripletInt,
53
+ overlap: int = 0,
54
+ ) -> TypeImageEmbeddings:
55
+ if len(set(self.patch_size)) > 1:
56
+ msg = (
57
+ "Sliding window encoding is only supported for models with cubic"
58
+ " patch sizes for now."
59
+ )
60
+ raise NotImplementedError(msg)
61
+ else:
62
+ patch_size = self.patch_size[0]
63
+ image_key = "image" # could be anything
64
+ embeddings = []
65
+ for image in images:
66
+ grid_sampler = tio.inference.GridSampler(
67
+ tio.Subject(**{image_key: tio.ScalarImage(tensor=image)}),
68
+ window_size,
69
+ overlap,
70
+ )
71
+ patch_loader = tio.SubjectsLoader(grid_sampler) # type: ignore[reportArgumentType]
72
+ aggregator = tio.data.GridAggregator(
73
+ grid_sampler,
74
+ downsampling_factor=patch_size,
75
+ )
76
+ for patches_batch in tqdm(patch_loader):
77
+ input_tensor = patches_batch[image_key][tio.DATA].to(self.device)
78
+ locations = patches_batch[tio.LOCATION]
79
+ outputs = self.backbone(input_tensor)
80
+ aggregator.add_batch(outputs, locations)
81
+ embeddings.append(aggregator.get_output_tensor())
82
+ return torch.stack(embeddings).to(self.device)
83
+
84
+ def forward(
85
+ self,
86
+ images: TypeImagesTensor,
87
+ *,
88
+ project: bool,
89
+ pool: bool,
90
+ normalize: bool,
91
+ window_size: TypeIntOrTripletInt | None = None,
92
+ ) -> TypeImageEmbeddings:
93
+ if pool and not project:
94
+ msg = "Pooling requires projection to be enabled. Set project=True."
95
+ raise NotImplementedError(msg)
96
+ if window_size is None:
97
+ embeddings = self.encode(images)
98
+ else:
99
+ embeddings = self.encode_sliding_window(images, window_size)
100
+ if project:
101
+ embeddings = self.projector(embeddings)
102
+ if pool:
103
+ sequence = rearrange(embeddings, "b c x y z -> b (x y z) c")
104
+ embeddings = self.pooler(sequence)
105
+ else:
106
+ embeddings = self.pooler.to_dense()(embeddings)
107
+ if normalize:
108
+ embeddings = F.normalize(embeddings, dim=1)
109
+ return embeddings
src/colipri/model/multimodal.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from accelerate import load_checkpoint_and_dispatch
9
+ from einops import rearrange
10
+ from hydra.utils import instantiate
11
+ from safetensors.torch import load_model
12
+ from safetensors.torch import save_model
13
+ from torch import nn
14
+ from transformers.utils.logging import get_verbosity
15
+ from transformers.utils.logging import set_verbosity
16
+ from transformers.utils.logging import set_verbosity_error
17
+
18
+ from ..checkpoint import download_weights
19
+ from ..checkpoint import load_model_config
20
+ from ..types import TypeImageEmbeddings
21
+ from ..types import TypeImagesTensor
22
+ from ..types import TypeIntOrTripletInt
23
+ from ..types import TypePatchEmbeddings
24
+ from ..types import TypePatchLogits
25
+ from ..types import TypePatchProbabilities
26
+ from ..types import TypePath
27
+ from ..types import TypePooledEmbeddings
28
+ from ..types import TypePooledLogits
29
+ from ..types import TypePooledProbabilities
30
+ from ..types import TypeTextAttentionMask
31
+ from ..types import TypeTextEmbeddings
32
+ from ..types import TypeTextTokenIds
33
+ from .base import ModelMixin
34
+ from .image import ImageEncoder
35
+ from .text import TextEncoder
36
+
37
+
38
+ def get_model(
39
+ checkpoint_path: TypePath | None = None,
40
+ *,
41
+ pretrained: bool = True,
42
+ image_only: bool = False,
43
+ **kwargs,
44
+ ) -> Model:
45
+ if pretrained and checkpoint_path is None:
46
+ checkpoint_path = download_weights()
47
+
48
+ overrides = []
49
+ for key, value in kwargs.items():
50
+ overrides.append(f"{key}={value}")
51
+ config = load_model_config(overrides=overrides)
52
+
53
+ if image_only:
54
+ config.text_encoder = None
55
+
56
+ transformers_verbosity = get_verbosity()
57
+ set_verbosity_error()
58
+ if checkpoint_path is None:
59
+ model = instantiate(config)
60
+ else:
61
+ with init_empty_weights():
62
+ model = instantiate(config)
63
+ accelerate_logger = logging.getLogger("accelerate.utils.modeling")
64
+ old_level = accelerate_logger.getEffectiveLevel()
65
+ accelerate_logger.setLevel(logging.ERROR)
66
+ model = load_checkpoint_and_dispatch(model, str(checkpoint_path))
67
+ accelerate_logger.setLevel(old_level)
68
+ set_verbosity(transformers_verbosity)
69
+
70
+ assert isinstance(model, Model)
71
+ return model.eval()
72
+
73
+
74
+ class Model(nn.Module, ModelMixin):
75
+ def __init__(
76
+ self,
77
+ image_encoder: ImageEncoder,
78
+ text_encoder: TextEncoder,
79
+ temperature: float = 1,
80
+ ):
81
+ super().__init__()
82
+ self.image_encoder = image_encoder
83
+ self.text_encoder = text_encoder
84
+ self.register_buffer("softmax_temperature", torch.tensor(temperature))
85
+
86
+ @property
87
+ def patch_size(self) -> tuple[int, int, int]:
88
+ return self.image_encoder.patch_size
89
+
90
+ def encode_image(
91
+ self,
92
+ images: TypeImagesTensor,
93
+ *,
94
+ pool: bool = False,
95
+ project: bool = False,
96
+ normalize: bool = False,
97
+ window_size: TypeIntOrTripletInt | None = None,
98
+ ) -> TypeImageEmbeddings:
99
+ return self.image_encoder(
100
+ images,
101
+ project=project,
102
+ pool=pool,
103
+ normalize=normalize,
104
+ window_size=window_size,
105
+ )
106
+
107
+ def encode_text(
108
+ self,
109
+ token_ids: TypeTextTokenIds,
110
+ attention_mask: TypeTextAttentionMask | None = None,
111
+ *,
112
+ pool: bool = True,
113
+ normalize: bool = True,
114
+ ) -> TypeTextEmbeddings:
115
+ return self.text_encoder(
116
+ token_ids,
117
+ attention_mask,
118
+ pool=pool,
119
+ normalize=normalize,
120
+ )
121
+
122
+ def compute_similarities(
123
+ self,
124
+ image_embeddings: TypePooledEmbeddings | TypePatchEmbeddings,
125
+ text_embeddings: TypePooledEmbeddings,
126
+ ) -> TypePatchLogits | TypePooledLogits:
127
+ text_embeddings = rearrange(text_embeddings, "num_prompts c -> c num_prompts")
128
+ is_grid = image_embeddings.ndim == 5
129
+ if is_grid:
130
+ num_images, _, x, y, z = image_embeddings.shape
131
+ image_embeddings = rearrange(
132
+ image_embeddings,
133
+ "num_images c x y z -> (num_images x y z) c",
134
+ )
135
+ similarities_flat = image_embeddings @ text_embeddings
136
+ similarities = rearrange(
137
+ similarities_flat,
138
+ "(num_images x y z) num_prompts -> num_images num_prompts x y z",
139
+ num_images=num_images,
140
+ x=x,
141
+ y=y,
142
+ z=z,
143
+ )
144
+ else:
145
+ similarities = image_embeddings @ text_embeddings
146
+ return similarities
147
+
148
+ def classify(
149
+ self,
150
+ images: TypePooledEmbeddings | TypePatchEmbeddings,
151
+ text: TypePooledEmbeddings,
152
+ ) -> TypePooledProbabilities | TypePatchProbabilities:
153
+ logits = self.compute_similarities(images, text)
154
+ assert isinstance(self.softmax_temperature, torch.Tensor)
155
+ probabilities = (logits / self.softmax_temperature).softmax(dim=-1)
156
+ return probabilities
157
+
158
+ def save_weights(self, path: TypePath) -> None:
159
+ path = Path(path)
160
+ if path.suffix == ".safetensors":
161
+ save_model(self, str(path))
162
+ else:
163
+ weights = self.state_dict()
164
+ torch.save(weights, path)
165
+
166
+ def load_weights(
167
+ self,
168
+ path: TypePath,
169
+ device: torch.device | str | None = None,
170
+ ) -> None:
171
+ path = Path(path)
172
+ if not path.exists():
173
+ raise FileNotFoundError(f"Checkpoint file {path} does not exist.")
174
+ if device is None:
175
+ device = torch.device("cpu")
176
+ if path.suffix == ".safetensors":
177
+ if device is not None:
178
+ device = str(device)
179
+ missing, unexpected = load_model(self, path, device=device)
180
+ if missing or unexpected:
181
+ raise RuntimeError("TODO")
182
+ else:
183
+ weights = torch.load(path, map_location=device)
184
+ self.load_state_dict(weights, strict=False)
src/colipri/model/text.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from transformers import BertModel
6
+
7
+ from ..types import TypeTextAttentionMask
8
+ from ..types import TypeTextEmbeddings
9
+ from ..types import TypeTextTokenIds
10
+ from .base import ModelMixin
11
+
12
+
13
+ class TextEncoder(nn.Module, ModelMixin):
14
+ def __init__(self, backbone: nn.Module, pooler: nn.Module):
15
+ super().__init__()
16
+ if hasattr(backbone, "bert"):
17
+ bert = backbone.bert
18
+ assert isinstance(bert, BertModel)
19
+ bert.pooler = None # we use our own pooler
20
+ backbone = bert
21
+ self.backbone = backbone
22
+ self.pooler = pooler
23
+
24
+ def forward(
25
+ self,
26
+ token_ids: TypeTextTokenIds,
27
+ attention_mask: TypeTextAttentionMask | None = None,
28
+ *,
29
+ pool: bool = True,
30
+ normalize: bool = True,
31
+ ) -> TypeTextEmbeddings:
32
+ token_ids = token_ids.to(self.device)
33
+ if attention_mask is not None:
34
+ attention_mask = attention_mask.to(self.device)
35
+ # We didn't use a projector during pre-training
36
+ output = self.backbone(token_ids, attention_mask=attention_mask)
37
+ embeddings = output["last_hidden_state"]
38
+ if pool:
39
+ embeddings = self.pooler(embeddings)
40
+ if normalize:
41
+ embeddings = F.normalize(embeddings, dim=1)
42
+ return embeddings
src/colipri/pipelines.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .model.multimodal import Model
4
+ from .processor import Processor
5
+ from .types import TypeImage
6
+ from .types import TypePath
7
+ from .types import TypeScores
8
+
9
+
10
+ class ZeroShotImageClassificationPipeline:
11
+ def __init__(
12
+ self,
13
+ model: Model,
14
+ processor: Processor,
15
+ batch_size: int = 1, # TODO
16
+ num_workers: int = 0, # TODO
17
+ ):
18
+ self._model = model.eval()
19
+ self._processor = processor
20
+
21
+ @property
22
+ def device(self) -> torch.device:
23
+ return self._model.device
24
+
25
+ # TODO: add support for multiple prompts per class
26
+ @torch.no_grad()
27
+ def __call__(
28
+ self,
29
+ images: TypePath | TypeImage | list[TypePath | TypeImage],
30
+ prompts: str | list[str],
31
+ ) -> TypeScores | list[TypeScores]:
32
+ preprocessed_images = self._processor.process_images(images)
33
+ images_batch = self._processor.to_images_batch(preprocessed_images)
34
+
35
+ text_token_ids, text_attention_mask = self._processor.process_text(prompts)
36
+
37
+ image_embeddings_batch = self._model.encode_image(
38
+ images_batch,
39
+ project=True,
40
+ pool=True,
41
+ normalize=True,
42
+ )
43
+ text_embeddings_batch = self._model.encode_text(
44
+ text_token_ids,
45
+ text_attention_mask,
46
+ )
47
+
48
+ probabilities = self._model.classify(
49
+ image_embeddings_batch,
50
+ text_embeddings_batch,
51
+ )
52
+
53
+ if not isinstance(prompts, list):
54
+ prompts = [prompts]
55
+
56
+ all_results = []
57
+ for image_probabilities in probabilities:
58
+ image_results = []
59
+ for prompt, score in zip(prompts, image_probabilities, strict=True):
60
+ image_results.append({"score": score.item(), "label": prompt})
61
+ all_results.append(image_results)
62
+
63
+ if len(images_batch) == 1 and not isinstance(images, list):
64
+ all_results = all_results[0]
65
+ return all_results
src/colipri/pooling.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ from .types import TypePooledEmbeddings
6
+ from .types import TypeSequenceEmbeddings
7
+
8
+
9
+ class AttentionPool1D(nn.Module):
10
+ def __init__(self, embed_dim: int, num_heads: int):
11
+ super().__init__()
12
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
13
+
14
+ def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
15
+ query = x.mean(dim=1, keepdim=True)
16
+ key = value = x
17
+ pooled, _ = self.attn(query, key, value)
18
+ return rearrange(pooled, "batch 1 embed_dim -> batch embed_dim")
19
+
20
+ def to_dense(self):
21
+ v_proj_in_weight_qkv = self.get_parameter("attn.in_proj_weight")
22
+ v_proj_in_bias_qkv = self.get_parameter("attn.in_proj_bias")
23
+ v_proj_out_weight = self.get_parameter("attn.out_proj.weight")
24
+ v_proj_out_bias = self.get_parameter("attn.out_proj.bias")
25
+ dim = v_proj_in_weight_qkv.shape[0] // 3
26
+ v_proj_in_weight_v = v_proj_in_weight_qkv[2 * dim :]
27
+ v_proj_in_bias_v = v_proj_in_bias_qkv[2 * dim :]
28
+
29
+ value_projection = nn.Conv3d(
30
+ in_channels=dim,
31
+ out_channels=dim,
32
+ kernel_size=1,
33
+ )
34
+ value_projection.weight.data = rearrange(
35
+ v_proj_in_weight_v,
36
+ "c_out c_in -> c_out c_in 1 1 1",
37
+ )
38
+ assert value_projection.bias is not None
39
+ value_projection.bias.data = v_proj_in_bias_v
40
+
41
+ out_projection = nn.Conv3d(
42
+ in_channels=dim,
43
+ out_channels=dim,
44
+ kernel_size=1,
45
+ )
46
+ out_projection.weight.data = rearrange(
47
+ v_proj_out_weight,
48
+ "c_out c_in -> c_out c_in 1 1 1",
49
+ )
50
+ assert out_projection.bias is not None
51
+ out_projection.bias.data = v_proj_out_bias
52
+
53
+ return nn.Sequential(
54
+ value_projection,
55
+ out_projection,
56
+ )
57
+
58
+
59
+ class MultiLearnedQueryAttentionPool1D(AttentionPool1D):
60
+ def __init__(self, embed_dim: int, num_heads: int):
61
+ super().__init__(embed_dim, num_heads)
62
+ # 4 Queries instead of 1
63
+ self.query = nn.Parameter(torch.randn(1, 4, embed_dim) / embed_dim**0.5)
64
+
65
+ def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
66
+ """
67
+ x: [B, T, D] — sequence of token embeddings
68
+ returns: [B, D] — pooled representation
69
+ """
70
+ B, T, D = x.shape
71
+ query = self.query.expand(B, -1, -1) # [B, 4, D]
72
+ pooled, _ = self.attn(query, x, x) # [B, 4, D]
73
+ # pooled: [4, B, D_out], want [B, D_out] by pooling over queries (mean)
74
+ return pooled.mean(dim=1) # [B, D]
src/colipri/processor.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torchio as tio
5
+ from hydra.utils import instantiate
6
+ from transformers import BertTokenizer
7
+ from transformers.tokenization_utils_base import BatchEncoding
8
+
9
+ from .checkpoint import load_processor_config
10
+ from .types import TypeImage
11
+ from .types import TypeImages
12
+ from .types import TypeImagesTensor
13
+ from .types import TypePath
14
+ from .types import TypeStringOrStrings
15
+ from .types import TypeTextAttentionMask
16
+ from .types import TypeTextTokenIds
17
+
18
+
19
+ def get_processor(*, image_only: bool = False, **kwargs) -> Processor:
20
+ overrides = []
21
+ for key, value in kwargs.items():
22
+ overrides.append(f"{key}={value}")
23
+ config = load_processor_config(overrides=overrides)
24
+ if image_only:
25
+ config.tokenizer = None
26
+ return instantiate(config)
27
+
28
+
29
+ class Processor:
30
+ def __init__(
31
+ self,
32
+ image_transform: tio.Transform,
33
+ tokenizer: BertTokenizer,
34
+ ):
35
+ self._image_transform = image_transform
36
+ self._text_tokenizer = tokenizer
37
+
38
+ def __repr__(self) -> str:
39
+ lines = [
40
+ f"{self.__class__.__name__}(",
41
+ f" image_transform={self._image_transform},",
42
+ f" text_tokenizer={self._text_tokenizer},",
43
+ ")",
44
+ ]
45
+ return "\n".join(lines)
46
+
47
+ def process_images(
48
+ self,
49
+ inputs: TypePath | TypeImage | list[TypePath | TypeImage],
50
+ ) -> TypeImages:
51
+ if not isinstance(inputs, list):
52
+ inputs = [inputs]
53
+ images = []
54
+ for image_or_path in inputs:
55
+ is_image = isinstance(image_or_path, tio.ScalarImage)
56
+ if is_image:
57
+ image = image_or_path
58
+ else:
59
+ path = image_or_path
60
+ image = tio.ScalarImage(path)
61
+ images.append(image)
62
+ images = [self._image_transform(image) for image in images]
63
+ return images
64
+
65
+ def to_images_batch(
66
+ self,
67
+ images: TypeImages,
68
+ ) -> TypeImagesTensor:
69
+ if not isinstance(images, list):
70
+ msg = f"Expected images to be a list, got {type(images)}"
71
+ raise TypeError(msg)
72
+ images_tensor = torch.stack([image.data for image in images])
73
+ return images_tensor
74
+
75
+ def process_text(
76
+ self,
77
+ text: TypeStringOrStrings,
78
+ **kwargs,
79
+ ) -> tuple[TypeTextTokenIds, TypeTextAttentionMask]:
80
+ encodings: BatchEncoding = self._text_tokenizer(
81
+ text,
82
+ max_length=512, # TODO
83
+ padding="max_length",
84
+ truncation=True,
85
+ return_tensors="pt",
86
+ **kwargs,
87
+ )
88
+ token_ids: TypeTextTokenIds = encodings["input_ids"] # type: ignore[reportAssignmentType]
89
+ attention_mask: TypeTextAttentionMask = encodings["attention_mask"] # type: ignore[reportAssignmentType]
90
+ return token_ids, attention_mask
src/colipri/sample_data.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torchio as tio
4
+
5
+
6
+ def load_sample_ct(clip: bool = True) -> tio.ScalarImage:
7
+ """Get a sample CT image."""
8
+ ct = get_sample_ct()
9
+ if clip:
10
+ clamp = tio.Clamp(-1000, 1000)
11
+ ct = clamp(ct)
12
+ else:
13
+ ct.load()
14
+ return ct
15
+
16
+
17
+ def get_sample_ct() -> tio.ScalarImage:
18
+ """Download a sample CT image if not already present."""
19
+ return tio.datasets.Slicer("CTChest").CT_chest # type: ignore[reportAttributeAccessIssue]
20
+
21
+
22
+ def get_sample_ct_path() -> Path:
23
+ """Get the path to a sample CT image."""
24
+ return get_sample_ct().path # type: ignore[reportReturnType]
src/colipri/types.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypeAlias
3
+
4
+ import torchio as tio
5
+ from jaxtyping import Float
6
+ from jaxtyping import Int64
7
+ from torch import Tensor
8
+
9
+ TypePath: TypeAlias = str | os.PathLike[str]
10
+ TypePaths: TypeAlias = list[TypePath]
11
+
12
+ # Raw inputs
13
+ TypeImage: TypeAlias = tio.ScalarImage
14
+ TypeImages: TypeAlias = list[tio.ScalarImage]
15
+ TypeImageOrImages: TypeAlias = TypeImage | TypeImages
16
+
17
+ TypeString: TypeAlias = str
18
+ TypeStrings: TypeAlias = list[str]
19
+ TypeStringOrStrings: TypeAlias = TypeString | TypeStrings
20
+
21
+ TypeIntOrTripletInt: TypeAlias = int | tuple[int, int, int]
22
+
23
+ # Processed inputs
24
+ TypeImagesTensor: TypeAlias = Float[Tensor, "batch 1 x_in y_in z_in"]
25
+
26
+ TypeTextTokenIds: TypeAlias = Int64[Tensor, "batch text_length"]
27
+ TypeTextAttentionMask: TypeAlias = Int64[Tensor, "batch text_length"]
28
+
29
+ # Outputs
30
+ TypeSequenceEmbeddings: TypeAlias = Float[Tensor, "batch sequence_length embed_dim"]
31
+ TypePooledEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim"]
32
+
33
+ TypePatchEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim x_out y_out z_out"]
34
+ TypeImageEmbeddings: TypeAlias = TypePatchEmbeddings | TypePooledEmbeddings
35
+
36
+ TypeTextEmbeddings: TypeAlias = TypeSequenceEmbeddings | TypePooledEmbeddings
37
+
38
+ TypePatchLogits: TypeAlias = Float[Tensor, "num_images num_prompts x_out y_out z_out"]
39
+ TypePooledLogits: TypeAlias = Float[Tensor, "num_images num_prompts"]
40
+ TypePatchProbabilities: TypeAlias = TypePatchLogits
41
+ TypePooledProbabilities: TypeAlias = TypePooledLogits
42
+ TypeScores: TypeAlias = list[dict[str, float | str]]
tests/test_config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for config loading, including GlobalHydra conflict regression."""
2
+
3
+ from omegaconf import OmegaConf
4
+
5
+ from colipri.checkpoint import _load_config
6
+ from colipri.checkpoint import load_model_config
7
+ from colipri.checkpoint import load_processor_config
8
+
9
+
10
+ def test_load_config():
11
+ config = _load_config()
12
+ assert config.input_size == 192
13
+ assert config.spacing == 2
14
+ assert config.model._target_ == "colipri.model.multimodal.Model"
15
+ assert config.processor._target_ == "colipri.processor.Processor"
16
+
17
+
18
+ def test_load_model_config():
19
+ config = load_model_config()
20
+ assert config._target_ == "colipri.model.multimodal.Model"
21
+ assert "image_encoder" in config
22
+ assert "text_encoder" in config
23
+
24
+
25
+ def test_load_processor_config():
26
+ config = load_processor_config()
27
+ assert config._target_ == "colipri.processor.Processor"
28
+ assert "image_transform" in config
29
+ assert "tokenizer" in config
30
+
31
+
32
+ def test_overrides():
33
+ config = _load_config(overrides=["input_size=256", "spacing=3"])
34
+ assert config.input_size == 256
35
+ assert config.spacing == 3
36
+
37
+
38
+ def test_interpolation():
39
+ config = _load_config()
40
+ resolved = OmegaConf.to_container(config, resolve=True)
41
+ assert isinstance(resolved, dict)
42
+ backbone = resolved["model"]["image_encoder"]["backbone"]
43
+ assert backbone["embed_dim"] == 864 # ${image_embed_dim}
44
+ assert backbone["input_shape"] == [192, 192, 192] # ${input_size}
45
+
46
+
47
+ def test_config_loading_with_hydra_preinitialized():
48
+ """Regression test: COLIPRI must work when GlobalHydra is already initialized.
49
+
50
+ See https://huggingface.co/microsoft/colipri/discussions/3
51
+ """
52
+ from hydra import initialize
53
+ from hydra.core.global_hydra import GlobalHydra
54
+
55
+ with initialize(config_path=None, version_base=None):
56
+ assert GlobalHydra.instance().is_initialized()
57
+
58
+ model_cfg = load_model_config()
59
+ proc_cfg = load_processor_config()
60
+ assert model_cfg._target_ == "colipri.model.multimodal.Model"
61
+ assert proc_cfg._target_ == "colipri.processor.Processor"
62
+
63
+ # Overrides must also work
64
+ config = _load_config(overrides=["input_size=256"])
65
+ assert config.input_size == 256
66
+
67
+ # GlobalHydra state must not be corrupted
68
+ assert not GlobalHydra.instance().is_initialized()
uv.lock ADDED
The diff for this file is too large to render. See raw diff