sharathmajjigi's picture
Implement proper UI-TARS grounding model with Qwen2.5-VL architecture
efd12df
raw
history blame
5.36 kB
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoProcessor
import torch
from PIL import Image
import io
import base64
import json
import numpy as np
# UI-TARS is a Qwen2.5-VL model - use the correct model class
model_name = "ByteDance-Seed/UI-TARS-1.5-7B"
def load_model():
"""Load UI-TARS model with proper configuration"""
try:
# UI-TARS requires specific handling for Qwen2.5-VL architecture
from transformers import Qwen2_5VLMForCausalLM, Qwen2_5VLMProcessor
# Load processor and model with proper configuration
processor = Qwen2_5VLMProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
model = Qwen2_5VLMForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16, # Use half precision for memory efficiency
device_map="auto", # Automatically handle device placement
trust_remote_code=True,
low_cpu_mem_usage=True
)
print("βœ… UI-TARS model loaded successfully!")
return model, processor
except Exception as e:
print(f"❌ Error loading UI-TARS: {e}")
print("Falling back to alternative approach...")
try:
# Alternative: Use AutoModel with trust_remote_code
processor = AutoProcessor.from_pretrained(
model_name,
trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True,
low_cpu_mem_usage=True
)
print("βœ… UI-TARS loaded with AutoModelForCausalLM")
return model, processor
except Exception as e2:
print(f"❌ Alternative approach failed: {e2}")
return None, None
# Load model at startup
print("πŸ”„ Loading UI-TARS model...")
model, processor = load_model()
def process_grounding(image, prompt):
"""
Process image with UI-TARS grounding model
"""
try:
if model is None or processor is None:
return json.dumps({
"error": "Model not loaded",
"status": "failed"
}, indent=2)
# Convert image to PIL if needed
if isinstance(image, str):
image_data = base64.b64decode(image)
image = Image.open(io.BytesIO(image_data))
# Prepare prompt for UI-TARS
# UI-TARS expects specific formatting for grounding tasks
formatted_prompt = f"""<image>
Please analyze this screenshot and provide grounding information for the following task: {prompt}
Please identify UI elements and provide:
1. Element locations (x, y coordinates)
2. Element types (button, text field, etc.)
3. Recommended actions (click, type, etc.)
4. Confidence scores
Format your response as JSON with the following structure:
{{
"elements": [
{{"type": "button", "x": 100, "y": 200, "text": "Click me", "confidence": 0.9}}
],
"actions": [
{{"action": "click", "x": 100, "y": 200, "description": "Click button"}}
]
}}"""
# Prepare inputs for the model
inputs = processor(
text=formatted_prompt,
images=image,
return_tensors="pt"
)
# Move inputs to same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate grounding results
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1
)
# Decode outputs
result_text = processor.decode(outputs[0], skip_special_tokens=True)
# Extract the response part after the prompt
response_start = result_text.find('{')
if response_start != -1:
response_json = result_text[response_start:]
try:
# Try to parse as JSON
parsed_result = json.loads(response_json)
return json.dumps(parsed_result, indent=2)
except json.JSONDecodeError:
# If JSON parsing fails, return the raw text
return f"Raw Response:\n{result_text}\n\nNote: Response could not be parsed as JSON"
else:
return f"Model Response:\n{result_text}"
except Exception as e:
return json.dumps({
"error": f"Error processing image: {str(e)}",
"status": "failed"
}, indent=2)
# Create Gradio interface
iface = gr.Interface(
fn=process_grounding,
inputs=[
gr.Image(type="pil", label="Upload Screenshot"),
gr.Textbox(label="Prompt/Goal", placeholder="What do you want to do?")
],
outputs=gr.Textbox(label="Grounding Results", lines=15),
title="UI-TARS Grounding Model",
description="Upload a screenshot and describe your goal to get grounding results from UI-TARS"
)
iface.launch()