Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import re | |
| import json | |
| import gc | |
| import time | |
| import shutil | |
| import uuid | |
| import tempfile | |
| import unicodedata | |
| from io import BytesIO | |
| from typing import Tuple, Optional, List, Dict, Any | |
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import spaces | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Transformers & Qwen Utils | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| # ----------------------------------------------------------------------------- | |
| # 1. CONSTANTS & SYSTEM PROMPT | |
| # ----------------------------------------------------------------------------- | |
| # Mapping UI labels to Hugging Face Model IDs | |
| MODEL_MAP = { | |
| "Fara-7B": "microsoft/Fara-7B", | |
| # Using the official SFT checkpoint for UI-TARS | |
| "UI-TARS-1.5-7B": "bytedance/UI-TARS-7B-SFT" | |
| } | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # Global model state | |
| CURRENT_MODEL = None | |
| CURRENT_PROCESSOR = None | |
| CURRENT_MODEL_NAME = None | |
| # Updated System Prompt to encourage the JSON format | |
| OS_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and a screenshot of the current status. | |
| You need to generate the next action to complete the task. | |
| Output your action inside a <tool_call> block using JSON format. | |
| Include "coordinate": [x, y] in pixels for interactions. | |
| Examples: | |
| <tool_call> | |
| {"name": "User", "arguments": {"action": "click", "coordinate": [400, 300]}} | |
| </tool_call> | |
| <tool_call> | |
| {"name": "User", "arguments": {"action": "type", "coordinate": [100, 200], "text": "hello"}} | |
| </tool_call> | |
| """ | |
| # ----------------------------------------------------------------------------- | |
| # 2. MODEL LOADING LOGIC | |
| # ----------------------------------------------------------------------------- | |
| def load_model_to_device(model_name: str): | |
| """ | |
| Loads the specified model to GPU, unloading previous models to save VRAM. | |
| """ | |
| global CURRENT_MODEL, CURRENT_PROCESSOR, CURRENT_MODEL_NAME | |
| target_id = MODEL_MAP.get(model_name, model_name) | |
| # If already loaded, skip | |
| if CURRENT_MODEL_NAME == model_name and CURRENT_MODEL is not None: | |
| return CURRENT_MODEL, CURRENT_PROCESSOR | |
| print(f"π Switching model to: {model_name} ({target_id})...") | |
| # 1. Cleanup previous model | |
| if CURRENT_MODEL is not None: | |
| del CURRENT_MODEL | |
| del CURRENT_PROCESSOR | |
| CURRENT_MODEL = None | |
| CURRENT_PROCESSOR = None | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| print("ποΈ Previous model unloaded.") | |
| # 2. Load New Model | |
| try: | |
| processor = AutoProcessor.from_pretrained(target_id, trust_remote_code=True) | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| target_id, | |
| trust_remote_code=True, | |
| torch_dtype=torch.bfloat16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto" if DEVICE == "cuda" else None, | |
| ) | |
| if DEVICE == "cpu": | |
| model.to("cpu") | |
| model.eval() | |
| CURRENT_MODEL = model | |
| CURRENT_PROCESSOR = processor | |
| CURRENT_MODEL_NAME = model_name | |
| print(f"β {model_name} loaded successfully.") | |
| return model, processor | |
| except Exception as e: | |
| print(f"β Error loading {model_name}: {e}") | |
| raise e | |
| def generate_response(model, processor, messages, max_new_tokens=512): | |
| """Generic generation function for Qwen2.5-VL based models""" | |
| # Apply Chat Template | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Process Images | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| # Prepare Inputs | |
| inputs = processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = inputs.to(model.device) | |
| # Generate | |
| with torch.no_grad(): | |
| generated_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) | |
| # Decode | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| return output_text | |
| # ----------------------------------------------------------------------------- | |
| # 3. PARSING & VISUALIZATION LOGIC | |
| # ----------------------------------------------------------------------------- | |
| def array_to_image(image_array: np.ndarray) -> Image.Image: | |
| if image_array is None: | |
| raise ValueError("No image provided. Please upload an image.") | |
| return Image.fromarray(np.uint8(image_array)) | |
| def get_navigation_prompt(task, image): | |
| return [ | |
| {"role": "system", "content": [{"type": "text", "text": OS_SYSTEM_PROMPT}]}, | |
| {"role": "user", "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": f"Instruction: {task}"}, | |
| ]}, | |
| ] | |
| def parse_tool_calls(response: str) -> list[dict]: | |
| """ | |
| Parses the <tool_call>{JSON}</tool_call> format. | |
| """ | |
| actions = [] | |
| # Regex to find content between <tool_call> tags | |
| matches = re.findall(r"<tool_call>(.*?)</tool_call>", response, re.DOTALL) | |
| for match in matches: | |
| try: | |
| json_str = match.strip() | |
| data = json.loads(json_str) | |
| args = data.get("arguments", {}) | |
| coords = args.get("coordinate", []) | |
| action_type = args.get("action", "unknown") | |
| text_content = args.get("text", "") | |
| # Check if coords exist and are a list of length 2 | |
| if coords and isinstance(coords, list) and len(coords) == 2: | |
| actions.append({ | |
| "type": action_type, | |
| "x": float(coords[0]), | |
| "y": float(coords[1]), | |
| "text": text_content, | |
| "raw_json": data | |
| }) | |
| print(f"Parsed Action: {action_type} at {coords}") | |
| else: | |
| # Some actions like 'scroll' might not have coordinates in some models | |
| print(f"Non-coordinate action or invalid: {json_str}") | |
| except json.JSONDecodeError as e: | |
| print(f"Failed to parse JSON: {e}") | |
| return actions | |
| def create_localized_image(original_image: Image.Image, actions: list[dict]) -> Optional[Image.Image]: | |
| """Draws markers on the image based on parsed pixel coordinates.""" | |
| if not actions: | |
| return None | |
| img_copy = original_image.copy() | |
| draw = ImageDraw.Draw(img_copy) | |
| width, height = img_copy.size | |
| try: | |
| font = ImageFont.load_default() | |
| except: | |
| font = None | |
| colors = { | |
| 'type': 'blue', | |
| 'click': 'red', | |
| 'left_click': 'red', | |
| 'right_click': 'purple', | |
| 'double_click': 'orange', | |
| 'unknown': 'green' | |
| } | |
| for i, act in enumerate(actions): | |
| x = act['x'] | |
| y = act['y'] | |
| # Check if Normalized (0.0 - 1.0) or Absolute (Pixels > 1.0) | |
| if x <= 1.0 and y <= 1.0 and x > 0: | |
| pixel_x = int(x * width) | |
| pixel_y = int(y * height) | |
| else: | |
| pixel_x = int(x) | |
| pixel_y = int(y) | |
| action_type = act['type'] | |
| color = colors.get(action_type, 'green') | |
| # Draw Circle Target | |
| r = 15 # Radius | |
| draw.ellipse( | |
| [pixel_x - r, pixel_y - r, pixel_x + r, pixel_y + r], | |
| outline=color, | |
| width=4 | |
| ) | |
| # Draw Center Dot | |
| draw.ellipse( | |
| [pixel_x - 4, pixel_y - 4, pixel_x + 4, pixel_y + 4], | |
| fill=color | |
| ) | |
| # Label Text | |
| label_text = f"{action_type}" | |
| if act['text']: | |
| label_text += f": '{act['text']}'" | |
| # Text Background | |
| text_pos = (pixel_x + 18, pixel_y - 12) | |
| bbox = draw.textbbox(text_pos, label_text, font=font) | |
| # Add padding to bbox | |
| bbox = (bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2) | |
| draw.rectangle(bbox, fill="black") | |
| draw.text(text_pos, label_text, fill="white", font=font) | |
| return img_copy | |
| # ----------------------------------------------------------------------------- | |
| # 4. GRADIO LOGIC | |
| # ----------------------------------------------------------------------------- | |
| def process_screenshot(input_numpy_image: np.ndarray, task: str, model_choice: str) -> Tuple[str, Optional[Image.Image]]: | |
| if input_numpy_image is None: | |
| return "β οΈ Please upload an image first.", None | |
| # 1. Load Requested Model (Switching if necessary) | |
| model, processor = load_model_to_device(model_choice) | |
| # 2. Prepare Data | |
| input_pil_image = array_to_image(input_numpy_image) | |
| prompt = get_navigation_prompt(task, input_pil_image) | |
| # 3. Generate | |
| print(f"Generating response using {model_choice}...") | |
| raw_response = generate_response(model, processor, prompt, max_new_tokens=512) | |
| print(f"Raw Output:\n{raw_response}") | |
| # 4. Parse & Visualize | |
| actions = parse_tool_calls(raw_response) | |
| output_image = input_pil_image | |
| if actions: | |
| visualized = create_localized_image(input_pil_image, actions) | |
| if visualized: | |
| output_image = visualized | |
| return raw_response, output_image | |
| # ----------------------------------------------------------------------------- | |
| # 5. GRADIO UI SETUP | |
| # ----------------------------------------------------------------------------- | |
| title = "CUA GUI Agent π₯οΈ" | |
| description = """ | |
| **Computer Use Agent (CUA)** Demo. | |
| Upload a screenshot and provide a task instruction. The model will analyze the UI and output the precise coordinates and actions required. | |
| **Models Supported:** | |
| * **Fara-7B**: Microsoft's GUI agent model. | |
| * **UI-TARS-1.5-7B**: ByteDance's GUI agent model. | |
| """ | |
| custom_css = """ | |
| #out_img { height: 600px; object-fit: contain; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>") | |
| gr.Markdown(description) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Upload Screenshot", height=500) | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| label="Choose CUA Model", | |
| choices=list(MODEL_MAP.keys()), | |
| value="Fara-7B", | |
| interactive=True | |
| ) | |
| task_input = gr.Textbox( | |
| label="Task Instruction", | |
| placeholder="e.g. Input the server address readyforquantum.com...", | |
| lines=2 | |
| ) | |
| submit_btn = gr.Button("Analyze UI & Generate Action", variant="primary") | |
| with gr.Column(): | |
| output_image = gr.Image(label="Visualized Action Points", elem_id="out_img", height=500) | |
| output_text = gr.Textbox(label="Raw Model Output", lines=8, show_copy_button=True) | |
| # Wire up the button | |
| submit_btn.click( | |
| fn=process_screenshot, | |
| inputs=[input_image, task_input, model_choice], | |
| outputs=[output_text, output_image] | |
| ) | |
| # Example for quick testing | |
| gr.Examples( | |
| examples=[ | |
| ["./assets/google.png", "Search for 'Hugging Face'", "Fara-7B"], | |
| ], | |
| inputs=[input_image, task_input, model_choice], | |
| label="Quick Examples" | |
| ) | |
| if __name__ == "__main__": | |
| # Pre-load default model to speed up first request if memory allows, | |
| # but strictly loading on GPU request is safer for Spaces. | |
| demo.queue().launch() |