Safetensors
custom_code
jonggwon-park commited on
Commit
bc4db88
·
1 Parent(s): 5cd156b

Update Readme

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. README.md +60 -11
  3. requirements.txt +265 -0
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  data/MIMIC-CXR/*.json filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  data/MIMIC-CXR/*.json filter=lfs diff=lfs merge=lfs -text
37
+ misc/*.png filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -2,26 +2,57 @@
2
  license: cc-by-nc-4.0
3
  ---
4
 
5
- <!-- markdownlint-disable first-line-h1 -->
6
- <!-- markdownlint-disable html -->
7
-
8
  <div align="center">
9
- <h1>
10
- RadZero: Similarity-Based Cross-Attention for Explainable Vision-Language Alignment in Radiology with Zero-Shot Multi-Task Capability
11
- </h1>
12
- </div>
13
 
14
  <p align="center">
15
- 📝 <a href="" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Deepnoid/RadZero" target="_blank">Hugging Face</a> • 🧩 <a href="" target="_blank">Github</a>
16
  </p>
17
 
18
- <div align="center">
19
  </div>
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- ## 🎬 Get Started
23
 
24
- ```python
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  # Deepnoid/RadZero/inference.py
26
  import warnings
27
 
@@ -76,3 +107,21 @@ if __name__ == "__main__":
76
  print(similarity_map.shape)
77
  ```
78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  license: cc-by-nc-4.0
3
  ---
4
 
 
 
 
5
  <div align="center">
6
+
7
+ # RadZero: Similarity-Based Cross-Attention for Explainable Vision-Language Alignment in Chest X-ray with Zero-Shot Multi-Task Capability [NeurIPS 2025]
8
+
9
+
10
 
11
  <p align="center">
12
+ 📝 <a href="https://arxiv.org/abs/2504.07416" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Deepnoid/RadZero" target="_blank">Model</a> • 🧩 <a href="https://github.com/deepnoid-ai/RadZero" target="_blank">Codes</a>
13
  </p>
14
 
 
15
  </div>
16
 
17
+ ## Introduction
18
+
19
+ <p align="center">
20
+ <img src="misc/introduction.png" alt="Key Differences vs. Existing Methods" width="80%" />
21
+ </p>
22
+ <p align="center">
23
+ <em>Figure 1. Comparison of attention maps and the proposed VL similarity map for visualizing VL
24
+ alignment. (a) While traditional attention maps inevitably exhibit high values at certain points due to softmax activation, the proposed VL similarity maps yield low values for unrelated image-text pair. (b) Their fixed scale, originating from cosine similarity, enables open-vocabulary semantic segmentation through simple thresholding.</em>
25
+ </p>
26
+
27
+ <p align="center">
28
+ <img src="misc/method.png" alt="RadZero Method Overview" width="80%" />
29
+
30
+ </p>
31
+ <p align="center">
32
+ <em>Figure 2. Overview of the RadZero framework. Finding-sentences are extracted from reports and aligned with local image patch features through similarity-based cross-attention (VL-CABS), enabling zero-shot classification, grounding, and segmentation.</em>
33
+ </p>
34
+
35
+
36
+ ## Abstract
37
 
38
+ > Recent advancements in multimodal models have significantly improved vision-language (VL) alignment in radiology. However, existing approaches struggle to effectively utilize complex radiology reports for learning and offer limited interpretability through attention probability visualizations. To address these challenges, we introduce RadZero, a novel framework for VL alignment in chest X-ray with zero-shot multi-task capability. A key component of our approach is VL-CABS (Vision-Language Cross-Attention Based on Similarity), which aligns text embeddings with local image features for interpretable, fine-grained VL reasoning. RadZero leverages large language models to extract concise semantic sentences from radiology reports and employs multi-positive contrastive training to effectively capture relationships between images and multiple relevant textual descriptions. It uses a pre-trained vision encoder with additional trainable Transformer layers, allowing efficient high-resolution image processing. By computing similarity between text embeddings and local image patch features, VL-CABS enables zero-shot inference with similarity probability for classification, and pixel-level VL similarity maps for grounding and segmentation. Experimental results on public chest radiograph benchmarks show that RadZero outperforms state-of-the-art methods in zero-shot classification, grounding, and segmentation. Furthermore, VL similarity map analysis highlights the potential of VL-CABS for improving explainability in VL alignment. Additionally, qualitative evaluation demonstrates RadZero's capability for open-vocabulary semantic segmentation, further validating its effectiveness in medical imaging.
39
 
40
+
41
+ ## RadZero Model Inference
42
+
43
+ ### Install dependencies
44
+
45
+ ```
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ ### Model Inference Codes
50
+
51
+ RadZero can perform **zero-shot classification / grounding / segmentation for chest X-ray**
52
+ using the **RadZero model** on 🤗 <a href="https://huggingface.co/Deepnoid/RadZero" target="_blank">Hugging Face</a>.
53
+
54
+
55
+ ```
56
  # Deepnoid/RadZero/inference.py
57
  import warnings
58
 
 
107
  print(similarity_map.shape)
108
  ```
109
 
110
+
111
+
112
+
113
+ ## References
114
+
115
+ - **Pretrained models**
116
+ - **Vision encoder**: [XrayDINOv2](https://huggingface.co/StanfordAIMI/dinov2-base-xray-224)
117
+ - **Text encoder**: [all-mpnet-base-v2](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)
118
+
119
+
120
+ ## Acknowledgments
121
+ This work was supported by the Technology Innovation Program (RS-2025-02221011, Development
122
+ of Medical-Specialized Multimodal Hyperscale Generative AI Technology for Global Integration)
123
+ funded by the Ministry of Trade Industry & Energy (MOTIE, South Korea).
124
+
125
+ ## LICENSE
126
+ [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/)
127
+
requirements.txt ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.27.2
2
+ aiofiles==23.2.1
3
+ aiohttp==3.9.3
4
+ aiosignal==1.3.1
5
+ albucore==0.0.23
6
+ albumentations==2.0.3
7
+ altair==5.2.0
8
+ annotated-types==0.6.0
9
+ antlr4-python3-runtime==4.9.3
10
+ anyio==4.2.0
11
+ appdirs==1.4.4
12
+ argon2-cffi==21.3.0
13
+ argon2-cffi-bindings==21.2.0
14
+ asttokens==2.0.5
15
+ async-lru==2.0.4
16
+ attrs==23.1.0
17
+ Babel==2.11.0
18
+ backoff==2.2.1
19
+ beautifulsoup4==4.12.2
20
+ bert-score==0.3.13
21
+ bitsandbytes==0.43.0
22
+ black==24.2.0
23
+ bleach==4.1.0
24
+ blessed==1.20.0
25
+ Brotli==1.0.9
26
+ certifi==2024.2.2
27
+ cffi==1.16.0
28
+ cfgv==3.4.0
29
+ chardet==5.2.0
30
+ charset-normalizer==2.0.4
31
+ click==8.1.7
32
+ colorama==0.4.6
33
+ comm==0.2.1
34
+ contourpy==1.2.0
35
+ cycler==0.12.1
36
+ datasets==2.18.0
37
+ debugpy==1.6.7
38
+ decorator==5.1.1
39
+ deepspeed==0.14.0
40
+ defusedxml==0.7.1
41
+ dill==0.3.8
42
+ distlib==0.3.8
43
+ distro==1.9.0
44
+ docker-pycreds==0.4.0
45
+ docstring-parser==0.15
46
+ efficientnet-pytorch==0.7.1
47
+ einops==0.7.0
48
+ executing==0.8.3
49
+ f1chexbert==0.0.2
50
+ fairscale==0.4.13
51
+ faiss-gpu-cu12==1.8.0.2
52
+ fastapi==0.110.0
53
+ fastjsonschema==2.16.2
54
+ ffmpy==0.3.2
55
+ filelock==3.13.1
56
+ flake8==7.0.0
57
+ fonttools==4.49.0
58
+ frozenlist==1.4.1
59
+ fsspec==2024.2.0
60
+ ftfy==6.2.0
61
+ gitdb==4.0.11
62
+ GitPython==3.1.42
63
+ gpustat==1.1.1
64
+ gradio==4.21.0
65
+ gradio_client==0.12.0
66
+ h11==0.14.0
67
+ hjson==3.1.0
68
+ httpcore==1.0.4
69
+ httpx==0.27.0
70
+ huggingface-hub==0.28.1
71
+ identify==2.5.35
72
+ idna==3.4
73
+ imageio==2.34.1
74
+ importlib_resources==6.1.3
75
+ ipdb==0.13.13
76
+ ipykernel==6.28.0
77
+ ipython==8.20.0
78
+ ipywidgets==8.1.2
79
+ isort==5.13.2
80
+ jedi==0.18.1
81
+ Jinja2==3.1.3
82
+ joblib==1.3.2
83
+ json5==0.9.6
84
+ jsonschema==4.19.2
85
+ jsonschema-specifications==2023.7.1
86
+ jupyter==1.0.0
87
+ jupyter_client==8.6.0
88
+ jupyter-console==6.6.3
89
+ jupyter_core==5.5.0
90
+ jupyter-events==0.8.0
91
+ jupyter-lsp==2.2.0
92
+ jupyter_server==2.10.0
93
+ jupyter_server_terminals==0.4.4
94
+ jupyterlab==4.0.11
95
+ jupyterlab-pygments==0.1.2
96
+ jupyterlab_server==2.25.1
97
+ jupyterlab-widgets==3.0.10
98
+ kaggle==1.7.4.5
99
+ kiwisolver==1.4.5
100
+ lazy_loader==0.4
101
+ lightning-utilities==0.10.1
102
+ markdown-it-py==3.0.0
103
+ MarkupSafe==2.1.3
104
+ matplotlib==3.8.3
105
+ matplotlib-inline==0.1.6
106
+ mccabe==0.7.0
107
+ mdurl==0.1.2
108
+ mistune==2.0.4
109
+ mpmath==1.3.0
110
+ msgpack==1.1.0
111
+ multidict==6.0.5
112
+ multiprocess==0.70.16
113
+ munch==4.0.0
114
+ mypy-extensions==1.0.0
115
+ natsort==8.4.0
116
+ nbclient==0.8.0
117
+ nbconvert==7.10.0
118
+ nbformat==5.9.2
119
+ nest-asyncio==1.6.0
120
+ networkx==3.2.1
121
+ nibabel==5.3.2
122
+ ninja==1.11.1.1
123
+ nltk==3.8.1
124
+ nodeenv==1.8.0
125
+ notebook==7.0.8
126
+ notebook_shim==0.2.3
127
+ numpy==1.26.4
128
+ nvidia-cublas-cu12==12.1.3.1
129
+ nvidia-cuda-cupti-cu12==12.1.105
130
+ nvidia-cuda-nvrtc-cu12==12.1.105
131
+ nvidia-cuda-runtime-cu12==12.1.105
132
+ nvidia-cudnn-cu12==8.9.2.26
133
+ nvidia-cufft-cu12==11.0.2.54
134
+ nvidia-curand-cu12==10.3.2.106
135
+ nvidia-cusolver-cu12==11.4.5.107
136
+ nvidia-cusparse-cu12==12.1.0.106
137
+ nvidia-ml-py==12.535.133
138
+ nvidia-nccl-cu12==2.19.3
139
+ nvidia-nvjitlink-cu12==12.4.99
140
+ nvidia-nvtx-cu12==12.1.105
141
+ omegaconf==2.3.0
142
+ open-clip-torch==2.24.0
143
+ openai==1.35.7
144
+ opencv-python==4.9.0.80
145
+ opencv-python-headless==4.11.0.86
146
+ orjson==3.9.15
147
+ overrides==7.4.0
148
+ packaging==23.1
149
+ pandas==2.2.1
150
+ pandocfilters==1.5.0
151
+ parso==0.8.3
152
+ pathspec==0.12.1
153
+ patsy==0.5.6
154
+ peft==0.9.0
155
+ pexpect==4.8.0
156
+ pillow==10.2.0
157
+ pip==23.3.1
158
+ pip-chill==1.0.3
159
+ platformdirs==3.10.0
160
+ ply==3.11
161
+ pre-commit==3.6.2
162
+ pretrainedmodels==0.7.4
163
+ prometheus-client==0.14.1
164
+ prompt-toolkit==3.0.43
165
+ protobuf==4.25.3
166
+ psutil==5.9.0
167
+ ptyprocess==0.7.0
168
+ pure-eval==0.2.2
169
+ py-cpuinfo==9.0.0
170
+ pyarrow==15.0.1
171
+ pyarrow-hotfix==0.6
172
+ pycocoevalcap==1.2
173
+ pycocotools==2.0.8
174
+ pycodestyle==2.11.1
175
+ pycparser==2.21
176
+ pydantic==2.10.6
177
+ pydantic_core==2.27.2
178
+ pydicom==2.4.4
179
+ pydub==0.25.1
180
+ pyflakes==3.2.0
181
+ Pygments==2.15.1
182
+ pynvml==11.5.0
183
+ pyparsing==3.1.2
184
+ PyQt5==5.15.10
185
+ PyQt5-sip==12.13.0
186
+ PySocks==1.7.1
187
+ python-dateutil==2.8.2
188
+ python-dotenv==1.0.1
189
+ python-json-logger==2.0.7
190
+ python-multipart==0.0.9
191
+ python-slugify==8.0.4
192
+ pytorch-lightning==2.5.0.post0
193
+ pytz==2023.3.post1
194
+ PyYAML==6.0.1
195
+ pyzmq==25.1.2
196
+ qtconsole==5.5.1
197
+ QtPy==2.4.1
198
+ referencing==0.30.2
199
+ regex==2023.12.25
200
+ requests==2.31.0
201
+ rfc3339-validator==0.1.4
202
+ rfc3986-validator==0.1.1
203
+ rich==13.7.1
204
+ rpds-py==0.10.6
205
+ ruff==0.3.2
206
+ safetensors==0.4.2
207
+ scikit-image==0.23.2
208
+ scikit-learn==1.4.1.post1
209
+ scipy==1.12.0
210
+ seaborn==0.13.2
211
+ segmentation_models_pytorch==0.4.0
212
+ semantic-version==2.10.0
213
+ Send2Trash==1.8.2
214
+ sentencepiece==0.2.0
215
+ sentry-sdk==1.41.0
216
+ setproctitle==1.3.3
217
+ setuptools==68.2.2
218
+ shellingham==1.5.4
219
+ shtab==1.7.1
220
+ simsimd==6.2.1
221
+ sip==6.7.12
222
+ six==1.16.0
223
+ smmap==5.0.1
224
+ sniffio==1.3.0
225
+ soupsieve==2.5
226
+ stack-data==0.2.0
227
+ starlette==0.36.3
228
+ statsmodels==0.14.1
229
+ stringzilla==3.11.3
230
+ sympy==1.12
231
+ terminado==0.17.1
232
+ text-unidecode==1.3
233
+ threadpoolctl==3.3.0
234
+ tifffile==2024.4.18
235
+ timm==0.9.16
236
+ tinycss2==1.2.1
237
+ tokenizers==0.15.2
238
+ tomlkit==0.12.0
239
+ toolz==0.12.1
240
+ torch==2.2.1
241
+ torchaudio==2.2.1
242
+ torchmetrics==1.6.1
243
+ torchvision==0.17.1
244
+ tornado==6.3.3
245
+ tqdm==4.66.2
246
+ traitlets==5.7.1
247
+ transformers==4.39.3
248
+ triton==2.2.0
249
+ trl==0.7.11
250
+ typer==0.9.0
251
+ typing_extensions==4.12.2
252
+ tyro==0.7.3
253
+ tzdata==2024.1
254
+ urllib3==2.1.0
255
+ uvicorn==0.28.0
256
+ virtualenv==20.25.1
257
+ wandb==0.16.4
258
+ wcwidth==0.2.13
259
+ webencodings==0.5.1
260
+ websocket-client==0.58.0
261
+ websockets==11.0.3
262
+ wheel==0.41.2
263
+ widgetsnbextension==4.0.10
264
+ xxhash==3.4.1
265
+ yarl==1.9.4