import gradio as gr from transformers import AutoProcessor, AutoModel import torch from PIL import Image import io import base64 import json import numpy as np from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware import re # UI-TARS model name model_name = "ByteDance-Seed/UI-TARS-1.5-7B" def load_model(): """Load UI-TARS model with fallback""" try: print("🔄 Loading UI-TARS model...") # Use AutoProcessor and AutoModel (most compatible) processor = AutoProcessor.from_pretrained(model_name) print("✅ Processor loaded successfully!") model = AutoModel.from_pretrained(model_name) print("✅ UI-TARS model loaded successfully!") return model, processor except Exception as e: print(f"❌ Error loading UI-TARS: {str(e)}") print("Falling back to alternative approach...") try: # Fallback: Load just the processor processor = AutoProcessor.from_pretrained(model_name) print("✅ UI-TARS model loaded with fallback configuration!") return None, processor except Exception as e2: print(f"❌ Alternative approach failed: {str(e2)}") return None, None def fix_base64_string(base64_str): """Fix truncated base64 strings""" try: # Remove any whitespace and newlines base64_str = base64_str.strip() # Check if it's a data URL if base64_str.startswith('data:image/'): # Extract just the base64 part after the comma base64_str = base64_str.split(',', 1)[1] # Fix padding issues missing_padding = len(base64_str) % 4 if missing_padding: base64_str += '=' * (4 - missing_padding) # Validate base64 try: base64.b64decode(base64_str) return base64_str except: # If still invalid, try to find the complete base64 in the string # Look for base64 pattern (alphanumeric + / + =) match = re.search(r'[A-Za-z0-9+/]+={0,2}', base64_str) if match: fixed_str = match.group(0) # Fix padding missing_padding = len(fixed_str) % 4 if missing_padding: fixed_str += '=' * (4 - missing_padding) return fixed_str return base64_str except Exception as e: print(f"Error fixing base64: {e}") return base64_str def process_grounding(image_data, prompt): """Process image with UI-TARS grounding model""" try: print(f"Processing image with UI-TARS model...") # Fix base64 string if needed if isinstance(image_data, str): image_data = fix_base64_string(image_data) # Convert base64 to PIL Image try: if image_data.startswith('data:image/'): # Handle data URL format image_data = image_data.split(',', 1)[1] image_bytes = base64.b64decode(image_data) image = Image.open(io.BytesIO(image_bytes)) print(f"✅ Image loaded successfully: {image.size}") except Exception as e: print(f"❌ Error decoding base64: {e}") return { "error": f"Failed to decode image: {str(e)}", "status": "failed" } # For now, return a mock response since we're using fallback # In production, you'd process with the actual model return { "status": "success", "elements": [ { "type": "button", "text": "calculator button", "bbox": [100, 100, 200, 150], "confidence": 0.95 } ], "message": f"Processed image with prompt: {prompt}" } except Exception as e: print(f"❌ Error in process_grounding: {e}") return { "error": f"Error processing image: {str(e)}", "status": "failed" } # Load model model, processor = load_model() # Create FastAPI app app = FastAPI(title="UI-TARS Grounding Model API") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/v1/ground/chat/completions") async def chat_completions(request: Request): """Chat completions endpoint that Agent-S expects""" try: print("=" * 60) print("�� DEBUG: New request received") print("=" * 60) # Parse request body body = await request.body() print(f"�� RAW REQUEST BODY (bytes): {len(body)} bytes") print(f"�� RAW REQUEST BODY (string): {body.decode('utf-8')[:500]}...") # Parse JSON try: data = json.loads(body) print(f"✅ PARSED JSON SUCCESSFULLY") print(f"🔑 JSON KEYS: {list(data.keys())}") except json.JSONDecodeError as e: print(f"❌ JSON PARSE ERROR: {e}") return {"error": "Invalid JSON", "status": "failed"} # Extract messages messages = data.get("messages", []) print(f"💬 MESSAGES COUNT: {len(messages)}") # Find user message with image user_message = None image_data = None prompt = None for i, msg in enumerate(messages): print(f"📨 Message {i}: role='{msg.get('role')}', content type={type(msg.get('content'))}") if msg.get("role") == "user": content = msg.get("content", []) if isinstance(content, list): for item in content: if isinstance(item, dict): if item.get("type") == "image_url": image_data = item.get("image_url", {}).get("url", "") print(f"🖼️ Found image_url: {image_data[:100]}...") elif item.get("type") == "text": prompt = item.get("text", "") print(f"📝 Found text: {prompt[:100]}...") elif isinstance(content, str): prompt = content print(f"📝 Found string content: {prompt[:100]}...") if not image_data: print("❌ No image data found in request") return { "error": "No image data provided", "status": "failed" } if not prompt: prompt = "Analyze this image and identify UI elements" print(f"⚠️ No prompt found, using default: {prompt}") print(f"🖼️ USER MESSAGE EXTRACTED: {prompt[:100]}...") # Process with grounding model result = process_grounding(image_data, prompt) print(f"🔍 GROUNDING RESULT: {result}") # Format response for Agent-S response = { "id": "chatcmpl-123", "object": "chat.completion", "created": 1677652288, "model": "ui-tars-1.5-7b", "choices": [ { "index": 0, "message": { "role": "assistant", "content": json.dumps(result) if isinstance(result, dict) else str(result) }, "finish_reason": "stop" } ], "usage": { "prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30 } } print(f"📤 SENDING RESPONSE: {json.dumps(response, indent=2)}") return response except Exception as e: print(f"❌ ERROR in chat_completions: {e}") return { "error": f"Internal server error: {str(e)}", "status": "failed" } # Create Gradio interface for testing def gradio_interface(image, prompt): """Gradio interface for testing""" if image is None: return {"error": "No image provided", "status": "failed"} # Convert PIL image to base64 buffer = io.BytesIO() image.save(buffer, format="PNG") img_str = base64.b64encode(buffer.getvalue()).decode() # Process with grounding model result = process_grounding(img_str, prompt) return result # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=[ gr.Image(label="Upload Screenshot", type="pil"), gr.Textbox(label="Prompt/Goal", placeholder="Describe what you want to do...") ], outputs=gr.JSON(label="Grounding Results"), title="UI-TARS Grounding Model", description="Upload a screenshot and describe your goal to get UI element coordinates", examples=[ ["https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", "Click on the calculator button"] ] ) # Mount Gradio app app = gr.mount_gradio_app(app, iface, path="/") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)