Jayden Park commited on
Commit
789ad27
·
1 Parent(s): 576c8e4

Update README and add requirements.txt

Browse files
Files changed (5) hide show
  1. README.md +205 -101
  2. app.py +2 -2
  3. interface.py +3 -3
  4. requirements.txt +40 -0
  5. task_examples.py +3 -3
README.md CHANGED
@@ -1,122 +1,226 @@
1
  ---
2
  license: cc-by-nc-4.0
 
 
 
 
 
 
 
 
 
 
 
3
  ---
4
 
5
  <!-- markdownlint-disable first-line-h1 -->
6
  <!-- markdownlint-disable html -->
7
 
8
- <div align="center">
9
- <h1>
10
- M4CXR
11
- </h1>
12
- </div>
13
 
14
  <p align="center">
15
- 📝 <a href="https://www.arxiv.org/abs/2408.16213" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Deepnoid/M4CXR" target="_blank">Hugging Face</a> • 🧩 <a href="" target="_blank">Github</a>
 
 
 
16
  </p>
17
 
18
- <div align="center">
19
- </div>
20
 
 
21
 
22
- ## 🎬 Get Started
 
 
23
 
24
- ```python
25
- import io
 
 
 
 
 
 
 
 
 
26
 
27
- import requests
 
 
 
 
 
28
  import torch
29
- from PIL import Image
30
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- def load_image_from_url(url):
34
- try:
35
- response = requests.get(url)
36
- response.raise_for_status()
37
-
38
- image = Image.open(io.BytesIO(response.content))
39
- return image
40
-
41
- except requests.exceptions.RequestException as e:
42
- print(f"Error loading image: {e}")
43
- return None
44
-
45
-
46
- def do_generate(prompts, images, model, processor, generation_config):
47
- """The interface for generation
48
-
49
- Args:
50
- prompts (List[str]): List of prompt texts for entire batch
51
- images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch
52
- model (MllmForConditionalGeneration): MllmForConditionalGeneration
53
- processor (MllmProcessor): MllmProcessor
54
- generation_config (GenerationConfig): generation configurations
55
-
56
- Returns:
57
- outputs (List[str]): Generated responses for entire batch
58
- """
59
-
60
- # image, text processing
61
- inputs = processor(texts=prompts, images=images)
62
-
63
- # prepare inputs
64
- inputs = {
65
- k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
66
- }
67
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
-
69
- # batch decoding
70
- with torch.inference_mode():
71
- res = model.generate(**inputs, generation_config=generation_config)
72
-
73
- # decode tokens
74
- outputs = processor.batch_decode(res, skip_special_tokens=True)
75
- return outputs
76
-
77
-
78
- if __name__ == "__main__":
79
- # Setup constant
80
- device = torch.device("cuda")
81
- dtype = torch.bfloat16
82
- do_sample = False
83
-
84
- # Load Processor and Model
85
- processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
86
- generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR")
87
- model = AutoModelForCausalLM.from_pretrained(
88
- "Deepnoid/M4CXR",
89
- trust_remote_code=True,
90
- torch_dtype=dtype,
91
- device_map=device,
92
- )
93
-
94
- # Prepare images
95
- images = [
96
- load_image_from_url(
97
- "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
98
- ),
99
- load_image_from_url(
100
- "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
101
- ),
102
- ]
103
-
104
- # seperate question list
105
- questions = [
106
- "radiology image: <image> What is the view of this chest X-ray?",
107
- "radiology image: <image> Provide a description of the findings in the radiology image.",
108
- ]
109
-
110
- # build prompts with chat template
111
- prompts = []
112
- for question in questions:
113
- chats = [{"role": "user", "content": question}]
114
- prompt = processor.apply_chat_template(chats, tokenize=False)
115
- prompts.append(prompt)
116
-
117
- # Generate responses
118
- generation_config.do_sample = do_sample
119
- outputs = do_generate(prompts, images, model, processor, generation_config)
120
- print(outputs)
121
  ```
122
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-4.0
3
+ pipeline_tag: image-text-to-text
4
+ tags:
5
+ - medical
6
+ - chest-x-ray
7
+ - radiology
8
+ - multi-modal
9
+ - multi-task
10
+ - vision-language
11
+ - report-generation
12
+ - visual-grounding
13
+ - vqa
14
  ---
15
 
16
  <!-- markdownlint-disable first-line-h1 -->
17
  <!-- markdownlint-disable html -->
18
 
