farrell236 commited on
Commit
5420fa9
·
verified ·
1 Parent(s): 9210fed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -201
app.py CHANGED
@@ -1,209 +1,319 @@
1
- import os
2
- import time
3
- import torch
4
- import requests
5
 
6
- from PIL import Image
7
- from collections.abc import Iterator
 
 
8
  from threading import Thread
9
 
10
  import gradio as gr
11
- from gradio import FileData
12
-
13
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
14
  from qwen_vl_utils import process_vision_info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- DESCRIPTION = """\
17
- # Qwen2.5-VL-32B-Instruct
18
- """
19
-
20
- MAX_MAX_NEW_TOKENS = 2048
21
- DEFAULT_MAX_NEW_TOKENS = 1024
22
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
23
-
24
- auth_token = os.environ.get("HF_spaces")
25
-
26
- model_id = 'Qwen/Qwen2.5-VL-3B-Instruct'
27
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
28
- 'farrell236/test_model',
29
- use_auth_token=auth_token,
30
- # torch_dtype=torch.bfloat16,
31
- # attn_implementation="flash_attention_2",
32
- device_map="auto"
33
- )
34
- processor = AutoProcessor.from_pretrained(model_id)
35
-
36
-
37
- import base64
38
- from PIL import Image
39
- import io
40
-
41
- # Function to encode the image (scaled down by half)
42
- def encode_image(image_path, scale=0.25):
43
- with Image.open(image_path) as img:
44
- # Resize image to half its size
45
- new_size = (int(img.width * scale), int(img.height * scale))
46
- img = img.resize(new_size)
47
-
48
- # Save the resized image to a bytes buffer
49
- buffer = io.BytesIO()
50
- img.save(buffer, format="JPEG") # Change format if needed (e.g., JPEG)
51
- buffer.seek(0)
52
-
53
- # Encode to base64
54
- return base64.b64encode(buffer.read()).decode('utf-8')
55
-
56
-
57
- def generate(
58
- message: str,
59
- history: list[dict],
60
- max_new_tokens: int = 1024,
61
- temperature: float = 0.6,
62
- top_p: float = 0.9,
63
- top_k: int = 50,
64
- num_beams: int = 1,
65
- repetition_penalty: float = 1.2,
66
- ) -> Iterator[str]:
67
-
68
- txt = message["text"]
69
- ext_buffer = f"{txt}"
70
-
71
- messages= []
72
- images = []
73
-
74
-
75
- for i, msg in enumerate(history):
76
- if isinstance(msg[0], tuple):
77
- print('HIT2', msg[0])
78
- messages.append({"role": "user", "content": [
79
- {"type": "text", "text": history[i+1][0]},
80
- {"type": "image", "image": f"data:image/jpeg;base64,{encode_image(msg[0][0])}"}
81
- ]})
82
- messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
83
- elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
84
- # messages are already handled
85
- pass
86
- elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
87
- messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
88
- messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
89
-
90
- # add current message
91
- if len(message["files"]) == 1:
92
-
93
- if isinstance(message["files"][0], str): # examples
94
- base64_image = encode_image(message["files"][0])
95
- else: # regular input
96
- base64_image = encode_image(message["files"][0]["path"])
97
- messages.append({"role": "user", "content": [
98
- {"type": "text", "text": txt},
99
- {"type": "image", "image": f"data:image/jpeg;base64,{base64_image}"}]})
100
  else:
