omoi-ui-detector / ui_element_api_server.py
makeitfr's picture
Upload ui_element_api_server.py with huggingface_hub
25207a2 verified
Raw
History Blame Contribute Delete
15.1 kB
"""
UI Element Detection API Server
Combines OmniParser UI detection with template matching to provide
precise coordinates for all UI elements in an image.
Usage:
python ui_element_api_server.py --port 8001
Then POST a PNG image to: http://localhost:8001/analyze
Response includes:
- JSON coordinates data
- CSV format data
- Visualization PNG with bounding boxes
"""
import cv2
import numpy as np
import json
import os
import sys
import io
import time
import base64
from pathlib import Path
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse, FileResponse
import argparse
import uvicorn
from typing import Dict, Any, Optional, Tuple
from PIL import Image
import csv
import tempfile
import threading
# Add OmniParser to path dynamically
from pathlib import Path
omoi_root = Path(__file__).parent
sys.path.insert(0, str(omoi_root / 'OmniParser'))
from util.omniparser import Omniparser
from config import get_omniparser_config
# ============ Utility Functions ============
def to_rgb(img: np.ndarray) -> Optional[np.ndarray]:
"""Converts image to BGR format (3 channels)."""
if img is None:
return None
if len(img.shape) == 2:
return cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
return cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
return img
def match_ui_elements(
original_image_array: np.ndarray,
cropped_images_dir: str,
threshold: float = 0.7
) -> Tuple[list, Dict]:
"""
Match cropped UI templates against original image.
Returns list of matches and metadata.
"""
original_img_rgb = to_rgb(original_image_array)
if original_img_rgb is None:
raise ValueError("Failed to convert original image")
img_height, img_width = original_img_rgb.shape[:2]
# Load templates
templates = {}
template_files = sorted(Path(cropped_images_dir).glob('crop_*.png'))
for template_file in template_files:
template_img = cv2.imread(str(template_file), cv2.IMREAD_UNCHANGED)
if template_img is not None:
template_img_rgb = to_rgb(template_img)
templates[template_file.name] = template_img_rgb
# Match templates
matches = []
for template_name, template_img in templates.items():
try:
if template_img.shape[0] > img_height or template_img.shape[1] > img_width:
continue
if template_img.shape[0] < 4 or template_img.shape[1] < 4:
continue
result = cv2.matchTemplate(original_img_rgb, template_img, cv2.TM_CCOEFF_NORMED)
_, max_val, _, max_loc = cv2.minMaxLoc(result)
if max_val >= threshold:
template_h, template_w = template_img.shape[:2]
x1, y1 = max_loc
x2 = x1 + template_w
y2 = y1 + template_h
center_x = (x1 + x2) / 2
center_y = (y1 + y2) / 2
matches.append({
'template_id': template_name.replace('.png', ''),
'template_file': template_name,
'confidence': float(max_val),
'bbox': {
'x1': int(x1),
'y1': int(y1),
'x2': int(x2),
'y2': int(y2),
'width': int(template_w),
'height': int(template_h)
},
'center': {
'x': int(center_x),
'y': int(center_y)
},
'bbox_ratio': {
'x1': x1 / img_width,
'y1': y1 / img_height,
'x2': x2 / img_width,
'y2': y2 / img_height
}
})
except Exception:
continue
matches.sort(key=lambda x: x['confidence'], reverse=True)
metadata = {
'image_size': {'width': img_width, 'height': img_height},
'templates_loaded': len(templates),
'threshold': threshold,
'matches_found': len(matches)
}
return matches, metadata
def visualize_matches(
original_image_array: np.ndarray,
matches: list
) -> np.ndarray:
"""Create visualization with bounding boxes."""
img = original_image_array.copy()
for match in matches:
bbox = match['bbox']
center = match['center']
confidence = match['confidence']
template_id = match['template_id']
# Draw bounding box
color = (0, 255, 0) # Green
thickness = 2
cv2.rectangle(img, (bbox['x1'], bbox['y1']), (bbox['x2'], bbox['y2']), color, thickness)
# Draw center point
cv2.circle(img, (center['x'], center['y']), 3, (0, 0, 255), -1) # Red
# Draw label
label = f"ID:{template_id} ({confidence:.2f})"
cv2.putText(img, label, (bbox['x1'], bbox['y1'] - 5),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 0, 0), 1)
return img
def matches_to_csv(matches: list, image_width: int, image_height: int) -> str:
"""Convert matches to CSV format (returns string)."""
output = io.StringIO()
writer = csv.writer(output)
writer.writerow([
'Element_ID', 'Template_File', 'Confidence',
'X1', 'Y1', 'X2', 'Y2', 'Width', 'Height',
'Center_X', 'Center_Y',
'Ratio_X1', 'Ratio_Y1', 'Ratio_X2', 'Ratio_Y2'
])
for match in matches:
bbox = match['bbox']
center = match['center']
ratio = match['bbox_ratio']
writer.writerow([
match['template_id'],
match['template_file'],
f"{match['confidence']:.4f}",
bbox['x1'], bbox['y1'], bbox['x2'], bbox['y2'],
bbox['width'], bbox['height'],
center['x'], center['y'],
f"{ratio['x1']:.6f}", f"{ratio['y1']:.6f}",
f"{ratio['x2']:.6f}", f"{ratio['y2']:.6f}"
])
return output.getvalue()
# ============ FastAPI Server ============
# Global OmniParser instance
omniparser = None
omniparser_lock = threading.Lock()
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Initialize and cleanup on server startup/shutdown."""
global omniparser
try:
with omniparser_lock:
config = get_omniparser_config()
omniparser = Omniparser(config)
print("[Server] OmniParser initialized successfully")
except Exception as e:
print(f"[ERROR] Failed to initialize OmniParser: {str(e)}")
import traceback
traceback.print_exc()
yield # Application runs here
# Cleanup (if any)
print("[Server] Shutting down...")
app = FastAPI(
title="UI Element Detection API",
description="Detects and locates all UI elements in screenshots",
lifespan=lifespan
)
@app.get("/health")
async def health():
"""Health check endpoint."""
return {"status": "ok", "service": "UI Element Detection API"}
@app.post("/analyze")
async def analyze_image(file: UploadFile = File(...)):
"""
Analyze an image for UI elements.
Returns:
JSON response with coordinates, CSV data, and base64-encoded visualization
"""
if not omniparser:
raise HTTPException(status_code=503, detail="OmniParser not initialized")
try:
print(f"\n[Analysis] Starting analysis for: {file.filename}")
start_time = time.time()
# 1. Read and decode image
print("[Step 1] Reading image file...")
content = await file.read()
np_array = np.frombuffer(content, np.uint8)
original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if original_img is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
print(f"[Step 1] Image loaded: {original_img.shape}")
# 2. Encode for OmniParser
print("[Step 2] Encoding for OmniParser...")
_, buffer = cv2.imencode('.png', original_img)
image_base64 = base64.b64encode(buffer).decode()
# 3. Run OmniParser
print("[Step 3] Running OmniParser detection...")
omni_time = time.time()
_, parsed_content = omniparser.parse(image_base64)
omni_elapsed = time.time() - omni_time
print(f"[Step 3] OmniParser complete in {omni_elapsed:.2f}s")
# 4. Get cropped images directory
cropped_dir = '/tmp/omoi_cropped_images'
if not Path(cropped_dir).exists():
raise HTTPException(status_code=500, detail="Cropped images directory not found")
# 5. Match UI elements
print("[Step 4] Matching templates...")
match_time = time.time()
matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7)
match_elapsed = time.time() - match_time
print(f"[Step 4] Matching complete in {match_elapsed:.2f}s - Found {len(matches)} elements")
# 6. Create visualization
print("[Step 5] Creating visualization...")
viz_img = visualize_matches(original_img, matches)
_, viz_buffer = cv2.imencode('.png', viz_img)
viz_base64 = base64.b64encode(viz_buffer).decode()
# 7. Generate CSV
print("[Step 6] Generating CSV...")
csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height'])
# 8. Prepare response
print("[Step 7] Preparing response...")
response_data = {
'status': 'success',
'processing_time_seconds': time.time() - start_time,
'timing': {
'omniparser_seconds': omni_elapsed,
'template_matching_seconds': match_elapsed
},
'image_info': {
'filename': file.filename,
'size': metadata['image_size']
},
'analysis': {
'total_elements_detected': len(matches),
'elements': matches
},
'exports': {
'csv_data': csv_data,
'visualization_png_base64': viz_base64
}
}
total_time = time.time() - start_time
print(f"[Analysis] Complete in {total_time:.2f}s")
return JSONResponse(content=response_data)
except HTTPException:
raise
except Exception as e:
print(f"[ERROR] Analysis failed: {str(e)}")
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"Analysis failed: {str(e)}")
@app.post("/analyze_batch")
async def analyze_batch(file: UploadFile = File(...)):
"""
Analyze image and return as separate parts for easier client handling.
Returns:
{
'metadata': analysis metadata,
'coordinates_json': full coordinates data,
'csv_data': CSV string,
'visualization_png_base64': visualization image
}
"""
if not omniparser:
raise HTTPException(status_code=503, detail="OmniParser not initialized")
try:
print(f"\n[Batch Analysis] Starting for: {file.filename}")
# Read image
content = await file.read()
np_array = np.frombuffer(content, np.uint8)
original_img = cv2.imdecode(np_array, cv2.IMREAD_UNCHANGED)
if original_img is None:
raise HTTPException(status_code=400, detail="Failed to decode image")
# Run OmniParser
image_base64 = base64.b64encode(cv2.imencode('.png', original_img)[1]).decode()
_, parsed_content = omniparser.parse(image_base64)
# Match templates
cropped_dir = '/tmp/omoi_cropped_images'
matches, metadata = match_ui_elements(original_img, cropped_dir, threshold=0.7)
# Create visualization
viz_img = visualize_matches(original_img, matches)
_, viz_buffer = cv2.imencode('.png', viz_img)
viz_base64 = base64.b64encode(viz_buffer).decode()
# CSV data
csv_data = matches_to_csv(matches, metadata['image_size']['width'], metadata['image_size']['height'])
# Create JSON structure
coordinates_json = {
'source_image': file.filename,
'image_size': metadata['image_size'],
'total_elements': len(matches),
'elements': matches
}
return JSONResponse(content={
'metadata': {
'filename': file.filename,
'image_size': metadata['image_size'],
'total_elements_detected': len(matches),
'templates_loaded': metadata['templates_loaded']
},
'coordinates_json': coordinates_json,
'csv_data': csv_data,
'visualization_png_base64': viz_base64
})
except Exception as e:
print(f"[ERROR] Batch analysis failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import multiprocessing
parser = argparse.ArgumentParser(description='UI Element Detection API Server')
parser.add_argument('--host', type=str, default='127.0.0.1', help='Host to bind to')
parser.add_argument('--port', type=int, default=8001, help='Port to listen on')
parser.add_argument('--reload', action='store_true', help='Enable auto-reload')
parser.add_argument('--workers', type=int, default=1, help='Number of worker processes (default: 1, use for production with module import)')
args = parser.parse_args()
# Get CPU count for reference
cpu_count = multiprocessing.cpu_count()
print(f"\n{'='*70}")
print("UI Element Detection API Server - Optimized")
print(f"{'='*70}")
print(f"Starting server on http://{args.host}:{args.port}")
print(f"CPU Cores Available: {cpu_count}")
print(f"Workers: {args.workers} (direct mode - async concurrency enabled)")
print(f"\nEndpoints:")
print(f" POST /analyze - Analyze image with details")
print(f" POST /analyze_batch - Analyze image with structured response")
print(f" GET /health - Health check")
print(f"{'='*70}\n")
try:
# Run with async concurrency instead of multiple workers for direct instantiation
uvicorn.run(
app,
host=args.host,
port=args.port,
reload=args.reload,
loop="auto"
)
except KeyboardInterrupt:
print("\n[Server] Shutting down...")
except Exception as e:
print(f"\n[ERROR] Server error: {str(e)}")
import traceback
traceback.print_exc()