Spaces:
Running on Zero
Running on Zero
| import functools | |
| import math | |
| import os | |
| from collections import defaultdict | |
| import numpy as np | |
| import PIL | |
| import torch | |
| from PIL import Image, ImageDraw, ImageFile | |
| from transformers import AutoModelForImageTextToText, AutoProcessor | |
| import gradio as gr | |
| import spaces | |
| from molmo_utils import process_vision_info | |
| from typing import Iterable | |
| from gradio.themes import Soft | |
| from gradio.themes.utils import colors, fonts, sizes | |
| Image.MAX_IMAGE_PIXELS = None | |
| ImageFile.LOAD_TRUNCATED_IMAGES = True | |
| # ββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_ID = "allenai/MolmoPoint-GUI-8B" | |
| MAX_IMAGE_SIZE = 512 | |
| POINT_SIZE = 0.01 | |
| MAX_NEW_TOKENS = 2048 | |
| COLORS = [ | |
| "rgb(255, 100, 180)", | |
| "rgb(100, 180, 255)", | |
| "rgb(180, 255, 100)", | |
| "rgb(255, 180, 100)", | |
| "rgb(100, 255, 180)", | |
| "rgb(180, 100, 255)", | |
| "rgb(255, 255, 100)", | |
| "rgb(100, 255, 255)", | |
| "rgb(255, 120, 120)", | |
| "rgb(120, 255, 255)", | |
| "rgb(255, 255, 120)", | |
| "rgb(255, 120, 255)", | |
| ] | |
| # ββ Model loading ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| print(f"Loading {MODEL_ID}...") | |
| processor = AutoProcessor.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| padding_side="left", | |
| ) | |
| model = AutoModelForImageTextToText.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True, | |
| dtype="bfloat16", | |
| device_map="auto", | |
| ) | |
| print("Model loaded successfully.") | |
| # ββ Helper functions βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cast_float_bf16(t: torch.Tensor): | |
| if torch.is_floating_point(t): | |
| t = t.to(torch.bfloat16) | |
| return t | |
| def draw_points(image, points): | |
| if isinstance(image, np.ndarray): | |
| annotation = PIL.Image.fromarray(image) | |
| else: | |
| annotation = image.copy() | |
| draw = ImageDraw.Draw(annotation) | |
| w, h = annotation.size | |
| size = max(5, int(max(w, h) * POINT_SIZE)) | |
| for i, (x, y) in enumerate(points): | |
| color = COLORS[0] | |
| draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None) | |
| return annotation | |
| def format_points_list(points): | |
| """Format extracted points as a flat Python list string.""" | |
| if not points: | |
| return "[]" | |
| rows = [] | |
| for object_id, ix, x, y in points: | |
| rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]") | |
| return "[" + ", ".join(rows) + "]" | |
| # ββ Inference functions ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def process_images(user_text, input_images, max_tokens): | |
| if not input_images: | |
| return "Please upload at least one image.", [], "[]" | |
| pil_images = [] | |
| for img_path in input_images: | |
| if isinstance(img_path, tuple): | |
| img_path = img_path[0] | |
| pil_images.append(Image.open(img_path).convert("RGB")) | |
| # Build messages | |
| content = [dict(type="text", text=user_text)] | |
| for img in pil_images: | |
| content.append(dict(type="image", image=img)) | |
| messages = [{"role": "user", "content": content}] | |
| # Process inputs | |
| images, _, _ = process_vision_info(messages) | |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| print(f"Prompt: {text}") | |
| inputs = processor( | |
| images=images, | |
| text=text, | |
| padding=True, | |
| return_tensors="pt", | |
| return_pointing_metadata=True, | |
| ) | |
| metadata = inputs.pop("metadata") | |
| inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} | |
| # Generate | |
| with torch.inference_mode(): | |
| with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): | |
| output = model.generate( | |
| **inputs, | |
| logits_processor=model.build_logit_processor_from_inputs(inputs), | |
| max_new_tokens=int(max_tokens), | |
| temperature=0 | |
| ) | |
| generated_tokens = output[0, inputs["input_ids"].size(1):] | |
| generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) | |
| # Extract points | |
| points = model.extract_image_points( | |
| generated_text, | |
| metadata["token_pooling"], | |
| metadata["subpatch_mapping"], | |
| metadata["image_sizes"], | |
| ) | |
| points_table = format_points_list(points) | |
| print(f"Output text: {generated_text}") | |
| print("Extracted points:", points_table) | |
| if points: | |
| group_by_index = defaultdict(list) | |
| for object_id, ix, x, y in points: | |
| group_by_index[ix].append((x, y)) | |
| annotated = [] | |
| for ix, pts in group_by_index.items(): | |
| annotated.append(draw_points(images[ix], pts)) | |
| return generated_text, annotated, points_table | |
| return generated_text, pil_images, points_table | |
| # ββ Gradio UI ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 960px; | |
| } | |
| #main-title h1 {font-size: 2.3em !important;} | |
| #input_image image { | |
| object-fit: contain !important; | |
| } | |
| .gallery-item img { | |
| border: none !important; | |
| outline: none !important; | |
| } | |
| """ | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# **MolmoPoint-GUI-8B Demo (GUI-Specialized)**", elem_id="main-title") | |
| gr.Markdown( | |
| "Single-point prediction on GUI screenshots using the " | |
| "[MolmoPoint-GUI-8B](https://huggingface.co/allenai/MolmoPoint-GUI-8B) model. " | |
| "Given a natural language instruction, the model predicts the single UI element to click." | |
| ) | |
| with gr.Row(): | |
| # ββ LEFT COLUMN: Inputs ββ | |
| with gr.Column(): | |
| images_input = gr.Gallery( | |
| label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE, | |
| ) | |
| input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text") | |
| max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS) | |
| with gr.Row(): | |
| submit_button = gr.Button("Submit", variant="primary", scale=3) | |
| clear_all_button = gr.ClearButton( | |
| components=[images_input, input_text], value="Clear All", scale=1, | |
| ) | |
| # ββ RIGHT COLUMN: Outputs ββ | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| with gr.TabItem("Output Text"): | |
| output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10) | |
| with gr.TabItem("Extracted Points"): | |
| output_points = gr.Textbox( | |
| label="Extracted Points ([[id, index, x, y]])", lines=15, | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") | |
| output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE) | |
| # ββ Examples ββ | |
| with gr.Group(): | |
| gr.Markdown("### Image Examples") | |
| gr.Examples( | |
| examples=[ | |
| [["example-images/example-1.png"], "open the attachment folder"], | |
| [["example-images/example-2.png"], "check new york knicks"], | |
| [["example-images/example-3.jpg"], "change the smoothing percentage"], | |
| [["example-images/example-4.png"], "click the cell F-11"], | |
| [["example-images/example-5.png"], "point to section 303"], | |
| [["example-images/example-6.jpg"], "change profile photo"], | |
| ], | |
| inputs=[images_input, input_text], | |
| label="Image Pointing Examples", | |
| ) | |
| submit_button.click( | |
| fn=process_images, | |
| inputs=[input_text, images_input, max_tok_slider], | |
| outputs=[output_text, output_annotations_img, output_points], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True, share=True) | |