19
+ # M4CXR: Exploring Multi-task Potentials of Multi-modal Large Language Models for Chest X-ray Interpretation [IEEE TNNLS]
 
 
 
 
20
 
21
  <p align="center">
22
+ 📝 <a href="https://arxiv.org/abs/2408.16213" target="_blank">arXiv</a> •
23
+ 📖 <a href="https://ieeexplore.ieee.org/abstract/document/11106750" target="_blank">IEEE TNNLS</a> •
24
+ 🤗 <a href="https://huggingface.co/Deepnoid/M4CXR-TNNLS" target="_blank">Model</a> •
25
+ 🧩 <a href="https://github.com/deepnoid-ai/M4CXR-TNNLS" target="_blank">Codes</a>
26
  </p>
27
 
28
+ ## Introduction
 
29
 
30
+ **M4CXR** is a multi-modal large language model (MLLM) designed for **chest X-ray (CXR) interpretation**, capable of handling **multiple tasks** in a unified conversational framework. It is trained on a visual instruction-following dataset assembled from diverse CXR tasks, and supports:
31
 
32
+ - 📝 **Medical Report Generation (MRG)** — single-image, multi-image, and multi-study (with prior reports) scenarios, powered by a **chain-of-thought (CoT)** prompting strategy for state-of-the-art clinical accuracy.
33
+ - 🎯 **Visual Grounding** — localizing anatomical regions or findings described in free-text phrases.
34
+ - 💬 **Visual Question Answering (VQA)** — answering open-ended questions about CXR images, including difference VQA across studies.
35
 
36
+ ## Abstract
37
+
38
+ > The rapid evolution of artificial intelligence, especially in large language models (LLMs), has significantly impacted various domains, including healthcare. In chest X-ray (CXR) analysis, previous studies have employed LLMs, but with limitations: either underutilizing the LLMs' capability for multitask learning or lacking clinical accuracy. This article presents M4CXR, a multimodal LLM designed to enhance CXR interpretation. The model is trained on a visual instruction-following dataset that integrates various task-specific datasets in a conversational format. As a result, the model supports multiple tasks such as medical report generation (MRG), visual grounding, and visual question answering (VQA). M4CXR achieves state-of-the-art clinical accuracy in MRG by employing a chain-of-thought (CoT) prompting strategy, in which it identifies findings in CXR images and subsequently generates corresponding reports. The model is adaptable to various MRG scenarios depending on the available inputs, such as single-image, multiimage, and multistudy contexts. In addition to MRG, M4CXR performs visual grounding at a level comparable to specialized models and demonstrates outstanding performance in VQA. Both quantitative and qualitative assessments reveal M4CXR's versatility in MRG, visual grounding, and VQA, while consistently maintaining clinical accuracy.
39
+
40
+ ## Get Started
41
+
42
+ ### Install dependencies
43
+
44
+ ```bash
45
+ pip install -r requirements.txt
46
+ ```
47
 
48
+ ### Basic Inference
49
+
50
+ A minimal example — load the model, feed a chest X-ray with a text question, and get a response.
51
+ The full runnable script is available as [interface.py](./interface.py).
52
+
53
+ ```python
54
  import torch
 
55
  from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
56
 
57
+ from interface import do_generate, load_image_from_url
58
+
59
+
60
+ # Setup
61
+ device = torch.device("cuda")
62
+ dtype = torch.bfloat16
63
+
64
+ # Load processor, model, and generation config
65
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
66
+ generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
67
+ model = AutoModelForCausalLM.from_pretrained(
68
+ "Deepnoid/M4CXR-TNNLS",
69
+ trust_remote_code=True,
70
+ torch_dtype=dtype,
71
+ device_map=device,
72
+ )
73
+
74
+ # Prepare a batch of images and questions
75
+ images = [
76
+ load_image_from_url(
77
+ "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
78
+ ),
79
+ load_image_from_url(
80
+ "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
81
+ ),
82
+ ]
83
+ questions = [
84
+ "radiology image: <image> What is the view of this chest X-ray?",
85
+ "radiology image: <image> Provide a description of the findings in the radiology image.",
86
+ ]
87
+
88
+ # Build prompts with the chat template
89
+ prompts = [
90
+ processor.apply_chat_template([{"role": "user", "content": q}], tokenize=False)
91
+ for q in questions
92
+ ]
93
+
94
+ # Generate
95
+ generation_config.do_sample = False
96
+ outputs = do_generate(prompts, images, model, processor, generation_config)
97
+ print(outputs)
98
+ ```
99
+
100
+ ## Task-specific Usage
101
+
102
+ M4CXR supports diverse CXR interpretation tasks through single- or multi-turn conversations. Full runnable examples are provided in [task_examples.py](./task_examples.py).
103
 
104
+ The examples below use the helpers from [interface.py](./interface.py) and the multi-turn driver defined in [task_examples.py](./task_examples.py):
105
+
106
+ ```python
107
+ findings = (
108
+ "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, "
109
+ "consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, "
110
+ "pleural other, fracture, support devices"
111
+ )
112
+ ```
113
+
114
+ ### 1. Single-image Medical Report Generation (CoT)
115
+
116
+ The model first predicts findings from a list of candidates, then writes the report conditioned on its own predictions.
117
+
118
+ ```python
119
+ images = [image]
120
+ questions = [
121
+ f"radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}",
122
+ "Based on the previous conversation, provide a description of the findings in the radiology image.",
123
+ ]
124
+ chats = do_generate_multi_turn(questions, images, model, processor, generation_config)
125
+ ```
126
+
127
+ ### 2. Multi-image Medical Report Generation (CoT)
128
+
129
+ Multiple views of the same study can be provided in a single prompt.
130
+
131
+ ```python
132
+ images = [image_pa, image_lat] # e.g., PA + lateral
133
+ image_tokens = " ".join("<image>" for _ in images)
134
+ questions = [
135
+ f"radiology images: {image_tokens} Which of the following findings are present in the radiology images? Findings: {findings}",
136
+ "Based on the previous conversation, provide a description of the findings in the radiology images.",
137
+ ]
138
+ chats = do_generate_multi_turn(questions, images, model, processor, generation_config)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  ```
140
 
