Jayden Park commited on
Commit
576c8e4
·
verified ·
1 Parent(s): 6c7c280

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +122 -122
README.md CHANGED
@@ -1,122 +1,122 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
4
-
5
- <!-- markdownlint-disable first-line-h1 -->
6
- <!-- markdownlint-disable html -->
7
-
8
- <div align="center">
9
- <h1>
10
- M4CXR
11
- </h1>
12
- </div>
13
-
14
- <p align="center">
15
- 📝 <a href="" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Deepnoid/M4CXR" target="_blank">Hugging Face</a> • 🧩 <a href="" target="_blank">Github</a>
16
- </p>
17
-
18
- <div align="center">
19
- </div>
20
-
21
-
22
- ## 🎬 Get Started
23
-
24
- ```python
25
- import io
26
-
27
- import requests
28
- import torch
29
- from PIL import Image
30
- from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
31
-
32
-
33
- def load_image_from_url(url):
34
- try:
35
- response = requests.get(url)
36
- response.raise_for_status()
37
-
38
- image = Image.open(io.BytesIO(response.content))
39
- return image
40
-
41
- except requests.exceptions.RequestException as e:
42
- print(f"Error loading image: {e}")
43
- return None
44
-
45
-
46
- def do_generate(prompts, images, model, processor, generation_config):
47
- """The interface for generation
48
-
49
- Args:
50
- prompts (List[str]): List of prompt texts for entire batch
51
- images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch
52
- model (MllmForConditionalGeneration): MllmForConditionalGeneration
53
- processor (MllmProcessor): MllmProcessor
54
- generation_config (GenerationConfig): generation configurations
55
-
56
- Returns:
57
- outputs (List[str]): Generated responses for entire batch
58
- """
59
-
60
- # image, text processing
61
- inputs = processor(texts=prompts, images=images)
62
-
63
- # prepare inputs
64
- inputs = {
65
- k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
66
- }
67
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
-
69
- # batch decoding
70
- with torch.inference_mode():
71
- res = model.generate(**inputs, generation_config=generation_config)
72
-
73
- # decode tokens
74
- outputs = processor.batch_decode(res, skip_special_tokens=True)
75
- return outputs
76
-
77
-
78
- if __name__ == "__main__":
79
- # Setup constant
80
- device = torch.device("cuda")
81
- dtype = torch.bfloat16
82
- do_sample = False
83
-
84
- # Load Processor and Model
85
- processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
86
- generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR")
87
- model = AutoModelForCausalLM.from_pretrained(
88
- "Deepnoid/M4CXR",
89
- trust_remote_code=True,
90
- torch_dtype=dtype,
91
- device_map=device,
92
- )
93
-
94
- # Prepare images
95
- images = [
96
- load_image_from_url(
97
- "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
98
- ),
99
- load_image_from_url(
100
- "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
101
- ),
102
- ]
103
-
104
- # seperate question list
105
- questions = [
106
- "radiology image: <image> What is the view of this chest X-ray?",
107
- "radiology image: <image> Provide a description of the findings in the radiology image.",
108
- ]
109
-
110
- # build prompts with chat template
111
- prompts = []
112
- for question in questions:
113
- chats = [{"role": "user", "content": question}]
114
- prompt = processor.apply_chat_template(chats, tokenize=False)
115
- prompts.append(prompt)
116
-
117
- # Generate responses
118
- generation_config.do_sample = do_sample
119
- outputs = do_generate(prompts, images, model, processor, generation_config)
120
- print(outputs)
121
- ```
122
-
 
1
+ ---
2
+ license: cc-by-nc-4.0
3
+ ---
4
+
5
+ <!-- markdownlint-disable first-line-h1 -->
6
+ <!-- markdownlint-disable html -->
7
+
8
+ <div align="center">
9
+ <h1>
10
+ M4CXR
11
+ </h1>
12
+ </div>
13
+
14
+ <p align="center">
15
+ 📝 <a href="https://www.arxiv.org/abs/2408.16213" target="_blank">Paper</a> • 🤗 <a href="https://huggingface.co/Deepnoid/M4CXR" target="_blank">Hugging Face</a> • 🧩 <a href="" target="_blank">Github</a>
16
+ </p>
17
+
18
+ <div align="center">
19
+ </div>
20
+
21
+
22
+ ## 🎬 Get Started
23
+
24
+ ```python
25
+ import io
26
+
27
+ import requests
28
+ import torch
29
+ from PIL import Image
30
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
31
+
32
+
33
+ def load_image_from_url(url):
34
+ try:
35
+ response = requests.get(url)
36
+ response.raise_for_status()
37
+
38
+ image = Image.open(io.BytesIO(response.content))
39
+ return image
40
+
41
+ except requests.exceptions.RequestException as e:
42
+ print(f"Error loading image: {e}")
43
+ return None
44
+
45
+
46
+ def do_generate(prompts, images, model, processor, generation_config):
47
+ """The interface for generation
48
+
49
+ Args:
50
+ prompts (List[str]): List of prompt texts for entire batch
51
+ images (List[str or PIL.Image]): Paths or PIL.Image of images for entire batch
52
+ model (MllmForConditionalGeneration): MllmForConditionalGeneration
53
+ processor (MllmProcessor): MllmProcessor
54
+ generation_config (GenerationConfig): generation configurations
55
+
56
+ Returns:
57
+ outputs (List[str]): Generated responses for entire batch
58
+ """
59
+
60
+ # image, text processing
61
+ inputs = processor(texts=prompts, images=images)
62
+
63
+ # prepare inputs
64
+ inputs = {
65
+ k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
66
+ }
67
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
68
+
69
+ # batch decoding
70
+ with torch.inference_mode():
71
+ res = model.generate(**inputs, generation_config=generation_config)
72
+
73
+ # decode tokens
74
+ outputs = processor.batch_decode(res, skip_special_tokens=True)
75
+ return outputs
76
+
77
+
78
+ if __name__ == "__main__":
79
+ # Setup constant
80
+ device = torch.device("cuda")
81
+ dtype = torch.bfloat16
82
+ do_sample = False
83
+
84
+ # Load Processor and Model
85
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
86
+ generation_config = GenerationConfig.from_pretrained("Deepnoid/M4CXR")
87
+ model = AutoModelForCausalLM.from_pretrained(
88
+ "Deepnoid/M4CXR",
89
+ trust_remote_code=True,
90
+ torch_dtype=dtype,
91
+ device_map=device,
92
+ )
93
+
94
+ # Prepare images
95
+ images = [
96
+ load_image_from_url(
97
+ "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
98
+ ),
99
+ load_image_from_url(
100
+ "https://upload.wikimedia.org/wikipedia/commons/a/a1/Normal_posteroanterior_%28PA%29_chest_radiograph_%28X-ray%29.jpg"
101
+ ),
102
+ ]
103
+
104
+ # seperate question list
105
+ questions = [
106
+ "radiology image: <image> What is the view of this chest X-ray?",
107
+ "radiology image: <image> Provide a description of the findings in the radiology image.",
108
+ ]
109
+
110
+ # build prompts with chat template
111
+ prompts = []
112
+ for question in questions:
113
+ chats = [{"role": "user", "content": question}]
114
+ prompt = processor.apply_chat_template(chats, tokenize=False)
115
+ prompts.append(prompt)
116
+
117
+ # Generate responses
118
+ generation_config.do_sample = do_sample
119
+ outputs = do_generate(prompts, images, model, processor, generation_config)
120
+ print(outputs)
121
+ ```
122
+