import gradio as gr import spaces from PIL import Image, ImageDraw, ImageFont import re import random import numpy as np from skimage.measure import label, regionprops from skimage.morphology import binary_dilation, disk from sam3.model_builder import build_sam3_image_model from sam3.model.sam3_image_processor import Sam3Processor from sam3.visualization_utils import plot_bbox, plot_mask, COLORS import matplotlib.pyplot as plt from detect_tools.upn import UPNWrapper from vlm_fo1.model.builder import load_pretrained_model from vlm_fo1.mm_utils import ( prepare_inputs, extract_predictions_to_indexes, ) from vlm_fo1.task_templates import * import torch import os from copy import deepcopy EXAMPLES = [ ["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"], ["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"], ["demo/sam3_examples/00000-22.jpg","person wearing a black top"], ["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"], ] def get_valid_examples(): valid_examples = [] demo_dir = os.path.dirname(os.path.abspath(__file__)) for example in EXAMPLES: img_path = example[0] full_path = os.path.join(demo_dir, img_path) if os.path.exists(full_path): valid_examples.append([ full_path, example[1], ]) elif os.path.exists(img_path): valid_examples.append([ img_path, example[1], ]) return valid_examples def detect_model_upn(image, threshold=0.3): proposals = upn_model.inference(image) filtered_proposals = upn_model.filter(proposals, min_score=threshold) picked_proposals = filtered_proposals['original_xyxy_boxes'][0][:100] return picked_proposals def detect_model_sam3(image, text, threshold=0.3): inference_state = sam3_processor.set_image(image) output = sam3_processor.set_text_prompt(state=inference_state, prompt=text) boxes, scores, masks = output["boxes"], output["scores"], output["masks"] sorted_indices = torch.argsort(scores, descending=True) boxes = boxes[sorted_indices][:100, :] scores = scores[sorted_indices][:100] masks = masks[sorted_indices][:100] output = { "boxes": boxes, "scores": scores, "masks": masks, } return boxes.tolist(), scores.tolist(), masks.tolist(), output def multimodal_model(image, bboxes, text, scores=None): if len(bboxes) == 0: return None, {}, [] if '' in text: print(text) parts = [part.replace('\\n', '\n') for part in re.split(rf'()', text) if part.strip()] print(parts) content = [] for part in parts: if part == '': content.append({"type": "image_url", "image_url": {"url": image}}) else: content.append({"type": "text", "text": part}) else: content = [{ "type": "image_url", "image_url": { "url": image } }, { "type": "text", "text": text }] messages = [ { "role": "user", "content": content, "bbox_list": bboxes } ] generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages, max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024) with torch.inference_mode(): output_ids = model.generate(**generation_kwargs) outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip() print("========output========\n", outputs) if '' in outputs: prediction_dict = extract_predictions_to_indexes(outputs) else: match_pattern = r"" matches = re.findall(match_pattern, outputs) prediction_dict = {f"": {int(m)} for m in matches} ans_bbox_json = [] ans_bbox_list = [] for k, v in prediction_dict.items(): for box_index in v: box_index = int(box_index) if box_index < len(bboxes): current_bbox = bboxes[box_index] item = { "region_index": f"", "xmin": current_bbox[0], "ymin": current_bbox[1], "xmax": current_bbox[2], "ymax": current_bbox[3], "label": k, } if scores is not None and box_index < len(scores): item["score"] = scores[box_index] ans_bbox_json.append(item) ans_bbox_list.append(current_bbox) return outputs, ans_bbox_json, ans_bbox_list def draw_sam3_results(img, results): fig, ax = plt.subplots(figsize=(12, 8)) # fig.subplots_adjust(0, 0, 1, 1) ax.imshow(img) nb_objects = len(results["scores"]) print(f"found {nb_objects} object(s)") for i in range(nb_objects): color = COLORS[i % len(COLORS)] plot_mask(results["masks"][i].squeeze(0).cpu(), color=color) w, h = img.size prob = results["scores"][i].item() plot_bbox( h, w, results["boxes"][i].cpu(), text=f"(id={i}, {prob=:.2f})", box_format="XYXY", color=color, relative_coords=False, ) ax.axis("off") fig.tight_layout(pad=0) # Convert matplotlib figure to PIL Image fig.canvas.draw() buf = fig.canvas.buffer_rgba() pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf) plt.close(fig) return pil_img def draw_bboxes_simple(image, bboxes, labels=None): image = image.copy() draw = ImageDraw.Draw(image) for bbox in bboxes: draw.rectangle(bbox, outline="red", width=3) return image @spaces.GPU def process(image, prompt, threshold=0.3): if image is None: error_msg = "Error: Please upload an image or select a valid example." print(f"Error: image is None, original input type: {type(image)}") return None, None, None, None, [], [] try: image = image.convert('RGB') except Exception as e: error_msg = f"Error: Cannot process image - {str(e)}" return None, None, None, None, [], [] # --- SAM3 Pipeline --- print("Running SAM3 Pipeline...") sam3_bboxes, sam3_scores, masks, sam3_output = detect_model_sam3(image, prompt, threshold) # Generate SAM3 outputs (Directly from SAM3, no VLM-FO1) sam3_detection_image = draw_sam3_results(image, sam3_output) sam3_annotated_bboxes = [] sam3_ans_bbox_json = [] img_width, img_height = image.size for i, bbox in enumerate(sam3_bboxes): xmin = max(0, min(img_width, int(bbox[0]))) ymin = max(0, min(img_height, int(bbox[1]))) xmax = max(0, min(img_width, int(bbox[2]))) ymax = max(0, min(img_height, int(bbox[3]))) score = sam3_scores[i] # Format label with score label_text = f"{prompt} {score:.2f}" sam3_annotated_bboxes.append( ((xmin, ymin, xmax, ymax), label_text) ) sam3_ans_bbox_json.append({ "region_index": i, "xmin": bbox[0], "ymin": bbox[1], "xmax": bbox[2], "ymax": bbox[3], "label": prompt, "score": score }) sam3_annotated_image = (image, sam3_annotated_bboxes) # --- UPN Pipeline --- print("Running UPN Pipeline...") upn_bboxes = detect_model_upn(image, threshold=0.3) # Use default threshold for UPN fo1_prompt_upn = OD_template.format(prompt) upn_bboxes = upn_bboxes[::-1] upn_ans, upn_ans_bbox_json, upn_ans_bbox_list = multimodal_model(image, upn_bboxes, fo1_prompt_upn) upn_detection_image = draw_bboxes_simple(image, upn_bboxes) upn_annotated_bboxes = [] if len(upn_ans_bbox_json) > 0: img_width, img_height = image.size for item in upn_ans_bbox_json: xmin = max(0, min(img_width, int(item['xmin']))) ymin = max(0, min(img_height, int(item['ymin']))) xmax = max(0, min(img_width, int(item['xmax']))) ymax = max(0, min(img_height, int(item['ymax']))) upn_annotated_bboxes.append( ((xmin, ymin, xmax, ymax), item['label']) ) upn_annotated_image = (image, upn_annotated_bboxes) return sam3_annotated_image, sam3_detection_image, \ upn_annotated_image, upn_detection_image, upn_ans_bbox_json def update_btn(is_processing): if is_processing: return gr.update(value="Processing...", interactive=False) else: return gr.update(value="Submit", interactive=True) def launch_demo(): with gr.Blocks() as demo: gr.Markdown("# 🚀 VLM-FO1 vs SAM3 Demo") gr.Markdown(""" ### 📋 Instructions Compare the detection performance of **SAM3** vs **VLM-FO1**. **How it works** 1. Upload or pick an example image. 2. Describe the target object in natural language. 3. Hit **Submit** to run both pipelines. """) with gr.Row(): with gr.Column(): img_input_draw = gr.Image( label="Image Input", type="pil", sources=['upload'], ) gr.Markdown("### Prompt") prompt_input = gr.Textbox( label="Label Prompt", lines=2, ) submit_btn = gr.Button("Submit", variant="primary") examples = gr.Examples( examples=EXAMPLES, inputs=[img_input_draw, prompt_input], label="Click to load example", examples_per_page=5 ) with gr.Column(): gr.Markdown("### SAM3 Result") with gr.Accordion("SAM3 Masks & Boxes", open=False): sam3_detection_output = gr.Image(label="SAM3 Visualization", height=300) sam3_final_output = gr.AnnotatedImage(label="SAM3 Detections", height=400) # sam3_json_output = gr.JSON(label="SAM3 Output Data") with gr.Column(): gr.Markdown("### VLM-FO1 Result") with gr.Accordion("Bboxes Proposals", open=False): upn_detection_output = gr.Image(label="Bboxes", height=300) upn_final_output = gr.AnnotatedImage(label="VLM-FO1 Final", height=400) upn_json_output = gr.JSON(label="VLM-FO1 Details") submit_btn.click( update_btn, inputs=[gr.State(True)], outputs=[submit_btn], queue=False ).then( process, inputs=[img_input_draw, prompt_input], outputs=[ sam3_final_output, sam3_detection_output, upn_final_output, upn_detection_output, upn_json_output ], queue=True ).then( update_btn, inputs=[gr.State(False)], outputs=[submit_btn], queue=False ) return demo if __name__ == "__main__": import os exit_code = os.system(f"wget -c https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3.pt") model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01' # sam3_model_path = './resources/sam3/sam3.pt' upn_ckpt_path = "./resources/upn_large.pth" # Load FO1 tokenizer, model, image_processors = load_pretrained_model( model_path=model_path, device="cuda:0", ) # Load SAM3 sam3_model = build_sam3_image_model(checkpoint_path='./sam3.pt', device="cuda",bpe_path='/home/user/app/resources/bpe_simple_vocab_16e6.txt.gz') sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.5, device="cuda") # Load UPN upn_model = UPNWrapper(upn_ckpt_path) demo = launch_demo() demo.launch()