Spaces:
Paused
Paused
| """ | |
| UI Element Detection API Server | |
| Combines OmniParser UI detection with template matching to provide | |
| precise coordinates for all UI elements in an image. | |
| Usage: | |
| python ui_element_api_server.py --port 8001 | |
| Then POST a PNG image to: http://localhost:8001/analyze | |
| Response includes: | |
| - JSON coordinates data | |
| - CSV format data | |
| - Visualization PNG with bounding boxes | |
| """ | |
| import cv2 | |
| import numpy as np | |
| import json | |
| import os | |
| import sys | |
| import io | |
| import time | |
| import base64 | |
| from pathlib import Path | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, File, UploadFile, HTTPException | |
| from fastapi.responses import JSONResponse, FileResponse | |
| import argparse | |
| import uvicorn | |
| from typing import Dict, Any, Optional, Tuple | |
| from PIL import Image | |
| import csv | |
| import tempfile | |
| import threading | |
| # Add OmniParser to path dynamically | |
| from pathlib import Path | |
| omoi_root = Path(__file__).parent | |
| sys.path.insert(0, str(omoi_root / 'OmniParser')) | |
| from util.omniparser import Omniparser | |
| from config import get_omniparser_config | |
| # ============ Utility Functions ============ | |
| def to_rgb(img: np.ndarray) -> Optional[np.ndarray]: | |
| """Converts image to BGR format (3 channels).""" | |
| if img is None: | |
| return None | |
| if len(img.shape) == 2: | |
| return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| if img.shape[2] == 4: | |
| return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR) | |
| return img | |
| def match_ui_elements( | |
| original_image_array: np.ndarray, | |
| cropped_images_dir: str, | |
| threshold: float = 0.7 | |
| ) -> Tuple[list, Dict]: | |
| """ | |
| Match cropped UI templates against original image. | |
| Returns list of matches and metadata. | |
| """ | |
| original_img_rgb = to_rgb(original_image_array) | |
| if original_img_rgb is None: | |
| raise ValueError("Failed to convert original image") | |
| img_height, img_width = original_img_rgb.shape[:2] | |
| # Load templates | |
| templates = {} | |
| template_files = sorted(Path(cropped_images_dir).glob('crop_*.png')) | |
| for template_file in template_files: | |
| template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED) | |
| if template_img is not None: | |
| template_img_rgb = to_rgb(template_img) | |
| templates[template_file.name] = template_img_rgb | |
| # Match templates | |
| matches = [] | |
| for template_name, template_img in templates.items(): | |
| try: | |
| if template_img.shape[0] > img_height or template_img.shape[1] > img_width: | |
| continue | |
| if template_img.shape[0] < 4 or template_img.shape[1] < 4: | |
| continue | |
| result = cv2.matchTemplate(original_img_rgb, template_img, cv2.TM_CCOEFF_NORMED) | |
| _, max_val, _, max_loc = cv2.minMaxLoc(result) | |
| if max_val >= threshold: | |
| template_h, template_w = template_img.shape[:2] | |
| x1, y1 = max_loc | |
| x2 = x1 + template_w | |
| y2 = y1 + template_h | |
| center_x = (x1 + x2) / 2 | |
| center_y = (y1 + y2) / 2 | |
| matches.append({ | |
| 'template_id': template_name.replace('.png', ''), | |
| 'template_file': template_name, | |
| 'confidence': float(max_val), | |
| 'bbox': { | |
| 'x1': int(x1), | |
| 'y1': int(y1), | |
| 'x2': int(x2), | |
| 'y2': int(y2), | |
| 'width': int(template_w), | |
| 'height': int(template_h) | |
| }, | |
| 'center': { | |
| 'x': int(center_x), | |
| 'y': int(center_y) | |
| }, | |
| 'bbox_ratio': { | |
| 'x1': x1 / img_width, | |
| 'y1': y1 / img_height, | |
| 'x2': x2 / img_width, | |
| 'y2': y2 / img_height | |
| } | |
| }) | |
| except Exception: | |
| continue | |
| matches.sort(key=lambda x: x['confidence'], reverse=True) | |
| metadata = { | |
| 'image_size': {'width': img_width, 'height': img_height}, | |
| 'templates_loaded': len(templates), | |
| 'threshold': threshold, | |
| 'matches_found': len(matches) | |
| } | |
| return matches, metadata | |
| def visualize_matches( | |
| original_image_array: np.ndarray, | |
| matches: list | |
| ) -> np.ndarray: | |
| """Create visualization with bounding boxes.""" | |
| img = original_image_array.copy() | |
| for match in matches: | |
| bbox = match['bbox'] | |
| center = match['center'] | |
| confidence = match['confidence'] | |
| template_id = match['template_id'] | |
| # Draw bounding box | |
| color = (0, 255, 0) # Green | |
| thickness = 2 | |
| cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox['x2'], bbox['y2']), color, thickness) | |
| # Draw center point | |
| cv2.circle(img, (center['x'], center['y']), 3, (0, 0, 255), -1) # Red | |
| # Draw label | |
| label = f"ID:{template_id} ({confidence:.2f})" | |
| cv2.putText(img, label, (bbox['x1'], bbox['y1'] - 5), | |
| cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1) | |
| return img | |
| def matches_to_csv(matches: list, image_width: int, image_height: int) -> str: | |
| """Convert matches to CSV format (returns string).""" | |
| output = io.StringIO() | |
| writer = csv.writer(output) | |
| writer.writerow([ | |
| 'Element_ID', 'Template_File', 'Confidence', | |
| 'X1', 'Y1', 'X2', 'Y2', 'Width', 'Height', | |
| 'Center_X', 'Center_Y', | |
| 'Ratio_X1', 'Ratio_Y1', 'Ratio_X2', 'Ratio_Y2' | |
| ]) | |
| for match in matches: | |
| bbox = match['bbox'] | |
| center = match['center'] | |
| ratio = match['bbox_ratio'] | |
| writer.writerow([ | |
| match['template_id'], | |
| match['template_file'], | |
| f"{match['confidence']:.4f}", | |
| bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2'], | |
| bbox['width'], bbox['height'], | |
| center['x'], center['y'], | |
| f"{ratio['x1']:.6f}", f"{ratio['y1']:.6f}", | |
| f"{ratio['x2']:.6f}", f"{ratio['y2']:.6f}" | |
| ]) | |
| return output.getvalue() | |
| # ============ FastAPI Server ============ | |
| # Global OmniParser instance | |
| omniparser = None | |
| omniparser_lock = threading.Lock() | |
| async def lifespan(app: FastAPI): | |
| """Initialize and cleanup on server startup/shutdown.""" | |
| global omniparser | |
| try: | |
| with omniparser_lock: | |
| config = get_omniparser_config() | |
| omniparser = Omniparser(config) | |
| print("[Server] OmniParser initialized successfully") | |
| except Exception as e: | |
| print(f"[ERROR] Failed to initialize OmniParser: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| yield # Application runs here | |
| # Cleanup (if any) | |
| print("[Server] Shutting down...") | |
| app = FastAPI( | |
| title="UI Element Detection API", | |
| description="Detects and locates all UI elements in screenshots", | |
| lifespan=lifespan | |
| ) | |
| async def health(): | |
| """Health check endpoint.""" | |
| return {"status": "ok", "service": "UI Element Detection API"} | |
| async def analyze_image(file: UploadFile = File(...)): | |
| """ | |
| Analyze an image for UI elements. | |
| Returns: | |
| JSON response with coordinates, CSV data, and base64-encoded visualization | |
| """ | |
| if not omniparser: | |
| raise HTTPException(status_code=503, detail="OmniParser not initialized") | |
| try: | |
| print(f"\n[Analysis] Starting analysis for: {file.filename}") | |
| start_time = time.time() | |
| # 1. Read and decode image | |
| print("[Step 1] Reading image file...") | |
| content = await file.read() | |
| np_array = np.frombuffer(content, np.uint8) | |
| original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) | |
| if original_img is None: | |
| raise HTTPException(status_code=400, detail="Failed to decode image") | |
| print(f"[Step 1] Image loaded: {original_img.shape}") | |
| # 2. Encode for OmniParser | |
| print("[Step 2] Encoding for OmniParser...") | |
| _, buffer = cv2.imencode('.png', original_img) | |
| image_base64 = base64.b64encode(buffer).decode() | |
| # 3. Run OmniParser | |
| print("[Step 3] Running OmniParser detection...") | |
| omni_time = time.time() | |
| _, parsed_content = omniparser.parse(image_base64) | |
| omni_elapsed = time.time() - omni_time | |
| print(f"[Step 3] OmniParser complete in {omni_elapsed:.2f}s") | |
| # 4. Get cropped images directory | |
| cropped_dir = '/tmp/omoi_cropped_images' | |
| if not Path(cropped_dir).exists(): | |
| raise HTTPException(status_code=500, detail="Cropped images directory not found") | |
| # 5. Match UI elements | |
| print("[Step 4] Matching templates...") | |
| match_time = time.time() | |
| matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) | |
| match_elapsed = time.time() - match_time | |
| print(f"[Step 4] Matching complete in {match_elapsed:.2f}s - Found {len(matches)} elements") | |
| # 6. Create visualization | |
| print("[Step 5] Creating visualization...") | |
| viz_img = visualize_matches(original_img, matches) | |
| _, viz_buffer = cv2.imencode('.png', viz_img) | |
| viz_base64 = base64.b64encode(viz_buffer).decode() | |
| # 7. Generate CSV | |
| print("[Step 6] Generating CSV...") | |
| csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) | |
| # 8. Prepare response | |
| print("[Step 7] Preparing response...") | |
| response_data = { | |
| 'status': 'success', | |
| 'processing_time_seconds': time.time() - start_time, | |
| 'timing': { | |
| 'omniparser_seconds': omni_elapsed, | |
| 'template_matching_seconds': match_elapsed | |
| }, | |
| 'image_info': { | |
| 'filename': file.filename, | |
| 'size': metadata['image_size'] | |
| }, | |
| 'analysis': { | |
| 'total_elements_detected': len(matches), | |
| 'elements': matches | |
| }, | |
| 'exports': { | |
| 'csv_data': csv_data, | |
| 'visualization_png_base64': viz_base64 | |
| } | |
| } | |
| total_time = time.time() - start_time | |
| print(f"[Analysis] Complete in {total_time:.2f}s") | |
| return JSONResponse(content=response_data) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| print(f"[ERROR] Analysis failed: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}") | |
| async def analyze_batch(file: UploadFile = File(...)): | |
| """ | |
| Analyze image and return as separate parts for easier client handling. | |
| Returns: | |
| { | |
| 'metadata': analysis metadata, | |
| 'coordinates_json': full coordinates data, | |
| 'csv_data': CSV string, | |
| 'visualization_png_base64': visualization image | |
| } | |
| """ | |
| if not omniparser: | |
| raise HTTPException(status_code=503, detail="OmniParser not initialized") | |
| try: | |
| print(f"\n[Batch Analysis] Starting for: {file.filename}") | |
| # Read image | |
| content = await file.read() | |
| np_array = np.frombuffer(content, np.uint8) | |
| original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED) | |
| if original_img is None: | |
| raise HTTPException(status_code=400, detail="Failed to decode image") | |
| # Run OmniParser | |
| image_base64 = base64.b64encode(cv2.imencode('.png', original_img)[1]).decode() | |
| _, parsed_content = omniparser.parse(image_base64) | |
| # Match templates | |
| cropped_dir = '/tmp/omoi_cropped_images' | |
| matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7) | |
| # Create visualization | |
| viz_img = visualize_matches(original_img, matches) | |
| _, viz_buffer = cv2.imencode('.png', viz_img) | |
| viz_base64 = base64.b64encode(viz_buffer).decode() | |
| # CSV data | |
| csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height']) | |
| # Create JSON structure | |
| coordinates_json = { | |
| 'source_image': file.filename, | |
| 'image_size': metadata['image_size'], | |
| 'total_elements': len(matches), | |
| 'elements': matches | |
| } | |
| return JSONResponse(content={ | |
| 'metadata': { | |
| 'filename': file.filename, | |
| 'image_size': metadata['image_size'], | |
| 'total_elements_detected': len(matches), | |
| 'templates_loaded': metadata['templates_loaded'] | |
| }, | |
| 'coordinates_json': coordinates_json, | |
| 'csv_data': csv_data, | |
| 'visualization_png_base64': viz_base64 | |
| }) | |
| except Exception as e: | |
| print(f"[ERROR] Batch analysis failed: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| import multiprocessing | |
| parser = argparse.ArgumentParser(description='UI Element Detection API Server') | |
| parser.add_argument('--host', type=str, default='127.0.0.1', help='Host to bind to') | |
| parser.add_argument('--port', type=int, default=8001, help='Port to listen on') | |
| parser.add_argument('--reload', action='store_true', help='Enable auto-reload') | |
| parser.add_argument('--workers', type=int, default=1, help='Number of worker processes (default: 1, use for production with module import)') | |
| args = parser.parse_args() | |
| # Get CPU count for reference | |
| cpu_count = multiprocessing.cpu_count() | |
| print(f"\n{'='*70}") | |
| print("UI Element Detection API Server - Optimized") | |
| print(f"{'='*70}") | |
| print(f"Starting server on http://{args.host}:{args.port}") | |
| print(f"CPU Cores Available: {cpu_count}") | |
| print(f"Workers: {args.workers} (direct mode - async concurrency enabled)") | |
| print(f"\nEndpoints:") | |
| print(f" POST /analyze - Analyze image with details") | |
| print(f" POST /analyze_batch - Analyze image with structured response") | |
| print(f" GET /health - Health check") | |
| print(f"{'='*70}\n") | |
| try: | |
| # Run with async concurrency instead of multiple workers for direct instantiation | |
| uvicorn.run( | |
| app, | |
| host=args.host, | |
| port=args.port, | |
| reload=args.reload, | |
| loop="auto" | |
| ) | |
| except KeyboardInterrupt: | |
| print("\n[Server] Shutting down...") | |
| except Exception as e: | |
| print(f"\n[ERROR] Server error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |