artificialguybr commited on
Commit
46b916c
·
1 Parent(s): c21abe0

Refactor qwen-vl space for dynamic Qwen2.5-VL model selection

Browse files
Files changed (2) hide show
  1. app.py +293 -145
  2. requirements.txt +5 -13
app.py CHANGED
@@ -1,156 +1,304 @@
 
 
 
 
 
 
1
  import gradio as gr
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, TextStreamer
 
3
  import torch
4
  from PIL import Image
5
- import re
6
- import requests
7
- from io import BytesIO
8
- import copy
9
- import secrets
10
- from pathlib import Path
11
-
12
- tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
13
- config = AutoConfig.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True, torch_dtype=torch.float16)
14
- #config.quantization_config["use_exllama"] = True
15
- config.quantization_config["disable_exllama"] = False
16
- config.quantization_config["exllama_config"] = {"version":2}
17
- model = AutoModelForCausalLM.from_config(config, trust_remote_code=True, torch_dtype=torch.float16)
18
-
19
- BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
20
- PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
21
-
22
- def _parse_text(text):
23
- lines = text.split("\n")
24
- lines = [line for line in lines if line != ""]
25
- count = 0
26
- for i, line in enumerate(lines):
27
- if "```" in line:
28
- count += 1
29
- items = line.split("`")
30
- if count % 2 == 1:
31
- lines[i] = f'<pre><code class="language-{items[-1]}">'
32
- else:
33
- lines[i] = f"<br></code></pre>"
34
- else:
35
- if i > 0:
36
- if count % 2 == 1:
37
- line = line.replace("`", r"\`")
38
- line = line.replace("<", "&lt;")
39
- line = line.replace(">", "&gt;")
40
- line = line.replace(" ", "&nbsp;")
41
- line = line.replace("*", "&ast;")
42
- line = line.replace("_", "&lowbar;")
43
- line = line.replace("-", "&#45;")
44
- line = line.replace(".", "&#46;")
45
- line = line.replace("!", "&#33;")
46
- line = line.replace("(", "&#40;")
47
- line = line.replace(")", "&#41;")
48
- line = line.replace("$", "&#36;")
49
- lines[i] = "<br>" + line
50
- text = "".join(lines)
51
- return text
52
-
53
- def predict(_chatbot, task_history):
54
- chat_query = _chatbot[-1][0]
55
- query = task_history[-1][0]
56
- history_cp = copy.deepcopy(task_history)
57
- full_response = ""
58
-
59
- history_filter = []
60
- pic_idx = 1
61
- pre = ""
62
- for i, (q, a) in enumerate(history_cp):
63
- if isinstance(q, (tuple, list)):
64
- q = f'Picture {pic_idx}: <img>{q[0]}</img>'
65
- pre += q + '\n'
66
- pic_idx += 1
67
- else:
68
- pre += q
69
- history_filter.append((pre, a))
70
- pre = ""
71
- history, message = history_filter[:-1], history_filter[-1][0]
72
- response, history = model.chat(tokenizer, message, history=history)
73
- image = tokenizer.draw_bbox_on_latest_picture(response, history)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  if image is not None:
75
- temp_dir = secrets.token_hex(20)
76
- temp_dir = Path("/tmp") / temp_dir
77
- temp_dir.mkdir(exist_ok=True, parents=True)
78
- name = f"tmp{secrets.token_hex(5)}.jpg"
79
- filename = temp_dir / name
80
- image.save(str(filename))
81
- _chatbot[-1] = (_parse_text(chat_query), (str(filename),))
82
- chat_response = response.replace("<ref>", "")
83
- chat_response = chat_response.replace(r"</ref>", "")
84
- chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
85
- if chat_response != "":
86
- _chatbot.append((None, chat_response))
87
- else:
88
- _chatbot[-1] = (_parse_text(chat_query), response)
89
- full_response = _parse_text(response)
90
- task_history[-1] = (query, full_response)
91
- return _chatbot
92
-
93
- def add_text(history, task_history, text):
94
- task_text = text
95
- if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
96
- task_text = text[:-1]
97
- history = history + [(_parse_text(text), None)]
98
- task_history = task_history + [(task_text, None)]
99
- return history, task_history, ""
100
-
101
- def add_file(history, task_history, file):
102
- history = history + [((file.name,), None)]
103
- task_history = task_history + [((file.name,), None)]
104
- return history, task_history
105
-
106
- def reset_user_input():
107
- return gr.update(value="")
108
-
109
- def reset_state(task_history):
110
- task_history.clear()
111
- return []
112
-
113
- def regenerate(_chatbot, task_history):
114
- print("Regenerate clicked")
115
- print("Before:", task_history, _chatbot)
116
- if not task_history:
117
- return _chatbot
118
- item = task_history[-1]
119
- if item[1] is None:
120
- return _chatbot
121
- task_history[-1] = (item[0], None)
122
- chatbot_item = _chatbot.pop(-1)
123
- if chatbot_item[0] is None:
124
- _chatbot[-1] = (_chatbot[-1][0], None)
125
  else:
126
- _chatbot.append((chatbot_item[0], None))
127
- print("After:", task_history, _chatbot)
128
- return predict(_chatbot, task_history)
 
129
 
130
- css = '''
131
- .gradio-container{max-width:800px !important}
132
- '''
 
 
133
 
134
- with gr.Blocks(css=css) as demo:
135
- gr.Markdown("# Qwen-VL-Chat Bot")
136
- gr.Markdown("## Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud **Space by [@Artificialguybr](https://twitter.com/artificialguybr). Test the [QwenLLM-14B](https://huggingface.co/spaces/artificialguybr/qwen-14b-chat-demo) here for free!</center>")
137
- chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520)
138
- query = gr.Textbox(lines=2, label='Input')
139
- task_history = gr.State([])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  with gr.Row():
142
- addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"])
143
- submit_btn = gr.Button("🚀 Submit")
144
- regen_btn = gr.Button("🤔️ Regenerate")
145
- empty_bin = gr.Button("🧹 Clear History")
146
-
147
- gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
148
- submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
149
- predict, [chatbot, task_history], [chatbot], show_progress=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  )
151
- submit_btn.click(reset_user_input, [], [query])
152
- empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
153
- regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
154
- addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
155
 
156
- demo.launch()
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import re
5
+ from typing import Any
6
+
7
  import gradio as gr
8
+ import requests
9
+ import spaces
10
  import torch
11
  from PIL import Image
12
+ from transformers import AutoModelForImageTextToText, AutoProcessor
13
+
14
+
15
+ HF_MODELS_API = "https://huggingface.co/api/models"
16
+ MIN_UPDATED_DATE = "2025-03-01"
17
+ ORG = "Qwen"
18
+ SEARCH_TERM = "Qwen2.5-VL"
19
+
20
+ DEFAULT_MODELS = [
21
+ {
22
+ "id": "Qwen/Qwen2.5-VL-3B-Instruct",
23
+ "updated": "2025-04-06",
24
+ "fit_note": "Best speed/quality for most tasks on 80GB.",
25
+ },
26
+ {
27
+ "id": "Qwen/Qwen2.5-VL-7B-Instruct",
28
+ "updated": "2025-04-06",
29
+ "fit_note": "Higher quality, still comfortable on 80GB.",
30
+ },
31
+ {
32
+ "id": "Qwen/Qwen2.5-VL-32B-Instruct-AWQ",
33
+ "updated": "2025-04-06",
34
+ "fit_note": "Strong quality with 4-bit AWQ quantization.",
35
+ },
36
+ {
37
+ "id": "Qwen/Qwen2.5-VL-72B-Instruct-AWQ",
38
+ "updated": "2025-03-07",
39
+ "fit_note": "Largest option; can fit on 80GB but heavier/less headroom.",
40
+ },
41
+ ]
42
+
43
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
44
+ LOADED_MODEL_ID: str | None = None
45
+ LOADED_MODEL: AutoModelForImageTextToText | None = None
46
+ LOADED_PROCESSOR: AutoProcessor | None = None
47
+
48
+
49
+ def _parse_param_billions(model_id: str) -> int:
50
+ match = re.search(r"-(\d+)B-", model_id)
51
+ if not match:
52
+ return 0
53
+ return int(match.group(1))
54
+
55
+
56
+ def _fits_80gb(model_id: str, tags: list[str]) -> bool:
57
+ params_b = _parse_param_billions(model_id)
58
+ lower_id = model_id.lower()
59
+ lower_tags = " ".join(str(tag).lower() for tag in tags)
60
+
61
+ if params_b == 0:
62
+ return False
63
+ if params_b <= 32:
64
+ return True
65
+ if params_b <= 72 and ("awq" in lower_id or "awq" in lower_tags):
66
+ return True
67
+ return False
68
+
69
+
70
+ def _fetch_model_catalog() -> list[dict[str, str]]:
71
+ params = {
72
+ "author": ORG,
73
+ "search": SEARCH_TERM,
74
+ "full": "true",
75
+ "limit": 200,
76
+ }
77
+ response = requests.get(HF_MODELS_API, params=params, timeout=60)
78
+ response.raise_for_status()
79
+ models = response.json()
80
+
81
+ selected: list[dict[str, str]] = []
82
+ for item in models:
83
+ model_id = item.get("id", "")
84
+ pipeline = item.get("pipeline_tag")
85
+ updated = (item.get("lastModified") or "")[:10]
86
+ tags = item.get("tags") or []
87
+
88
+ if pipeline != "image-text-to-text":
89
+ continue
90
+ if not model_id.startswith("Qwen/Qwen2.5-VL"):
91
+ continue
92
+ if not updated or updated < MIN_UPDATED_DATE:
93
+ continue
94
+ if not _fits_80gb(model_id, tags):
95
+ continue
96
+ if "gguf" in model_id.lower():
97
+ continue
98
+
99
+ selected.append(
100
+ {
101
+ "id": model_id,
102
+ "updated": updated,
103
+ "fit_note": "Auto-selected by VRAM fit heuristic for 80GB.",
104
+ }
105
+ )
106
+
107
+ selected.sort(key=lambda x: (_parse_param_billions(x["id"]), x["id"]))
108
+ return selected
109
+
110
+
111
+ def get_model_catalog() -> list[dict[str, str]]:
112
+ try:
113
+ models = _fetch_model_catalog()
114
+ if models:
115
+ return models
116
+ except Exception:
117
+ pass
118
+ return DEFAULT_MODELS
119
+
120
+
121
+ MODEL_CATALOG = get_model_catalog()
122
+ MODEL_LABELS = {
123
+ item["id"]: f"{item['id']} | updated {item['updated']}"
124
+ for item in MODEL_CATALOG
125
+ }
126
+
127
+
128
+ def _dtype_for_model(model_id: str) -> torch.dtype:
129
+ if DEVICE != "cuda":
130
+ return torch.float32
131
+ if "awq" in model_id.lower():
132
+ return torch.float16
133
+ return torch.bfloat16
134
+
135
+
136
+ def unload_current_model() -> None:
137
+ global LOADED_MODEL, LOADED_PROCESSOR, LOADED_MODEL_ID
138
+ LOADED_MODEL = None
139
+ LOADED_PROCESSOR = None
140
+ LOADED_MODEL_ID = None
141
+ gc.collect()
142
+ if torch.cuda.is_available():
143
+ torch.cuda.empty_cache()
144
+
145
+
146
+ def _first_model_device(model: AutoModelForImageTextToText) -> torch.device:
147
+ try:
148
+ return next(model.parameters()).device
149
+ except StopIteration:
150
+ return torch.device("cuda:0" if DEVICE == "cuda" else "cpu")
151
+
152
+
153
+ def load_model(model_id: str) -> tuple[AutoModelForImageTextToText, AutoProcessor]:
154
+ global LOADED_MODEL, LOADED_PROCESSOR, LOADED_MODEL_ID
155
+ if LOADED_MODEL is not None and LOADED_PROCESSOR is not None and LOADED_MODEL_ID == model_id:
156
+ return LOADED_MODEL, LOADED_PROCESSOR
157
+
158
+ unload_current_model()
159
+ dtype = _dtype_for_model(model_id)
160
+
161
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
162
+ model = AutoModelForImageTextToText.from_pretrained(
163
+ model_id,
164
+ trust_remote_code=True,
165
+ torch_dtype=dtype,
166
+ device_map="auto" if DEVICE == "cuda" else None,
167
+ )
168
+ model.eval()
169
+
170
+ LOADED_MODEL_ID = model_id
171
+ LOADED_MODEL = model
172
+ LOADED_PROCESSOR = processor
173
+ return model, processor
174
+
175
+
176
+ def format_model_status(model_id: str) -> str:
177
+ entry = next((item for item in MODEL_CATALOG if item["id"] == model_id), None)
178
+ if entry is None:
179
+ return f"**Model:** `{model_id}`"
180
+ return (
181
+ f"**Model:** `{entry['id']}`\n"
182
+ f"- Updated: **{entry['updated']}**\n"
183
+ f"- 80GB fit note: {entry['fit_note']}"
184
+ )
185
+
186
+
187
+ def _build_messages(prompt: str, image: Image.Image | None) -> list[dict[str, Any]]:
188
+ content: list[dict[str, Any]] = []
189
  if image is not None:
