P3ngLiu commited on
Commit
131c173
·
verified ·
1 Parent(s): b842d72

Update demo/gradio_demo.py

Browse files
Files changed (1) hide show
  1. demo/gradio_demo.py +167 -43
demo/gradio_demo.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import spaces
3
  from PIL import Image, ImageDraw, ImageFont
4
  import re
5
  import numpy as np
@@ -13,7 +12,8 @@ from vlm_fo1.mm_utils import (
13
  )
14
  from vlm_fo1.task_templates import *
15
  import torch
16
-
 
17
 
18
 
19
  TASK_TYPES = {
@@ -22,10 +22,39 @@ TASK_TYPES = {
22
  "Region_OCR": "Please provide the ocr results of these regions in the image.",
23
  "Brief_Region_Caption": "Provide a brief description for these regions in the image.",
24
  "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
25
- "Grounding": Grounding_template,
26
  "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
 
 
27
  }
28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
 
31
  def detect_model(image, threshold=0.3):
@@ -70,7 +99,12 @@ def multimodal_model(image, bboxes, text):
70
  outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
71
  print("========output========\n", outputs)
72
 
73
- prediction_dict = extract_predictions_to_indexes(outputs)
 
 
 
 
 
74
 
75
  ans_bbox_json = []
76
  ans_bbox_list = []
@@ -92,7 +126,6 @@ def multimodal_model(image, bboxes, text):
92
  return outputs, ans_bbox_json, ans_bbox_list
93
 
94
 
95
-
96
  def draw_bboxes(image, bboxes, labels=None):
97
  image = image.copy()
98
  draw = ImageDraw.Draw(image)
@@ -102,41 +135,67 @@ def draw_bboxes(image, bboxes, labels=None):
102
  return image
103
 
104
 
105
- def extract_bbox_and_original_image(edited_image: dict):
106
- original_image = edited_image["background"]
107
- bbox_list = []
 
 
 
 
 
 
 
 
108
 
109
- if original_image is None:
110
- return None, "Error, Please upload an image."
111
 
112
- if edited_image["layers"] is None or len(edited_image["layers"]) == 0:
113
- return original_image, []
 
 
114
 
115
- drawing_layer = edited_image["layers"][0]
116
- alpha_channel = drawing_layer.getchannel('A')
117
- alpha_np = np.array(alpha_channel)
118
 
119
- binary_mask = alpha_np > 0
 
120
 
121
- structuring_element = disk(5)
122
- dilated_mask = binary_dilation(binary_mask, structuring_element)
123
 
124
- labeled_image = label(dilated_mask)
125
- regions = regionprops(labeled_image)
 
 
 
 
126
 
127
- for prop in regions:
128
- y_min, x_min, y_max, x_max = prop.bbox
129
- bbox_list.append((x_min, y_min, x_max, y_max))
 
 
 
130
 
131
- return original_image, bbox_list
132
 
133
- @spaces.GPU
134
- def process(image, prompt, threshold):
135
  image, bbox_list = extract_bbox_and_original_image(image)
136
- image = image.convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if len(bbox_list) == 0:
139
- # Get bboxes from detection model
140
  bboxes = detect_model(image, threshold)
141
  else:
142
  bboxes = bbox_list
@@ -145,7 +204,6 @@ def process(image, prompt, threshold):
145
 
146
  ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
147
 
148
-
149
  image_with_opn = draw_bboxes(image, bboxes)
150
 
151
  annotated_bboxes = []
@@ -172,14 +230,27 @@ def update_btn(is_processing):
172
 
173
  def launch_demo():
174
  with gr.Blocks() as demo:
175
- gr.Markdown("## VLM-FO1 Demo")
176
  gr.Markdown("""
177
- **Instructions:**
178
- 1. Upload an image, then you can either draw circular regions on it using the red brush as the input regions or let the detection model detect the regions for you.
179
- 2. Select a task template and replace the [WRITE YOUR INPUT HERE] with your input targets, or write your own prompt.\n
180
- For example, if you want to detect "person" and "dog", you can replace the [WRITE YOUR INPUT HERE] with "person, dog".\n
181
- 3. Adjust the detection threshold if needed
182
- 4. Click Submit to get results
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  """)
184
 
185
  with gr.Row():
@@ -197,6 +268,23 @@ def launch_demo():
197
 
198
  def set_prompt_from_template(selected_task):
199
  return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  task_type_input = gr.Dropdown(
202
  choices=list(TASK_TYPES.keys()),
@@ -211,31 +299,67 @@ def launch_demo():
211
  lines=2,
212
  )
213
 
214
- task_type_input.change(
215
  set_prompt_from_template,
216
  inputs=task_type_input,
217
  outputs=prompt_input
218
  )
219
 
 
220
 
221
  threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
222
  submit_btn = gr.Button("Submit", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
  with gr.Column():
225
  with gr.Accordion("Detection Result", open=True):
226
- image_output_opn = gr.Image(label="Detection Result")
227
 
228
- image_output = gr.AnnotatedImage(label="Multimodal Model Output", height=500)
229
 
230
- result_output = gr.Textbox(label="Multimodal Model Output")
231
  ans_bbox_json = gr.JSON(label="Extracted Detection Output")
232
 
233
- submit_btn.click(update_btn, inputs=[gr.State(True)], outputs=[submit_btn], queue=False).then(
 
 
 
 
 
234
  process,
235
- inputs=[img_input_draw, prompt_input, threshold_input],
236
  outputs=[image_output, image_output_opn, result_output, ans_bbox_json],
237
  queue=True
238
- ).then(update_btn, inputs=[gr.State(False)], outputs=[submit_btn], queue=False)
 
 
 
 
 
239
 
240
  return demo
241
 
 
1
  import gradio as gr
 
2
  from PIL import Image, ImageDraw, ImageFont
3
  import re
4
  import numpy as np
 
12
  )
13
  from vlm_fo1.task_templates import *
14
  import torch
15
+ import os
16
+ from copy import deepcopy
17
 
18
 
19
  TASK_TYPES = {
 
22
  "Region_OCR": "Please provide the ocr results of these regions in the image.",
23
  "Brief_Region_Caption": "Provide a brief description for these regions in the image.",
24
  "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
 
25
  "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
26
+ "OD_All": OD_All_template,
27
+ "Grounding": Grounding_template,
28
  }
29
 
30
+ EXAMPLES = [
31
+ ["demo_image.jpg", TASK_TYPES["OD/REC"].format("orange, apple"), "OD/REC"],
32
+ ["demo_image_01.jpg", TASK_TYPES["ODCounting"].format("airplane with only one propeller"), "ODCounting"],
33
+ ["demo_image_02.jpg", TASK_TYPES["OD/REC"].format("the ball closest to the bear"), "OD/REC"],
34
+ ["demo_image_03.jpg", TASK_TYPES["OD_All"].format(""), "OD_All"],
35
+ ["demo_image_03.jpg", TASK_TYPES["Viusal_Region_Reasoning"].format("What's the brand of this computer?"), "Viusal_Region_Reasoning"],
36
+ ]
37
+
38
+
39
+ def get_valid_examples():
40
+ valid_examples = []
41
+ demo_dir = os.path.dirname(os.path.abspath(__file__))
42
+ for example in EXAMPLES:
43
+ img_path = example[0]
44
+ full_path = os.path.join(demo_dir, img_path)
45
+ if os.path.exists(full_path):
46
+ valid_examples.append([
47
+ full_path,
48
+ example[1],
49
+ example[2]
50
+ ])
51
+ elif os.path.exists(img_path):
52
+ valid_examples.append([
53
+ img_path,
54
+ example[1],
55
+ example[2]
56
+ ])
57
+ return valid_examples
58
 
59
 
60
  def detect_model(image, threshold=0.3):
 
99
  outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
100
  print("========output========\n", outputs)
101
 
102
+ if '<ground>' in outputs:
103
+ prediction_dict = extract_predictions_to_indexes(outputs)
104
+ else:
105
+ match_pattern = r"<region(\d+)>"
106
+ matches = re.findall(match_pattern, outputs)
107
+ prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
108
 
109
  ans_bbox_json = []
110
  ans_bbox_list = []
 
126
  return outputs, ans_bbox_json, ans_bbox_list
127
 
128
 
 
129
  def draw_bboxes(image, bboxes, labels=None):
130
  image = image.copy()
131
  draw = ImageDraw.Draw(image)
 
135
  return image
136
 
137
 
138
+ def extract_bbox_and_original_image(edited_image):
139
+ """Extract original image and bounding boxes from ImageEditor output"""
140
+ if edited_image is None:
141
+ return None, []
142
+
143
+ if isinstance(edited_image, dict):
144
+ original_image = edited_image.get("background")
145
+ bbox_list = []
146
+
147
+ if original_image is None:
148
+ return None, []
149
 
150
+ if edited_image.get("layers") is None or len(edited_image.get("layers", [])) == 0:
151
+ return original_image, []
152
 
153
+ try:
154
+ drawing_layer = edited_image["layers"][0]
155
+ alpha_channel = drawing_layer.getchannel('A')
156
+ alpha_np = np.array(alpha_channel)
157
 
158
+ binary_mask = alpha_np > 0
 
 
159
 
160
+ structuring_element = disk(5)
161
+ dilated_mask = binary_dilation(binary_mask, structuring_element)
162
 
163
+ labeled_image = label(dilated_mask)
164
+ regions = regionprops(labeled_image)
165
 
166
+ for prop in regions:
167
+ y_min, x_min, y_max, x_max = prop.bbox
168
+ bbox_list.append((x_min, y_min, x_max, y_max))
169
+ except Exception as e:
170
+ print(f"Error extracting bboxes from layers: {e}")
171
+ return original_image, []
172
 
173
+ return original_image, bbox_list
174
+ elif isinstance(edited_image, Image.Image):
175
+ return edited_image, []
176
+ else:
177
+ print(f"Unknown input type: {type(edited_image)}")
178
+ return None, []
179
 
 
180
 
181
+ def process(image, example_image, prompt, threshold):
 
182
  image, bbox_list = extract_bbox_and_original_image(image)
183
+
184
+ if example_image is not None:
185
+ image = example_image
186
+
187
+ if image is None:
188
+ error_msg = "Error: Please upload an image or select a valid example."
189
+ print(f"Error: image is None, original input type: {type(image)}")
190
+ return None, None, error_msg, []
191
+
192
+ try:
193
+ image = image.convert('RGB')
194
+ except Exception as e:
195
+ error_msg = f"Error: Cannot process image - {str(e)}"
196
+ return None, None, error_msg, []
197
 
198
  if len(bbox_list) == 0:
 
199
  bboxes = detect_model(image, threshold)
200
  else:
201
  bboxes = bbox_list
 
204
 
205
  ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
206
 
 
207
  image_with_opn = draw_bboxes(image, bboxes)
208
 
209
  annotated_bboxes = []
 
230
 
231
  def launch_demo():
232
  with gr.Blocks() as demo:
233
+ gr.Markdown("# 🚀 VLM-FO1 Demo")
234
  gr.Markdown("""
235
+ ### 📋 Instructions
236
+
237
+ **Step 1: Prepare Your Image**
238
+ - Upload an image using the image editor below
239
+ - *Optional:* Draw circular regions with the red brush to specify areas of interest
240
+ - *Alternative:* If not drawing regions, the detection model will automatically identify regions
241
+
242
+ **Step 2: Configure Your Task**
243
+ - Select a task template from the dropdown menu
244
+ - Replace `[WRITE YOUR INPUT HERE]` with your target objects or query
245
+ - *Example:* For detecting "person" and "dog", replace with: `person, dog`
246
+ - *Or:* Write your own custom prompt
247
+
248
+ **Step 3: Fine-tune Detection** *(Optional)*
249
+ - Adjust the detection threshold slider to control sensitivity
250
+
251
+ **Step 4: Generate Results**
252
+ - Click the **Submit** button to process your request
253
+ - View the detection results and model outputs below
254
  """)
255
 
256
  with gr.Row():
 
268
 
269
  def set_prompt_from_template(selected_task):
270
  return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
271
+
272
+ def load_example(prompt_input, task_type_input, hidden_image_box):
273
+ cached_image = deepcopy(hidden_image_box)
274
+ w, h = cached_image.size
275
+
276
+ transparent_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
277
+
278
+ new_editor_value = {
279
+ "background": cached_image,
280
+ "layers": [transparent_layer],
281
+ "composite": None
282
+ }
283
+
284
+ return new_editor_value, prompt_input, task_type_input
285
+
286
+ def reset_hidden_image_box():
287
+ return gr.update(value=None)
288
 
289
  task_type_input = gr.Dropdown(
290
  choices=list(TASK_TYPES.keys()),
 
299
  lines=2,
300
  )
301
 
302
+ task_type_input.select(
303
  set_prompt_from_template,
304
  inputs=task_type_input,
305
  outputs=prompt_input
306
  )
307
 
308
+ hidden_image_box = gr.Image(label="Image", type="pil", image_mode="RGBA", visible=False)
309
 
310
  threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
311
  submit_btn = gr.Button("Submit", variant="primary")
312
+
313
+ valid_examples = get_valid_examples()
314
+ if len(valid_examples) > 0:
315
+ gr.Markdown("### Examples")
316
+ gr.Markdown("Click on the examples below to quickly load images and corresponding prompts:")
317
+
318
+ examples_data = [[example[0], example[1], example[2]] for index, example in enumerate(valid_examples)]
319
+
320
+ examples = gr.Examples(
321
+ examples=examples_data,
322
+ inputs=[hidden_image_box, prompt_input, task_type_input],
323
+ label="Click to load example",
324
+ examples_per_page=5
325
+ )
326
+
327
+ examples.load_input_event.then(
328
+ fn=load_example,
329
+ inputs=[prompt_input, task_type_input, hidden_image_box],
330
+ outputs=[img_input_draw, prompt_input, task_type_input]
331
+ )
332
+
333
+ img_input_draw.upload(
334
+ fn=reset_hidden_image_box,
335
+ outputs=[hidden_image_box]
336
+ )
337
 
338
  with gr.Column():
339
  with gr.Accordion("Detection Result", open=True):
340
+ image_output_opn = gr.Image(label="Detection Result", height=200)
341
 
342
+ image_output = gr.AnnotatedImage(label="VLM-FO1 Result", height=400)
343
 
344
+ result_output = gr.Textbox(label="VLM-FO1 Output", lines=5)
345
  ans_bbox_json = gr.JSON(label="Extracted Detection Output")
346
 
347
+ submit_btn.click(
348
+ update_btn,
349
+ inputs=[gr.State(True)],
350
+ outputs=[submit_btn],
351
+ queue=False
352
+ ).then(
353
  process,
354
+ inputs=[img_input_draw, hidden_image_box, prompt_input, threshold_input],
355
  outputs=[image_output, image_output_opn, result_output, ans_bbox_json],
356
  queue=True
357
+ ).then(
358
+ update_btn,
359
+ inputs=[gr.State(False)],
360
+ outputs=[submit_btn],
361
+ queue=False
362
+ )
363
 
364
  return demo
365