fepegar commited on
Commit
5c067a5
·
1 Parent(s): 9d5c69b

Add package code and model weights

Browse files
.gitattributes CHANGED
@@ -19,6 +19,7 @@
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
 
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
  *.rar filter=lfs diff=lfs merge=lfs -text
 
19
  *.pb filter=lfs diff=lfs merge=lfs -text
20
  *.pickle filter=lfs diff=lfs merge=lfs -text
21
  *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.png filter=lfs diff=lfs merge=lfs -text
23
  *.pt filter=lfs diff=lfs merge=lfs -text
24
  *.pth filter=lfs diff=lfs merge=lfs -text
25
  *.rar filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+
12
+ .vscode/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.13
COLIPRI_demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Microsoft Corporation.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE
README.md CHANGED
@@ -1,5 +1,180 @@
1
  ---
2
  license: mit
 
 
 
3
  ---
4
 
5
- This repository will be populated very soon.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: mit
3
+ language:
4
+ - en
5
+ pipeline_tag: zero-shot-image-classification
6
  ---
7
 
8
+ # Model card for COLIPRI
9
+
10
+ <!-- Provide a quick summary of what the model is/does. -->
11
+
12
+ COLIPRI is a 3D vision&ndash;language transformer model trained to encode chest CT scans and reports.
13
+
14
+ ## Model description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ COLIPRI was trained using tens of thousands of chest CT scans and reports, without any annotations, using multiple objectives to learn strong joint representations of 3D images and text.
19
+ The procedure is described in detail in our manuscript, [_Comprehensive language-image pre-training for 3D medical image understanding_](https://arxiv.org/abs/2510.15042) (Wald et al. 2026).
20
+
21
+ The weights shared here correspond to our best-performing model, COLIPRI-CRM.
22
+
23
+ - **Developed by:** Microsoft Health Futures
24
+ - **Model type:** 3D vision&ndash;language encoder
25
+ - **License:** [MIT](./LICENSE)
26
+
27
+ ## Uses
28
+
29
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
30
+
31
+ COLIPRI is shared for research purposes only.
32
+ It is **not meant to be used for clinical practice**.
33
+
34
+ <!-- ### Downstream use -->
35
+
36
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
37
+
38
+ The encoders be plugged to other models, or used independently or jointly for many downstream tasks, such as:
39
+
40
+ - Image classification with text prompts
41
+ - Image clustering
42
+ - Text clustering
43
+ - Text-to-image retrieval
44
+ - Image-to-image retrieval
45
+ - Image-to-text retrieval
46
+ - Text-to-text retrieval
47
+ - Image classification with a classifier
48
+ - Text classification with a classifier
49
+ - Image segmentation with a decoder
50
+ - Report generation with a language decoder
51
+
52
+ Fine-tuning COLIPRI is typically not necessary to obtain good performance in downstream tasks.
53
+
54
+ <!-- ### Out-of-scope use -->
55
+
56
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
57
+
58
+ ## Biases, risks, and limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ COLIPRI was trained with data from Turkey and the USA only, therefore it might be biased towards population in the training data.
63
+ Underlying biases of the training datasets may not be well characterized.
64
+
65
+ ## Installation
66
+
67
+ ```shell
68
+ pip install git+https://huggingface.co/microsoft/colipri.git
69
+ ```
70
+
71
+ ## Usage examples
72
+
73
+ First, let's get a 3D chest CT we can use for demonstration.
74
+ The plotted slices intersect a lung nodule near the heart.
75
+
76
+ ```python
77
+ >>> from colipri import load_sample_ct
78
+ >>> image = load_sample_ct()
79
+ >>> image
80
+ ScalarImage(shape: (1, 512, 512, 139); spacing: (0.76, 0.76, 2.50); orientation: LPS+; dtype: torch.IntTensor; memory: 139.0 MiB)
81
+ >>> nodule_indices_ct = 180, 341, 76
82
+ >>> image.plot(indices=nodule_indices_ct)
83
+ ```
84
+
85
+ ![Input CT](assets/input.png)
86
+
87
+ Now, let's instantiate the model and processor.
88
+
89
+ ```python
90
+ >>> from colipri import get_model
91
+ >>> from colipri import get_processor
92
+ >>> model = get_model().cuda()
93
+ >>> processor = get_processor()
94
+ ```
95
+
96
+ ### Feature extraction
97
+
98
+ ```python
99
+ >>> import torch
100
+ >>> preprocessed_images = processor.process_images(image)
101
+ >>> preprocessed_images[0]
102
+ ScalarImage(shape: (1, 192, 192, 192); spacing: (2.00, 2.00, 2.00); orientation: SAR+; dtype: torch.FloatTensor; memory: 27.0 MiB)
103
+ >>> images_batch = processor.to_images_batch(preprocessed_images)
104
+ images_batch.shape
105
+ torch.Size([1, 1, 192, 192, 192])
106
+ >>> with torch.no_grad():
107
+ ... patch_embeddings = model.encode_image(images_batch)
108
+ >>> patch_embeddings.shape
109
+ torch.Size([1, 768, 24, 24, 24])
110
+ >>> with torch.no_grad():
111
+ ... pooled_embeddings = model.encode_image(images_batch, pool=True)
112
+ >>> pooled_embeddings.shape
113
+ torch.Size([1, 768])
114
+ ```
115
+
116
+ ### Zero-shot classification
117
+
118
+ ```python
119
+ >>> from colipri import ZeroShotImageClassificationPipeline
120
+ >>> pipeline = ZeroShotImageClassificationPipeline(model, processor)
121
+ >>> pipeline(image, ["No lung nodules", "Lung nodules"])
122
+ [
123
+ {'score': 0.005, 'label': 'No lung nodules'},
124
+ {'score': 0.995, 'label': 'Lung nodules'}
125
+ ]
126
+ ```
127
+
128
+ ## Environmental impact
129
+
130
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
131
+
132
+ <!-- Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700). -->
133
+
134
+ - **Hardware type:** NVIDIA A100 GPUs
135
+ - **Hours used:** 72 hours × 4 GPUs = 288 GPU-hours
136
+ - **Cloud provider:** Azure
137
+ - **Compute region:** West US 2
138
+ - **Carbon emitted:** 21.6 kg CO₂ eq.
139
+
140
+ ### Compute infrastructure
141
+
142
+ COLIPRI was trained on [Azure Machine Learning](https://azure.microsoft.com/en-us/products/machine-learning).
143
+
144
+ #### Hardware
145
+
146
+ | Stage | Node type | Num. nodes | GPU type | GPUs per node |
147
+ | --- | --- | --- | --- | --- |
148
+ | Pre-training | [`Standard_NC96ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 4 |
149
+ | Evaluation | [`Standard_NC24ads_A100_v4`](https://learn.microsoft.com/en-us/azure/virtual-machines/sizes/gpu-accelerated/nca100v4-series?tabs=sizeaccelerators) | 1 | NVIDIA A100 (80 GB) | 1 |
150
+
151
+ #### Software
152
+
153
+ The main software libraries used in this work were [nnSSL](https://github.com/MIC-DKFZ/nnssl) for training, [TorchIO](https://torchio.org/) for preprocessing and augmentation, [`nifti-zarr-py`](https://github.com/neuroscales/nifti-zarr-py) for data loading, and [nnU-Net](https://github.com/MIC-DKFZ/nnUNet) for segmentation evaluation.
154
+
155
+ ## Citation
156
+
157
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
158
+
159
+ ### BibTeX
160
+
161
+ ```bibtex
162
+ @misc{
163
+ wald2026_colipri,
164
+ title={Comprehensive language-image pre-training for 3D medical image understanding},
165
+ author={Tassilo Wald and Ibrahim Ethem Hamamci and Yuan Gao and Sam Bond-Taylor and Harshita Sharma and Maximilian Ilse and Cynthia Lo and Olesya Melnichenko and Anton Schwaighofer and Noel C. F. Codella and Maria Teodora Wetscherek and Klaus H. Maier-Hein and Panagiotis Korfiatis and Valentina Salvatelli and Javier Alvarez-Valle and P{\'e}rez-Garc{\'i}a},
166
+ year={2026},
167
+ eprint={2510.15042},
168
+ archivePrefix={arXiv},
169
+ primaryClass={cs.CV},
170
+ url={https://arxiv.org/abs/2510.15042},
171
+ }
172
+ ```
173
+
174
+ ### APA
175
+
176
+ > Wald, T., Hamamci, I. E., Gao, Y., Bond-Taylor, S., Sharma, H., Ilse, M., Lo, C., Melnichenko, O., Schwaighofer, A., Codella, N. C. F., Wetscherek, M. T., Maier-Hein, K. H., Korfiatis, P., Salvatelli, V., Alvarez-Valle, J., & Pérez-García, F. (2026). Comprehensive language-image pre-training for 3D medical image understanding. arXiv. <https://doi.org/10.48550/ARXIV.2510.15042>
177
+
178
+ ## Model card contact
179
+
180
+ Fernando Pérez-García ([`fperezgarcia@microsoft.com`](mailto:fperezgarcia@microsoft.com)).
assets/input.png ADDED

Git LFS Details

  • SHA256: 9cb8413eaa5628f5499c6fc3a349561c60dfde81351afbb4a3c19b33cd54c547
  • Pointer size: 131 Bytes
  • Size of remote file: 928 kB
justfile ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ requirements:
2
+ uv export --quiet --no-hashes --no-emit-project --output-file requirements_demo.txt
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97d66586507d1f48577b5149777d02e12d25787435a2a295cac5563df9291c14
3
+ size 1033541380
model_card.md ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Model card
3
+
4
+ Date effective: September 2025
5
+ Owner: ORA, Internal Policy & Practice
6
+
7
+ ## Model summary
8
+
9
+ | | |
10
+ | --- | --- |
11
+ | Developer | Microsoft Research |
12
+ | Description | COLIPRI is composed of a vision encoder for 3D CT scans and a text encoder for clinical reports. The encoders may be leveraged for downstream tasks such as classification or segmentation. The models were trained using CT-RATE and NLST. COLIPRI is not a generative model. |
13
+ | Model architecture | The vision encoder is Primus-M, a 3D vision transformer. The text encoder architecture is BERT. |
14
+ | Parameters | 1-500M |
15
+ | Inputs | 3D chest CT scans and/or radiology reports |
16
+ | Context length | N/A (not an LLM) |
17
+ | Outputs | Image and/or text embeddings, not directly human-interpretable. |
18
+ | GPUs | 4 x A100 GPUs |
19
+ | Training time | 1.75 Days |
20
+ | Public data summary (or summaries) | Training Data: National Lung Screening Trial (NLST) dataset: Chest CT images from approximately 26k patients CT-RATE dataset: Approximately 25.7k CT acquisitions with associate reports |
21
+ | Training dates | Dates of training: August 2024. Intended model weight release date: Dec 5, 2025 |
22
+ | Status | Static |
23
+ | Release date | Intended release date: Dec 5, 2025 |
24
+ | Release date in the EU (if different) | Intended release date: Dec 5, 2025 |
25
+ | License | Free/open source |
26
+ | Model dependencies | COLIPRI’s text decoder is fine-tuned from microsoft/BiomedVLP-CXR-BERT-specialized · Hugging Face. |
27
+ | List and link to any additional related assets | N/A |
28
+ | Acceptable use policy | N/A |
29
+
30
+ ## Model overview
31
+
32
+ COLIPRI is an encoder designed for 3D CT scans. Its architecture combines a 3D vision encoder (PRIMUS-M transformer) with a biomedical text encoder (BiomedVLP-CXR-BERT) and enhances existing models by adding additional training objectives and optimizations. These include the inclusion of new training objectives and optimizations to the CLIP paradigm to account for limited data availability. This multi-objective training approach optimizes both global semantic alignment and dense feature learning. After pre-training, the model is evaluated comprehensively on downstream medical imaging tasks such as classification, retrieval, and semantic segmentation.
33
+ COLIPRI is not a generative model. It cannot generate any images or text.
34
+
35
+ ### Alignment approach
36
+
37
+ N/A
38
+
39
+ ## Usage
40
+
41
+ ### Primary use cases
42
+
43
+ COLIPRI is primarily intended for 3D medical imaging tasks such as segmentation and classification. It is intended for research only.
44
+
45
+ ### Out-of-scope use cases
46
+
47
+ The COLIPRI model family was trained on 3D CT volumes and is therefore best suited for inputs of this type. It is intended for research only. The model is for research use only; clinical or deployed use cases are out of scope.
48
+
49
+ ### Distribution channels
50
+
51
+ The model will be released with open access via Hugging Face.
52
+
53
+ ### Input formats
54
+
55
+ The vision encoder takes 3D chest CT scans, and the text encoder takes radiology reports.
56
+
57
+ ### Technical requirements and integration guidance
58
+
59
+ The model requires a Python-based environment with PyTorch. The only hardware requirement is one GPU.
60
+
61
+ ### Responsible AI considerations
62
+
63
+ The model may underperform or produce unreliable outputs for non-English languages, out-of-distribution or adversarial inputs, or clinical scenarios not well represented in the training data. The model should be used in an AI-assisted research setup with human oversight, rather than as a sole diagnostic tool.
64
+
65
+ ## Data overview
66
+
67
+ ### Training, testing, and validation datasets
68
+
69
+ The COLIPRI model family was pre-trained on the CT-RATE dataset, consisting of approximately 25.7k CT acquisitions with associated reports, and the National Lung Screening Trial (NLST) dataset, consisting of chest CT images from approximately 26k patients.
70
+ Classification performance was evaluated on a withheld test set of CT-RATE and the publicly available subset of RAD-ChestCT, which comprises 3.6k chest CT volumes with 16 multi-abnormality labels. Semantic segmentation was evaluated by training a five-fold cross validation on four datasets:
71
+
72
+ - LiTS: Task 3 of the Medical Segmentation Decathlon (MSD), containing segmentations for liver and liver tumors
73
+ - Lung: Task 6 of MSD, containing segmentations of primary lung cancers
74
+ - HVS: Task 8 of the MSD, containing segmentations of hepatic vessels and tumors next to such vessels
75
+ - KiTS23: containing segmentations of tumors, cysts, and the kidney
76
+
77
+ ## Quality and performance evaluation
78
+
79
+ ### Results summary
80
+
81
+ #### Classification Probes
82
+
83
+ The COLIPRI family exceeded all third-party models on both the CT-RATE and RAD-ChestCT datasets across all metrics.
84
+
85
+ #### Zero-Shot Classification
86
+
87
+ COLIPRI encoders exceed the state of the art using short prompts on the CT-RATE and RAD-ChestCT datasets.
88
+
89
+ #### Report-to-Image Retrieval
90
+
91
+ COLIPRI encoders are substantially better in retrieving the associated image given a report, evaluated on the CT-RATE test set.
92
+
93
+ #### Segmentation
94
+
95
+ Across four datasets (LiTs, Lung, HVS, KiTS23), COLIPRI-CM and COLIPRI-CRM perform on-par or better than the state of the art.
96
+
97
+ #### Qualitative Analysis
98
+
99
+ COLIPRI embeddings generates sharper and more coherent features than third-party models.
100
+
101
+ ### Long context
102
+
103
+ N/A – not an LLM.
104
+
105
+ ### Safety evaluation and red-teaming
106
+
107
+ N/A - model is limited to the medical imaging domain. Model outputs were compared to ground-truth images. The standard quality metrics were computed.
108
+
109
+ ## Tracked capability evaluations
110
+
111
+ N/A This model is not a frontier model.
112
+
113
+ ## Additional information
114
+
115
+ Requests for additional information may be directed to [MSFTAIActRequest@microsoft.com](mailto:MSFTAIActRequest@microsoft.com).
pyproject.toml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "colipri"
3
+ version = "0.1.0"
4
+ description = "Vision-language encoder for chest CT scans and reports."
5
+ readme = "README.md"
6
+ authors = [
7
+ { name = "Microsoft Health Futures", email = "innereyedev@microsoft.com" },
8
+ ]
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "accelerate",
12
+ "dynamic-network-architectures@git+https://github.com/fepegar/dynamic-network-architectures.git@88ad3c9",
13
+ "huggingface-hub",
14
+ "hydra-core",
15
+ "jaxtyping",
16
+ "safetensors",
17
+ "timm<1.0.23", # RuntimeError related to position embeddings in 1.0.23
18
+ "torchio>=0.21.1",
19
+ "tqdm",
20
+ "transformers",
21
+ "typer",
22
+ ]
23
+
24
+ [project.urls]
25
+ Homepage = "https://huggingface.co/microsoft/colipri"
26
+ Source = "https://huggingface.co/microsoft/colipri"
27
+ "Issue tracker" = "https://huggingface.co/microsoft/colipri/discussions/new"
28
+ Documentation = "https://huggingface.co/microsoft/colipri/blob/main/README.md"
29
+
30
+ [project.optional-dependencies]
31
+ demo = [
32
+ "lovely-tensors",
33
+ "matplotlib",
34
+ "scikit-learn",
35
+ "torchinfo",
36
+ ]
37
+
38
+ [dependency-groups]
39
+ dev = [
40
+ "ipykernel",
41
+ "ipywidgets",
42
+ ]
43
+ types = ["ty"]
44
+
45
+ [build-system]
46
+ requires = ["uv_build"]
47
+ build-backend = "uv_build"
48
+
49
+ [tool.ruff.lint]
50
+ # Defaults from https://docs.astral.sh/ruff/linter/#rule-selection
51
+ select = [
52
+ # pycodestyle
53
+ "E",
54
+ # Pyflakes
55
+ "F",
56
+ # pyupgrade
57
+ "UP",
58
+ # flake8-bugbear
59
+ "B",
60
+ # flake8-simplify
61
+ "SIM",
62
+ # isort
63
+ "I",
64
+ ]
65
+ ignore = [
66
+ "F722", # https://docs.kidger.site/jaxtyping/faq/#flake8-or-ruff-are-throwing-an-error
67
+ ]
68
+
69
+ [tool.ruff.lint.isort]
70
+ force-single-line = true
src/colipri/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model.multimodal import get_model
2
+ from .pipelines import ZeroShotImageClassificationPipeline
3
+ from .processor import get_processor
4
+ from .sample_data import load_sample_ct
5
+
6
+ __all__ = [
7
+ "get_model",
8
+ "get_processor",
9
+ "load_sample_ct",
10
+ "ZeroShotImageClassificationPipeline",
11
+ ]
src/colipri/checkpoint.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from huggingface_hub import hf_hub_download
4
+ from hydra import compose
5
+ from hydra import initialize_config_dir
6
+ from omegaconf import DictConfig
7
+ from torch import Tensor
8
+
9
+ from .defaults import REPO_NAME
10
+ from .defaults import REPO_USER
11
+ from .defaults import REVISION
12
+ from .defaults import ROOT_CONFIG_FILENAME
13
+ from .defaults import WEIGHTS_FILENAME
14
+ from .defaults import get_configs_dir
15
+
16
+ TypeStateDict = dict[str, Tensor]
17
+
18
+
19
+ def download_weights(**kwargs) -> Path:
20
+ weights_path = _download_from_hugging_face(**kwargs)
21
+ return weights_path
22
+
23
+
24
+ def _download_from_hugging_face(
25
+ repo_user: str = REPO_USER,
26
+ repo_name: str = REPO_NAME,
27
+ revision: str | None = REVISION,
28
+ filename: str = WEIGHTS_FILENAME,
29
+ ) -> Path:
30
+ repo_id = f"{repo_user}/{repo_name}"
31
+ try:
32
+ weights_path = hf_hub_download(
33
+ repo_id=repo_id,
34
+ filename=filename,
35
+ revision=revision,
36
+ )
37
+ except Exception as e:
38
+ msg = f'Failed to download "{filename}" from Hugging Face Hub repo "{repo_id}".'
39
+ raise RuntimeError(msg) from e
40
+ return Path(weights_path)
41
+
42
+
43
+ def _load_config(**kwargs) -> DictConfig:
44
+ with initialize_config_dir(str(get_configs_dir()), version_base=None):
45
+ config = compose(ROOT_CONFIG_FILENAME, **kwargs)
46
+ return config
47
+
48
+
49
+ def load_model_config(**kwargs) -> DictConfig:
50
+ return _load_config(**kwargs)["model"]
51
+
52
+
53
+ def load_processor_config(**kwargs) -> DictConfig:
54
+ return _load_config(**kwargs)["processor"]
src/colipri/configs/config.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - model: default
3
+ - processor: default
4
+ - _self_
5
+
6
+ input_size: 192
7
+ spacing: 2 # mm
8
+ padding_mode: minimum
9
+ image_embed_dim: 864
10
+ projection_dim: 768
src/colipri/configs/model/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_encoder: default
3
+ - text_encoder: default
4
+
5
+ _target_: colipri.model.multimodal.Model
src/colipri/configs/model/image_encoder/backbone/default.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: dynamic_network_architectures.architectures.primus.Primus
2
+ input_channels: 1
3
+ num_classes: 1
4
+ eva_depth: 16
5
+ eva_numheads: 12
6
+ embed_dim: ${image_embed_dim}
7
+ patch_embed_size:
8
+ - 8
9
+ - 8
10
+ - 8
11
+ input_shape:
12
+ - ${input_size}
13
+ - ${input_size}
14
+ - ${input_size}
15
+ use_rot_pos_emb: True
16
+ use_abs_pos_embed: False
17
+ drop_path_rate: 0.2
18
+ init_values: 0.1
19
+ scale_attn_inner: True
20
+ num_register_tokens: 0
src/colipri/configs/model/image_encoder/default.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ defaults:
2
+ - backbone: default
3
+ - pooler: default
4
+ - projector: default
5
+
6
+ _target_: colipri.model.image.ImageEncoder
src/colipri/configs/model/image_encoder/pooler/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: colipri.pooling.AttentionPool1D
2
+ embed_dim: ${projection_dim}
3
+ num_heads: 12
src/colipri/configs/model/image_encoder/projector/default.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: torch.nn.Conv3d
2
+ in_channels: ${image_embed_dim}
3
+ out_channels: ${projection_dim}
4
+ kernel_size: 1
src/colipri/configs/model/text_encoder/backbone/default.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: transformers.AutoModel.from_pretrained
2
+ pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
3
+ trust_remote_code: yes
4
+ revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
src/colipri/configs/model/text_encoder/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - backbone: default
3
+ - pooler: default
4
+
5
+ _target_: colipri.model.text.TextEncoder
src/colipri/configs/model/text_encoder/pooler/default.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: colipri.pooling.AttentionPool1D
2
+ embed_dim: ${projection_dim}
3
+ num_heads: 12
src/colipri/configs/processor/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ defaults:
2
+ - image_transform: default
3
+ - tokenizer: default
4
+
5
+ _target_: colipri.processor.Processor
src/colipri/configs/processor/image_transform/default.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Exported automatically from nnssl: https://github.com/microsoft/nnssl/pull/41, then
2
+ # adapted for Colipri needs.
3
+ _target_: torchio.transforms.augmentation.composition.Compose
4
+ transforms:
5
+ - _target_: torchio.transforms.preprocessing.spatial.to_orientation.ToOrientation
6
+ copy: true
7
+ exclude: null
8
+ include: null
9
+ orientation: SAR
10
+ - _target_: torchio.transforms.preprocessing.spatial.resample.Resample
11
+ antialias: false
12
+ copy: true
13
+ exclude: null
14
+ image_interpolation: linear
15
+ include: null
16
+ label_interpolation: nearest
17
+ pre_affine_name: null
18
+ scalars_only: false
19
+ target: ${spacing}
20
+ - _target_: torchio.transforms.preprocessing.intensity.clamp.Clamp
21
+ out_min: -1000
22
+ out_max: 1000
23
+ - _target_: torchio.transforms.preprocessing.intensity.rescale.RescaleIntensity
24
+ copy: true
25
+ exclude: null
26
+ in_min_max:
27
+ - -1000
28
+ - 1000
29
+ include: null
30
+ masking_method: null
31
+ out_min_max:
32
+ - -1
33
+ - 1
34
+ percentiles:
35
+ - 0
36
+ - 100
37
+ - _target_: torchio.transforms.preprocessing.spatial.crop_or_pad.CropOrPad
38
+ copy: true
39
+ exclude: null
40
+ include: null
41
+ labels: null
42
+ mask_name: null
43
+ only_crop: false
44
+ only_pad: false
45
+ padding_mode: ${padding_mode}
46
+ target_shape:
47
+ - ${input_size}
48
+ - ${input_size}
49
+ - ${input_size}
src/colipri/configs/processor/tokenizer/default.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: transformers.AutoTokenizer.from_pretrained
2
+ pretrained_model_name_or_path: microsoft/BiomedVLP-CXR-BERT-specialized
3
+ trust_remote_code: yes
4
+ revision: 5157bdba1437a3aed316dacb1a5b68edf96b9902
5
+ model_max_length: 512
src/colipri/defaults.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.resources
2
+ from pathlib import Path
3
+
4
+ package_name = __package__
5
+ assert isinstance(package_name, str)
6
+
7
+ REPO_USER = "microsoft"
8
+ REPO_NAME = package_name
9
+ REVISION = None
10
+ WEIGHTS_FILENAME = "model.safetensors"
11
+ ROOT_CONFIG_FILENAME = "config.yaml"
12
+
13
+
14
+ def get_configs_dir(pkg: str | None = None) -> Path:
15
+ package_dir = importlib.resources.files(package_name)
16
+ return Path(package_dir / "configs")
17
+
18
+
19
+ __all__ = [
20
+ "get_configs_dir",
21
+ ]
src/colipri/model/__init__.py ADDED
File without changes
src/colipri/model/base.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class ModelMixin:
5
+ @property
6
+ def device(self) -> torch.device:
7
+ one_parameter = next(self.parameters())
8
+ return one_parameter.device
src/colipri/model/image.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import torchio as tio
6
+ from dynamic_network_architectures.architectures.primus import Primus
7
+ from einops import rearrange
8
+ from torch import nn
9
+ from tqdm.auto import tqdm
10
+
11
+ from ..pooling import AttentionPool1D
12
+ from ..types import TypeImageEmbeddings
13
+ from ..types import TypeImagesTensor
14
+ from ..types import TypeIntOrTripletInt
15
+ from .base import ModelMixin
16
+
17
+
18
+ class ImageEncoder(nn.Module, ModelMixin):
19
+ def __init__(
20
+ self,
21
+ backbone: Primus,
22
+ projector: nn.Conv3d,
23
+ pooler: AttentionPool1D,
24
+ ):
25
+ super().__init__()
26
+ self.backbone = self._remove_decoder(backbone)
27
+ self.projector = projector
28
+ self.pooler = pooler
29
+
30
+ def _remove_decoder(self, backbone: Primus) -> Primus:
31
+ if hasattr(backbone, "up_projection"):
32
+ backbone.up_projection = nn.Identity() # type: ignore
33
+ return backbone
34
+
35
+ @property
36
+ def patch_size(self) -> tuple[int, int, int]:
37
+ patch_embedder: nn.Conv3d = self.backbone.down_projection.proj # type: ignore[reportAssignmentType]
38
+ patch_size: tuple[int, int, int] = patch_embedder.stride # type: ignore[reportAssignmentType]
39
+ return patch_size
40
+
41
+ def encode(
42
+ self,
43
+ images: TypeImagesTensor,
44
+ ) -> TypeImageEmbeddings:
45
+ images = images.to(self.device)
46
+ embeddings: TypeImageEmbeddings = self.backbone(images)
47
+ return embeddings
48
+
49
+ def encode_sliding_window(
50
+ self,
51
+ images: TypeImagesTensor,
52
+ window_size: TypeIntOrTripletInt,
53
+ overlap: int = 0,
54
+ ) -> TypeImageEmbeddings:
55
+ if len(set(self.patch_size)) > 1:
56
+ msg = (
57
+ "Sliding window encoding is only supported for models with cubic"
58
+ " patch sizes for now."
59
+ )
60
+ raise NotImplementedError(msg)
61
+ else:
62
+ patch_size = self.patch_size[0]
63
+ image_key = "image" # could be anything
64
+ embeddings = []
65
+ for image in images:
66
+ grid_sampler = tio.inference.GridSampler(
67
+ tio.Subject(**{image_key: tio.ScalarImage(tensor=image)}),
68
+ window_size,
69
+ overlap,
70
+ )
71
+ patch_loader = tio.SubjectsLoader(grid_sampler) # type: ignore[reportArgumentType]
72
+ aggregator = tio.data.GridAggregator(
73
+ grid_sampler,
74
+ downsampling_factor=patch_size,
75
+ )
76
+ for patches_batch in tqdm(patch_loader):
77
+ input_tensor = patches_batch[image_key][tio.DATA].to(self.device)
78
+ locations = patches_batch[tio.LOCATION]
79
+ outputs = self.backbone(input_tensor)
80
+ aggregator.add_batch(outputs, locations)
81
+ embeddings.append(aggregator.get_output_tensor())
82
+ return torch.stack(embeddings).to(self.device)
83
+
84
+ def forward(
85
+ self,
86
+ images: TypeImagesTensor,
87
+ *,
88
+ project: bool,
89
+ pool: bool,
90
+ normalize: bool,
91
+ window_size: TypeIntOrTripletInt | None = None,
92
+ ) -> TypeImageEmbeddings:
93
+ if pool and not project:
94
+ msg = "Pooling requires projection to be enabled. Set project=True."
95
+ raise NotImplementedError(msg)
96
+ if window_size is None:
97
+ embeddings = self.encode(images)
98
+ else:
99
+ embeddings = self.encode_sliding_window(images, window_size)
100
+ if project:
101
+ embeddings = self.projector(embeddings)
102
+ if pool:
103
+ sequence = rearrange(embeddings, "b c x y z -> b (x y z) c")
104
+ embeddings = self.pooler(sequence)
105
+ else:
106
+ embeddings = self.pooler.to_dense()(embeddings)
107
+ if normalize:
108
+ embeddings = F.normalize(embeddings, dim=1)
109
+ return embeddings
src/colipri/model/multimodal.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from pathlib import Path
5
+
6
+ import torch
7
+ from accelerate import init_empty_weights
8
+ from accelerate import load_checkpoint_and_dispatch
9
+ from einops import rearrange
10
+ from hydra.utils import instantiate
11
+ from safetensors.torch import load_model
12
+ from safetensors.torch import save_model
13
+ from torch import nn
14
+ from transformers.utils.logging import get_verbosity
15
+ from transformers.utils.logging import set_verbosity
16
+ from transformers.utils.logging import set_verbosity_error
17
+
18
+ from ..checkpoint import download_weights
19
+ from ..checkpoint import load_model_config
20
+ from ..types import TypeImageEmbeddings
21
+ from ..types import TypeImagesTensor
22
+ from ..types import TypeIntOrTripletInt
23
+ from ..types import TypePatchEmbeddings
24
+ from ..types import TypePatchLogits
25
+ from ..types import TypePatchProbabilities
26
+ from ..types import TypePath
27
+ from ..types import TypePooledEmbeddings
28
+ from ..types import TypePooledLogits
29
+ from ..types import TypePooledProbabilities
30
+ from ..types import TypeTextAttentionMask
31
+ from ..types import TypeTextEmbeddings
32
+ from ..types import TypeTextTokenIds
33
+ from .base import ModelMixin
34
+ from .image import ImageEncoder
35
+ from .text import TextEncoder
36
+
37
+
38
+ def get_model(
39
+ checkpoint_path: TypePath | None = None,
40
+ *,
41
+ pretrained: bool = True,
42
+ image_only: bool = False,
43
+ **kwargs,
44
+ ) -> Model:
45
+ if pretrained and checkpoint_path is None:
46
+ checkpoint_path = download_weights()
47
+
48
+ overrides = []
49
+ for key, value in kwargs.items():
50
+ overrides.append(f"{key}={value}")
51
+ config = load_model_config(overrides=overrides)
52
+
53
+ if image_only:
54
+ config.text_encoder = None
55
+
56
+ transformers_verbosity = get_verbosity()
57
+ set_verbosity_error()
58
+ if checkpoint_path is None:
59
+ model = instantiate(config)
60
+ else:
61
+ with init_empty_weights():
62
+ model = instantiate(config)
63
+ accelerate_logger = logging.getLogger("accelerate.utils.modeling")
64
+ old_level = accelerate_logger.getEffectiveLevel()
65
+ accelerate_logger.setLevel(logging.ERROR)
66
+ model = load_checkpoint_and_dispatch(model, checkpoint_path)
67
+ accelerate_logger.setLevel(old_level)
68
+ set_verbosity(transformers_verbosity)
69
+
70
+ assert isinstance(model, Model)
71
+ return model.eval()
72
+
73
+
74
+ class Model(nn.Module, ModelMixin):
75
+ def __init__(
76
+ self,
77
+ image_encoder: ImageEncoder,
78
+ text_encoder: TextEncoder,
79
+ temperature: float = 1,
80
+ ):
81
+ super().__init__()
82
+ self.image_encoder = image_encoder
83
+ self.text_encoder = text_encoder
84
+ self.register_buffer("softmax_temperature", torch.tensor(temperature))
85
+
86
+ @property
87
+ def patch_size(self) -> tuple[int, int, int]:
88
+ return self.image_encoder.patch_size
89
+
90
+ def encode_image(
91
+ self,
92
+ images: TypeImagesTensor,
93
+ *,
94
+ pool: bool = False,
95
+ project: bool = False,
96
+ normalize: bool = False,
97
+ window_size: TypeIntOrTripletInt | None = None,
98
+ ) -> TypeImageEmbeddings:
99
+ return self.image_encoder(
100
+ images,
101
+ project=project,
102
+ pool=pool,
103
+ normalize=normalize,
104
+ window_size=window_size,
105
+ )
106
+
107
+ def encode_text(
108
+ self,
109
+ token_ids: TypeTextTokenIds,
110
+ attention_mask: TypeTextAttentionMask | None = None,
111
+ *,
112
+ pool: bool = True,
113
+ normalize: bool = True,
114
+ ) -> TypeTextEmbeddings:
115
+ return self.text_encoder(
116
+ token_ids,
117
+ attention_mask,
118
+ pool=pool,
119
+ normalize=normalize,
120
+ )
121
+
122
+ def compute_similarities(
123
+ self,
124
+ image_embeddings: TypePooledEmbeddings | TypePatchEmbeddings,
125
+ text_embeddings: TypePooledEmbeddings,
126
+ ) -> TypePatchLogits | TypePooledLogits:
127
+ text_embeddings = rearrange(text_embeddings, "num_prompts c -> c num_prompts")
128
+ is_grid = image_embeddings.ndim == 5
129
+ if is_grid:
130
+ num_images, _, x, y, z = image_embeddings.shape
131
+ image_embeddings = rearrange(
132
+ image_embeddings,
133
+ "num_images c x y z -> (num_images x y z) c",
134
+ )
135
+ similarities_flat = image_embeddings @ text_embeddings
136
+ similarities = rearrange(
137
+ similarities_flat,
138
+ "(num_images x y z) num_prompts -> num_images num_prompts x y z",
139
+ num_images=num_images,
140
+ x=x,
141
+ y=y,
142
+ z=z,
143
+ )
144
+ else:
145
+ similarities = image_embeddings @ text_embeddings
146
+ return similarities
147
+
148
+ def classify(
149
+ self,
150
+ images: TypePooledEmbeddings | TypePatchEmbeddings,
151
+ text: TypePooledEmbeddings,
152
+ ) -> TypePooledProbabilities | TypePatchProbabilities:
153
+ logits = self.compute_similarities(images, text)
154
+ assert isinstance(self.softmax_temperature, torch.Tensor)
155
+ probabilities = (logits / self.softmax_temperature).softmax(dim=-1)
156
+ return probabilities
157
+
158
+ def save_weights(self, path: TypePath) -> None:
159
+ path = Path(path)
160
+ if path.suffix == ".safetensors":
161
+ save_model(self, str(path))
162
+ else:
163
+ weights = self.state_dict()
164
+ torch.save(weights, path)
165
+
166
+ def load_weights(
167
+ self,
168
+ path: TypePath,
169
+ device: torch.device | str | None = None,
170
+ ) -> None:
171
+ path = Path(path)
172
+ if not path.exists():
173
+ raise FileNotFoundError(f"Checkpoint file {path} does not exist.")
174
+ if device is None:
175
+ device = torch.device("cpu")
176
+ if path.suffix == ".safetensors":
177
+ if device is not None:
178
+ device = str(device)
179
+ missing, unexpected = load_model(self, path, device=device)
180
+ if missing or unexpected:
181
+ raise RuntimeError("TODO")
182
+ else:
183
+ weights = torch.load(path, map_location=device)
184
+ self.load_state_dict(weights, strict=False)
src/colipri/model/text.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch.nn.functional as F
4
+ from torch import nn
5
+ from transformers import BertModel
6
+
7
+ from ..types import TypeTextAttentionMask
8
+ from ..types import TypeTextEmbeddings
9
+ from ..types import TypeTextTokenIds
10
+ from .base import ModelMixin
11
+
12
+
13
+ class TextEncoder(nn.Module, ModelMixin):
14
+ def __init__(self, backbone: nn.Module, pooler: nn.Module):
15
+ super().__init__()
16
+ if hasattr(backbone, "bert"):
17
+ bert = backbone.bert
18
+ assert isinstance(bert, BertModel)
19
+ bert.pooler = None # we use our own pooler
20
+ backbone = bert
21
+ self.backbone = backbone
22
+ self.pooler = pooler
23
+
24
+ def forward(
25
+ self,
26
+ token_ids: TypeTextTokenIds,
27
+ attention_mask: TypeTextAttentionMask | None = None,
28
+ *,
29
+ pool: bool = True,
30
+ normalize: bool = True,
31
+ ) -> TypeTextEmbeddings:
32
+ token_ids = token_ids.to(self.device)
33
+ if attention_mask is not None:
34
+ attention_mask = attention_mask.to(self.device)
35
+ # We didn't use a projector during pre-training
36
+ output = self.backbone(token_ids, attention_mask=attention_mask)
37
+ embeddings = output["last_hidden_state"]
38
+ if pool:
39
+ embeddings = self.pooler(embeddings)
40
+ if normalize:
41
+ embeddings = F.normalize(embeddings, dim=1)
42
+ return embeddings
src/colipri/pipelines.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .model.multimodal import Model
4
+ from .processor import Processor
5
+ from .types import TypeImage
6
+ from .types import TypePath
7
+ from .types import TypeScores
8
+
9
+
10
+ class ZeroShotImageClassificationPipeline:
11
+ def __init__(
12
+ self,
13
+ model: Model,
14
+ processor: Processor,
15
+ batch_size: int = 1, # TODO
16
+ num_workers: int = 0, # TODO
17
+ ):
18
+ self._model = model.eval()
19
+ self._processor = processor
20
+
21
+ @property
22
+ def device(self) -> torch.device:
23
+ return self._model.device
24
+
25
+ # TODO: add support for multiple prompts per class
26
+ @torch.no_grad()
27
+ def __call__(
28
+ self,
29
+ images: TypePath | TypeImage | list[TypePath | TypeImage],
30
+ prompts: str | list[str],
31
+ ) -> TypeScores | list[TypeScores]:
32
+ preprocessed_images = self._processor.process_images(images)
33
+ images_batch = self._processor.to_images_batch(preprocessed_images)
34
+
35
+ text_token_ids, text_attention_mask = self._processor.process_text(prompts)
36
+
37
+ image_embeddings_batch = self._model.encode_image(
38
+ images_batch,
39
+ project=True,
40
+ pool=True,
41
+ normalize=True,
42
+ )
43
+ text_embeddings_batch = self._model.encode_text(
44
+ text_token_ids,
45
+ text_attention_mask,
46
+ )
47
+
48
+ probabilities = self._model.classify(
49
+ image_embeddings_batch,
50
+ text_embeddings_batch,
51
+ )
52
+
53
+ if not isinstance(prompts, list):
54
+ prompts = [prompts]
55
+
56
+ all_results = []
57
+ for image_probabilities in probabilities:
58
+ image_results = []
59
+ for prompt, score in zip(prompts, image_probabilities, strict=True):
60
+ image_results.append({"score": score.item(), "label": prompt})
61
+ all_results.append(image_results)
62
+
63
+ if len(images_batch) == 1 and not isinstance(images, list):
64
+ all_results = all_results[0]
65
+ return all_results
src/colipri/pooling.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ from .types import TypePooledEmbeddings
6
+ from .types import TypeSequenceEmbeddings
7
+
8
+
9
+ class AttentionPool1D(nn.Module):
10
+ def __init__(self, embed_dim: int, num_heads: int):
11
+ super().__init__()
12
+ self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
13
+
14
+ def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
15
+ query = x.mean(dim=1, keepdim=True)
16
+ key = value = x
17
+ pooled, _ = self.attn(query, key, value)
18
+ return rearrange(pooled, "batch 1 embed_dim -> batch embed_dim")
19
+
20
+ def to_dense(self):
21
+ v_proj_in_weight_qkv = self.get_parameter("attn.in_proj_weight")
22
+ v_proj_in_bias_qkv = self.get_parameter("attn.in_proj_bias")
23
+ v_proj_out_weight = self.get_parameter("attn.out_proj.weight")
24
+ v_proj_out_bias = self.get_parameter("attn.out_proj.bias")
25
+ dim = v_proj_in_weight_qkv.shape[0] // 3
26
+ v_proj_in_weight_v = v_proj_in_weight_qkv[2 * dim :]
27
+ v_proj_in_bias_v = v_proj_in_bias_qkv[2 * dim :]
28
+
29
+ value_projection = nn.Conv3d(
30
+ in_channels=dim,
31
+ out_channels=dim,
32
+ kernel_size=1,
33
+ )
34
+ value_projection.weight.data = rearrange(
35
+ v_proj_in_weight_v,
36
+ "c_out c_in -> c_out c_in 1 1 1",
37
+ )
38
+ assert value_projection.bias is not None
39
+ value_projection.bias.data = v_proj_in_bias_v
40
+
41
+ out_projection = nn.Conv3d(
42
+ in_channels=dim,
43
+ out_channels=dim,
44
+ kernel_size=1,
45
+ )
46
+ out_projection.weight.data = rearrange(
47
+ v_proj_out_weight,
48
+ "c_out c_in -> c_out c_in 1 1 1",
49
+ )
50
+ assert out_projection.bias is not None
51
+ out_projection.bias.data = v_proj_out_bias
52
+
53
+ return nn.Sequential(
54
+ value_projection,
55
+ out_projection,
56
+ )
57
+
58
+
59
+ class MultiLearnedQueryAttentionPool1D(AttentionPool1D):
60
+ def __init__(self, embed_dim: int, num_heads: int):
61
+ super().__init__(embed_dim, num_heads)
62
+ # 4 Queries instead of 1
63
+ self.query = nn.Parameter(torch.randn(1, 4, embed_dim) / embed_dim**0.5)
64
+
65
+ def forward(self, x: TypeSequenceEmbeddings) -> TypePooledEmbeddings:
66
+ """
67
+ x: [B, T, D] — sequence of token embeddings
68
+ returns: [B, D] — pooled representation
69
+ """
70
+ B, T, D = x.shape
71
+ query = self.query.expand(B, -1, -1) # [B, 4, D]
72
+ pooled, _ = self.attn(query, x, x) # [B, 4, D]
73
+ # pooled: [4, B, D_out], want [B, D_out] by pooling over queries (mean)
74
+ return pooled.mean(dim=1) # [B, D]
src/colipri/processor.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import torch
4
+ import torchio as tio
5
+ from hydra.utils import instantiate
6
+ from transformers import BertTokenizer
7
+ from transformers.tokenization_utils_base import BatchEncoding
8
+
9
+ from .checkpoint import load_processor_config
10
+ from .types import TypeImage
11
+ from .types import TypeImages
12
+ from .types import TypeImagesTensor
13
+ from .types import TypePath
14
+ from .types import TypeStringOrStrings
15
+ from .types import TypeTextAttentionMask
16
+ from .types import TypeTextTokenIds
17
+
18
+
19
+ def get_processor(*, image_only: bool = False, **kwargs) -> Processor:
20
+ overrides = []
21
+ for key, value in kwargs.items():
22
+ overrides.append(f"{key}={value}")
23
+ config = load_processor_config(overrides=overrides)
24
+ if image_only:
25
+ config.tokenizer = None
26
+ return instantiate(config)
27
+
28
+
29
+ class Processor:
30
+ def __init__(
31
+ self,
32
+ image_transform: tio.Transform,
33
+ tokenizer: BertTokenizer,
34
+ ):
35
+ self._image_transform = image_transform
36
+ self._text_tokenizer = tokenizer
37
+
38
+ def __repr__(self) -> str:
39
+ lines = [
40
+ f"{self.__class__.__name__}(",
41
+ f" image_transform={self._image_transform},",
42
+ f" text_tokenizer={self._text_tokenizer},",
43
+ ")",
44
+ ]
45
+ return "\n".join(lines)
46
+
47
+ def process_images(
48
+ self,
49
+ inputs: TypePath | TypeImage | list[TypePath | TypeImage],
50
+ ) -> TypeImages:
51
+ if not isinstance(inputs, list):
52
+ inputs = [inputs]
53
+ images = []
54
+ for image_or_path in inputs:
55
+ is_image = isinstance(image_or_path, tio.ScalarImage)
56
+ if is_image:
57
+ image = image_or_path
58
+ else:
59
+ path = image_or_path
60
+ image = tio.ScalarImage(path)
61
+ images.append(image)
62
+ images = [self._image_transform(image) for image in images]
63
+ return images
64
+
65
+ def to_images_batch(
66
+ self,
67
+ images: TypeImages,
68
+ ) -> TypeImagesTensor:
69
+ if not isinstance(images, list):
70
+ msg = f"Expected images to be a list, got {type(images)}"
71
+ raise TypeError(msg)
72
+ images_tensor = torch.stack([image.data for image in images])
73
+ return images_tensor
74
+
75
+ def process_text(
76
+ self,
77
+ text: TypeStringOrStrings,
78
+ **kwargs,
79
+ ) -> tuple[TypeTextTokenIds, TypeTextAttentionMask]:
80
+ encodings: BatchEncoding = self._text_tokenizer(
81
+ text,
82
+ max_length=512, # TODO
83
+ padding="max_length",
84
+ truncation=True,
85
+ return_tensors="pt",
86
+ **kwargs,
87
+ )
88
+ token_ids: TypeTextTokenIds = encodings["input_ids"] # type: ignore[reportAssignmentType]
89
+ attention_mask: TypeTextAttentionMask = encodings["attention_mask"] # type: ignore[reportAssignmentType]
90
+ return token_ids, attention_mask
src/colipri/sample_data.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torchio as tio
4
+
5
+
6
+ def load_sample_ct(clip: bool = True) -> tio.ScalarImage:
7
+ """Get a sample CT image."""
8
+ ct = get_sample_ct()
9
+ if clip:
10
+ clamp = tio.Clamp(-1000, 1000)
11
+ ct = clamp(ct)
12
+ else:
13
+ ct.load()
14
+ return ct
15
+
16
+
17
+ def get_sample_ct() -> tio.ScalarImage:
18
+ """Download a sample CT image if not already present."""
19
+ return tio.datasets.Slicer("CTChest").CT_chest # type: ignore[reportAttributeAccessIssue]
20
+
21
+
22
+ def get_sample_ct_path() -> Path:
23
+ """Get the path to a sample CT image."""
24
+ return get_sample_ct().path # type: ignore[reportReturnType]
src/colipri/types.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import TypeAlias
3
+
4
+ import torchio as tio
5
+ from jaxtyping import Float
6
+ from jaxtyping import Int64
7
+ from torch import Tensor
8
+
9
+ TypePath: TypeAlias = str | os.PathLike[str]
10
+ TypePaths: TypeAlias = list[TypePath]
11
+
12
+ # Raw inputs
13
+ TypeImage: TypeAlias = tio.ScalarImage
14
+ TypeImages: TypeAlias = list[tio.ScalarImage]
15
+ TypeImageOrImages: TypeAlias = TypeImage | TypeImages
16
+
17
+ TypeString: TypeAlias = str
18
+ TypeStrings: TypeAlias = list[str]
19
+ TypeStringOrStrings: TypeAlias = TypeString | TypeStrings
20
+
21
+ TypeIntOrTripletInt: TypeAlias = int | tuple[int, int, int]
22
+
23
+ # Processed inputs
24
+ TypeImagesTensor: TypeAlias = Float[Tensor, "batch 1 x_in y_in z_in"]
25
+
26
+ TypeTextTokenIds: TypeAlias = Int64[Tensor, "batch text_length"]
27
+ TypeTextAttentionMask: TypeAlias = Int64[Tensor, "batch text_length"]
28
+
29
+ # Outputs
30
+ TypeSequenceEmbeddings: TypeAlias = Float[Tensor, "batch sequence_length embed_dim"]
31
+ TypePooledEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim"]
32
+
33
+ TypePatchEmbeddings: TypeAlias = Float[Tensor, "batch embed_dim x_out y_out z_out"]
34
+ TypeImageEmbeddings: TypeAlias = TypePatchEmbeddings | TypePooledEmbeddings
35
+
36
+ TypeTextEmbeddings: TypeAlias = TypeSequenceEmbeddings | TypePooledEmbeddings
37
+
38
+ TypePatchLogits: TypeAlias = Float[Tensor, "num_images num_prompts x_out y_out z_out"]
39
+ TypePooledLogits: TypeAlias = Float[Tensor, "num_images num_prompts"]
40
+ TypePatchProbabilities: TypeAlias = TypePatchLogits
41
+ TypePooledProbabilities: TypeAlias = TypePooledLogits
42
+ TypeScores: TypeAlias = list[dict[str, float | str]]
uv.lock ADDED
The diff for this file is too large to render. See raw diff