Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| from PIL import Image, ImageDraw, ImageFont | |
| import re | |
| import random | |
| import numpy as np | |
| from skimage.measure import label, regionprops | |
| from skimage.morphology import binary_dilation, disk | |
| from sam3.model_builder import build_sam3_image_model | |
| from sam3.model.sam3_image_processor import Sam3Processor | |
| from sam3.visualization_utils import plot_bbox, plot_mask, COLORS | |
| import matplotlib.pyplot as plt | |
| from detect_tools.upn import UPNWrapper | |
| from vlm_fo1.model.builder import load_pretrained_model | |
| from vlm_fo1.mm_utils import ( | |
| prepare_inputs, | |
| extract_predictions_to_indexes, | |
| ) | |
| from vlm_fo1.task_templates import * | |
| import torch | |
| import os | |
| from copy import deepcopy | |
| EXAMPLES = [ | |
| ["demo/sam3_examples/00000-72.jpg","airplane with letter AE on its body"], | |
| ["demo/sam3_examples/00000-32.jpg","the lying cat which is not black"], | |
| ["demo/sam3_examples/00000-22.jpg","person wearing a black top"], | |
| ["demo/sam3_examples/000000378453.jpg", "zebra inside the mud puddle"], | |
| ] | |
| def get_valid_examples(): | |
| valid_examples = [] | |
| demo_dir = os.path.dirname(os.path.abspath(__file__)) | |
| for example in EXAMPLES: | |
| img_path = example[0] | |
| full_path = os.path.join(demo_dir, img_path) | |
| if os.path.exists(full_path): | |
| valid_examples.append([ | |
| full_path, | |
| example[1], | |
| ]) | |
| elif os.path.exists(img_path): | |
| valid_examples.append([ | |
| img_path, | |
| example[1], | |
| ]) | |
| return valid_examples | |
| def detect_model_upn(image, threshold=0.3): | |
| proposals = upn_model.inference(image) | |
| filtered_proposals = upn_model.filter(proposals, min_score=threshold) | |
| picked_proposals = filtered_proposals['original_xyxy_boxes'][0][:100] | |
| return picked_proposals | |
| def detect_model_sam3(image, text, threshold=0.3): | |
| inference_state = sam3_processor.set_image(image) | |
| output = sam3_processor.set_text_prompt(state=inference_state, prompt=text) | |
| boxes, scores, masks = output["boxes"], output["scores"], output["masks"] | |
| sorted_indices = torch.argsort(scores, descending=True) | |
| boxes = boxes[sorted_indices][:100, :] | |
| scores = scores[sorted_indices][:100] | |
| masks = masks[sorted_indices][:100] | |
| output = { | |
| "boxes": boxes, | |
| "scores": scores, | |
| "masks": masks, | |
| } | |
| return boxes.tolist(), scores.tolist(), masks.tolist(), output | |
| def multimodal_model(image, bboxes, text, scores=None): | |
| if len(bboxes) == 0: | |
| return None, {}, [] | |
| if '<image>' in text: | |
| print(text) | |
| parts = [part.replace('\\n', '\n') for part in re.split(rf'(<image>)', text) if part.strip()] | |
| print(parts) | |
| content = [] | |
| for part in parts: | |
| if part == '<image>': | |
| content.append({"type": "image_url", "image_url": {"url": image}}) | |
| else: | |
| content.append({"type": "text", "text": part}) | |
| else: | |
| content = [{ | |
| "type": "image_url", | |
| "image_url": { | |
| "url": image | |
| } | |
| }, { | |
| "type": "text", | |
| "text": text | |
| }] | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": content, | |
| "bbox_list": bboxes | |
| } | |
| ] | |
| generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages, | |
| max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False, image_size=1024) | |
| with torch.inference_mode(): | |
| output_ids = model.generate(**generation_kwargs) | |
| outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip() | |
| print("========output========\n", outputs) | |
| if '<ground>' in outputs: | |
| prediction_dict = extract_predictions_to_indexes(outputs) | |
| else: | |
| match_pattern = r"<region(\d+)>" | |
| matches = re.findall(match_pattern, outputs) | |
| prediction_dict = {f"<region{m}>": {int(m)} for m in matches} | |
| ans_bbox_json = [] | |
| ans_bbox_list = [] | |
| for k, v in prediction_dict.items(): | |
| for box_index in v: | |
| box_index = int(box_index) | |
| if box_index < len(bboxes): | |
| current_bbox = bboxes[box_index] | |
| item = { | |
| "region_index": f"<region{box_index}>", | |
| "xmin": current_bbox[0], | |
| "ymin": current_bbox[1], | |
| "xmax": current_bbox[2], | |
| "ymax": current_bbox[3], | |
| "label": k, | |
| } | |
| if scores is not None and box_index < len(scores): | |
| item["score"] = scores[box_index] | |
| ans_bbox_json.append(item) | |
| ans_bbox_list.append(current_bbox) | |
| return outputs, ans_bbox_json, ans_bbox_list | |
| def draw_sam3_results(img, results): | |
| fig, ax = plt.subplots(figsize=(12, 8)) | |
| # fig.subplots_adjust(0, 0, 1, 1) | |
| ax.imshow(img) | |
| nb_objects = len(results["scores"]) | |
| print(f"found {nb_objects} object(s)") | |
| for i in range(nb_objects): | |
| color = COLORS[i % len(COLORS)] | |
| plot_mask(results["masks"][i].squeeze(0).cpu(), color=color) | |
| w, h = img.size | |
| prob = results["scores"][i].item() | |
| plot_bbox( | |
| h, | |
| w, | |
| results["boxes"][i].cpu(), | |
| text=f"(id={i}, {prob=:.2f})", | |
| box_format="XYXY", | |
| color=color, | |
| relative_coords=False, | |
| ) | |
| ax.axis("off") | |
| fig.tight_layout(pad=0) | |
| # Convert matplotlib figure to PIL Image | |
| fig.canvas.draw() | |
| buf = fig.canvas.buffer_rgba() | |
| pil_img = Image.frombytes('RGBA', fig.canvas.get_width_height(), buf) | |
| plt.close(fig) | |
| return pil_img | |
| def draw_bboxes_simple(image, bboxes, labels=None): | |
| image = image.copy() | |
| draw = ImageDraw.Draw(image) | |
| for bbox in bboxes: | |
| draw.rectangle(bbox, outline="red", width=3) | |
| return image | |
| def process(image, prompt, threshold=0.3): | |
| if image is None: | |
| error_msg = "Error: Please upload an image or select a valid example." | |
| print(f"Error: image is None, original input type: {type(image)}") | |
| return None, None, None, None, [], [] | |
| try: | |
| image = image.convert('RGB') | |
| except Exception as e: | |
| error_msg = f"Error: Cannot process image - {str(e)}" | |
| return None, None, None, None, [], [] | |
| # --- SAM3 Pipeline --- | |
| print("Running SAM3 Pipeline...") | |
| sam3_bboxes, sam3_scores, masks, sam3_output = detect_model_sam3(image, prompt, threshold) | |
| # Generate SAM3 outputs (Directly from SAM3, no VLM-FO1) | |
| sam3_detection_image = draw_sam3_results(image, sam3_output) | |
| sam3_annotated_bboxes = [] | |
| sam3_ans_bbox_json = [] | |
| img_width, img_height = image.size | |
| for i, bbox in enumerate(sam3_bboxes): | |
| xmin = max(0, min(img_width, int(bbox[0]))) | |
| ymin = max(0, min(img_height, int(bbox[1]))) | |
| xmax = max(0, min(img_width, int(bbox[2]))) | |
| ymax = max(0, min(img_height, int(bbox[3]))) | |
| score = sam3_scores[i] | |
| # Format label with score | |
| label_text = f"{prompt} {score:.2f}" | |
| sam3_annotated_bboxes.append( | |
| ((xmin, ymin, xmax, ymax), label_text) | |
| ) | |
| sam3_ans_bbox_json.append({ | |
| "region_index": i, | |
| "xmin": bbox[0], | |
| "ymin": bbox[1], | |
| "xmax": bbox[2], | |
| "ymax": bbox[3], | |
| "label": prompt, | |
| "score": score | |
| }) | |
| sam3_annotated_image = (image, sam3_annotated_bboxes) | |
| # --- UPN Pipeline --- | |
| print("Running UPN Pipeline...") | |
| upn_bboxes = detect_model_upn(image, threshold=0.3) # Use default threshold for UPN | |
| fo1_prompt_upn = OD_template.format(prompt) | |
| upn_bboxes = upn_bboxes[::-1] | |
| upn_ans, upn_ans_bbox_json, upn_ans_bbox_list = multimodal_model(image, upn_bboxes, fo1_prompt_upn) | |
| upn_detection_image = draw_bboxes_simple(image, upn_bboxes) | |
| upn_annotated_bboxes = [] | |
| if len(upn_ans_bbox_json) > 0: | |
| img_width, img_height = image.size | |
| for item in upn_ans_bbox_json: | |
| xmin = max(0, min(img_width, int(item['xmin']))) | |
| ymin = max(0, min(img_height, int(item['ymin']))) | |
| xmax = max(0, min(img_width, int(item['xmax']))) | |
| ymax = max(0, min(img_height, int(item['ymax']))) | |
| upn_annotated_bboxes.append( | |
| ((xmin, ymin, xmax, ymax), item['label']) | |
| ) | |
| upn_annotated_image = (image, upn_annotated_bboxes) | |
| return sam3_annotated_image, sam3_detection_image, \ | |
| upn_annotated_image, upn_detection_image, upn_ans_bbox_json | |
| def update_btn(is_processing): | |
| if is_processing: | |
| return gr.update(value="Processing...", interactive=False) | |
| else: | |
| return gr.update(value="Submit", interactive=True) | |
| def launch_demo(): | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# π VLM-FO1 vs SAM3 Demo") | |
| gr.Markdown(""" | |
| ### π Instructions | |
| Compare the detection performance of **SAM3** vs **VLM-FO1**. | |
| **How it works** | |
| 1. Upload or pick an example image. | |
| 2. Describe the target object in natural language. | |
| 3. Hit **Submit** to run both pipelines. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_input_draw = gr.Image( | |
| label="Image Input", | |
| type="pil", | |
| sources=['upload'], | |
| ) | |
| gr.Markdown("### Prompt") | |
| prompt_input = gr.Textbox( | |
| label="Label Prompt", | |
| lines=2, | |
| ) | |
| submit_btn = gr.Button("Submit", variant="primary") | |
| examples = gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=[img_input_draw, prompt_input], | |
| label="Click to load example", | |
| examples_per_page=5 | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("### SAM3 Result") | |
| with gr.Accordion("SAM3 Masks & Boxes", open=False): | |
| sam3_detection_output = gr.Image(label="SAM3 Visualization", height=300) | |
| sam3_final_output = gr.AnnotatedImage(label="SAM3 Detections", height=400) | |
| # sam3_json_output = gr.JSON(label="SAM3 Output Data") | |
| with gr.Column(): | |
| gr.Markdown("### VLM-FO1 Result") | |
| with gr.Accordion("Bboxes Proposals", open=False): | |
| upn_detection_output = gr.Image(label="Bboxes", height=300) | |
| upn_final_output = gr.AnnotatedImage(label="VLM-FO1 Final", height=400) | |
| upn_json_output = gr.JSON(label="VLM-FO1 Details") | |
| submit_btn.click( | |
| update_btn, | |
| inputs=[gr.State(True)], | |
| outputs=[submit_btn], | |
| queue=False | |
| ).then( | |
| process, | |
| inputs=[img_input_draw, prompt_input], | |
| outputs=[ | |
| sam3_final_output, sam3_detection_output, | |
| upn_final_output, upn_detection_output, upn_json_output | |
| ], | |
| queue=True | |
| ).then( | |
| update_btn, | |
| inputs=[gr.State(False)], | |
| outputs=[submit_btn], | |
| queue=False | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| import os | |
| exit_code = os.system(f"wget -c https://airesources.oss-cn-hangzhou.aliyuncs.com/lp/wheel/sam3.pt") | |
| model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01' | |
| # sam3_model_path = './resources/sam3/sam3.pt' | |
| upn_ckpt_path = "./resources/upn_large.pth" | |
| # Load FO1 | |
| tokenizer, model, image_processors = load_pretrained_model( | |
| model_path=model_path, | |
| device="cuda:0", | |
| ) | |
| # Load SAM3 | |
| sam3_model = build_sam3_image_model(checkpoint_path='./sam3.pt', device="cuda",bpe_path='/home/user/app/resources/bpe_simple_vocab_16e6.txt.gz') | |
| sam3_processor = Sam3Processor(sam3_model, confidence_threshold=0.5, device="cuda") | |
| # Load UPN | |
| upn_model = UPNWrapper(upn_ckpt_path) | |
| demo = launch_demo() | |
| demo.launch() | |