Spaces:
Running
on
Zero
Running
on
Zero
Update demo/gradio_demo.py
Browse files- demo/gradio_demo.py +167 -43
demo/gradio_demo.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
-
import spaces
|
| 3 |
from PIL import Image, ImageDraw, ImageFont
|
| 4 |
import re
|
| 5 |
import numpy as np
|
|
@@ -13,7 +12,8 @@ from vlm_fo1.mm_utils import (
|
|
| 13 |
)
|
| 14 |
from vlm_fo1.task_templates import *
|
| 15 |
import torch
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
TASK_TYPES = {
|
|
@@ -22,10 +22,39 @@ TASK_TYPES = {
|
|
| 22 |
"Region_OCR": "Please provide the ocr results of these regions in the image.",
|
| 23 |
"Brief_Region_Caption": "Provide a brief description for these regions in the image.",
|
| 24 |
"Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
|
| 25 |
-
"Grounding": Grounding_template,
|
| 26 |
"Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
|
|
|
|
|
|
|
| 27 |
}
|
| 28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
|
| 31 |
def detect_model(image, threshold=0.3):
|
|
@@ -70,7 +99,12 @@ def multimodal_model(image, bboxes, text):
|
|
| 70 |
outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
|
| 71 |
print("========output========\n", outputs)
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
ans_bbox_json = []
|
| 76 |
ans_bbox_list = []
|
|
@@ -92,7 +126,6 @@ def multimodal_model(image, bboxes, text):
|
|
| 92 |
return outputs, ans_bbox_json, ans_bbox_list
|
| 93 |
|
| 94 |
|
| 95 |
-
|
| 96 |
def draw_bboxes(image, bboxes, labels=None):
|
| 97 |
image = image.copy()
|
| 98 |
draw = ImageDraw.Draw(image)
|
|
@@ -102,41 +135,67 @@ def draw_bboxes(image, bboxes, labels=None):
|
|
| 102 |
return image
|
| 103 |
|
| 104 |
|
| 105 |
-
def extract_bbox_and_original_image(edited_image
|
| 106 |
-
|
| 107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
| 111 |
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
| 114 |
|
| 115 |
-
|
| 116 |
-
alpha_channel = drawing_layer.getchannel('A')
|
| 117 |
-
alpha_np = np.array(alpha_channel)
|
| 118 |
|
| 119 |
-
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
-
|
| 123 |
|
| 124 |
-
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
|
| 130 |
|
| 131 |
-
return original_image, bbox_list
|
| 132 |
|
| 133 |
-
|
| 134 |
-
def process(image, prompt, threshold):
|
| 135 |
image, bbox_list = extract_bbox_and_original_image(image)
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
if len(bbox_list) == 0:
|
| 139 |
-
# Get bboxes from detection model
|
| 140 |
bboxes = detect_model(image, threshold)
|
| 141 |
else:
|
| 142 |
bboxes = bbox_list
|
|
@@ -145,7 +204,6 @@ def process(image, prompt, threshold):
|
|
| 145 |
|
| 146 |
ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
|
| 147 |
|
| 148 |
-
|
| 149 |
image_with_opn = draw_bboxes(image, bboxes)
|
| 150 |
|
| 151 |
annotated_bboxes = []
|
|
@@ -172,14 +230,27 @@ def update_btn(is_processing):
|
|
| 172 |
|
| 173 |
def launch_demo():
|
| 174 |
with gr.Blocks() as demo:
|
| 175 |
-
gr.Markdown("
|
| 176 |
gr.Markdown("""
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
""")
|
| 184 |
|
| 185 |
with gr.Row():
|
|
@@ -197,6 +268,23 @@ def launch_demo():
|
|
| 197 |
|
| 198 |
def set_prompt_from_template(selected_task):
|
| 199 |
return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
task_type_input = gr.Dropdown(
|
| 202 |
choices=list(TASK_TYPES.keys()),
|
|
@@ -211,31 +299,67 @@ def launch_demo():
|
|
| 211 |
lines=2,
|
| 212 |
)
|
| 213 |
|
| 214 |
-
task_type_input.
|
| 215 |
set_prompt_from_template,
|
| 216 |
inputs=task_type_input,
|
| 217 |
outputs=prompt_input
|
| 218 |
)
|
| 219 |
|
|
|
|
| 220 |
|
| 221 |
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
|
| 222 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
|
| 224 |
with gr.Column():
|
| 225 |
with gr.Accordion("Detection Result", open=True):
|
| 226 |
-
image_output_opn = gr.Image(label="Detection Result")
|
| 227 |
|
| 228 |
-
image_output = gr.AnnotatedImage(label="
|
| 229 |
|
| 230 |
-
result_output = gr.Textbox(label="
|
| 231 |
ans_bbox_json = gr.JSON(label="Extracted Detection Output")
|
| 232 |
|
| 233 |
-
submit_btn.click(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
process,
|
| 235 |
-
inputs=[img_input_draw, prompt_input, threshold_input],
|
| 236 |
outputs=[image_output, image_output_opn, result_output, ans_bbox_json],
|
| 237 |
queue=True
|
| 238 |
-
).then(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
return demo
|
| 241 |
|
|
|
|
| 1 |
import gradio as gr
|
|
|
|
| 2 |
from PIL import Image, ImageDraw, ImageFont
|
| 3 |
import re
|
| 4 |
import numpy as np
|
|
|
|
| 12 |
)
|
| 13 |
from vlm_fo1.task_templates import *
|
| 14 |
import torch
|
| 15 |
+
import os
|
| 16 |
+
from copy import deepcopy
|
| 17 |
|
| 18 |
|
| 19 |
TASK_TYPES = {
|
|
|
|
| 22 |
"Region_OCR": "Please provide the ocr results of these regions in the image.",
|
| 23 |
"Brief_Region_Caption": "Provide a brief description for these regions in the image.",
|
| 24 |
"Detailed_Region_Caption": "Provide a detailed description for these regions in the image.",
|
|
|
|
| 25 |
"Viusal_Region_Reasoning": Viusal_Region_Reasoning_template,
|
| 26 |
+
"OD_All": OD_All_template,
|
| 27 |
+
"Grounding": Grounding_template,
|
| 28 |
}
|
| 29 |
|
| 30 |
+
EXAMPLES = [
|
| 31 |
+
["demo_image.jpg", TASK_TYPES["OD/REC"].format("orange, apple"), "OD/REC"],
|
| 32 |
+
["demo_image_01.jpg", TASK_TYPES["ODCounting"].format("airplane with only one propeller"), "ODCounting"],
|
| 33 |
+
["demo_image_02.jpg", TASK_TYPES["OD/REC"].format("the ball closest to the bear"), "OD/REC"],
|
| 34 |
+
["demo_image_03.jpg", TASK_TYPES["OD_All"].format(""), "OD_All"],
|
| 35 |
+
["demo_image_03.jpg", TASK_TYPES["Viusal_Region_Reasoning"].format("What's the brand of this computer?"), "Viusal_Region_Reasoning"],
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_valid_examples():
|
| 40 |
+
valid_examples = []
|
| 41 |
+
demo_dir = os.path.dirname(os.path.abspath(__file__))
|
| 42 |
+
for example in EXAMPLES:
|
| 43 |
+
img_path = example[0]
|
| 44 |
+
full_path = os.path.join(demo_dir, img_path)
|
| 45 |
+
if os.path.exists(full_path):
|
| 46 |
+
valid_examples.append([
|
| 47 |
+
full_path,
|
| 48 |
+
example[1],
|
| 49 |
+
example[2]
|
| 50 |
+
])
|
| 51 |
+
elif os.path.exists(img_path):
|
| 52 |
+
valid_examples.append([
|
| 53 |
+
img_path,
|
| 54 |
+
example[1],
|
| 55 |
+
example[2]
|
| 56 |
+
])
|
| 57 |
+
return valid_examples
|
| 58 |
|
| 59 |
|
| 60 |
def detect_model(image, threshold=0.3):
|
|
|
|
| 99 |
outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
|
| 100 |
print("========output========\n", outputs)
|
| 101 |
|
| 102 |
+
if '<ground>' in outputs:
|
| 103 |
+
prediction_dict = extract_predictions_to_indexes(outputs)
|
| 104 |
+
else:
|
| 105 |
+
match_pattern = r"<region(\d+)>"
|
| 106 |
+
matches = re.findall(match_pattern, outputs)
|
| 107 |
+
prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
|
| 108 |
|
| 109 |
ans_bbox_json = []
|
| 110 |
ans_bbox_list = []
|
|
|
|
| 126 |
return outputs, ans_bbox_json, ans_bbox_list
|
| 127 |
|
| 128 |
|
|
|
|
| 129 |
def draw_bboxes(image, bboxes, labels=None):
|
| 130 |
image = image.copy()
|
| 131 |
draw = ImageDraw.Draw(image)
|
|
|
|
| 135 |
return image
|
| 136 |
|
| 137 |
|
| 138 |
+
def extract_bbox_and_original_image(edited_image):
|
| 139 |
+
"""Extract original image and bounding boxes from ImageEditor output"""
|
| 140 |
+
if edited_image is None:
|
| 141 |
+
return None, []
|
| 142 |
+
|
| 143 |
+
if isinstance(edited_image, dict):
|
| 144 |
+
original_image = edited_image.get("background")
|
| 145 |
+
bbox_list = []
|
| 146 |
+
|
| 147 |
+
if original_image is None:
|
| 148 |
+
return None, []
|
| 149 |
|
| 150 |
+
if edited_image.get("layers") is None or len(edited_image.get("layers", [])) == 0:
|
| 151 |
+
return original_image, []
|
| 152 |
|
| 153 |
+
try:
|
| 154 |
+
drawing_layer = edited_image["layers"][0]
|
| 155 |
+
alpha_channel = drawing_layer.getchannel('A')
|
| 156 |
+
alpha_np = np.array(alpha_channel)
|
| 157 |
|
| 158 |
+
binary_mask = alpha_np > 0
|
|
|
|
|
|
|
| 159 |
|
| 160 |
+
structuring_element = disk(5)
|
| 161 |
+
dilated_mask = binary_dilation(binary_mask, structuring_element)
|
| 162 |
|
| 163 |
+
labeled_image = label(dilated_mask)
|
| 164 |
+
regions = regionprops(labeled_image)
|
| 165 |
|
| 166 |
+
for prop in regions:
|
| 167 |
+
y_min, x_min, y_max, x_max = prop.bbox
|
| 168 |
+
bbox_list.append((x_min, y_min, x_max, y_max))
|
| 169 |
+
except Exception as e:
|
| 170 |
+
print(f"Error extracting bboxes from layers: {e}")
|
| 171 |
+
return original_image, []
|
| 172 |
|
| 173 |
+
return original_image, bbox_list
|
| 174 |
+
elif isinstance(edited_image, Image.Image):
|
| 175 |
+
return edited_image, []
|
| 176 |
+
else:
|
| 177 |
+
print(f"Unknown input type: {type(edited_image)}")
|
| 178 |
+
return None, []
|
| 179 |
|
|
|
|
| 180 |
|
| 181 |
+
def process(image, example_image, prompt, threshold):
|
|
|
|
| 182 |
image, bbox_list = extract_bbox_and_original_image(image)
|
| 183 |
+
|
| 184 |
+
if example_image is not None:
|
| 185 |
+
image = example_image
|
| 186 |
+
|
| 187 |
+
if image is None:
|
| 188 |
+
error_msg = "Error: Please upload an image or select a valid example."
|
| 189 |
+
print(f"Error: image is None, original input type: {type(image)}")
|
| 190 |
+
return None, None, error_msg, []
|
| 191 |
+
|
| 192 |
+
try:
|
| 193 |
+
image = image.convert('RGB')
|
| 194 |
+
except Exception as e:
|
| 195 |
+
error_msg = f"Error: Cannot process image - {str(e)}"
|
| 196 |
+
return None, None, error_msg, []
|
| 197 |
|
| 198 |
if len(bbox_list) == 0:
|
|
|
|
| 199 |
bboxes = detect_model(image, threshold)
|
| 200 |
else:
|
| 201 |
bboxes = bbox_list
|
|
|
|
| 204 |
|
| 205 |
ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt)
|
| 206 |
|
|
|
|
| 207 |
image_with_opn = draw_bboxes(image, bboxes)
|
| 208 |
|
| 209 |
annotated_bboxes = []
|
|
|
|
| 230 |
|
| 231 |
def launch_demo():
|
| 232 |
with gr.Blocks() as demo:
|
| 233 |
+
gr.Markdown("# 🚀 VLM-FO1 Demo")
|
| 234 |
gr.Markdown("""
|
| 235 |
+
### 📋 Instructions
|
| 236 |
+
|
| 237 |
+
**Step 1: Prepare Your Image**
|
| 238 |
+
- Upload an image using the image editor below
|
| 239 |
+
- *Optional:* Draw circular regions with the red brush to specify areas of interest
|
| 240 |
+
- *Alternative:* If not drawing regions, the detection model will automatically identify regions
|
| 241 |
+
|
| 242 |
+
**Step 2: Configure Your Task**
|
| 243 |
+
- Select a task template from the dropdown menu
|
| 244 |
+
- Replace `[WRITE YOUR INPUT HERE]` with your target objects or query
|
| 245 |
+
- *Example:* For detecting "person" and "dog", replace with: `person, dog`
|
| 246 |
+
- *Or:* Write your own custom prompt
|
| 247 |
+
|
| 248 |
+
**Step 3: Fine-tune Detection** *(Optional)*
|
| 249 |
+
- Adjust the detection threshold slider to control sensitivity
|
| 250 |
+
|
| 251 |
+
**Step 4: Generate Results**
|
| 252 |
+
- Click the **Submit** button to process your request
|
| 253 |
+
- View the detection results and model outputs below
|
| 254 |
""")
|
| 255 |
|
| 256 |
with gr.Row():
|
|
|
|
| 268 |
|
| 269 |
def set_prompt_from_template(selected_task):
|
| 270 |
return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]"))
|
| 271 |
+
|
| 272 |
+
def load_example(prompt_input, task_type_input, hidden_image_box):
|
| 273 |
+
cached_image = deepcopy(hidden_image_box)
|
| 274 |
+
w, h = cached_image.size
|
| 275 |
+
|
| 276 |
+
transparent_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0))
|
| 277 |
+
|
| 278 |
+
new_editor_value = {
|
| 279 |
+
"background": cached_image,
|
| 280 |
+
"layers": [transparent_layer],
|
| 281 |
+
"composite": None
|
| 282 |
+
}
|
| 283 |
+
|
| 284 |
+
return new_editor_value, prompt_input, task_type_input
|
| 285 |
+
|
| 286 |
+
def reset_hidden_image_box():
|
| 287 |
+
return gr.update(value=None)
|
| 288 |
|
| 289 |
task_type_input = gr.Dropdown(
|
| 290 |
choices=list(TASK_TYPES.keys()),
|
|
|
|
| 299 |
lines=2,
|
| 300 |
)
|
| 301 |
|
| 302 |
+
task_type_input.select(
|
| 303 |
set_prompt_from_template,
|
| 304 |
inputs=task_type_input,
|
| 305 |
outputs=prompt_input
|
| 306 |
)
|
| 307 |
|
| 308 |
+
hidden_image_box = gr.Image(label="Image", type="pil", image_mode="RGBA", visible=False)
|
| 309 |
|
| 310 |
threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold")
|
| 311 |
submit_btn = gr.Button("Submit", variant="primary")
|
| 312 |
+
|
| 313 |
+
valid_examples = get_valid_examples()
|
| 314 |
+
if len(valid_examples) > 0:
|
| 315 |
+
gr.Markdown("### Examples")
|
| 316 |
+
gr.Markdown("Click on the examples below to quickly load images and corresponding prompts:")
|
| 317 |
+
|
| 318 |
+
examples_data = [[example[0], example[1], example[2]] for index, example in enumerate(valid_examples)]
|
| 319 |
+
|
| 320 |
+
examples = gr.Examples(
|
| 321 |
+
examples=examples_data,
|
| 322 |
+
inputs=[hidden_image_box, prompt_input, task_type_input],
|
| 323 |
+
label="Click to load example",
|
| 324 |
+
examples_per_page=5
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
examples.load_input_event.then(
|
| 328 |
+
fn=load_example,
|
| 329 |
+
inputs=[prompt_input, task_type_input, hidden_image_box],
|
| 330 |
+
outputs=[img_input_draw, prompt_input, task_type_input]
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
img_input_draw.upload(
|
| 334 |
+
fn=reset_hidden_image_box,
|
| 335 |
+
outputs=[hidden_image_box]
|
| 336 |
+
)
|
| 337 |
|
| 338 |
with gr.Column():
|
| 339 |
with gr.Accordion("Detection Result", open=True):
|
| 340 |
+
image_output_opn = gr.Image(label="Detection Result", height=200)
|
| 341 |
|
| 342 |
+
image_output = gr.AnnotatedImage(label="VLM-FO1 Result", height=400)
|
| 343 |
|
| 344 |
+
result_output = gr.Textbox(label="VLM-FO1 Output", lines=5)
|
| 345 |
ans_bbox_json = gr.JSON(label="Extracted Detection Output")
|
| 346 |
|
| 347 |
+
submit_btn.click(
|
| 348 |
+
update_btn,
|
| 349 |
+
inputs=[gr.State(True)],
|
| 350 |
+
outputs=[submit_btn],
|
| 351 |
+
queue=False
|
| 352 |
+
).then(
|
| 353 |
process,
|
| 354 |
+
inputs=[img_input_draw, hidden_image_box, prompt_input, threshold_input],
|
| 355 |
outputs=[image_output, image_output_opn, result_output, ans_bbox_json],
|
| 356 |
queue=True
|
| 357 |
+
).then(
|
| 358 |
+
update_btn,
|
| 359 |
+
inputs=[gr.State(False)],
|
| 360 |
+
outputs=[submit_btn],
|
| 361 |
+
queue=False
|
| 362 |
+
)
|
| 363 |
|
| 364 |
return demo
|
| 365 |
|