omoi-ui-detector / ui_element_client.py
makeitfr's picture
Upload ui_element_client.py with huggingface_hub
d6f28c9 verified
Raw
History Blame Contribute Delete
6.73 kB
#!/usr/bin/env python3
"""
UI Element Detection API - Python Client Example
Demonstrates how to use the API to analyze screenshots
"""
import requests
import json
import base64
import argparse
from pathlib import Path
from PIL import Image
import io
class UIElementDetectionClient:
def __init__(self, api_url="http://127.0.0.1:8001"):
self.api_url = api_url
def health_check(self):
"""Check if API is running."""
try:
response = requests.get(f"{self.api_url}/health", timeout=5)
return response.json()
except Exception as e:
return {"error": str(e)}
def analyze_image(self, image_path):
"""
Analyze an image and get UI element coordinates.
Args:
image_path: Path to PNG image file
Returns:
Dictionary with analysis results
"""
if not Path(image_path).exists():
raise FileNotFoundError(f"Image not found: {image_path}")
print(f"[Client] Analyzing: {image_path}")
print(f"[Client] Uploading to {self.api_url}/analyze...")
with open(image_path, 'rb') as f:
files = {'file': f}
response = requests.post(
f"{self.api_url}/analyze",
files=files,
timeout=300
)
if response.status_code != 200:
raise Exception(f"API error: {response.status_code} - {response.text}")
return response.json()
def get_element_coordinates(self, image_path, save_outputs=False):
"""Analyze image and return clean coordinates."""
result = self.analyze_image(image_path)
if result['status'] != 'success':
raise Exception(f"Analysis failed: {result}")
print(f"\n[Results] Successfully analyzed image")
print(f" Total UI Elements: {result['analysis']['total_elements_detected']}")
print(f" Processing Time: {result['processing_time_seconds']:.2f}s")
print(f" Image Size: {result['image_info']['size']['width']}x{result['image_info']['size']['height']}")
# Save outputs if requested
if save_outputs:
base_name = Path(image_path).stem
# Save JSON
json_file = f"{base_name}_coordinates.json"
with open(json_file, 'w') as f:
json.dump(result['analysis'], f, indent=2)
print(f" Saved JSON: {json_file}")
# Save CSV
csv_file = f"{base_name}_coordinates.csv"
with open(csv_file, 'w') as f:
f.write(result['exports']['csv_data'])
print(f" Saved CSV: {csv_file}")
# Save visualization
viz_file = f"{base_name}_visualization.png"
viz_bytes = base64.b64decode(result['exports']['visualization_png_base64'])
Image.open(io.BytesIO(viz_bytes)).save(viz_file)
print(f" Saved Visualization: {viz_file}")
return result
def get_element_by_id(self, image_path, element_id):
"""Get specific element coordinates by ID."""
result = self.analyze_image(image_path)
for elem in result['analysis']['elements']:
if elem['template_id'] == element_id:
return elem
return None
def find_elements_in_region(self, image_path, x1, y1, x2, y2):
"""Find all elements within a region."""
result = self.analyze_image(image_path)
elements = []
for elem in result['analysis']['elements']:
bbox = elem['bbox']
# Check if element overlaps with region
if (bbox['x1'] < x2 and bbox['x2'] > x1 and
bbox['y1'] < y2 and bbox['y2'] > y1):
elements.append(elem)
return elements
def main():
parser = argparse.ArgumentParser(description='UI Element Detection API Client')
parser.add_argument('image', help='Path to image file')
parser.add_argument('--api', default='http://127.0.0.1:8001', help='API URL')
parser.add_argument('--save', action='store_true', help='Save output files')
parser.add_argument('--element', help='Get specific element by ID')
parser.add_argument('--region', nargs=4, type=int, metavar=('X1', 'Y1', 'X2', 'Y2'),
help='Find elements in region')
args = parser.parse_args()
client = UIElementDetectionClient(args.api)
# Check API health
print("[Client] Checking API health...")
health = client.health_check()
if 'error' in health:
print(f"[ERROR] API not available: {health['error']}")
return
print(f"[Client] API Status: {health['status']}")
# Analyze image
print()
try:
if args.element:
# Get specific element
result = client.analyze_image(args.image)
element = None
for elem in result['analysis']['elements']:
if elem['template_id'] == args.element:
element = elem
break
if element:
print(f"\n[Element: {args.element}]")
print(f" Position (center): ({element['center']['x']}, {element['center']['y']})")
print(f" Bounding Box: ({element['bbox']['x1']}, {element['bbox']['y1']}) -> ({element['bbox']['x2']}, {element['bbox']['y2']})")
print(f" Size: {element['bbox']['width']}x{element['bbox']['height']}")
print(f" Confidence: {element['confidence']:.4f}")
else:
print(f"[ERROR] Element '{args.element}' not found")
elif args.region:
# Find in region
elements = client.find_elements_in_region(args.image, *args.region)
print(f"\n[Found {len(elements)} elements in region {args.region}]")
for elem in elements:
print(f" - {elem['template_id']} @ ({elem['center']['x']}, {elem['center']['y']})")
else:
# Full analysis
result = client.get_element_coordinates(args.image, save_outputs=args.save)
print(f"\n[Top 5 Elements by Confidence]")
for i, elem in enumerate(result['analysis']['elements'][:5], 1):
print(f" {i}. {elem['template_id']} @ ({elem['center']['x']}, {elem['center']['y']}) - {elem['confidence']:.4f}")
except Exception as e:
print(f"[ERROR] {str(e)}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
main()