|
|
import os |
|
|
import re |
|
|
import traceback |
|
|
from datetime import datetime |
|
|
from typing import Any, Literal |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import requests |
|
|
import spaces |
|
|
import torch |
|
|
from PIL import Image, ImageDraw |
|
|
from pydantic import BaseModel, Field |
|
|
from transformers import AutoProcessor |
|
|
from transformers.models.auto.modeling_auto import AutoModelForImageTextToText |
|
|
from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize |
|
|
|
|
|
|
|
|
MODEL_ID = "Hcompany/Holo1.5-7B" |
|
|
|
|
|
|
|
|
print(f"Loading model and processor for {MODEL_ID}...") |
|
|
model = None |
|
|
processor = None |
|
|
model_loaded = False |
|
|
load_error_message = "" |
|
|
|
|
|
|
|
|
try: |
|
|
model = AutoModelForImageTextToText.from_pretrained( |
|
|
MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True |
|
|
).to("cuda") |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True) |
|
|
|
|
|
model_loaded = True |
|
|
print("Model and processor loaded successfully.") |
|
|
except Exception as e: |
|
|
load_error_message = ( |
|
|
f"Error loading model/processor: {e}\n" |
|
|
"This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n" |
|
|
"Ensure you have a stable internet connection and the necessary libraries installed." |
|
|
) |
|
|
print(load_error_message) |
|
|
|
|
|
|
|
|
title = "Holo1.5-7B: Localization VLM Demo" |
|
|
|
|
|
description = """ |
|
|
This demo showcases [**Holo1.5-7B**](https://huggingface.co/Hcompany/Holo1.5-7B), a new version of the Action Vision-Language Model developed by HCompany, fine-tuned from Qwen/Qwen2.5-VL-7B-Instruct. |
|
|
It's designed to perform complex navigation tasks in Web, Android, and Desktop interfaces. |
|
|
**How to use:** |
|
|
1. Upload an image (e.g., a screenshot of a UI, see example below). |
|
|
2. Provide a target UI element (e.g., "Docs tab"). |
|
|
3. The model will predict the coordinates of the element on the screenshot. |
|
|
The model processor resizes your input image. Coordinates are relative to this resized image. |
|
|
""" |
|
|
|
|
|
|
|
|
def array_to_image_path(image_array): |
|
|
if image_array is None: |
|
|
raise ValueError("No image provided. Please upload an image before submitting.") |
|
|
|
|
|
img = Image.fromarray(np.uint8(image_array)) |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
filename = f"image_{timestamp}.png" |
|
|
|
|
|
|
|
|
img.save(filename) |
|
|
|
|
|
|
|
|
full_path = os.path.abspath(filename) |
|
|
|
|
|
return full_path |
|
|
|
|
|
|
|
|
LOCALIZATION_PROMPT: str = """Localize an element on the GUI image according to the provided target and output a click position. |
|
|
* Only output the click position, do not output any other text. |
|
|
* The click position should be in the format 'Click(x, y)' with x: num pixels from the left edge and y: num pixels from the top edge |
|
|
Your target is:""" |
|
|
|
|
|
|
|
|
class ClickAbsoluteAction(BaseModel): |
|
|
"""Click at absolute coordinates.""" |
|
|
|
|
|
action: Literal["click_absolute"] = "click_absolute" |
|
|
x: int = Field(description="The x coordinate, number of pixels from the left edge.") |
|
|
y: int = Field(description="The y coordinate, number of pixels from the top edge.") |
|
|
|
|
|
|
|
|
def get_localization_prompt(component, image, step=1): |
|
|
""" |
|
|
Get the prompt for the localization task. |
|
|
- component: The component to localize |
|
|
- image: The current screenshot of the web page |
|
|
- step: The current step of the task |
|
|
""" |
|
|
return [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "image", |
|
|
"image": image, |
|
|
}, |
|
|
{"type": "text", "text": LOCALIZATION_PROMPT + "\n" + component}, |
|
|
], |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
def array_to_image(image_array: np.ndarray) -> Image.Image: |
|
|
if image_array is None: |
|
|
raise ValueError("No image provided. Please upload an image before submitting.") |
|
|
|
|
|
img = Image.fromarray(np.uint8(image_array)) |
|
|
return img |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=20) |
|
|
def run_inference_localization( |
|
|
messages_for_template: list[dict[str, Any]], pil_image_for_processing: Image.Image |
|
|
) -> str: |
|
|
model.to("cuda") |
|
|
torch.cuda.set_device(0) |
|
|
""" |
|
|
Runs inference using the Holo1 model. |
|
|
- messages_for_template: The prompt structure, potentially including the PIL image object |
|
|
(which apply_chat_template converts to an image tag). |
|
|
- pil_image_for_processing: The actual PIL image to be processed into tensors. |
|
|
""" |
|
|
|
|
|
|
|
|
text_prompt = processor.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True) |
|
|
|
|
|
|
|
|
inputs = processor( |
|
|
text=[text_prompt], |
|
|
images=[pil_image_for_processing], |
|
|
padding=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
inputs = inputs.to(model.device) |
|
|
|
|
|
|
|
|
|
|
|
generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False) |
|
|
|
|
|
|
|
|
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)] |
|
|
|
|
|
|
|
|
decoded_output = processor.batch_decode( |
|
|
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False |
|
|
) |
|
|
|
|
|
return decoded_output[0] if decoded_output else "" |
|
|
|
|
|
|
|
|
|
|
|
def localize(input_numpy_image: np.ndarray, task: str) -> str: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_pil_image = array_to_image(input_numpy_image) |
|
|
assert isinstance(input_pil_image, Image.Image) |
|
|
image_proc_config = processor.image_processor |
|
|
try: |
|
|
resized_height, resized_width = smart_resize( |
|
|
input_pil_image.height, |
|
|
input_pil_image.width, |
|
|
factor=image_proc_config.patch_size * image_proc_config.merge_size, |
|
|
min_pixels=image_proc_config.min_pixels, |
|
|
max_pixels=image_proc_config.max_pixels, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
resized_image = input_pil_image.resize( |
|
|
size=(resized_width, resized_height), |
|
|
resample=Image.Resampling.LANCZOS, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error resizing image: {e}") |
|
|
return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB") |
|
|
|
|
|
|
|
|
prompt = get_localization_prompt(task, resized_image, step=1) |
|
|
|
|
|
print("Prompt:") |
|
|
print(prompt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
localization = run_inference_localization(prompt, resized_image) |
|
|
except Exception as e: |
|
|
print(f"Error during model inference: {e}") |
|
|
return f"Error during model inference: {e}", resized_image.copy().convert("RGB") |
|
|
|
|
|
|
|
|
output_image_with_click = resized_image.copy().convert("RGB") |
|
|
match = re.search(r"Click\((\d+),\s*(\d+)\)", localization) |
|
|
if match: |
|
|
try: |
|
|
x = int(match.group(1)) |
|
|
y = int(match.group(2)) |
|
|
draw = ImageDraw.Draw(output_image_with_click) |
|
|
radius = max(5, min(resized_width // 100, resized_height // 100, 15)) |
|
|
bbox = (x - radius, y - radius, x + radius, y + radius) |
|
|
draw.ellipse(bbox, outline="red", width=max(2, radius // 4)) |
|
|
print(f"Predicted and drawn click at: ({x}, {y}) on resized image ({resized_width}x{resized_height})") |
|
|
except Exception as e: |
|
|
print(f"Error drawing on image: {e}") |
|
|
traceback.print_exc() |
|
|
else: |
|
|
print(f"Could not parse 'Click(x, y)' from model output: {localization}") |
|
|
|
|
|
return localization, output_image_with_click |
|
|
|
|
|
|
|
|
|
|
|
example_image_url = "https://huggingface.co/spaces/Hcompany/Holo1.5-Localization/resolve/main/desktop_3.png" |
|
|
example_image = Image.open(requests.get(example_image_url, stream=True).raw) |
|
|
example_task = "Email quote for Hyundai Kona" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") |
|
|
gr.Markdown(description) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
input_image_component = gr.Image(label="Input UI Image", height=400) |
|
|
task_component = gr.Textbox( |
|
|
label="component", |
|
|
placeholder="Email quote for Hyundai Kona", |
|
|
info="Describe the UI component to find.", |
|
|
) |
|
|
submit_button = gr.Button("Localize", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
output_coords_component = gr.Textbox(label="Localization Step") |
|
|
|
|
|
output_image_component = gr.Image( |
|
|
type="pil", label="Image with coordinates of the component", height=400, interactive=False |
|
|
) |
|
|
|
|
|
submit_button.click( |
|
|
localize, [input_image_component, task_component], [output_coords_component, output_image_component] |
|
|
) |
|
|
|
|
|
gr.Examples( |
|
|
examples=[[example_image, example_task]], |
|
|
inputs=[input_image_component, task_component], |
|
|
outputs=[output_coords_component, output_image_component], |
|
|
fn=localize, |
|
|
cache_examples="lazy", |
|
|
) |
|
|
|
|
|
demo.queue(api_open=False) |
|
|
demo.launch(debug=True) |
|
|
|