merve HF Staff commited on
Commit
f793138
·
verified ·
1 Parent(s): b83d741

Add Qwen3VL & moondream3

Browse files
Files changed (1) hide show
  1. app.py +105 -50
app.py CHANGED
@@ -9,22 +9,20 @@ from qwen_vl_utils import process_vision_info
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoProcessor,
12
- Qwen2_5_VLForConditionalGeneration,
13
  )
14
 
15
  from spaces import GPU
16
  import supervision as sv
17
 
18
- # --- Config ---
19
- model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
20
- model_moondream_id = "vikhyatk/moondream2"
21
 
22
- model_qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained(
23
- model_qwen_id, torch_dtype="auto", device_map="auto"
24
  )
25
  model_moondream = AutoModelForCausalLM.from_pretrained(
26
  model_moondream_id,
27
- revision="2025-06-21",
28
  trust_remote_code=True,
29
  device_map={"": "cuda"},
30
  )
@@ -34,15 +32,11 @@ def extract_model_short_name(model_id):
34
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
35
 
36
 
37
- model_qwen_name = extract_model_short_name(model_qwen_id) # → "Qwen2.5 VL 3B Instruct"
38
- model_moondream_name = extract_model_short_name(model_moondream_id) # → "moondream2"
39
 
40
 
41
- min_pixels = 224 * 224
42
- max_pixels = 1024 * 1024
43
- processor_qwen = AutoProcessor.from_pretrained(
44
- "Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels
45
- )
46
 
47
 
48
  def create_annotated_image(image, json_data, height, width):
@@ -112,7 +106,6 @@ def create_annotated_image_normalized(image, json_data, label="object"):
112
  original_width, original_height = image.size
113
  annotated_image = np.array(image.convert("RGB"))
114
 
115
- # Handle points for keypoint detection
116
  points = []
117
  if "points" in json_data:
118
  for point in json_data.get("points", []):
@@ -154,6 +147,79 @@ def create_annotated_image_normalized(image, json_data, label="object"):
154
 
155
  return Image.fromarray(annotated_image)
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  @GPU
158
  def detect_qwen(image, prompt):
159
  messages = [
@@ -167,16 +233,12 @@ def detect_qwen(image, prompt):
167
  ]
168
 
169
  t0 = time.perf_counter()
170
- text = processor_qwen.apply_chat_template(
171
- messages, tokenize=False, add_generation_prompt=True
172
- )
173
- image_inputs, video_inputs = process_vision_info(messages)
174
- inputs = processor_qwen(
175
- text=[text],
176
- images=image_inputs,
177
- videos=video_inputs,
178
- padding=True,
179
- return_tensors="pt",
180
  ).to(model_qwen.device)
181
 
182
  generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
@@ -186,20 +248,15 @@ def detect_qwen(image, prompt):
186
  ]
187
  output_text = processor_qwen.batch_decode(
188
  generated_ids_trimmed,
189
- do_sample=True,
190
  skip_special_tokens=True,
191
  clean_up_tokenization_spaces=False,
192
  )[0]
193
  elapsed_ms = (time.perf_counter() - t0) * 1_000
194
 
195
- input_height = inputs["image_grid_thw"][0][1] * 14
196
- input_width = inputs["image_grid_thw"][0][2] * 14
197
-
198
- annotated_image = create_annotated_image(
199
- image, output_text, input_height, input_width
200
- )
201
 
202
  time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
 
203
  return annotated_image, output_text, time_taken
204
 
205
 
@@ -251,15 +308,14 @@ button#gradio-share-link-button-0 {
251
  }
252
  """
253
 
254
- # --- Gradio Interface ---
255
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
256
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
257
  gr.Markdown(
258
  "### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts."
259
  )
260
  gr.Markdown("""
