import gradio as gr import spaces from PIL import Image, ImageDraw, ImageFont import re import numpy as np from skimage.measure import label, regionprops from skimage.morphology import binary_dilation, disk from detect_tools.upn import UPNWrapper from vlm_fo1.model.builder import load_pretrained_model from vlm_fo1.mm_utils import ( prepare_inputs, extract_predictions_to_indexes, ) from vlm_fo1.task_templates import * import torch import os from copy import deepcopy TASK_TYPES = { "OD/REC": OD_template, "ODCounting": OD_Counting_template, "Region_OCR": "Please provide the ocr results of these regions in the image.", "Brief_Region_Caption": "Provide a brief description for these regions in the image.", "Detailed_Region_Caption": "Provide a detailed description for these regions in the image.", "Viusal_Region_Reasoning": Viusal_Region_Reasoning_template, "OD_All": OD_All_template, "Grounding": Grounding_template, } EXAMPLES = [ ["demo_image.jpg", TASK_TYPES["OD/REC"].format("orange, apple"), "OD/REC"], ["demo_image_01.jpg", TASK_TYPES["ODCounting"].format("airplane with only one propeller"), "ODCounting"], ["demo_image_02.jpg", TASK_TYPES["OD/REC"].format("the ball closest to the bear"), "OD/REC"], ["demo_image_03.jpg", TASK_TYPES["OD_All"].format(""), "OD_All"], ["demo_image_03.jpg", TASK_TYPES["Viusal_Region_Reasoning"].format("What's the brand of this computer?"), "Viusal_Region_Reasoning"], ] def get_valid_examples(): valid_examples = [] demo_dir = os.path.dirname(os.path.abspath(__file__)) for example in EXAMPLES: img_path = example[0] full_path = os.path.join(demo_dir, img_path) if os.path.exists(full_path): valid_examples.append([ full_path, example[1], example[2] ]) elif os.path.exists(img_path): valid_examples.append([ img_path, example[1], example[2] ]) return valid_examples def detect_model(image, threshold=0.3): proposals = upn_model.inference(image) filtered_proposals = upn_model.filter(proposals, min_score=threshold) return filtered_proposals['original_xyxy_boxes'][0][:100] def multimodal_model(image, bboxes, text): if '' in text: print(text) parts = [part.replace('\\n', '\n') for part in re.split(rf'()', text) if part.strip()] print(parts) content = [] for part in parts: if part == '': content.append({"type": "image_url", "image_url": {"url": image}}) else: content.append({"type": "text", "text": part}) else: content = [{ "type": "image_url", "image_url": { "url": image } }, { "type": "text", "text": text }] messages = [ { "role": "user", "content": content, "bbox_list": bboxes } ] generation_kwargs = prepare_inputs(model_path, model, image_processors, tokenizer, messages, max_tokens=4096, top_p=0.05, temperature=0.0, do_sample=False) with torch.inference_mode(): output_ids = model.generate(**generation_kwargs) outputs = tokenizer.decode(output_ids[0, generation_kwargs['inputs'].shape[1]:]).strip() print("========output========\n", outputs) if '' in outputs: prediction_dict = extract_predictions_to_indexes(outputs) else: match_pattern = r"" matches = re.findall(match_pattern, outputs) prediction_dict = {f"": {int(m)} for m in matches} ans_bbox_json = [] ans_bbox_list = [] for k, v in prediction_dict.items(): for box_index in v: box_index = int(box_index) if box_index < len(bboxes): current_bbox = bboxes[box_index] ans_bbox_json.append({ "region_index": f"", "xmin": current_bbox[0], "ymin": current_bbox[1], "xmax": current_bbox[2], "ymax": current_bbox[3], "label": k }) ans_bbox_list.append(current_bbox) return outputs, ans_bbox_json, ans_bbox_list def draw_bboxes(image, bboxes, labels=None): image = image.copy() draw = ImageDraw.Draw(image) for bbox in bboxes: draw.rectangle(bbox, outline="red", width=3) return image def extract_bbox_and_original_image(edited_image): """Extract original image and bounding boxes from ImageEditor output""" if edited_image is None: return None, [] if isinstance(edited_image, dict): original_image = edited_image.get("background") bbox_list = [] if original_image is None: return None, [] if edited_image.get("layers") is None or len(edited_image.get("layers", [])) == 0: return original_image, [] try: drawing_layer = edited_image["layers"][0] alpha_channel = drawing_layer.getchannel('A') alpha_np = np.array(alpha_channel) binary_mask = alpha_np > 0 structuring_element = disk(5) dilated_mask = binary_dilation(binary_mask, structuring_element) labeled_image = label(dilated_mask) regions = regionprops(labeled_image) for prop in regions: y_min, x_min, y_max, x_max = prop.bbox bbox_list.append((x_min, y_min, x_max, y_max)) except Exception as e: print(f"Error extracting bboxes from layers: {e}") return original_image, [] return original_image, bbox_list elif isinstance(edited_image, Image.Image): return edited_image, [] else: print(f"Unknown input type: {type(edited_image)}") return None, [] @spaces.GPU def process(image, example_image, prompt, threshold): image, bbox_list = extract_bbox_and_original_image(image) if example_image is not None: image = example_image if image is None: error_msg = "Error: Please upload an image or select a valid example." print(f"Error: image is None, original input type: {type(image)}") return None, None, error_msg, [] try: image = image.convert('RGB') except Exception as e: error_msg = f"Error: Cannot process image - {str(e)}" return None, None, error_msg, [] if len(bbox_list) == 0: bboxes = detect_model(image, threshold) else: bboxes = bbox_list for idx in range(len(bboxes)): prompt += f'' ans, ans_bbox_json, ans_bbox_list = multimodal_model(image, bboxes, prompt) image_with_opn = draw_bboxes(image, bboxes) annotated_bboxes = [] if len(ans_bbox_json) > 0: for item in ans_bbox_json: annotated_bboxes.append( ((int(item['xmin']), int(item['ymin']), int(item['xmax']), int(item['ymax'])), item['label']) ) annotated_image = (image, annotated_bboxes) return annotated_image, image_with_opn, ans, ans_bbox_json def show_label_input(choice): return gr.update(visible=(choice == "OmDet")) def update_btn(is_processing): if is_processing: return gr.update(value="Processing...", interactive=False) else: return gr.update(value="Submit", interactive=True) def launch_demo(): with gr.Blocks() as demo: gr.Markdown("# 🚀 VLM-FO1 Demo") gr.Markdown(""" ### 📋 Instructions **Step 1: Prepare Your Image** - Upload an image using the image editor below - *Optional:* Draw circular regions with the red brush to specify areas of interest - *Alternative:* If not drawing regions, the detection model will automatically identify regions **Step 2: Configure Your Task** - Select a task template from the dropdown menu - Replace `[WRITE YOUR INPUT HERE]` with your target objects or query - *Example:* For detecting "person" and "dog", replace with: `person, dog` - *Or:* Write your own custom prompt **Step 3: Fine-tune Detection** *(Optional)* - Adjust the detection threshold slider to control sensitivity **Step 4: Generate Results** - Click the **Submit** button to process your request - View the detection results and model outputs below 🔗 [GitHub Repository](https://github.com/om-ai-lab/VLM-FO1) """) with gr.Row(): with gr.Column(): img_input_draw = gr.ImageEditor( label="Image Input", image_mode="RGBA", type="pil", sources=['upload'], brush=gr.Brush(colors=["#FF0000"], color_mode="fixed", default_size=2), interactive=True ) gr.Markdown("### Prompt & Parameters") def set_prompt_from_template(selected_task): return gr.update(value=TASK_TYPES[selected_task].format("[WRITE YOUR INPUT HERE]")) def load_example(prompt_input, task_type_input, hidden_image_box): cached_image = deepcopy(hidden_image_box) w, h = cached_image.size transparent_layer = Image.new('RGBA', (w, h), (0, 0, 0, 0)) new_editor_value = { "background": cached_image, "layers": [transparent_layer], "composite": None } return new_editor_value, prompt_input, task_type_input def reset_hidden_image_box(): return gr.update(value=None) task_type_input = gr.Dropdown( choices=list(TASK_TYPES.keys()), value="OD/REC", label="Prompt Templates", info="Select the prompt template for the task, or write your own prompt." ) prompt_input = gr.Textbox( label="Task Prompt", value=TASK_TYPES["OD/REC"].format("[WRITE YOUR INPUT HERE]"), lines=2, ) task_type_input.select( set_prompt_from_template, inputs=task_type_input, outputs=prompt_input ) hidden_image_box = gr.Image(label="Image", type="pil", image_mode="RGBA", visible=False) threshold_input = gr.Slider(minimum=0.0, maximum=1.0, value=0.3, step=0.01, label="Detection Model Threshold") submit_btn = gr.Button("Submit", variant="primary") valid_examples = get_valid_examples() if len(valid_examples) > 0: gr.Markdown("### Examples") gr.Markdown("Click on the examples below to quickly load images and corresponding prompts:") examples_data = [[example[0], example[1], example[2]] for index, example in enumerate(valid_examples)] examples = gr.Examples( examples=examples_data, inputs=[hidden_image_box, prompt_input, task_type_input], label="Click to load example", examples_per_page=5 ) examples.load_input_event.then( fn=load_example, inputs=[prompt_input, task_type_input, hidden_image_box], outputs=[img_input_draw, prompt_input, task_type_input] ) img_input_draw.upload( fn=reset_hidden_image_box, outputs=[hidden_image_box] ) with gr.Column(): with gr.Accordion("Detection Result", open=True): image_output_opn = gr.Image(label="Detection Result", height=200) image_output = gr.AnnotatedImage(label="VLM-FO1 Result", height=400) result_output = gr.Textbox(label="VLM-FO1 Output", lines=5) ans_bbox_json = gr.JSON(label="Extracted Detection Output") submit_btn.click( update_btn, inputs=[gr.State(True)], outputs=[submit_btn], queue=False ).then( process, inputs=[img_input_draw, hidden_image_box, prompt_input, threshold_input], outputs=[image_output, image_output_opn, result_output, ans_bbox_json], queue=True ).then( update_btn, inputs=[gr.State(False)], outputs=[submit_btn], queue=False ) return demo if __name__ == "__main__": model_path = 'omlab/VLM-FO1_Qwen2.5-VL-3B-v01' upn_ckpt_path = "./resources/upn_large.pth" tokenizer, model, image_processors = load_pretrained_model( model_path=model_path, device="cuda:0", ) upn_model = UPNWrapper(upn_ckpt_path) demo = launch_demo() demo.launch()