jiang-cc commited on
Commit
a3b3596
·
verified ·
1 Parent(s): 875cb52

feat: auto-visualize bounding boxes on test image when model outputs bbox JSON

Browse files
Files changed (1) hide show
  1. app.py +85 -6
app.py CHANGED
@@ -2,14 +2,16 @@
2
  AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model
3
  """
4
 
 
5
  import os
 
6
  import traceback
7
  import spaces
8
  import gradio as gr
9
  import torch
10
  from transformers import AutoModelForImageTextToText, AutoProcessor
11
  from qwen_vl_utils import process_vision_info
12
- from PIL import Image
13
 
14
  # ---------------------------------------------------------------------------
15
  # Model loading (happens once at Space startup; weights stay on CPU until
@@ -31,6 +33,75 @@ model = AutoModelForImageTextToText.from_pretrained(
31
  ).eval()
32
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # ---------------------------------------------------------------------------
35
  # Inference
36
  # ---------------------------------------------------------------------------
@@ -42,7 +113,7 @@ def predict(
42
  max_new_tokens: float,
43
  ):
44
  if reference_image is None or test_image is None:
45
- return "Please upload both a reference (good) image and a test image."
46
 
47
  try:
48
  max_new_tokens = int(max_new_tokens)
@@ -88,11 +159,18 @@ def predict(
88
  skip_special_tokens=True,
89
  clean_up_tokenization_spaces=False,
90
  )[0]
91
- return output
 
 
 
 
 
 
 
92
  except Exception as e:
93
  tb = traceback.format_exc()
94
  print(tb, flush=True)
95
- return f"Error:\n{tb}"
96
 
97
 
98
  # ---------------------------------------------------------------------------
@@ -183,17 +261,18 @@ with gr.Blocks(theme=gr.themes.Soft(), title=TITLE) as demo:
183
  run_btn = gr.Button("Detect Anomaly", variant="primary", scale=2)
184
 
185
  output = gr.Textbox(label="Model Output", lines=4)
 
186
 
187
  run_btn.click(
188
  fn=predict,
189
  inputs=[ref_img, test_img, prompt, max_tokens],
190
- outputs=output,
191
  )
192
 
193
  gr.Examples(
194
  examples=EXAMPLES,
195
  inputs=[ref_img, test_img, prompt, max_tokens],
196
- outputs=output,
197
  fn=predict,
198
  cache_examples=False,
199
  )
 
2
  AD-Copilot Demo: Comparison-Aware Anomaly Detection with Vision-Language Model
3
  """
4
 
5
+ import json
6
  import os
7
+ import re
8
  import traceback
9
  import spaces
10
  import gradio as gr
11
  import torch
12
  from transformers import AutoModelForImageTextToText, AutoProcessor
13
  from qwen_vl_utils import process_vision_info
14
+ from PIL import Image, ImageDraw, ImageFont
15
 
16
  # ---------------------------------------------------------------------------
17
  # Model loading (happens once at Space startup; weights stay on CPU until
 
33
  ).eval()
34
 
35
 
36
+ # ---------------------------------------------------------------------------
37
+ # BBox visualization
38
+ # ---------------------------------------------------------------------------
39
+ COLORS = [
40
+ "#FF4444", "#44AA44", "#4488FF", "#FF8800",
41
+ "#AA44FF", "#00CCCC", "#FF44AA", "#88AA00",
42
+ ]
43
+
44
+
45
+ def parse_bboxes(text):
46
+ """Try to extract bbox JSON from model output."""
47
+ # Match ```json ... ``` or raw JSON array
48
+ pattern = r'```(?:json)?\s*(\[.*?\])\s*```'
49
+ match = re.search(pattern, text, re.DOTALL)
50
+ if match:
51
+ raw = match.group(1)
52
+ else:
53
+ # Try bare JSON array
54
+ match = re.search(r'(\[\s*\{.*?\}\s*\])', text, re.DOTALL)
55
+ if match:
56
+ raw = match.group(1)
57
+ else:
58
+ return None
59
+ try:
60
+ bboxes = json.loads(raw)
61
+ if isinstance(bboxes, list) and len(bboxes) > 0 and "bbox_2d" in bboxes[0]:
62
+ return bboxes
63
+ except json.JSONDecodeError:
64
+ pass
65
+ return None
66
+
67
+
68
+ def draw_bboxes(image, bboxes):
69
+ """Draw bounding boxes with labels on image."""
70
+ img = image.copy()
71
+ draw = ImageDraw.Draw(img)
72
+
73
+ # Try to get a reasonable font
74
+ try:
75
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
76
+ small_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 13)
77
+ except (IOError, OSError):
78
+ font = ImageFont.load_default()
79
+ small_font = font
80
+
81
+ for i, bbox_info in enumerate(bboxes):
82
+ bbox = bbox_info.get("bbox_2d", [])
83
+ label = bbox_info.get("label", f"defect_{i}")
84
+ if len(bbox) != 4:
85
+ continue
86
+
87
+ x1, y1, x2, y2 = bbox
88
+ color = COLORS[i % len(COLORS)]
89
+
90
+ # Draw box with thicker outline
91
+ for w in range(3):
92
+ draw.rectangle([x1 - w, y1 - w, x2 + w, y2 + w], outline=color)
93
+
94
+ # Draw label background
95
+ text_bbox = draw.textbbox((0, 0), label, font=small_font)
96
+ tw = text_bbox[2] - text_bbox[0] + 8
97
+ th = text_bbox[3] - text_bbox[1] + 6
98
+ label_y = max(0, y1 - th - 2)
99
+ draw.rectangle([x1, label_y, x1 + tw, label_y + th], fill=color)
100
+ draw.text((x1 + 4, label_y + 2), label, fill="white", font=small_font)
101
+
102
+ return img
103
+
104
+
105
  # ---------------------------------------------------------------------------
106
  # Inference
107
  # ---------------------------------------------------------------------------
 
113
  max_new_tokens: float,
114
  ):
115
  if reference_image is None or test_image is None:
116
+ return "Please upload both a reference (good) image and a test image.", None
117
 
118
  try:
119
  max_new_tokens = int(max_new_tokens)
 
159
  skip_special_tokens=True,
160
  clean_up_tokenization_spaces=False,
161
  )[0]
162
+
163
+ # Try to visualize bboxes if present
164
+ bboxes = parse_bboxes(output)
165
+ vis_image = None
166
+ if bboxes:
167
+ vis_image = draw_bboxes(test_image, bboxes)
168
+
169
+ return output, vis_image
170
  except Exception as e:
171
  tb = traceback.format_exc()
172
  print(tb, flush=True)
173
+ return f"Error:\n{tb}", None
174
 
175
 
176
  # ---------------------------------------------------------------------------
 
261
  run_btn = gr.Button("Detect Anomaly", variant="primary", scale=2)
262
 
263
  output = gr.Textbox(label="Model Output", lines=4)
264
+ vis_output = gr.Image(label="Detection Visualization", visible=True)
265
 
266
  run_btn.click(
267
  fn=predict,
268
  inputs=[ref_img, test_img, prompt, max_tokens],
269
+ outputs=[output, vis_output],
270
  )
271
 
272
  gr.Examples(
273
  examples=EXAMPLES,
274
  inputs=[ref_img, test_img, prompt, max_tokens],
275
+ outputs=[output, vis_output],
276
  fn=predict,
277
  cache_examples=False,
278
  )