261
- *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Moondream 2B (revision="2025-06-21")](https://huggingface.co/vikhyatk/moondream2). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
262
- *Moondream 2B uses the [moondream.py API](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
263
  """)
264
 
265
  with gr.Row():
@@ -312,13 +368,13 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
312
  example_prompts = [
313
  [
314
  "examples/example_1.jpg",
315
- "Detect all objects in the image and return their locations and labels.",
316
  "objects",
317
  "Object Detection",
318
  ],
319
  [
320
  "examples/example_2.JPG",
321
- "Detect all the individual candies in the image and return their locations and labels.",
322
  "candies",
323
  "Object Detection",
324
  ],
@@ -336,7 +392,7 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
336
  ],
337
  [
338
  "examples/example_1.jpg",
339
- "Identify the red cars in this image, detect their key points and return their positions in the form of points.",
340
  "red cars",
341
  "Visual Grounding + Keypoint Detection",
342
  ],
@@ -348,27 +404,26 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
348
  ],
349
  [
350
  "examples/example_1.jpg",
351
- "Detect the red car that is leading in this image and return its location and label.",
352
  "leading red car",
353
  "Visual Grounding + Object Detection",
354
  ],
355
  [
356
  "examples/example_2.JPG",
357
- "Detect the blue candy located at the top of the group in this image and return its location and label.",
358
  "blue candy located at the top of the group",
359
  "Visual Grounding + Object Detection",
360
  ],
361
  ]
362
-
363
  gr.Examples(
364
- examples=example_prompts,
365
- inputs=[
366
- image_input,
367
- prompt_input_model_1,
368
- prompt_input_model_2,
369
- category_input,
370
- ],
371
- label="Click an example to populate the input",
372
  )
373
 
374
  generate_btn.click(
@@ -390,4 +445,4 @@ with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
390
  )
391
 
392
  if __name__ == "__main__":
393
- demo.launch()
 
9
  from transformers import (
10
  AutoModelForCausalLM,
11
  AutoProcessor,
12
+ Qwen3VLForConditionalGeneration,
13
  )
14
 
15
  from spaces import GPU
16
  import supervision as sv
17
 
18
+ model_qwen_id = "Qwen/Qwen3-VL-4B-Instruct"
19
+ model_moondream_id = "moondream/moondream3-preview"
 
20
 
21
+ model_qwen = Qwen3VLForConditionalGeneration.from_pretrained(
22
+ model_qwen_id, torch_dtype="auto", device_map="auto",
23
  )
24
  model_moondream = AutoModelForCausalLM.from_pretrained(
25
  model_moondream_id,
 
26
  trust_remote_code=True,
27
  device_map={"": "cuda"},
28
  )
 
32
  return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
33
 
34
 
35
+ model_qwen_name = extract_model_short_name(model_qwen_id)
36
+ model_moondream_name = extract_model_short_name(model_moondream_id)
37
 
38
 
39
+ processor_qwen = AutoProcessor.from_pretrained(model_qwen_id)
 
 
 
 
40
 
41
 
42
  def create_annotated_image(image, json_data, height, width):
 
106
  original_width, original_height = image.size
107
  annotated_image = np.array(image.convert("RGB"))
108
 
 
109
  points = []
110
  if "points" in json_data:
111
  for point in json_data.get("points", []):
 
147
 
148
  return Image.fromarray(annotated_image)
149
 
150
+
151
+ def parse_qwen3_json(json_output):
152
+ lines = json_output.splitlines()
153
+ for i, line in enumerate(lines):
154
+ if line == "```json":
155
+ json_output = "\n".join(lines[i+1:])
156
+ json_output = json_output.split("```")[0]
157
+ break
158
+
159
+ try:
160
+ boxes = json.loads(json_output)
161
+ except json.JSONDecodeError:
162
+ end_idx = json_output.rfind('"}') + len('"}')
163
+ truncated_text = json_output[:end_idx] + "]"
164
+ boxes = json.loads(truncated_text)
165
+
166
+ if not isinstance(boxes, list):
167
+ boxes = [boxes]
168
+
169
+ return boxes
170
+
171
+
172
+ def create_annotated_image_qwen3(image, json_output):
173
+ try:
174
+ boxes = parse_qwen3_json(json_output)
175
+ except Exception as e:
176
+ print(f"Error parsing JSON: {e}")
177
+ return image
178
+
179
+ if not boxes:
180
+ return image
181
+
182
+ original_width, original_height = image.size
183
+ annotated_image = np.array(image.convert("RGB"))
184
+
185
+ xyxy = []
186
+ labels = []
187
+
188
+ for box in boxes:
189
+ if "bbox_2d" in box and "label" in box:
190
+ x1, y1, x2, y2 = box["bbox_2d"]
191
+ scale = 1000
192
+ x1 = max(0, min(scale, x1)) / scale * original_width
193
+ y1 = max(0, min(scale, y1)) / scale * original_height
194
+ x2 = max(0, min(scale, x2)) / scale * original_width
195
+ y2 = max(0, min(scale, y2)) / scale * original_height
196
+ # Ensure x1 <= x2 and y1 <= y2
197
+ if x1 > x2: x1, x2 = x2, x1
198
+ if y1 > y2: y1, y2 = y2, y1
199
+ xyxy.append([int(x1), int(y1), int(x2), int(y2)])
200
+ labels.append(box["label"])
201
+
202
+ if not xyxy:
203
+ return image
204
+
205
+ detections = sv.Detections(
206
+ xyxy=np.array(xyxy),
207
+ class_id=np.arange(len(xyxy))
208
+ )
209
+
210
+ bounding_box_annotator = sv.BoxAnnotator(color_lookup=sv.ColorLookup.INDEX)
211
+ label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX)
212
+
213
+ annotated_image = bounding_box_annotator.annotate(
214
+ scene=annotated_image, detections=detections
215
+ )
216
+ annotated_image = label_annotator.annotate(
217
+ scene=annotated_image, detections=detections, labels=labels
218
+ )
219
+
220
+ return Image.fromarray(annotated_image)
221
+
222
+
223
  @GPU
224
  def detect_qwen(image, prompt):
225
  messages = [
 
233
  ]
234
 
235
  t0 = time.perf_counter()
236
+ inputs = processor_qwen.apply_chat_template(
237
+ messages,
238
+ tokenize=True,
239
+ add_generation_prompt=True,
240
+ return_dict=True,
241
+ return_tensors="pt"
 
 
 
 
242
  ).to(model_qwen.device)
243
 
244
  generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
 
248
  ]
249
  output_text = processor_qwen.batch_decode(
250
  generated_ids_trimmed,
 
251
  skip_special_tokens=True,
252
  clean_up_tokenization_spaces=False,
253
  )[0]
254
  elapsed_ms = (time.perf_counter() - t0) * 1_000
255
 
256
+ annotated_image = create_annotated_image_qwen3(image, output_text)
 
 
 
 
 
257
 
258
  time_taken = f"**Inference time ({model_qwen_name}):** {elapsed_ms:.0f} ms"
259
+
260
  return annotated_image, output_text, time_taken
261
 
262
 
 
308
  }
309
  """
310
 
 
311
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
312
  gr.Markdown("# 👓 Object Understanding with Vision Language Models")
313
  gr.Markdown(
314
  "### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts."
315
  )
316
  gr.Markdown("""
317
+ *Powered by [Qwen3-VL 4B](https://huggingface.co/Qwen/Qwen3-VL-4B-Instruct) and [Moondream 3 Preview](https://huggingface.co/moondream/moondream3-preview). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
318
+ *Moondream 3 uses the [moondream-preview](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
319
  """)
320
 
321
  with gr.Row():
 
368
  example_prompts = [
369
  [
370
  "examples/example_1.jpg",
371
+ "locate every instance in the image. Report bbox coordinates in JSON format.",
372
  "objects",
373
  "Object Detection",
374
  ],
375
  [
376
  "examples/example_2.JPG",
377
+ 'locate every instance that belongs to the following categories: "candy, hand". Report bbox coordinates in JSON format.',
378
  "candies",
379
  "Object Detection",
380
  ],
 
392
  ],
393
  [
394
  "examples/example_1.jpg",
395
+ 'locate every instance that belongs to the following categories: "red car". Report bbox coordinates in JSON format..',
396
  "red cars",
397
  "Visual Grounding + Keypoint Detection",
398
  ],
 
404
  ],
405
  [
406
  "examples/example_1.jpg",
407
+ 'locate every instance that belongs to the following categories: "leading red car". Report bbox coordinates in JSON format..',
408
  "leading red car",
409
  "Visual Grounding + Object Detection",
410
  ],
411
  [
412
  "examples/example_2.JPG",
413
+ 'locate every instance that belongs to the following categories: "blue candy located at the top of the group". Report bbox coordinates in JSON format.',
414
  "blue candy located at the top of the group",
415
  "Visual Grounding + Object Detection",
416
  ],
417
  ]
 
418
  gr.Examples(
419
+ examples=example_prompts,
420
+ inputs=[
421
+ image_input,
422
+ prompt_input_model_1,
423
+ prompt_input_model_2,
424
+ category_input,
425
+ ],
426
+ label="Click an example to populate the input",
427
  )
428
 
429
  generate_btn.click(
 
445
  )
446
 
447
  if __name__ == "__main__":
448
+ demo.launch()