import functools import math import os from collections import defaultdict import numpy as np import PIL import torch from PIL import Image, ImageDraw, ImageFile from transformers import AutoModelForImageTextToText, AutoProcessor import gradio as gr import spaces from molmo_utils import process_vision_info from typing import Iterable from gradio.themes import Soft from gradio.themes.utils import colors, fonts, sizes Image.MAX_IMAGE_PIXELS = None ImageFile.LOAD_TRUNCATED_IMAGES = True # ── Constants ────────────────────────────────────────────────────────────────── MODEL_ID = "allenai/MolmoPoint-GUI-8B" MAX_IMAGE_SIZE = 512 POINT_SIZE = 0.01 MAX_NEW_TOKENS = 2048 COLORS = [ "rgb(255, 100, 180)", "rgb(100, 180, 255)", "rgb(180, 255, 100)", "rgb(255, 180, 100)", "rgb(100, 255, 180)", "rgb(180, 100, 255)", "rgb(255, 255, 100)", "rgb(100, 255, 255)", "rgb(255, 120, 120)", "rgb(120, 255, 255)", "rgb(255, 255, 120)", "rgb(255, 120, 255)", ] # ── Model loading ────────────────────────────────────────────────────────────── print(f"Loading {MODEL_ID}...") processor = AutoProcessor.from_pretrained( MODEL_ID, trust_remote_code=True, padding_side="left", ) model = AutoModelForImageTextToText.from_pretrained( MODEL_ID, trust_remote_code=True, dtype="bfloat16", device_map="auto", ) print("Model loaded successfully.") # ── Helper functions ─────────────────────────────────────────────────────────── def cast_float_bf16(t: torch.Tensor): if torch.is_floating_point(t): t = t.to(torch.bfloat16) return t def draw_points(image, points): if isinstance(image, np.ndarray): annotation = PIL.Image.fromarray(image) else: annotation = image.copy() draw = ImageDraw.Draw(annotation) w, h = annotation.size size = max(5, int(max(w, h) * POINT_SIZE)) for i, (x, y) in enumerate(points): color = COLORS[0] draw.ellipse((x - size, y - size, x + size, y + size), fill=color, outline=None) return annotation def format_points_list(points): """Format extracted points as a flat Python list string.""" if not points: return "[]" rows = [] for object_id, ix, x, y in points: rows.append(f"[{int(object_id)}, {int(ix)}, {float(x):.1f}, {float(y):.1f}]") return "[" + ", ".join(rows) + "]" # ── Inference functions ──────────────────────────────────────────────────────── @spaces.GPU def process_images(user_text, input_images, max_tokens): if not input_images: return "Please upload at least one image.", [], "[]" pil_images = [] for img_path in input_images: if isinstance(img_path, tuple): img_path = img_path[0] pil_images.append(Image.open(img_path).convert("RGB")) # Build messages content = [dict(type="text", text=user_text)] for img in pil_images: content.append(dict(type="image", image=img)) messages = [{"role": "user", "content": content}] # Process inputs images, _, _ = process_vision_info(messages) text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) print(f"Prompt: {text}") inputs = processor( images=images, text=text, padding=True, return_tensors="pt", return_pointing_metadata=True, ) metadata = inputs.pop("metadata") inputs = {k: cast_float_bf16(v.to(model.device)) for k, v in inputs.items()} # Generate with torch.inference_mode(): with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16): output = model.generate( **inputs, logits_processor=model.build_logit_processor_from_inputs(inputs), max_new_tokens=int(max_tokens), temperature=0 ) generated_tokens = output[0, inputs["input_ids"].size(1):] generated_text = processor.decode(generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=False) # Extract points points = model.extract_image_points( generated_text, metadata["token_pooling"], metadata["subpatch_mapping"], metadata["image_sizes"], ) points_table = format_points_list(points) print(f"Output text: {generated_text}") print("Extracted points:", points_table) if points: group_by_index = defaultdict(list) for object_id, ix, x, y in points: group_by_index[ix].append((x, y)) annotated = [] for ix, pts in group_by_index.items(): annotated.append(draw_points(images[ix], pts)) return generated_text, annotated, points_table return generated_text, pil_images, points_table # ── Gradio UI ────────────────────────────────────────────────────────────────── css = """ #col-container { margin: 0 auto; max-width: 960px; } #main-title h1 {font-size: 2.3em !important;} #input_image image { object-fit: contain !important; } .gallery-item img { border: none !important; outline: none !important; } """ with gr.Blocks() as demo: gr.Markdown("# **MolmoPoint-GUI-8B Demo (GUI-Specialized)**", elem_id="main-title") gr.Markdown( "Single-point prediction on GUI screenshots using the " "[MolmoPoint-GUI-8B](https://huggingface.co/allenai/MolmoPoint-GUI-8B) model. " "Given a natural language instruction, the model predicts the single UI element to click." ) with gr.Row(): # ── LEFT COLUMN: Inputs ── with gr.Column(): images_input = gr.Gallery( label="Input Images", elem_id="input_image", type="filepath", height=MAX_IMAGE_SIZE, ) input_text = gr.Textbox(placeholder="Enter the prompt", label="Input text") max_tok_slider = gr.Slider(label="max_tokens", minimum=1, maximum=4096, step=1, value=MAX_NEW_TOKENS) with gr.Row(): submit_button = gr.Button("Submit", variant="primary", scale=3) clear_all_button = gr.ClearButton( components=[images_input, input_text], value="Clear All", scale=1, ) # ── RIGHT COLUMN: Outputs ── with gr.Column(): with gr.Tabs(): with gr.TabItem("Output Text"): output_text = gr.Textbox(placeholder="Output text", label="Output text", lines=10) with gr.TabItem("Extracted Points"): output_points = gr.Textbox( label="Extracted Points ([[id, index, x, y]])", lines=15, ) with gr.Group(): gr.Markdown("*Click a frame to zoom in. Press Esc to go back.*") output_annotations_img = gr.Gallery(label="Annotated Images", height=MAX_IMAGE_SIZE) # ── Examples ── with gr.Group(): gr.Markdown("### Image Examples") gr.Examples( examples=[ [["example-images/example-1.png"], "open the attachment folder"], [["example-images/example-2.png"], "check new york knicks"], [["example-images/example-3.jpg"], "change the smoothing percentage"], [["example-images/example-4.png"], "click the cell F-11"], [["example-images/example-5.png"], "point to section 303"], [["example-images/example-6.jpg"], "change profile photo"], ], inputs=[images_input, input_text], label="Image Pointing Examples", ) submit_button.click( fn=process_images, inputs=[input_text, images_input, max_tok_slider], outputs=[output_text, output_annotations_img, output_points], ) if __name__ == "__main__": demo.launch(css=css, mcp_server=True, ssr_mode=False, show_error=True, share=True)