""" UI Element Detection API Server Combines OmniParser UI detection with template matching to provide precise coordinates for all UI elements in an image. Usage: python ui_element_api_server.py --port 8001 Then POST a PNG image to: http://localhost:8001/analyze Response includes: - JSON coordinates data - CSV format data - Visualization PNG with bounding boxes """ import cv2 import numpy as np import json import os import sys import io import time import base64 from pathlib import Path from contextlib import asynccontextmanager from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse, FileResponse import argparse import uvicorn from typing import Dict, Any, Optional, Tuple from PIL import Image import csv import tempfile import threading # Add OmniParser to path dynamically from pathlib import Path omoi_root = Path(__file__).parent sys.path.insert(0, str(omoi_root / 'OmniParser')) from util.omniparser import Omniparser from config import get_omniparser_config # ============ Utility Functions ============ def to_rgb(img: np.ndarray) -> Optional[np.ndarray]: """Converts image to BGR format (3 channels).""" if img is None: return None if len(img.shape) == 2: return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) if img.shape[2] == 4: return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) return img def match_ui_elements( original_image_array: np.ndarray, cropped_images_dir: str, threshold: float = 0.7 ) -> Tuple[list, Dict]: """ Match cropped UI templates against original image. Returns list of matches and metadata. """ original_img_rgb = to_rgb(original_image_array) if original_img_rgb is None: raise ValueError("Failed to convert original image") img_height, img_width = original_img_rgb.shape[:2] # Load templates templates = {} template_files = sorted(Path(cropped_images_dir).glob('crop_*.png')) for template_file in template_files: template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED) if template_img is not None: template_img_rgb = to_rgb(template_img) templates[template_file.name] = template_img_rgb # Match templates matches = [] for template_name, template_img in templates.items(): try: if template_img.shape[0] > img_height or template_img.shape[1] > img_width: continue if template_img.shape[0] < 4 or template_img.shape[1] < 4: continue result = cv2.matchTemplate(original_img_rgb, template_img, cv2.TM_CCOEFF_NORMED) _, max_val, _, max_loc = cv2.minMaxLoc(result) if max_val >= threshold: template_h, template_w = template_img.shape[:2] x1, y1 = max_loc x2 = x1 + template_w y2 = y1 + template_h center_x = (x1 + x2) / 2 center_y = (y1 + y2) / 2 matches.append({ 'template_id': template_name.replace('.png', ''), 'template_file': template_name, 'confidence': float(max_val), 'bbox': { 'x1': int(x1), 'y1': int(y1), 'x2': int(x2), 'y2': int(y2), 'width': int(template_w), 'height': int(template_h) }, 'center': { 'x': int(center_x), 'y': int(center_y) }, 'bbox_ratio': { 'x1': x1 / img_width, 'y1': y1 / img_height, 'x2': x2 / img_width, 'y2': y2 / img_height } }) except Exception: continue matches.sort(key=lambda x: x['confidence'], reverse=True) metadata = { 'image_size': {'width': img_width, 'height': img_height}, 'templates_loaded': len(templates), 'threshold': threshold, 'matches_found': len(matches) } return matches, metadata def visualize_matches( original_image_array: np.ndarray, matches: list ) -> np.ndarray: """Create visualization with bounding boxes.""" img = original_image_array.copy() for match in matches: bbox = match['bbox'] center = match['center'] confidence = match['confidence'] template_id = match['template_id'] # Draw bounding box color = (0, 255, 0) # Green thickness = 2 cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox['x2'], bbox['y2']), color, thickness) # Draw center point cv2.circle(img, (center['x'], center['y']), 3, (0, 0, 255), -1) # Red # Draw label label = f"ID:{template_id} ({confidence:.2f})" cv2.putText(img, label, (bbox['x1'], bbox['y1'] - 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1) return img def matches_to_csv(matches: list, image_width: int, image_height: int) -> str: """Convert matches to CSV format (returns string).""" output = io.StringIO() writer = csv.writer(output) writer.writerow([ 'Element_ID', 'Template_File', 'Confidence', 'X1', 'Y1', 'X2', 'Y2', 'Width', 'Height', 'Center_X', 'Center_Y', 'Ratio_X1', 'Ratio_Y1', 'Ratio_X2', 'Ratio_Y2' ]) for match in matches: bbox = match['bbox'] center = match['center'] ratio = match['bbox_ratio'] writer.writerow([ match['template_id'], match['template_file'], f"{match['confidence']:.4f}", bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2'], bbox['width'], bbox['height'], center['x'], center['y'], f"{ratio['x1']:.6f}", f"{ratio['y1']:.6f}", f"{ratio['x2']:.6f}", f"{ratio['y2']:.6f}" ]) return output.getvalue() # ============ FastAPI Server ============ # Global OmniParser instance omniparser = None omniparser_lock = threading.Lock() @asynccontextmanager async def lifespan(app: FastAPI): """Initialize and cleanup on server startup/shutdown.""" global omniparser try: with omniparser_lock: config = get_omniparser_config() omniparser = Omniparser(config) print("[Server] OmniParser initialized successfully") except Exception as e: print(f"[ERROR] Failed to initialize OmniParser: {str(e)}") import traceback traceback.print_exc() yield # Application runs here # Cleanup (if any) print("[Server] Shutting down...") app = FastAPI( title="UI Element Detection API", description="Detects and locates all UI elements in screenshots", lifespan=lifespan ) @app.get("/health") async def health(): """Health check endpoint.""" return {"status": "ok", "service": "UI Element Detection API"} @app.post("/analyze") async def analyze_image(file: UploadFile = File(...)): """ Analyze an image for UI elements. Returns: JSON response with coordinates, CSV data, and base64-encoded visualization """ if not omniparser: raise HTTPException(status_code=503, detail="OmniParser not initialized") try: print(f"\n[Analysis] Starting analysis for: {file.filename}") start_time = time.time() # 1. Read and decode image print("[Step 1] Reading image file...") content = await file.read() np_array = np.frombuffer(content, np.uint8) original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) if original_img is None: raise HTTPException(status_code=400, detail="Failed to decode image") print(f"[Step 1] Image loaded: {original_img.shape}") # 2. Encode for OmniParser print("[Step 2] Encoding for OmniParser...") _, buffer = cv2.imencode('.png', original_img) image_base64 = base64.b64encode(buffer).decode() # 3. Run OmniParser print("[Step 3] Running OmniParser detection...") omni_time = time.time() _, parsed_content = omniparser.parse(image_base64) omni_elapsed = time.time() - omni_time print(f"[Step 3] OmniParser complete in {omni_elapsed:.2f}s") # 4. Get cropped images directory cropped_dir = '/tmp/omoi_cropped_images' if not Path(cropped_dir).exists(): raise HTTPException(status_code=500, detail="Cropped images directory not found") # 5. Match UI elements print("[Step 4] Matching templates...") match_time = time.time() matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) match_elapsed = time.time() - match_time print(f"[Step 4] Matching complete in {match_elapsed:.2f}s - Found {len(matches)} elements") # 6. Create visualization print("[Step 5] Creating visualization...") viz_img = visualize_matches(original_img, matches) _, viz_buffer = cv2.imencode('.png', viz_img) viz_base64 = base64.b64encode(viz_buffer).decode() # 7. Generate CSV print("[Step 6] Generating CSV...") csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) # 8. Prepare response print("[Step 7] Preparing response...") response_data = { 'status': 'success', 'processing_time_seconds': time.time() - start_time, 'timing': { 'omniparser_seconds': omni_elapsed, 'template_matching_seconds': match_elapsed }, 'image_info': { 'filename': file.filename, 'size': metadata['image_size'] }, 'analysis': { 'total_elements_detected': len(matches), 'elements': matches }, 'exports': { 'csv_data': csv_data, 'visualization_png_base64': viz_base64 } } total_time = time.time() - start_time print(f"[Analysis] Complete in {total_time:.2f}s") return JSONResponse(content=response_data) except HTTPException: raise except Exception as e: print(f"[ERROR] Analysis failed: {str(e)}") import traceback traceback.print_exc() raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") @app.post("/analyze_batch") async def analyze_batch(file: UploadFile = File(...)): """ Analyze image and return as separate parts for easier client handling. Returns: { 'metadata': analysis metadata, 'coordinates_json': full coordinates data, 'csv_data': CSV string, 'visualization_png_base64': visualization image } """ if not omniparser: raise HTTPException(status_code=503, detail="OmniParser not initialized") try: print(f"\n[Batch Analysis] Starting for: {file.filename}") # Read image content = await file.read() np_array = np.frombuffer(content, np.uint8) original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) if original_img is None: raise HTTPException(status_code=400, detail="Failed to decode image") # Run OmniParser image_base64 = base64.b64encode(cv2.imencode('.png', original_img)[1]).decode() _, parsed_content = omniparser.parse(image_base64) # Match templates cropped_dir = '/tmp/omoi_cropped_images' matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) # Create visualization viz_img = visualize_matches(original_img, matches) _, viz_buffer = cv2.imencode('.png', viz_img) viz_base64 = base64.b64encode(viz_buffer).decode() # CSV data csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) # Create JSON structure coordinates_json = { 'source_image': file.filename, 'image_size': metadata['image_size'], 'total_elements': len(matches), 'elements': matches } return JSONResponse(content={ 'metadata': { 'filename': file.filename, 'image_size': metadata['image_size'], 'total_elements_detected': len(matches), 'templates_loaded': metadata['templates_loaded'] }, 'coordinates_json': coordinates_json, 'csv_data': csv_data, 'visualization_png_base64': viz_base64 }) except Exception as e: print(f"[ERROR] Batch analysis failed: {str(e)}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": import multiprocessing parser = argparse.ArgumentParser(description='UI Element Detection API Server') parser.add_argument('--host', type=str, default='127.0.0.1', help='Host to bind to') parser.add_argument('--port', type=int, default=8001, help='Port to listen on') parser.add_argument('--reload', action='store_true', help='Enable auto-reload') parser.add_argument('--workers', type=int, default=1, help='Number of worker processes (default: 1, use for production with module import)') args = parser.parse_args() # Get CPU count for reference cpu_count = multiprocessing.cpu_count() print(f"\n{'='*70}") print("UI Element Detection API Server - Optimized") print(f"{'='*70}") print(f"Starting server on http://{args.host}:{args.port}") print(f"CPU Cores Available: {cpu_count}") print(f"Workers: {args.workers} (direct mode - async concurrency enabled)") print(f"\nEndpoints:") print(f" POST /analyze - Analyze image with details") print(f" POST /analyze_batch - Analyze image with structured response") print(f" GET /health - Health check") print(f"{'='*70}\n") try: # Run with async concurrency instead of multiple workers for direct instantiation uvicorn.run( app, host=args.host, port=args.port, reload=args.reload, loop="auto" ) except KeyboardInterrupt: print("\n[Server] Shutting down...") except Exception as e: print(f"\n[ERROR] Server error: {str(e)}") import traceback traceback.print_exc()