Spaces:
Build error
Build error
| """ | |
| UI-TARS-1.5-7B API Server for Hugging Face Spaces (Optimized) | |
| ============================================================== | |
| نسخة محسنة تستخدم Hugging Face Inference API للعمل بسرعة على النسخة المجانية | |
| Author: AI Assistant | |
| Model: ByteDance-Seed/UI-TARS-1.5-7B | |
| """ | |
| import os | |
| import base64 | |
| import io | |
| import json | |
| import re | |
| import time | |
| from typing import Optional, List, Dict, Any, Union | |
| from contextlib import asynccontextmanager | |
| import httpx | |
| from PIL import Image | |
| from fastapi import FastAPI, HTTPException, File, UploadFile, Form | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from pydantic import BaseModel, Field | |
| import uvicorn | |
| # ============================================================================ | |
| # Configuration | |
| # ============================================================================ | |
| MODEL_NAME = os.getenv("MODEL_NAME", "ByteDance-Seed/UI-TARS-1.5-7B") | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) # Optional: للنماذج الخاصة | |
| TEMPERATURE = float(os.getenv("TEMPERATURE", "0.7")) | |
| TOP_P = float(os.getenv("TOP_P", "0.9")) | |
| MAX_TOKENS = int(os.getenv("MAX_TOKENS", "2048")) | |
| # Hugging Face Inference API endpoint | |
| HF_API_URL = f"https://api-inference.huggingface.co/models/{MODEL_NAME}" | |
| # System prompts | |
| COMPUTER_USE_SYSTEM_PROMPT = """You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. | |
| ## Output Format | |
| Thought: ... | |
| Action: ... | |
| ## Action Space | |
| click(start_box='<|box_start|>(x1,y1)<|box_end|>') | |
| left_double(start_box='<|box_start|>(x1,y1)<|box_end|>') | |
| right_single(start_box='<|box_start|>(x1,y1)<|box_end|>') | |
| drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') | |
| hotkey(key='') | |
| type(content='xxx') | |
| scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') | |
| wait() | |
| finished(content='xxx') | |
| ## Note | |
| - Use English in `Thought` part. | |
| - Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. | |
| ## User Instruction | |
| {instruction} | |
| """ | |
| MOBILE_USE_SYSTEM_PROMPT = """You are a GUI agent for mobile devices. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task. | |
| ## Output Format | |
| Thought: ... | |
| Action: ... | |
| ## Action Space | |
| click(start_box='<|box_start|>(x1,y1)<|box_end|>') | |
| long_press(start_box='<|box_start|>(x1,y1)<|box_end|>') | |
| drag(start_box='<|box_start|>(x1,y1)<|box_end|>', end_box='<|box_start|>(x3,y3)<|box_end|>') | |
| type(content='xxx') | |
| scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left') | |
| open_app(app_name='xxx') | |
| press_home() | |
| press_back() | |
| wait() | |
| finished(content='xxx') | |
| ## Note | |
| - Use English in `Thought` part. | |
| - Write a small plan and finally summarize your next action (with its target element) in one sentence in `Thought` part. | |
| ## User Instruction | |
| {instruction} | |
| """ | |
| GROUNDING_SYSTEM_PROMPT = """Output only the coordinate of one point in your response. What element matches the following task: {instruction}""" | |
| # ============================================================================ | |
| # Pydantic Models | |
| # ============================================================================ | |
| class InferenceRequest(BaseModel): | |
| """Inference request model""" | |
| instruction: str = Field(..., description="User instruction/task") | |
| image: Optional[str] = Field(default=None, description="Base64 encoded screenshot image") | |
| system_prompt_type: str = Field(default="computer", description="Type: computer, mobile, grounding") | |
| language: str = Field(default="English", description="Language for thought process") | |
| temperature: float = Field(default=TEMPERATURE, ge=0.0, le=2.0) | |
| top_p: float = Field(default=TOP_P, ge=0.0, le=1.0) | |
| max_tokens: int = Field(default=MAX_TOKENS, ge=1, le=8192) | |
| use_thought: bool = Field(default=True, description="Enable thought decomposition") | |
| class InferenceResponse(BaseModel): | |
| """Inference response model""" | |
| thought: Optional[str] = Field(default=None, description="Agent's reasoning") | |
| action: str = Field(..., description="Predicted action") | |
| raw_response: str = Field(..., description="Raw model output") | |
| coordinates: Optional[Dict[str, int]] = Field(default=None, description="Parsed coordinates if applicable") | |
| class BatchInferenceRequest(BaseModel): | |
| """Batch inference request""" | |
| requests: List[InferenceRequest] | |
| class HealthResponse(BaseModel): | |
| """Health check response""" | |
| status: str | |
| api_available: bool | |
| model_name: str | |
| class ModelInfoResponse(BaseModel): | |
| """Model information response""" | |
| model_name: str | |
| api_type: str | |
| temperature: float | |
| top_p: float | |
| max_tokens: int | |
| capabilities: List[str] | |
| # ============================================================================ | |
| # Model Manager (Using HF Inference API) | |
| # ============================================================================ | |
| class ModelManager: | |
| """Manages inference using Hugging Face Inference API""" | |
| def __init__(self): | |
| self.api_url = HF_API_URL | |
| self.headers = {} | |
| if HF_TOKEN: | |
| self.headers["Authorization"] = f"Bearer {HF_TOKEN}" | |
| self.client = httpx.AsyncClient(timeout=120.0) | |
| self.is_available = False | |
| async def check_availability(self): | |
| """Check if the API is available""" | |
| try: | |
| # Simple health check | |
| response = await self.client.get( | |
| self.api_url, | |
| headers=self.headers | |
| ) | |
| self.is_available = response.status_code in [200, 503] # 503 means loading | |
| return self.is_available | |
| except Exception as e: | |
| print(f"API check failed: {e}") | |
| self.is_available = False | |
| return False | |
| def get_system_prompt(self, prompt_type: str, instruction: str, language: str = "English") -> str: | |
| """Get the appropriate system prompt""" | |
| if prompt_type == "computer": | |
| return COMPUTER_USE_SYSTEM_PROMPT.format(instruction=instruction, language=language) | |
| elif prompt_type == "mobile": | |
| return MOBILE_USE_SYSTEM_PROMPT.format(instruction=instruction, language=language) | |
| elif prompt_type == "grounding": | |
| return GROUNDING_SYSTEM_PROMPT.format(instruction=instruction) | |
| else: | |
| return COMPUTER_USE_SYSTEM_PROMPT.format(instruction=instruction, language=language) | |
| def parse_action(self, response: str) -> Dict[str, Any]: | |
| """Parse the model response to extract thought and action""" | |
| result = { | |
| "thought": None, | |
| "action": None, | |
| "coordinates": None | |
| } | |
| # Extract thought | |
| thought_match = re.search(r'Thought:\s*(.+?)(?=\nAction:|$)', response, re.DOTALL) | |
| if thought_match: | |
| result["thought"] = thought_match.group(1).strip() | |
| # Extract action | |
| action_match = re.search(r'Action:\s*(.+?)(?=\n|$)', response, re.DOTALL) | |
| if action_match: | |
| result["action"] = action_match.group(1).strip() | |
| else: | |
| # No "Action:" prefix, try to parse the whole response | |
| result["action"] = response.strip() | |
| # Extract coordinates if present | |
| coord_pattern = r'<\|box_start\|\>\((\d+),(\d+)\)<\|box_end\|\>' | |
| coord_match = re.search(coord_pattern, result.get("action", "")) | |
| if coord_match: | |
| result["coordinates"] = { | |
| "x": int(coord_match.group(1)), | |
| "y": int(coord_match.group(2)) | |
| } | |
| return result | |
| async def inference( | |
| self, | |
| instruction: str, | |
| image_data: Optional[str] = None, | |
| system_prompt_type: str = "computer", | |
| language: str = "English", | |
| temperature: float = TEMPERATURE, | |
| top_p: float = TOP_P, | |
| max_tokens: int = MAX_TOKENS, | |
| use_thought: bool = True | |
| ) -> Dict[str, Any]: | |
| """Run inference using HF Inference API""" | |
| # Build the prompt | |
| system_prompt = self.get_system_prompt(system_prompt_type, instruction, language) | |
| # Prepare the payload for HF Inference API | |
| payload = { | |
| "inputs": system_prompt, | |
| "parameters": { | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_new_tokens": max_tokens, | |
| "return_full_text": False | |
| } | |
| } | |
| # If image is provided, include it | |
| if image_data: | |
| # HF Inference API expects the image in specific format | |
| # For vision models, we need to format the request differently | |
| try: | |
| # Decode base64 image | |
| image_bytes = base64.b64decode(image_data) | |
| # Make request with image | |
| files = { | |
| "file": ("image.png", io.BytesIO(image_bytes), "image/png") | |
| } | |
| data = { | |
| "inputs": system_prompt, | |
| "parameters": json.dumps(payload["parameters"]) | |
| } | |
| max_retries = 3 | |
| retry_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| response = await self.client.post( | |
| self.api_url, | |
| headers=self.headers, | |
| files=files, | |
| data=data | |
| ) | |
| if response.status_code == 503: | |
| # Model is loading | |
| if attempt < max_retries - 1: | |
| wait_time = retry_delay * (attempt + 1) | |
| print(f"Model loading, waiting {wait_time}s...") | |
| await asyncio.sleep(wait_time) | |
| continue | |
| else: | |
| return { | |
| "thought": "Model is still loading. Please try again in a moment.", | |
| "action": "wait()", | |
| "raw_response": "Model loading...", | |
| "coordinates": None | |
| } | |
| response.raise_for_status() | |
| result = response.json() | |
| break | |
| except httpx.HTTPStatusError as e: | |
| if attempt < max_retries - 1 and e.response.status_code in [503, 429]: | |
| wait_time = retry_delay * (attempt + 1) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| else: | |
| raise | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| else: | |
| # Text-only request | |
| max_retries = 3 | |
| retry_delay = 2 | |
| for attempt in range(max_retries): | |
| try: | |
| response = await self.client.post( | |
| self.api_url, | |
| headers=self.headers, | |
| json=payload | |
| ) | |
| if response.status_code == 503: | |
| if attempt < max_retries - 1: | |
| wait_time = retry_delay * (attempt + 1) | |
| print(f"Model loading, waiting {wait_time}s...") | |
| await asyncio.sleep(wait_time) | |
| continue | |
| else: | |
| return { | |
| "thought": "Model is still loading. Please try again in a moment.", | |
| "action": "wait()", | |
| "raw_response": "Model loading...", | |
| "coordinates": None | |
| } | |
| response.raise_for_status() | |
| result = response.json() | |
| break | |
| except httpx.HTTPStatusError as e: | |
| if attempt < max_retries - 1 and e.response.status_code in [503, 429]: | |
| wait_time = retry_delay * (attempt + 1) | |
| await asyncio.sleep(wait_time) | |
| continue | |
| else: | |
| raise | |
| # Parse the response | |
| if isinstance(result, list) and len(result) > 0: | |
| generated_text = result[0].get("generated_text", "") | |
| elif isinstance(result, dict): | |
| generated_text = result.get("generated_text", str(result)) | |
| else: | |
| generated_text = str(result) | |
| # Parse thought and action | |
| parsed = self.parse_action(generated_text) | |
| return { | |
| "thought": parsed["thought"], | |
| "action": parsed["action"] or "wait()", | |
| "raw_response": generated_text, | |
| "coordinates": parsed["coordinates"] | |
| } | |
| def convert_coordinates(x_rel: int, y_rel: int, screen_width: int, screen_height: int) -> Dict[str, int]: | |
| """Convert relative coordinates (0-1000) to absolute""" | |
| return { | |
| "x": round(screen_width * x_rel / 1000), | |
| "y": round(screen_height * y_rel / 1000) | |
| } | |
| # ============================================================================ | |
| # FastAPI App | |
| # ============================================================================ | |
| model_manager = ModelManager() | |
| async def lifespan(app: FastAPI): | |
| """Startup and shutdown events""" | |
| print("🚀 Starting UI-TARS API Server (Optimized for HF Spaces)") | |
| print(f"📦 Model: {MODEL_NAME}") | |
| print(f"🔗 API URL: {HF_API_URL}") | |
| # Check API availability | |
| await model_manager.check_availability() | |
| if model_manager.is_available: | |
| print("✅ Hugging Face Inference API is available") | |
| else: | |
| print("⚠️ Hugging Face Inference API may be loading") | |
| yield | |
| # Cleanup | |
| await model_manager.client.aclose() | |
| print("👋 Shutting down UI-TARS API Server") | |
| app = FastAPI( | |
| title="UI-TARS-1.5-7B API", | |
| description="Optimized API for UI automation using ByteDance's UI-TARS-1.5-7B via HF Inference API", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS middleware | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Import asyncio for sleep | |
| import asyncio | |
| # ============================================================================ | |
| # API Endpoints | |
| # ============================================================================ | |
| async def root(): | |
| """Root endpoint with API info""" | |
| return { | |
| "name": "UI-TARS-1.5-7B API", | |
| "version": "2.0.0", | |
| "model": MODEL_NAME, | |
| "api_type": "Hugging Face Inference API", | |
| "description": "Optimized for free Hugging Face Spaces", | |
| "endpoints": { | |
| "health": "/health", | |
| "model_info": "/model/info", | |
| "inference": "/v1/inference", | |
| "inference_file": "/v1/inference/file", | |
| "chat_completions": "/v1/chat/completions", | |
| "grounding": "/v1/grounding", | |
| "batch": "/v1/batch/inference" | |
| }, | |
| "documentation": "/docs" | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| await model_manager.check_availability() | |
| return HealthResponse( | |
| status="healthy" if model_manager.is_available else "loading", | |
| api_available=model_manager.is_available, | |
| model_name=MODEL_NAME | |
| ) | |
| async def model_info(): | |
| """Get model information""" | |
| return ModelInfoResponse( | |
| model_name=MODEL_NAME, | |
| api_type="Hugging Face Inference API", | |
| temperature=TEMPERATURE, | |
| top_p=TOP_P, | |
| max_tokens=MAX_TOKENS, | |
| capabilities=[ | |
| "gui_automation", | |
| "computer_use", | |
| "mobile_use", | |
| "grounding", | |
| "screenshot_analysis", | |
| "action_prediction" | |
| ] | |
| ) | |
| async def inference(request: InferenceRequest): | |
| """ | |
| Run inference on a single request | |
| This endpoint processes a screenshot and instruction to predict the next GUI action. | |
| """ | |
| try: | |
| result = await model_manager.inference( | |
| instruction=request.instruction, | |
| image_data=request.image, | |
| system_prompt_type=request.system_prompt_type, | |
| language=request.language, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| max_tokens=request.max_tokens, | |
| use_thought=request.use_thought | |
| ) | |
| return InferenceResponse(**result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def inference_with_file( | |
| instruction: str = Form(...), | |
| system_prompt_type: str = Form(default="computer"), | |
| language: str = Form(default="English"), | |
| temperature: float = Form(default=TEMPERATURE), | |
| top_p: float = Form(default=TOP_P), | |
| max_tokens: int = Form(default=MAX_TOKENS), | |
| use_thought: bool = Form(default=True), | |
| image: Optional[UploadFile] = File(default=None) | |
| ): | |
| """ | |
| Run inference with file upload | |
| Upload a screenshot image file along with the instruction. | |
| """ | |
| try: | |
| image_data = None | |
| if image: | |
| contents = await image.read() | |
| image_data = base64.b64encode(contents).decode('utf-8') | |
| result = await model_manager.inference( | |
| instruction=instruction, | |
| image_data=image_data, | |
| system_prompt_type=system_prompt_type, | |
| language=language, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens, | |
| use_thought=use_thought | |
| ) | |
| return InferenceResponse(**result) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def chat_completions(request: Dict[str, Any]): | |
| """ | |
| OpenAI-compatible chat completions endpoint | |
| Compatible with OpenAI's API format for easy integration. | |
| """ | |
| try: | |
| messages = request.get("messages", []) | |
| temperature = request.get("temperature", TEMPERATURE) | |
| top_p = request.get("top_p", TOP_P) | |
| max_tokens = request.get("max_tokens", MAX_TOKENS) | |
| # Extract the last user message | |
| instruction = "" | |
| image_data = None | |
| for msg in messages: | |
| if msg.get("role") == "user": | |
| content = msg.get("content", "") | |
| if isinstance(content, list): | |
| for item in content: | |
| if item.get("type") == "text": | |
| instruction = item.get("text", "") | |
| elif item.get("type") == "image_url": | |
| image_url = item.get("image_url", {}).get("url", "") | |
| if image_url.startswith("data:image"): | |
| # Extract base64 data | |
| image_data = image_url.split(",")[1] | |
| else: | |
| instruction = content | |
| result = await model_manager.inference( | |
| instruction=instruction, | |
| image_data=image_data, | |
| temperature=temperature, | |
| top_p=top_p, | |
| max_tokens=max_tokens | |
| ) | |
| # Format as OpenAI response | |
| return { | |
| "id": "chatcmpl-ui-tars", | |
| "object": "chat.completion", | |
| "created": int(time.time()), | |
| "model": MODEL_NAME, | |
| "choices": [{ | |
| "index": 0, | |
| "message": { | |
| "role": "assistant", | |
| "content": result["raw_response"] | |
| }, | |
| "finish_reason": "stop" | |
| }], | |
| "usage": { | |
| "prompt_tokens": 0, | |
| "completion_tokens": 0, | |
| "total_tokens": 0 | |
| } | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def grounding( | |
| instruction: str = Form(...), | |
| image: UploadFile = File(...), | |
| image_width: int = Form(default=1920), | |
| image_height: int = Form(default=1080) | |
| ): | |
| """ | |
| Grounding endpoint - Get coordinates for an element | |
| Returns the coordinates of the element matching the instruction. | |
| """ | |
| try: | |
| contents = await image.read() | |
| image_data = base64.b64encode(contents).decode('utf-8') | |
| result = await model_manager.inference( | |
| instruction=instruction, | |
| image_data=image_data, | |
| system_prompt_type="grounding", | |
| max_tokens=128 | |
| ) | |
| # Convert coordinates if present | |
| if result["coordinates"]: | |
| abs_coords = model_manager.convert_coordinates( | |
| result["coordinates"]["x"], | |
| result["coordinates"]["y"], | |
| image_width, | |
| image_height | |
| ) | |
| result["absolute_coordinates"] = abs_coords | |
| return result | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def batch_inference(request: BatchInferenceRequest): | |
| """ | |
| Batch inference endpoint | |
| Process multiple requests in one call. | |
| """ | |
| results = [] | |
| for req in request.requests: | |
| try: | |
| result = await model_manager.inference( | |
| instruction=req.instruction, | |
| image_data=req.image, | |
| system_prompt_type=req.system_prompt_type, | |
| language=req.language, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| max_tokens=req.max_tokens, | |
| use_thought=req.use_thought | |
| ) | |
| results.append({"success": True, "result": result}) | |
| except Exception as e: | |
| results.append({"success": False, "error": str(e)}) | |
| return {"results": results} | |
| # ============================================================================ | |
| # Main Entry Point | |
| # ============================================================================ | |
| if __name__ == "__main__": | |
| port = int(os.getenv("PORT", "7860")) | |
| host = os.getenv("HOST", "0.0.0.0") | |
| uvicorn.run( | |
| app, | |
| host=host, | |
| port=port, | |
| log_level="info" | |
| ) | |