jgpark commited on
Commit
22ecd08
·
1 Parent(s): e445111

implement gradio app

Browse files
Files changed (1) hide show
  1. app.py +321 -0
app.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gradio as gr
3
+ import torch
4
+ from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
5
+ from threading import Thread
6
+
7
+ findings = "enlarged cardiomediastinum, cardiomegaly, lung opacity, lung lesion, edema, consolidation, pneumonia, atelectasis, pneumothorax, pleural Effusion, pleural other, fracture, support devices"
8
+
9
+ templates = {
10
+ "single-image": (
11
+ "radiology image: <image> Which of the following findings are present in the radiology image? Findings: {findings}",
12
+ "Based on the previous conversation, provide a description of the findings in the radiology image.",
13
+ ),
14
+ "multi-image": (
15
+ "radiology images: {images} Which of the following findings are present in the radiology images? Findings: {findings}",
16
+ "Based on the previous conversation, provide a description of the findings in the radiology images.",
17
+ ),
18
+ "multi-study": (
19
+ "prior radiology images: {prior_images}, prior radiology report: {prior_report} follow-up images: {images}, The radiology studies are given in chronological order. Which of the following findings are present in the current follow-up radiology images? Findings: {findings}",
20
+ "Based on the previous conversation, provide a description of the findings in the current follow-up radiology images.",
21
+ ),
22
+ "visual-grounding": "Provide the bounding box coordinate of the region this phrase describes: {phrase}",
23
+ "easy-language": "Explain the description with easy language.",
24
+ "summarize": "Summarize the description in one concise sentence.",
25
+ "recommend": "What further diagnosis and treatment do you recommend based on the given x-ray?",
26
+ }
27
+
28
+ title_markdown = """
29
+ **Usage Instructions**:
30
+ 1. Add chest x-ray images of a study to the "Study images" section.
31
+ 2. (Optional) Add "Prior study images" and "Prior study report".
32
+ 3. Click the "Medical Report Generation" button.
33
+ 4. You can also have additional conversations. Please refer to the "Examples" for guidance.
34
+
35
+ **Notice**: Enabling "do_sample" in the "Parameters" may introduce some randomness to the output.
36
+ """
37
+
38
+
39
+ def load_model(device, dtype):
40
+ # Load Processor and Model
41
+ processor = AutoProcessor.from_pretrained("Deepnoid/M4CXR", trust_remote_code=True)
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ "Deepnoid/M4CXR",
44
+ trust_remote_code=True,
45
+ torch_dtype=dtype,
46
+ device_map=device,
47
+ )
48
+ return processor, model
49
+
50
+
51
+ def medical_report_generation(history, *args):
52
+ (
53
+ study_images,
54
+ do_sample,
55
+ temperature,
56
+ top_k,
57
+ top_p,
58
+ length_penalty,
59
+ num_beams,
60
+ no_repeat_ngram_size,
61
+ max_new_tokens,
62
+ prior_images,
63
+ prior_report,
64
+ ) = args
65
+ if history:
66
+ raise gr.Error("Please Clear the chat history.")
67
+
68
+ if not study_images:
69
+ raise gr.Error("Please add Study images. (right image box)")
70
+
71
+ images = [i[0] for i in study_images]
72
+
73
+ if prior_images:
74
+ images = [i[0] for i in prior_images] + images
75
+ prior_image_tokens = " ".join("<image>" for _ in prior_images)
76
+ follow_up_image_tokens = " ".join("<image>" for _ in study_images)
77
+ questions = list(templates["multi-study"])
78
+ questions[0] = questions[0].format(
79
+ prior_images=prior_image_tokens,
80
+ prior_report=prior_report,
81
+ images=follow_up_image_tokens,
82
+ findings=findings,
83
+ )
84
+ else:
85
+ if len(images) == 1:
86
+ questions = list(templates["single-image"])
87
+ questions[0] = questions[0].format(findings=findings)
88
+ else:
89
+ image_tokens = " ".join("<image>" for _ in images)
90
+ questions = list(templates["multi-image"])
91
+ questions[0] = questions[0].format(images=image_tokens, findings=findings)
92
+
93
+ generator = predict(
94
+ questions[0],
95
+ history,
96
+ study_images,
97
+ do_sample,
98
+ temperature,
99
+ top_k,
100
+ top_p,
101
+ length_penalty,
102
+ num_beams,
103
+ no_repeat_ngram_size,
104
+ max_new_tokens,
105
+ prior_images,
106
+ prior_report,
107
+ )
108
+ for output in generator:
109
+ response = output
110
+
111
+ history.append([questions[0], response])
112
+ generator = predict(
113
+ questions[1],
114
+ history,
115
+ study_images,
116
+ do_sample,
117
+ temperature,
118
+ top_k,
119
+ top_p,
120
+ length_penalty,
121
+ num_beams,
122
+ no_repeat_ngram_size,
123
+ max_new_tokens,
124
+ prior_images,
125
+ prior_report,
126
+ )
127
+ for output in generator:
128
+ response = output
129
+ history.append([questions[1], response])
130
+
131
+ return history, history
132
+
133
+
134
+ def predict(message, history, *args):
135
+ (
136
+ study_images,
137
+ do_sample,
138
+ temperature,
139
+ top_k,
140
+ top_p,
141
+ length_penalty,
142
+ num_beams,
143
+ no_repeat_ngram_size,
144
+ max_new_tokens,
145
+ prior_images,
146
+ prior_report,
147
+ ) = args
148
+
149
+ # build prompts with chat template
150
+ chats = []
151
+
152
+ for question, answer in history:
153
+ chats.append({"role": "user", "content": question})
154
+ chats.append({"role": "assistant", "content": answer})
155
+
156
+ chats.append({"role": "user", "content": message})
157
+
158
+ prompt = processor.apply_chat_template(chats, tokenize=False)
159
+ prompts = [prompt]
160
+
161
+ if study_images:
162
+ images = [i[0] for i in study_images]
163
+ # add prior images
164
+ if prior_images:
165
+ images = [i[0] for i in prior_images] + images
166
+ else:
167
+ images = None
168
+
169
+ # image, text processing
170
+ inputs = processor(texts=prompts, images=images)
171
+
172
+ # prepare inputs
173
+ inputs = {
174
+ k: v.to(model.dtype) if v.dtype == torch.float else v for k, v in inputs.items()
175
+ }
176
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
177
+
178
+ streamer = TextIteratorStreamer(
179
+ processor.tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
180
+ )
181
+
182
+ generate_kwargs = dict(
183
+ inputs,
184
+ streamer=streamer,
185
+ max_new_tokens=max_new_tokens,
186
+ do_sample=do_sample,
187
+ top_p=top_p,
188
+ top_k=top_k,
189
+ temperature=temperature,
190
+ num_beams=num_beams,
191
+ length_penalty=length_penalty,
192
+ no_repeat_ngram_size=no_repeat_ngram_size,
193
+ )
194
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
195
+ t.start()
196
+
197
+ partial_message = ""
198
+ for new_token in streamer:
199
+ partial_message += new_token
200
+ yield partial_message
201
+
202
+
203
+ def build_demo(model_name: str = "M4CXR"):
204
+ title_model_name = f"""<h1 align="center">{model_name} </h1>"""
205
+
206
+ with gr.Blocks(title=model_name) as demo:
207
+ state = gr.State()
208
+
209
+ gr.Markdown(title_model_name)
210
+ gr.Markdown(title_markdown)
211
+
212
+ with gr.Row():
213
+ with gr.Column(scale=3):
214
+
215
+ mrg = gr.Button(value="Medical Report Generation", variant="primary")
216
+
217
+ with gr.Row(visible=True) as button_row:
218
+ prior_images = gr.Gallery(label="Prior study images", type="pil")
219
+ study_images = gr.Gallery(label="Study images", type="pil")
220
+ prior_report = gr.Textbox(label="Prior study report")
221
+
222
+ with gr.Accordion(
223
+ "Parameters", open=False, visible=True
224
+ ) as generate_config:
225
+ do_sample = gr.Checkbox(
226
+ interactive=True, value=False, label="do_sample"
227
+ )
228
+ # gr.Slider(minimum, maximum, value, step, ...)
229
+ temperature = gr.Slider(
230
+ 0, 1, 1, step=0.1, interactive=True, label="Temperature"
231
+ )
232
+ top_k = gr.Slider(1, 5, 3, step=1, interactive=True, label="Top K")
233
+ top_p = gr.Slider(
234
+ 0, 1, 0.9, step=0.1, interactive=True, label="Top p"
235
+ )
236
+ length_penalty = gr.Slider(
237
+ 1, 5, 1, step=0.1, interactive=True, label="length_penalty"
238
+ )
239
+ num_beams = gr.Slider(
240
+ 1, 5, 1, step=1, interactive=True, label="Beam Size"
241
+ )
242
+ no_repeat_ngram_size = gr.Slider(
243
+ 1, 5, 2, step=1, interactive=True, label="no_repeat_ngram_size"
244
+ )
245
+ max_new_tokens = gr.Slider(
246
+ 0,
247
+ 1024,
248
+ 512,
249
+ step=64,
250
+ interactive=True,
251
+ label="Max New tokens",
252
+ )
253
+
254
+ with gr.Column(scale=6):
255
+
256
+ chat_interface = gr.ChatInterface(
257
+ fn=predict,
258
+ additional_inputs=[
259
+ study_images,
260
+ do_sample,
261
+ temperature,
262
+ top_k,
263
+ top_p,
264
+ length_penalty,
265
+ num_beams,
266
+ no_repeat_ngram_size,
267
+ max_new_tokens,
268
+ prior_images,
269
+ prior_report,
270
+ ],
271
+ examples=[
272
+ [templates["summarize"]],
273
+ [templates["easy-language"]],
274
+ [templates["recommend"]],
275
+ [templates["visual-grounding"]],
276
+ ],
277
+ )
278
+
279
+ # Connect the button to the function
280
+ mrg.click(
281
+ medical_report_generation,
282
+ inputs=[
283
+ chat_interface.chatbot_state,
284
+ study_images,
285
+ do_sample,
286
+ temperature,
287
+ top_k,
288
+ top_p,
289
+ length_penalty,
290
+ num_beams,
291
+ no_repeat_ngram_size,
292
+ max_new_tokens,
293
+ prior_images,
294
+ prior_report,
295
+ ],
296
+ outputs=[
297
+ chat_interface.chatbot,
298
+ chat_interface.chatbot_state,
299
+ ],
300
+ )
301
+
302
+ return demo
303
+
304
+
305
+ if __name__ == "__main__":
306
+ parser = argparse.ArgumentParser()
307
+ parser.add_argument("--host", type=str, default="0.0.0.0")
308
+ parser.add_argument("--debug", action="store_true", help="using debug mode")
309
+ parser.add_argument("--port", type=int)
310
+ parser.add_argument("--share", action="store_true", help="share")
311
+ parser.add_argument("--dtype", type=str, default="torch.bfloat16")
312
+ args = parser.parse_args()
313
+
314
+ device = torch.device("cuda")
315
+ dtype = eval(args.dtype)
316
+ processor, model = load_model(device, dtype)
317
+
318
+ demo = build_demo("M4CXR")
319
+ demo.queue(status_update_rate=10, api_open=False).launch(
320
+ server_name=args.host, debug=args.debug, server_port=args.port, share=args.share
321
+ )