141
+ ### 3. Multi-study Medical Report Generation (CoT)
142
+
143
+ Condition on prior images and the prior report to generate a follow-up report that references temporal changes.
144
+
145
+ ```python
146
+ prior_images = [prior_pa, prior_lat]
147
+ prior_report = "The lungs are clear. There is no pneumothorax."
148
+ follow_up_images = [current_pa, current_lat]
149
+ images = prior_images + follow_up_images
150
+
151
+ prior_tokens = " ".join("<image>" for _ in prior_images)
152
+ current_tokens = " ".join("<image>" for _ in follow_up_images)
153
+
154
+ questions = [
155
+ (
156
+ f"prior radiology images: {prior_tokens}, prior radiology report: {prior_report} "
157
+ f"follow-up images: {current_tokens}, The radiology studies are given in chronological order. "
158
+ f"Which of the following findings are present in the current follow-up radiology images? "
159
+ f"Findings: {findings}"
160
+ ),
161
+ "Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.",
162
+ ]
163
+ chats = do_generate_multi_turn(questions, images, model, processor, generation_config)
164
+ ```
165
+
166
+ ### 4. Visual Grounding
167
+
168
+ Given a phrase, the model returns the bounding box of the region it describes.
169
+
170
+ ```python
171
+ images = [image]
172
+ phrase = "right lower lobe"
173
+ questions = [
174
+ f"radiology image: <image> Provide the bounding box coordinate of the region this phrase describes: {phrase}",
175
+ ]
176
+ chats = do_generate_multi_turn(questions, images, model, processor, generation_config)
177
+ ```
178
+
179
+ ### 5. Report Summarization
180
+
181
+ Chain MRG with a follow-up summarization turn to obtain a concise one-sentence summary.
182
+
183
+ ```python
184
+ images = [image]
185
+ questions = [
186
+ f"radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}",
187
+ "Based on the previous conversation, provide a description of the findings in the radiology image.",
188
+ "Summarize the description in one concise sentence.",
189
+ ]
190
+ chats = do_generate_multi_turn(questions, images, model, processor, generation_config)
191
+ ```
192
+
193
+ ## Citation
194
+
195
+ If you use M4CXR in your research, please cite:
196
+
197
+ ```bibtex
198
+ @article{park2025m4cxr,
199
+ author={Park, Jonggwon and Kim, Soobum and Yoon, Byungmu and Hyun, Jihun and Choi, Kyoyun},
200
+ journal={IEEE Transactions on Neural Networks and Learning Systems},
201
+ title={M4CXR: Exploring Multitask Potentials of Multimodal Large Language Models for Chest X-Ray Interpretation},
202
+ year={2025},
203
+ volume={36},
204
+ number={10},
205
+ pages={17841-17855},
206
+ doi={10.1109/TNNLS.2025.3587687}
207
+ }
208
+ ```
209
+
210
+ ## References
211
+
212
+ - **Pretrained models**
213
+ - **Vision encoder**: [RAD-DINO](https://huggingface.co/microsoft/rad-dino)
214
+ - **Language model**: [Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
215
+ - **Visual projector**
216
+ - **C-Abstractor** from [Honeybee (CVPR 2024)](https://github.com/khanrc/honeybee)
217
+
218
+ ## Acknowledgments
219
+
220
+ This work was supported by the Technology Innovation Program (RS-2025-02221011, Development of Medical-Specialized Multimodal Hyperscale Generative AI Technology for Global Integration) funded by the Ministry of Trade Industry & Energy (MOTIE, South Korea).
221
+
222
+ ## License
223
+
224
+ [![License: CC BY-NC 4.0](https://img.shields.io/badge/License-CC%20BY--NC%204.0-lightgrey.svg)](https://creativecommons.org/licenses/by-nc/4.0/)
225
+
226
+ Released under **CC BY-NC 4.0**. The model and its outputs are provided **for research purposes only** and are **not intended for clinical use or medical decision-making**.
app.py CHANGED
@@ -38,9 +38,9 @@ title_markdown = """
38
 
39
  def load_model(device, dtype):
40
  # Load Processor and Model
41
- processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
42
  model = AutoModelForCausalLM.from_pretrained(
43
- "Deepnoid/M4CXR",
44
  trust_remote_code=True,
45
  torch_dtype=dtype,
46
  device_map=device,
 
38
 
39
  def load_model(device, dtype):
40
  # Load Processor and Model
41
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
42
  model = AutoModelForCausalLM.from_pretrained(
43
+ "Deepnoid/M4CXR-TNNLS",
44
  trust_remote_code=True,
45
  torch_dtype=dtype,
46
  device_map=device,
interface.py CHANGED
@@ -58,10 +58,10 @@ if __name__ == "__main__":
58
  do_sample = False
59
 
60
  # Load Processor and Model
61
- processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
62
- generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR")
63
  model = AutoModelForCausalLM.from_pretrained(
64
- "Deepnoid/M4CXR",
65
  trust_remote_code=True,
66
  torch_dtype=dtype,
67
  device_map=device,
 
58
  do_sample = False
59
 
60
  # Load Processor and Model
61
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
62
+ generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
63
  model = AutoModelForCausalLM.from_pretrained(
64
+ "Deepnoid/M4CXR-TNNLS",
65
  trust_remote_code=True,
66
  torch_dtype=dtype,
67
  device_map=device,
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ torchvision>=0.16.2
3
+ torchaudio
4
+ isort
5
+ black
6
+ flake8
7
+ einops>=0.7.0
8
+ gradio
9
+ tokenizers
10
+ wandb
11
+ deepspeed
12
+ peft
13
+ bitsandbytes
14
+ scikit-learn
15
+ requests
16
+ transformers==4.39.3
17
+ pillow>=10.0.1
18
+ seaborn>=0.13.0
19
+ timm>=0.9.12
20
+ accelerate>=0.25.0
21
+ trl
22
+ omegaconf
23
+ datasets
24
+ pre-commit
25
+ opencv-python
26
+ chardet
27
+ torchmetrics
28
+ statsmodels
29
+ nltk
30
+ ninja
31
+ natsort
32
+ pydicom
33
+ scikit-image
34
+ openai
35
+ open_clip_torch
36
+ fast-bleu
37
+ bert-score
38
+ python-dotenv
39
+ f1chexbert==0.0.2
40
+ pycocoevalcap
task_examples.py CHANGED
@@ -55,10 +55,10 @@ if __name__ == "__main__":
55
  do_sample = False
56
 
57
  # Load Processor and Model
58
- processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
59
- generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR")
60
  model = AutoModelForCausalLM.from_pretrained(
61
- "Deepnoid/M4CXR",
62
  trust_remote_code=True,
63
  torch_dtype=dtype,
64
  device_map=device,
 
55
  do_sample = False
56
 
57
  # Load Processor and Model
58
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR-TNNLS", trust_remote_code=True)
59
+ generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR-TNNLS")
60
  model = AutoModelForCausalLM.from_pretrained(
61
+ "Deepnoid/M4CXR-TNNLS",
62
  trust_remote_code=True,
63
  torch_dtype=dtype,
64
  device_map=device,