190
+ content.append({"type": "image", "image": image})
191
+ content.append({"type": "text", "text": prompt})
192
+ return [{"role": "user", "content": content}]
193
+
194
+
195
+ @spaces.GPU(duration=120)
196
+ def run_vl(
197
+ model_id: str,
198
+ image: Image.Image | None,
199
+ prompt: str,
200
+ max_new_tokens: int,
201
+ temperature: float,
202
+ top_p: float,
203
+ ) -> tuple[str, str]:
204
+ if not prompt or not prompt.strip():
205
+ raise gr.Error("Prompt is required.")
206
+ if image is None:
207
+ raise gr.Error("Upload an image first.")
208
+
209
+ model, processor = load_model(model_id)
210
+ messages = _build_messages(prompt.strip(), image)
211
+ text = processor.apply_chat_template(
212
+ messages, tokenize=False, add_generation_prompt=True
213
+ )
214
+ inputs = processor(text=[text], images=[image], return_tensors="pt")
215
+
216
+ model_device = _first_model_device(model)
217
+ inputs = {k: (v.to(model_device) if torch.is_tensor(v) else v) for k, v in inputs.items()}
218
+
219
+ generate_kwargs: dict[str, Any] = {
220
+ "max_new_tokens": int(max_new_tokens),
221
+ "top_p": float(top_p),
222
+ }
223
+ if temperature > 0:
224
+ generate_kwargs["do_sample"] = True
225
+ generate_kwargs["temperature"] = float(temperature)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  else:
227
+ generate_kwargs["do_sample"] = False
228
+
229
+ with torch.inference_mode():
230
+ output_ids = model.generate(**inputs, **generate_kwargs)
231
 
232
+ prompt_len = inputs["input_ids"].shape[1]
233
+ completion_ids = output_ids[:, prompt_len:]
234
+ answer = processor.batch_decode(
235
+ completion_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
236
+ )[0].strip()
237
 
238
+ return answer, format_model_status(model_id)
239
+
240
+
241
+ def on_model_change(model_id: str) -> str:
242
+ return format_model_status(model_id)
243
+
244
+
245
+ default_model_id = MODEL_CATALOG[0]["id"]
246
+
247
+ with gr.Blocks() as demo:
248
+ gr.Markdown("# Qwen2.5-VL Multi-Model Playground")
249
+ gr.Markdown(
250
+ "Selecione o modelo VL da Qwen, envie uma imagem e faça perguntas/extração. "
251
+ f"Critério aplicado: modelos Qwen2.5-VL com update em/apos {MIN_UPDATED_DATE} e que cabem em 80GB."
252
+ )
253
+
254
+ with gr.Row():
255
+ model_id = gr.Dropdown(
256
+ label="Model",
257
+ choices=[(MODEL_LABELS[item["id"]], item["id"]) for item in MODEL_CATALOG],
258
+ value=default_model_id,
259
+ )
260
+ model_status = gr.Markdown(value=format_model_status(default_model_id))
261
 
262
  with gr.Row():
263
+ image_input = gr.Image(type="pil", label="Image")
264
+ answer_output = gr.Textbox(lines=16, label="Answer")
265
+
266
+ prompt = gr.Textbox(
267
+ lines=3,
268
+ label="Prompt",
269
+ placeholder="Describe this image in detail. / Extract all text. / What's happening?",
270
+ )
271
+
272
+ with gr.Accordion("Advanced generation settings", open=False):
273
+ with gr.Row():
274
+ max_new_tokens = gr.Slider(
275
+ label="Max new tokens", minimum=32, maximum=2048, value=512, step=32
276
+ )
277
+ temperature = gr.Slider(
278
+ label="Temperature", minimum=0.0, maximum=1.5, value=0.2, step=0.05
279
+ )
280
+ top_p = gr.Slider(label="Top-p", minimum=0.1, maximum=1.0, value=0.9, step=0.05)
281
+
282
+ with gr.Row():
283
+ run_btn = gr.Button("Run", variant="primary")
284
+ clear_btn = gr.Button("Clear")
285
+ unload_btn = gr.Button("Unload current model")
286
+
287
+ run_btn.click(
288
+ fn=run_vl,
289
+ inputs=[model_id, image_input, prompt, max_new_tokens, temperature, top_p],
290
+ outputs=[answer_output, model_status],
291
+ )
292
+ model_id.change(fn=on_model_change, inputs=[model_id], outputs=[model_status])
293
+ clear_btn.click(
294
+ fn=lambda selected_model: (None, "", "", format_model_status(selected_model)),
295
+ inputs=[model_id],
296
+ outputs=[image_input, prompt, answer_output, model_status],
297
+ )
298
+ unload_btn.click(
299
+ fn=lambda: (unload_current_model() or "Model unloaded from memory."),
300
+ outputs=[answer_output],
301
  )
 
 
 
 
302
 
303
+ demo.queue(max_size=10)
304
+ demo.launch()
requirements.txt CHANGED
@@ -1,16 +1,8 @@
1
- transformers
 
2
  Pillow
3
  requests
4
- accelerate
5
- tiktoken
6
- einops
7
- transformers_stream_generator==0.0.4
8
- scipy
9
  torchvision
10
- pillow
11
- tensorboard
12
- matplotlib
13
- bitsandbytes
14
- optimum
15
- auto-gptq
16
- torch
 
1
+ accelerate
2
+ gradio
3
  Pillow
4
  requests
5
+ spaces
6
+ torch
 
 
 
7
  torchvision
8
+ transformers>=4.52.0