FO1_VS_SAM3_DEMO / demo /gradio_demo2.py
P3ngLiu's picture
Update demo/gradio_demo2.py
24ae542 verified
import gradio as gr
import spaces
from PIL import Image, ImageDraw, ImageFont
import re
import random
import numpy as np
from skimage.measure import label, regionprops
from skimage.morphology import binary_dilation, disk
from sam3.model_builder import build_sam3_image_model
from sam3.model.sam3_image_processor import Sam3Processor
from sam3.visualization_utils import plot_bbox, plot_mask, COLORS
import matplotlib.pyplot as plt
from detect_tools.upn import UPNWrapper
from vlm_fo1.model.builder import load_pretrained_model
from vlm_fo1.mm_utils import (
prepare_inputs,
extract_predictions_to_indexes,
)
from vlm_fo1.task_templates import *
import torch
import os
from copy import deepcopy
EXAMPLES = [
["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"],
["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"],
["demo/sam3_examples/00000-22.jpg","person wearing a black top"],
["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"],
]
def get_valid_examples():
valid_examples = []
demo_dir = os.path.dirname(os.path.abspath(__file__))
for example in EXAMPLES:
img_path = example[0]
full_path = os.path.join(demo_dir, img_path)
if os.path.exists(full_path):
valid_examples.append([
full_path,
example[1],
])
elif os.path.exists(img_path):
valid_examples.append([
img_path,
example[1],
])
return valid_examples
def detect_model_upn(image, threshold=0.3):
proposals = upn_model.inference(image)
filtered_proposals = upn_model.filter(proposals, min_score=threshold)
picked_proposals = filtered_proposals['original_xyxy_boxes'][0][:100]
return picked_proposals
def detect_model_sam3(image, text, threshold=0.3):
inference_state = sam3_processor.set_image(image)
output = sam3_processor.set_text_prompt(state=inference_state, prompt=text)
boxes, scores, masks = output["boxes"], output["scores"], output["masks"]
sorted_indices = torch.argsort(scores, descending=True)
boxes = boxes[sorted_indices][:100, :]
scores = scores[sorted_indices][:100]
masks = masks[sorted_indices][:100]
output = {
"boxes": boxes,
"scores": scores,
"masks": masks,
}
return boxes.tolist(), scores.tolist(), masks.tolist(), output
def multimodal_model(image, bboxes, text, scores=None):
if len(bboxes) == 0:
return None, {}, []
if '<image>' in text:
print(text)
parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()]
print(parts)
content = []
for part in parts:
if part == '<image>':
content.append({"type": "image_url", "image_url": {"url": image}})
else:
content.append({"type": "text", "text": part})
else:
content = [{
"type": "image_url",
"image_url": {
"url": image
}
}, {
"type": "text",
"text": text
}]
messages = [
{
"role": "user",
"content": content,
"bbox_list": bboxes
}
]
generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages,
max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024)
with torch.inference_mode():
output_ids = model.generate(**generation_kwargs)
outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip()
print("========output========\n", outputs)
if '<ground>' in outputs:
prediction_dict = extract_predictions_to_indexes(outputs)
else:
match_pattern = r"<region(\d+)>"
matches = re.findall(match_pattern, outputs)
prediction_dict = {f"<region{m}>": {int(m)} for m in matches}
ans_bbox_json = []
ans_bbox_list = []
for k, v in prediction_dict.items():
for box_index in v:
box_index = int(box_index)
if box_index < len(bboxes):
current_bbox = bboxes[box_index]
item = {
"region_index": f"<region{box_index}>",
"xmin": current_bbox[0],
"ymin": current_bbox[1],
"xmax": current_bbox[2],
"ymax": current_bbox[3],
"label": k,
}
if scores is not None and box_index < len(scores):
item["score"] = scores[box_index]
ans_bbox_json.append(item)
ans_bbox_list.append(current_bbox)
return outputs, ans_bbox_json, ans_bbox_list
def draw_sam3_results(img, results):
fig, ax = plt.subplots(figsize=(12, 8))
# fig.subplots_adjust(0, 0, 1, 1)
ax.imshow(img)
nb_objects = len(results["scores"])
print(f"found {nb_objects} object(s)")
for i in range(nb_objects):
color = COLORS[i % len(COLORS)]
plot_mask(results["masks"][i].squeeze(0).cpu(), color=color)
w, h = img.size
prob = results["scores"][i].item()
plot_bbox(
h,
w,
results["boxes"][i].cpu(),
text=f"(id={i}, {prob=:.2f})",
box_format="XYXY",
color=color,
relative_coords=False,
)
ax.axis("off")
fig.tight_layout(pad=0)
# Convert matplotlib figure to PIL Image
fig.canvas.draw()
buf = fig.canvas.buffer_rgba()
pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf)
plt.close(fig)
return pil_img
def draw_bboxes_simple(image, bboxes, labels=None):
image = image.copy()
draw = ImageDraw.Draw(image)
for bbox in bboxes:
draw.rectangle(bbox, outline="red", width=3)
return image
@spaces.GPU
def process(image, prompt, threshold=0.3):
if image is None:
error_msg = "Error: Please upload an image or select a valid example."
print(f"Error: image is None, original input type: {type(image)}")
return None, None, None, None, [], []
try:
image = image.convert('RGB')
except Exception as e:
error_msg = f"Error: Cannot process image - {str(e)}"
return None, None, None, None, [], []
# --- SAM3 Pipeline ---
print("Running SAM3 Pipeline...")
sam3_bboxes, sam3_scores, masks, sam3_output = detect_model_sam3(image, prompt, threshold)
# Generate SAM3 outputs (Directly from SAM3, no VLM-FO1)
sam3_detection_image = draw_sam3_results(image, sam3_output)
sam3_annotated_bboxes = []
sam3_ans_bbox_json = []
img_width, img_height = image.size
for i, bbox in enumerate(sam3_bboxes):
xmin = max(0, min(img_width, int(bbox[0])))
ymin = max(0, min(img_height, int(bbox[1])))
xmax = max(0, min(img_width, int(bbox[2])))
ymax = max(0, min(img_height, int(bbox[3])))
score = sam3_scores[i]
# Format label with score
label_text = f"{prompt} {score:.2f}"
sam3_annotated_bboxes.append(
((xmin, ymin, xmax, ymax), label_text)
)
sam3_ans_bbox_json.append({
"region_index": i,
"xmin": bbox[0],
"ymin": bbox[1],
"xmax": bbox[2],
"ymax": bbox[3],
"label": prompt,
"score": score
})
sam3_annotated_image = (image, sam3_annotated_bboxes)
# --- UPN Pipeline ---
print("Running UPN Pipeline...")
upn_bboxes = detect_model_upn(image, threshold=0.3) # Use default threshold for UPN
fo1_prompt_upn = OD_template.format(prompt)
upn_bboxes = upn_bboxes[::-1]
upn_ans, upn_ans_bbox_json, upn_ans_bbox_list = multimodal_model(image, upn_bboxes, fo1_prompt_upn)
upn_detection_image = draw_bboxes_simple(image, upn_bboxes)
upn_annotated_bboxes = []
if len(upn_ans_bbox_json) > 0:
img_width, img_height = image.size
for item in upn_ans_bbox_json:
xmin = max(0, min(img_width, int(item['xmin'])))
ymin = max(0, min(img_height, int(item['ymin'])))
xmax = max(0, min(img_width, int(item['xmax'])))
ymax = max(0, min(img_height, int(item['ymax'])))
upn_annotated_bboxes.append(
((xmin, ymin, xmax, ymax), item['label'])
)
upn_annotated_image = (image, upn_annotated_bboxes)
return sam3_annotated_image, sam3_detection_image, \
upn_annotated_image, upn_detection_image, upn_ans_bbox_json
def update_btn(is_processing):
if is_processing:
return gr.update(value="Processing...", interactive=False)
else:
return gr.update(value="Submit", interactive=True)
def launch_demo():
with gr.Blocks() as demo:
gr.Markdown("# πŸš€ VLM-FO1 vs SAM3 Demo")
gr.Markdown("""
### πŸ“‹ Instructions
Compare the detection performance of **SAM3** vs **VLM-FO1**.
**How it works**
1. Upload or pick an example image.
2. Describe the target object in natural language.
3. Hit **Submit** to run both pipelines.
""")
with gr.Row():
with gr.Column():
img_input_draw = gr.Image(
label="Image Input",
type="pil",
sources=['upload'],
)
gr.Markdown("### Prompt")
prompt_input = gr.Textbox(
label="Label Prompt",
lines=2,
)
submit_btn = gr.Button("Submit", variant="primary")
examples = gr.Examples(
examples=EXAMPLES,
inputs=[img_input_draw, prompt_input],
label="Click to load example",
examples_per_page=5
)
with gr.Column():
gr.Markdown("### SAM3 Result")
with gr.Accordion("SAM3 Masks & Boxes", open=False):
sam3_detection_output = gr.Image(label="SAM3 Visualization", height=300)
sam3_final_output = gr.AnnotatedImage(label="SAM3 Detections", height=400)
# sam3_json_output = gr.JSON(label="SAM3 Output Data")
with gr.Column():
gr.Markdown("### VLM-FO1 Result")
with gr.Accordion("Bboxes Proposals", open=False):
upn_detection_output = gr.Image(label="Bboxes", height=300)
upn_final_output = gr.AnnotatedImage(label="VLM-FO1 Final", height=400)
upn_json_output = gr.JSON(label="VLM-FO1 Details")
submit_btn.click(
update_btn,
inputs=[gr.State(True)],
outputs=[submit_btn],
queue=False
).then(
process,
inputs=[img_input_draw, prompt_input],
outputs=[
sam3_final_output, sam3_detection_output,
upn_final_output, upn_detection_output, upn_json_output
],
queue=True
).then(
update_btn,
inputs=[gr.State(False)],
outputs=[submit_btn],
queue=False
)
return demo
if __name__ == "__main__":
import os
exit_code = os.system(f"wget -c https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3.pt")
model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01'
# sam3_model_path = './resources/sam3/sam3.pt'
upn_ckpt_path = "./resources/upn_large.pth"
# Load FO1
tokenizer, model, image_processors = load_pretrained_model(
model_path=model_path,
device="cuda:0",
)
# Load SAM3
sam3_model = build_sam3_image_model(checkpoint_path='./sam3.pt', device="cuda",bpe_path='/home/user/app/resources/bpe_simple_vocab_16e6.txt.gz')
sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.5, device="cuda")
# Load UPN
upn_model = UPNWrapper(upn_ckpt_path)
demo = launch_demo()
demo.launch()