101
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
102
-
103
-
104
- texts = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
105
- image_inputs, video_inputs = process_vision_info(messages)
106
- inputs = processor(
107
- text=[texts],
108
- images=image_inputs,
109
- videos=video_inputs,
110
- padding=True,
111
- return_tensors="pt",
112
- ) # .to("cuda")
113
- streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
114
-
115
- generation_kwargs = dict(
116
- inputs,
117
- streamer=streamer,
118
- max_new_tokens=max_new_tokens,
119
- do_sample=True,
120
- top_p=top_p,
121
- top_k=top_k,
122
- temperature=temperature,
123
- num_beams=num_beams,
124
- # repetition_penalty=repetition_penalty,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  )
126
- generated_text = ""
127
-
128
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
129
- thread.start()
130
- buffer = ""
131
-
132
- for new_text in streamer:
133
- buffer += new_text
134
- generated_text_without_prompt = buffer
135
- time.sleep(0.01)
136
- yield buffer
137
-
138
- demo = gr.ChatInterface(fn=generate, title="Multimodal Qwen", examples=[
139
- [{"text": """\
140
- You are a highly experienced ophthalmologist specializing in retinal diseases.
141
- You will be shown a color fundus photograph of a patient's eye.
142
- Your task is to identify key retinal features and return a structured response.
143
- You must only respond in JSON format using the following fields:
144
- - ADVAMD: 1 if advanced age-related macular degeneration is present, otherwise 0
145
- - PIG: 1 if abnormal pigmentary is present, otherwise 0
146
- - DRUS: 0 if no drusen or small drusen, 1 if intermediate or medium drusen, 2 if large drusen
147
- - RPD: 1 if reticular pseudodrusen are present, otherwise 0
148
- - NVAMD: 1 if neovascular AMD is present, otherwise 0
149
- - GA: 1 if geographic atrophy is present, otherwise 0
150
-
151
- Do not include any explanation, just return the JSON object.
152
-
153
- Please assess this fundus image and return your findings in the specified JSON format.""",
154
- "files":["./examples/ret-hem250-304.jpg"]},
155
- 1024],
156
- ],
157
- textbox=gr.MultimodalTextbox(),
158
- additional_inputs = [
159
- gr.Slider(
160
- label="Max new tokens",
161
- minimum=1,
162
- maximum=MAX_MAX_NEW_TOKENS,
163
- step=1,
164
- value=DEFAULT_MAX_NEW_TOKENS,
165
- ),
166
- gr.Slider(
167
- label="Temperature",
168
- minimum=0.1,
169
- maximum=4.0,
170
- step=0.1,
171
- value=0.6,
172
- ),
173
- gr.Slider(
174
- label="Top-p (nucleus sampling)",
175
- minimum=0.05,
176
- maximum=1.0,
177
- step=0.05,
178
- value=0.9,
179
- ),
180
- gr.Slider(
181
- label="Top-k",
182
- minimum=1,
183
- maximum=1000,
184
- step=1,
185
- value=50,
186
- ),
187
- gr.Slider(
188
- label="Beam Search",
189
- minimum=1,
190
- maximum=1,
191
- step=1,
192
- value=1,
193
- ),
194
- gr.Slider(
195
- label="Repetition penalty",
196
- minimum=1.0,
197
- maximum=2.0,
198
- step=0.05,
199
- value=1.2,
200
- ),
201
- ],
202
- cache_examples=False,
203
- description=DESCRIPTION,
204
- stop_btn="Stop Generation",
205
- fill_height=True,
206
- multimodal=True)
207
-
208
- if __name__ == "__main__":
209
- demo.launch()
 
1
+ # Copyright (c) 2025 Team OpthChat.
2
+ #
3
+ # This source code is based on by web_demo_mm.py, by Alibaba Cloud.
4
+ # Licensed under Apache License 2.0
5
 
6
+ import os
7
+ import copy
8
+ import re
9
+ from argparse import ArgumentParser
10
  from threading import Thread
11
 
12
  import gradio as gr
13
+ import torch
 
 
14
  from qwen_vl_utils import process_vision_info
15
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, TextIteratorStreamer
16
+
17
+ DEFAULT_CKPT_PATH = 'farrell236/test_model'
18
+ AUTH_TOKEN = os.environ.get("HF_spaces")
19
+
20
+ def _get_args():
21
+ parser = ArgumentParser()
22
+ parser.add_argument('-c',
23
+ '--checkpoint-path',
24
+ type=str,
25
+ default=DEFAULT_CKPT_PATH,
26
+ help='Checkpoint name or path, default to %(default)r')
27
+ parser.add_argument('--cpu-only', action='store_true', help='Run demo with CPU only')
28
+ parser.add_argument('--flash-attn2',
29
+ action='store_true',
30
+ default=False,
31
+ help='Enable flash_attention_2 when loading the model.')
32
+ parser.add_argument('--share',
33
+ action='store_true',
34
+ default=False,
35
+ help='Create a publicly shareable link for the interface.')
36
+ parser.add_argument('--inbrowser',
37
+ action='store_true',
38
+ default=False,
39
+ help='Automatically launch the interface in a new tab on the default browser.')
40
+ parser.add_argument('--server-port', type=int, default=7860, help='Demo server port.')
41
+ parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Demo server name.')
42
+
43
+ args = parser.parse_args()
44
+ return args
45
 
46
+
47
+ def _load_model_processor(args):
48
+ if args.cpu_only:
49
+ device_map = 'cpu'
50
+ else:
51
+ device_map = 'auto'
52
+
53
+ # Check if flash-attn2 flag is enabled and load model accordingly
54
+ if args.flash_attn2:
55
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
56
+ args.checkpoint_path,
57
+ use_auth_token=AUTH_TOKEN,
58
+ torch_dtype='auto',
59
+ attn_implementation='flash_attention_2',
60
+ device_map=device_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  else:
62
+ model = Qwen2_5_VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)
63
+
64
+ processor = AutoProcessor.from_pretrained(args.checkpoint_path)
65
+ return model, processor
66
+
67
+
68
+ def _parse_text(text):
69
+ lines = text.split('\n')
70
+ lines = [line for line in lines if line != '']
71
+ count = 0
72
+ for i, line in enumerate(lines):
73
+ if '```' in line:
74
+ count += 1
75
+ items = line.split('`')
76
+ if count % 2 == 1:
77
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
78
+ else:
79
+ lines[i] = '<br></code></pre>'
80
+ else:
81
+ if i > 0:
82
+ if count % 2 == 1:
83
+ line = line.replace('`', r'\`')
84
+ line = line.replace('<', '&lt;')
85
+ line = line.replace('>', '&gt;')
86
+ line = line.replace(' ', '&nbsp;')
87
+ line = line.replace('*', '&ast;')
88
+ line = line.replace('_', '&lowbar;')
89
+ line = line.replace('-', '&#45;')
90
+ line = line.replace('.', '&#46;')
91
+ line = line.replace('!', '&#33;')
92
+ line = line.replace('(', '&#40;')
93
+ line = line.replace(')', '&#41;')
94
+ line = line.replace('$', '&#36;')
95
+ lines[i] = '<br>' + line
96
+ text = ''.join(lines)
97
+ return text
98
+
99
+
100
+ def _remove_image_special(text):
101
+ text = text.replace('<ref>', '').replace('</ref>', '')
102
+ return re.sub(r'<box>.*?(</box>|$)', '', text)
103
+
104
+
105
+ def _is_video_file(filename):
106
+ video_extensions = ['.mp4', '.avi', '.mkv', '.mov', '.wmv', '.flv', '.webm', '.mpeg']
107
+ return any(filename.lower().endswith(ext) for ext in video_extensions)
108
+
109
+
110
+ def _gc():
111
+ import gc
112
+ gc.collect()
113
+ if torch.cuda.is_available():
114
+ torch.cuda.empty_cache()
115
+
116
+
117
+ def _transform_messages(original_messages):
118
+ transformed_messages = []
119
+ for message in original_messages:
120
+ new_content = []
121
+ for item in message['content']:
122
+ if 'image' in item:
123
+ new_item = {'type': 'image', 'image': item['image']}
124
+ elif 'text' in item:
125
+ new_item = {'type': 'text', 'text': item['text']}
126
+ elif 'video' in item:
127
+ new_item = {'type': 'video', 'video': item['video']}
128
+ else:
129
+ continue
130
+ new_content.append(new_item)
131
+
132
+ new_message = {'role': message['role'], 'content': new_content}
133
+ transformed_messages.append(new_message)
134
+
135
+ return transformed_messages
136
+
137
+
138
+ def _launch_demo(args, model, processor):
139
+
140
+ def call_local_model(model, processor, messages,
141
+ max_tokens=1024, temperature=0.6,
142
+ top_p=0.9, top_k=50,
143
+ repetition_penalty=1.2):
144
+
145
+ messages = _transform_messages(messages)
146
+
147
+ text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
148
+ image_inputs, video_inputs = process_vision_info(messages)
149
+ inputs = processor(text=[text], images=image_inputs, videos=video_inputs, padding=True, return_tensors='pt')
150
+ inputs = inputs.to(model.device)
151
+
152
+ tokenizer = processor.tokenizer
153
+ streamer = TextIteratorStreamer(tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
154
+
155
+ gen_kwargs = {'max_new_tokens': max_tokens,
156
+ 'streamer': streamer,
157
+ 'temperature': temperature,
158
+ 'top_p': top_p,
159
+ 'top_k': top_k,
160
+ 'repetition_penalty': repetition_penalty,
161
+ **inputs}
162
+
163
+ thread = Thread(target=model.generate, kwargs=gen_kwargs)
164
+ thread.start()
165
+
166
+ generated_text = ''
167
+ for new_text in streamer:
168
+ generated_text += new_text
169
+ yield generated_text
170
+
171
+ def create_predict_fn():
172
+
173
+ def predict(_chatbot, task_history,
174
+ max_tokens, temperature, top_p, top_k, repetition_penalty):
175
+ nonlocal model, processor
176
+ chat_query = _chatbot[-1][0]
177
+ query = task_history[-1][0]
178
+ if len(chat_query) == 0:
179
+ _chatbot.pop()
180
+ task_history.pop()
181
+ return _chatbot
182
+ print('User: ' + _parse_text(query))
183
+ history_cp = copy.deepcopy(task_history)
184
+ full_response = ''
185
+ messages = []
186
+ content = []
187
+ for q, a in history_cp:
188
+ if isinstance(q, (tuple, list)):
189
+ if _is_video_file(q[0]):
190
+ content.append({'video': f'file://{q[0]}'})
191
+ else:
192
+ content.append({'image': f'file://{q[0]}'})
193
+ else:
194
+ content.append({'text': q})
195
+ messages.append({'role': 'user', 'content': content})
196
+ messages.append({'role': 'assistant', 'content': [{'text': a}]})
197
+ content = []
198
+ messages.pop()
199
+
200
+ for response in call_local_model(model, processor, messages):
201
+ _chatbot[-1] = (_parse_text(chat_query), _remove_image_special(_parse_text(response)))
202
+
203
+ yield _chatbot
204
+ full_response = _parse_text(response)
205
+
206
+ task_history[-1] = (query, full_response)
207
+ print('Qwen-VL-Chat: ' + _parse_text(full_response))
208
+ yield _chatbot
209
+
210
+ return predict
211
+
212
+ def create_regenerate_fn():
213
+
214
+ def regenerate(_chatbot, task_history):
215
+ nonlocal model, processor
216
+ if not task_history:
217
+ return _chatbot
218
+ item = task_history[-1]
219
+ if item[1] is None:
220
+ return _chatbot
221
+ task_history[-1] = (item[0], None)
222
+ chatbot_item = _chatbot.pop(-1)
223
+ if chatbot_item[0] is None:
224
+ _chatbot[-1] = (_chatbot[-1][0], None)
225
+ else:
226
+ _chatbot.append((chatbot_item[0], None))
227
+ _chatbot_gen = predict(_chatbot, task_history)
228
+ for _chatbot in _chatbot_gen:
229
+ yield _chatbot
230
+
231
+ return regenerate
232
+
233
+ predict = create_predict_fn()
234
+ regenerate = create_regenerate_fn()
235
+
236
+ def add_text(history, task_history, text):
237
+ task_text = text
238
+ history = history if history is not None else []
239
+ task_history = task_history if task_history is not None else []
240
+ history = history + [(_parse_text(text), None)]
241
+ task_history = task_history + [(task_text, None)]
242
+ return history, task_history, ''
243
+
244
+ def add_file(history, task_history, file):
245
+ history = history if history is not None else []
246
+ task_history = task_history if task_history is not None else []
247
+ history = history + [((file.name,), None)]
248
+ task_history = task_history + [((file.name,), None)]
249
+ return history, task_history
250
+
251
+ def reset_user_input():
252
+ return gr.update(value='')
253
+
254
+ def reset_state(_chatbot, task_history):
255
+ task_history.clear()
256
+ _chatbot.clear()
257
+ _gc()
258
+ return []
259
+
260
+ with gr.Blocks() as demo:
261
+ gr.Markdown("""\
262
+ <p align="center"><img src="https://home.mmc.edu/wp-content/uploads/2017/10/nih-logo-color.png" style="height: 80px"/><p>
263
+ <center><font size=6>Qwen2.5-VL (model_a) for OpthChat</center>
264
+ <center><font size=4></center>
265
+ <center><font size=4></center>
266
+ <center><font size=4></center>
267
+ """)
268
+
269
+ chatbot = gr.Chatbot(label='Qwen2.5-VL', elem_classes='control-height', height=500)
270
+
271
+ with gr.Accordion("Generation Parameters", open=False):
272
+ max_tokens = gr.Slider(64, 4096, value=512, step=64, label="Max Tokens")
273
+ temperature = gr.Slider(0.0, 2.0, value=0.6, step=0.1, label="Temperature")
274
+ top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
275
+ top_k = gr.Slider(0, 100, value=50, step=1, label="Top-k")
276
+ repetition_penalty = gr.Slider(0.5, 2.0, value=1.2, step=0.1, label="Repetition Penalty")
277
+
278
+ query = gr.Textbox(lines=2, label='Input')
279
+ task_history = gr.State([])
280
+
281
+ with gr.Row():
282
+ addfile_btn = gr.UploadButton('📁 Upload', file_types=['image', 'video'])
283
+ submit_btn = gr.Button('🚀 Submit')
284
+ regen_btn = gr.Button('♻️️ Regenerate')
285
+ empty_bin = gr.Button('🧹 Clear History')
286
+
287
+ submit_btn.click(add_text,
288
+ [chatbot, task_history, query],
289
+ [chatbot, task_history]).then(predict,
290
+ [chatbot, task_history, max_tokens,
291
+ temperature, top_p, top_k, repetition_penalty],
292
+ [chatbot], show_progress=True)
293
+ submit_btn.click(reset_user_input, [], [query])
294
+ empty_bin.click(reset_state, [chatbot, task_history], [chatbot], show_progress=True)
295
+ regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
296
+ addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
297
+
298
+ gr.Markdown("""\
299
+ <font size=2>Note: This demo is governed by the original license of Qwen2.5-VL,
300
+ WebUI based on [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL/blob/main/web_demo_mm.py).
301
+ Developed by Alibaba Cloud, modified by Team OpthChat
302
+ """)
303
+
304
+ demo.queue().launch(
305
+ share=args.share,
306
+ inbrowser=args.inbrowser,
307
+ server_port=args.server_port,
308
+ server_name=args.server_name,
309
  )
310
+
311
+
312
+ def main():
313
+ args = _get_args()
314
+ model, processor = _load_model_processor(args)
315
+ _launch_demo(args, model, processor)
316
+
317
+
318
+ if __name__ == '__main__':
319
+ main()