| """ |
| UI Element Detection API Server |
| Combines OmniParser UI detection with template matching to provide |
| precise coordinates for all UI elements in an image. |
| |
| Usage: |
| python ui_element_api_server.py --port 8001 |
| |
| Then POST a PNG image to: http://localhost:8001/analyze |
| |
| Response includes: |
| - JSON coordinates data |
| - CSV format data |
| - Visualization PNG with bounding boxes |
| """ |
|
|
| import cv2 |
| import numpy as np |
| import json |
| import os |
| import sys |
| import io |
| import time |
| import base64 |
| from pathlib import Path |
| from contextlib import asynccontextmanager |
| from fastapi import FastAPI, File, UploadFile, HTTPException |
| from fastapi.responses import JSONResponse, FileResponse |
| import argparse |
| import uvicorn |
| from typing import Dict, Any, Optional, Tuple |
| from PIL import Image |
| import csv |
| import tempfile |
| import threading |
|
|
| |
| from pathlib import Path |
| omoi_root = Path(__file__).parent |
| sys.path.insert(0, str(omoi_root / 'OmniParser')) |
| from util.omniparser import Omniparser |
| from config import get_omniparser_config |
|
|
| |
|
|
| def to_rgb(img: np.ndarray) -> Optional[np.ndarray]: |
| """Converts image to BGR format (3 channels).""" |
| if img is None: |
| return None |
| if len(img.shape) == 2: |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
| if img.shape[2] == 4: |
| return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) |
| return img |
|
|
| def match_ui_elements( |
| original_image_array: np.ndarray, |
| cropped_images_dir: str, |
| threshold: float = 0.7 |
| ) -> Tuple[list, Dict]: |
| """ |
| Match cropped UI templates against original image. |
| Returns list of matches and metadata. |
| """ |
| original_img_rgb = to_rgb(original_image_array) |
| if original_img_rgb is None: |
| raise ValueError("Failed to convert original image") |
| |
| img_height, img_width = original_img_rgb.shape[:2] |
| |
| |
| templates = {} |
| template_files = sorted(Path(cropped_images_dir).glob('crop_*.png')) |
| |
| for template_file in template_files: |
| template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED) |
| if template_img is not None: |
| template_img_rgb = to_rgb(template_img) |
| templates[template_file.name] = template_img_rgb |
| |
| |
| matches = [] |
| for template_name, template_img in templates.items(): |
| try: |
| if template_img.shape[0] > img_height or template_img.shape[1] > img_width: |
| continue |
| if template_img.shape[0] < 4 or template_img.shape[1] < 4: |
| continue |
| |
| result = cv2.matchTemplate(original_img_rgb, template_img, cv2.TM_CCOEFF_NORMED) |
| _, max_val, _, max_loc = cv2.minMaxLoc(result) |
| |
| if max_val >= threshold: |
| template_h, template_w = template_img.shape[:2] |
| x1, y1 = max_loc |
| x2 = x1 + template_w |
| y2 = y1 + template_h |
| |
| center_x = (x1 + x2) / 2 |
| center_y = (y1 + y2) / 2 |
| |
| matches.append({ |
| 'template_id': template_name.replace('.png', ''), |
| 'template_file': template_name, |
| 'confidence': float(max_val), |
| 'bbox': { |
| 'x1': int(x1), |
| 'y1': int(y1), |
| 'x2': int(x2), |
| 'y2': int(y2), |
| 'width': int(template_w), |
| 'height': int(template_h) |
| }, |
| 'center': { |
| 'x': int(center_x), |
| 'y': int(center_y) |
| }, |
| 'bbox_ratio': { |
| 'x1': x1 / img_width, |
| 'y1': y1 / img_height, |
| 'x2': x2 / img_width, |
| 'y2': y2 / img_height |
| } |
| }) |
| except Exception: |
| continue |
| |
| matches.sort(key=lambda x: x['confidence'], reverse=True) |
| |
| metadata = { |
| 'image_size': {'width': img_width, 'height': img_height}, |
| 'templates_loaded': len(templates), |
| 'threshold': threshold, |
| 'matches_found': len(matches) |
| } |
| |
| return matches, metadata |
|
|
| def visualize_matches( |
| original_image_array: np.ndarray, |
| matches: list |
| ) -> np.ndarray: |
| """Create visualization with bounding boxes.""" |
| img = original_image_array.copy() |
| |
| for match in matches: |
| bbox = match['bbox'] |
| center = match['center'] |
| confidence = match['confidence'] |
| template_id = match['template_id'] |
| |
| |
| color = (0, 255, 0) |
| thickness = 2 |
| cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox['x2'], bbox['y2']), color, thickness) |
| |
| |
| cv2.circle(img, (center['x'], center['y']), 3, (0, 0, 255), -1) |
| |
| |
| label = f"ID:{template_id} ({confidence:.2f})" |
| cv2.putText(img, label, (bbox['x1'], bbox['y1'] - 5), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1) |
| |
| return img |
|
|
| def matches_to_csv(matches: list, image_width: int, image_height: int) -> str: |
| """Convert matches to CSV format (returns string).""" |
| output = io.StringIO() |
| writer = csv.writer(output) |
| writer.writerow([ |
| 'Element_ID', 'Template_File', 'Confidence', |
| 'X1', 'Y1', 'X2', 'Y2', 'Width', 'Height', |
| 'Center_X', 'Center_Y', |
| 'Ratio_X1', 'Ratio_Y1', 'Ratio_X2', 'Ratio_Y2' |
| ]) |
| |
| for match in matches: |
| bbox = match['bbox'] |
| center = match['center'] |
| ratio = match['bbox_ratio'] |
| |
| writer.writerow([ |
| match['template_id'], |
| match['template_file'], |
| f"{match['confidence']:.4f}", |
| bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2'], |
| bbox['width'], bbox['height'], |
| center['x'], center['y'], |
| f"{ratio['x1']:.6f}", f"{ratio['y1']:.6f}", |
| f"{ratio['x2']:.6f}", f"{ratio['y2']:.6f}" |
| ]) |
| |
| return output.getvalue() |
|
|
| |
|
|
| |
| omniparser = None |
| omniparser_lock = threading.Lock() |
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| """Initialize and cleanup on server startup/shutdown.""" |
| global omniparser |
| try: |
| with omniparser_lock: |
| config = get_omniparser_config() |
| omniparser = Omniparser(config) |
| print("[Server] OmniParser initialized successfully") |
| except Exception as e: |
| print(f"[ERROR] Failed to initialize OmniParser: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| |
| yield |
| |
| |
| print("[Server] Shutting down...") |
|
|
| app = FastAPI( |
| title="UI Element Detection API", |
| description="Detects and locates all UI elements in screenshots", |
| lifespan=lifespan |
| ) |
|
|
| @app.get("/health") |
| async def health(): |
| """Health check endpoint.""" |
| return {"status": "ok", "service": "UI Element Detection API"} |
|
|
| @app.post("/analyze") |
| async def analyze_image(file: UploadFile = File(...)): |
| """ |
| Analyze an image for UI elements. |
| |
| Returns: |
| JSON response with coordinates, CSV data, and base64-encoded visualization |
| """ |
| |
| if not omniparser: |
| raise HTTPException(status_code=503, detail="OmniParser not initialized") |
| |
| try: |
| print(f"\n[Analysis] Starting analysis for: {file.filename}") |
| start_time = time.time() |
| |
| |
| print("[Step 1] Reading image file...") |
| content = await file.read() |
| np_array = np.frombuffer(content, np.uint8) |
| original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) |
| |
| if original_img is None: |
| raise HTTPException(status_code=400, detail="Failed to decode image") |
| |
| print(f"[Step 1] Image loaded: {original_img.shape}") |
| |
| |
| print("[Step 2] Encoding for OmniParser...") |
| _, buffer = cv2.imencode('.png', original_img) |
| image_base64 = base64.b64encode(buffer).decode() |
| |
| |
| print("[Step 3] Running OmniParser detection...") |
| omni_time = time.time() |
| _, parsed_content = omniparser.parse(image_base64) |
| omni_elapsed = time.time() - omni_time |
| print(f"[Step 3] OmniParser complete in {omni_elapsed:.2f}s") |
| |
| |
| cropped_dir = '/tmp/omoi_cropped_images' |
| if not Path(cropped_dir).exists(): |
| raise HTTPException(status_code=500, detail="Cropped images directory not found") |
| |
| |
| print("[Step 4] Matching templates...") |
| match_time = time.time() |
| matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) |
| match_elapsed = time.time() - match_time |
| print(f"[Step 4] Matching complete in {match_elapsed:.2f}s - Found {len(matches)} elements") |
| |
| |
| print("[Step 5] Creating visualization...") |
| viz_img = visualize_matches(original_img, matches) |
| _, viz_buffer = cv2.imencode('.png', viz_img) |
| viz_base64 = base64.b64encode(viz_buffer).decode() |
| |
| |
| print("[Step 6] Generating CSV...") |
| csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) |
| |
| |
| print("[Step 7] Preparing response...") |
| response_data = { |
| 'status': 'success', |
| 'processing_time_seconds': time.time() - start_time, |
| 'timing': { |
| 'omniparser_seconds': omni_elapsed, |
| 'template_matching_seconds': match_elapsed |
| }, |
| 'image_info': { |
| 'filename': file.filename, |
| 'size': metadata['image_size'] |
| }, |
| 'analysis': { |
| 'total_elements_detected': len(matches), |
| 'elements': matches |
| }, |
| 'exports': { |
| 'csv_data': csv_data, |
| 'visualization_png_base64': viz_base64 |
| } |
| } |
| |
| total_time = time.time() - start_time |
| print(f"[Analysis] Complete in {total_time:.2f}s") |
| |
| return JSONResponse(content=response_data) |
| |
| except HTTPException: |
| raise |
| except Exception as e: |
| print(f"[ERROR] Analysis failed: {str(e)}") |
| import traceback |
| traceback.print_exc() |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") |
|
|
| @app.post("/analyze_batch") |
| async def analyze_batch(file: UploadFile = File(...)): |
| """ |
| Analyze image and return as separate parts for easier client handling. |
| |
| Returns: |
| { |
| 'metadata': analysis metadata, |
| 'coordinates_json': full coordinates data, |
| 'csv_data': CSV string, |
| 'visualization_png_base64': visualization image |
| } |
| """ |
| |
| if not omniparser: |
| raise HTTPException(status_code=503, detail="OmniParser not initialized") |
| |
| try: |
| print(f"\n[Batch Analysis] Starting for: {file.filename}") |
| |
| |
| content = await file.read() |
| np_array = np.frombuffer(content, np.uint8) |
| original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) |
| |
| if original_img is None: |
| raise HTTPException(status_code=400, detail="Failed to decode image") |
| |
| |
| image_base64 = base64.b64encode(cv2.imencode('.png', original_img)[1]).decode() |
| _, parsed_content = omniparser.parse(image_base64) |
| |
| |
| cropped_dir = '/tmp/omoi_cropped_images' |
| matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) |
| |
| |
| viz_img = visualize_matches(original_img, matches) |
| _, viz_buffer = cv2.imencode('.png', viz_img) |
| viz_base64 = base64.b64encode(viz_buffer).decode() |
| |
| |
| csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) |
| |
| |
| coordinates_json = { |
| 'source_image': file.filename, |
| 'image_size': metadata['image_size'], |
| 'total_elements': len(matches), |
| 'elements': matches |
| } |
| |
| return JSONResponse(content={ |
| 'metadata': { |
| 'filename': file.filename, |
| 'image_size': metadata['image_size'], |
| 'total_elements_detected': len(matches), |
| 'templates_loaded': metadata['templates_loaded'] |
| }, |
| 'coordinates_json': coordinates_json, |
| 'csv_data': csv_data, |
| 'visualization_png_base64': viz_base64 |
| }) |
| |
| except Exception as e: |
| print(f"[ERROR] Batch analysis failed: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| if __name__ == "__main__": |
| import multiprocessing |
| |
| parser = argparse.ArgumentParser(description='UI Element Detection API Server') |
| parser.add_argument('--host', type=str, default='127.0.0.1', help='Host to bind to') |
| parser.add_argument('--port', type=int, default=8001, help='Port to listen on') |
| parser.add_argument('--reload', action='store_true', help='Enable auto-reload') |
| parser.add_argument('--workers', type=int, default=1, help='Number of worker processes (default: 1, use for production with module import)') |
| args = parser.parse_args() |
| |
| |
| cpu_count = multiprocessing.cpu_count() |
| |
| print(f"\n{'='*70}") |
| print("UI Element Detection API Server - Optimized") |
| print(f"{'='*70}") |
| print(f"Starting server on http://{args.host}:{args.port}") |
| print(f"CPU Cores Available: {cpu_count}") |
| print(f"Workers: {args.workers} (direct mode - async concurrency enabled)") |
| print(f"\nEndpoints:") |
| print(f" POST /analyze - Analyze image with details") |
| print(f" POST /analyze_batch - Analyze image with structured response") |
| print(f" GET /health - Health check") |
| print(f"{'='*70}\n") |
| |
| try: |
| |
| uvicorn.run( |
| app, |
| host=args.host, |
| port=args.port, |
| reload=args.reload, |
| loop="auto" |
| ) |
| except KeyboardInterrupt: |
| print("\n[Server] Shutting down...") |
| except Exception as e: |
| print(f"\n[ERROR] Server error: {str(e)}") |
| import traceback |
| traceback.print_exc() |
|
|