Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """ | |
| FastAPI server for OmniParser with detailed endpoints including coordinates. | |
| Run with: uvicorn server:app --host 0.0.0.0 --port 8000 | |
| """ | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.responses import JSONResponse | |
| import torch | |
| from PIL import Image | |
| import io | |
| import base64 | |
| from typing import List, Dict, Any, Optional | |
| import numpy as np | |
| from util.utils import check_ocr_box, get_yolo_model, get_caption_model_processor, get_som_labeled_img | |
| from huggingface_hub import snapshot_download | |
| # Monkey patch for gradio_client JSON schema bug | |
| try: | |
| from gradio_client import utils as gradio_client_utils | |
| original_json_schema_to_python_type = gradio_client_utils.json_schema_to_python_type | |
| def patched_json_schema_to_python_type(schema): | |
| try: | |
| if not isinstance(schema, dict): | |
| return "Any" | |
| return original_json_schema_to_python_type(schema) | |
| except (TypeError, AttributeError) as e: | |
| if "argument of type 'bool' is not iterable" in str(e): | |
| return "Any" | |
| raise | |
| gradio_client_utils.json_schema_to_python_type = patched_json_schema_to_python_type | |
| except Exception as e: | |
| print(f"Warning: Could not apply gradio_client patch: {e}") | |
| # Initialize FastAPI app | |
| app = FastAPI( | |
| title="OmniParser API", | |
| description="Screen parsing tool to convert GUI screens to structured elements with coordinates", | |
| version="2.0.0" | |
| ) | |
| # Global models | |
| _yolo_model = None | |
| _caption_model_processor = None | |
| # Proper device handling | |
| DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {DEVICE}") | |
| def load_models(): | |
| """Load models once and cache them""" | |
| global _yolo_model, _caption_model_processor | |
| if _yolo_model is None or _caption_model_processor is None: | |
| repo_id = "microsoft/OmniParser-v2.0" | |
| local_dir = "weights" | |
| print(f"Downloading repository to: {local_dir}...") | |
| snapshot_download(repo_id=repo_id, local_dir=local_dir) | |
| print(f"Repository downloaded to: {local_dir}") | |
| _yolo_model = get_yolo_model(model_path='weights/icon_detect/model.pt') | |
| _caption_model_processor = get_caption_model_processor( | |
| model_name="florence2", | |
| model_name_or_path="weights/icon_caption", | |
| device=DEVICE | |
| ) | |
| return _yolo_model, _caption_model_processor | |
| # Response Models | |
| class BoundingBox(Dict[str, Any]): | |
| """Bounding box with coordinates""" | |
| pass | |
| class Element(Dict[str, Any]): | |
| """UI element with all details""" | |
| pass | |
| class ParseResult(Dict[str, Any]): | |
| """Complete parse result""" | |
| pass | |
| async def startup_event(): | |
| """Load models on startup""" | |
| try: | |
| load_models() | |
| print("Models loaded successfully") | |
| except Exception as e: | |
| print(f"Warning: Models not fully loaded on startup: {e}") | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "service": "OmniParser API", | |
| "version": "2.0.0", | |
| "device": str(DEVICE) | |
| } | |
| async def info(): | |
| """Get API information""" | |
| return { | |
| "name": "OmniParser V2", | |
| "description": "Screen parsing tool to convert general GUI screens to structured elements", | |
| "version": "2.0.0", | |
| "device": str(DEVICE), | |
| "endpoints": { | |
| "parse": "/parse - POST - Parse an image and return structured elements", | |
| "parse_detailed": "/parse/detailed - POST - Parse with full coordinate details", | |
| "parse_batch": "/parse/batch - POST - Parse multiple images" | |
| }, | |
| "parameters": { | |
| "box_threshold": "Confidence threshold for bounding boxes (0.01-1.0)", | |
| "iou_threshold": "IOU threshold (0.01-1.0)", | |
| "use_paddleocr": "Use PaddleOCR for text detection (true/false)", | |
| "imgsz": "Image size for detection (640-1920, step 32)" | |
| } | |
| } | |
| async def parse_image( | |
| file: UploadFile = File(...), | |
| box_threshold: float = Form(0.05), | |
| iou_threshold: float = Form(0.1), | |
| use_paddleocr: bool = Form(True), | |
| imgsz: int = Form(640) | |
| ): | |
| """ | |
| Parse an image and return UI elements. | |
| Returns: | |
| - elements: List of detected UI elements | |
| - count: Total number of elements | |
| - image_base64: Parsed image with bounding boxes | |
| """ | |
| try: | |
| # Validate parameters | |
| if not 0.01 <= box_threshold <= 1.0: | |
| raise ValueError("box_threshold must be between 0.01 and 1.0") | |
| if not 0.01 <= iou_threshold <= 1.0: | |
| raise ValueError("iou_threshold must be between 0.01 and 1.0") | |
| if not (640 <= imgsz <= 1920 and imgsz % 32 == 0): | |
| raise ValueError("imgsz must be between 640-1920 and divisible by 32") | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| # Load models | |
| yolo_model, caption_model_processor = load_models() | |
| # Process image | |
| box_overlay_ratio = image.size[0] / 3200 | |
| draw_bbox_config = { | |
| 'text_scale': 0.8 * box_overlay_ratio, | |
| 'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
| 'text_padding': max(int(3 * box_overlay_ratio), 1), | |
| 'thickness': max(int(3 * box_overlay_ratio), 1), | |
| } | |
| ocr_bbox_rslt, _ = check_ocr_box( | |
| image, | |
| display_img=False, | |
| output_bb_format='xyxy', | |
| goal_filtering=None, | |
| easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
| use_paddleocr=use_paddleocr | |
| ) | |
| text, ocr_bbox = ocr_bbox_rslt | |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
| image, | |
| yolo_model, | |
| BOX_TRESHOLD=box_threshold, | |
| output_coord_in_ratio=True, | |
| ocr_bbox=ocr_bbox, | |
| draw_bbox_config=draw_bbox_config, | |
| caption_model_processor=caption_model_processor, | |
| ocr_text=text, | |
| iou_threshold=iou_threshold, | |
| imgsz=imgsz, | |
| use_local_semantics=True, | |
| scale_img=False, | |
| batch_size=128 | |
| ) | |
| # Format results | |
| elements = [f"icon {i}: {str(v)}" for i, v in enumerate(parsed_content_list)] | |
| return { | |
| "status": "success", | |
| "elements": elements, | |
| "count": len(elements), | |
| "image_base64": dino_labled_img, | |
| "parameters": { | |
| "box_threshold": box_threshold, | |
| "iou_threshold": iou_threshold, | |
| "use_paddleocr": use_paddleocr, | |
| "imgsz": imgsz | |
| } | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def parse_image_detailed( | |
| file: UploadFile = File(...), | |
| box_threshold: float = Form(0.05), | |
| iou_threshold: float = Form(0.1), | |
| use_paddleocr: bool = Form(True), | |
| imgsz: int = Form(640) | |
| ): | |
| """ | |
| Parse an image with detailed coordinate information. | |
| Returns: | |
| - elements: List of elements with full coordinate details | |
| - coordinates: Bounding box coordinates for each element | |
| - image_size: Original image dimensions | |
| - image_base64: Parsed image with annotations | |
| """ | |
| try: | |
| # Validate parameters | |
| if not 0.01 <= box_threshold <= 1.0: | |
| raise ValueError("box_threshold must be between 0.01 and 1.0") | |
| if not 0.01 <= iou_threshold <= 1.0: | |
| raise ValueError("iou_threshold must be between 0.01 and 1.0") | |
| if not (640 <= imgsz <= 1920 and imgsz % 32 == 0): | |
| raise ValueError("imgsz must be between 640-1920 and divisible by 32") | |
| # Read image | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| original_size = image.size | |
| # Load models | |
| yolo_model, caption_model_processor = load_models() | |
| # Process image | |
| box_overlay_ratio = image.size[0] / 3200 | |
| draw_bbox_config = { | |
| 'text_scale': 0.8 * box_overlay_ratio, | |
| 'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
| 'text_padding': max(int(3 * box_overlay_ratio), 1), | |
| 'thickness': max(int(3 * box_overlay_ratio), 1), | |
| } | |
| ocr_bbox_rslt, _ = check_ocr_box( | |
| image, | |
| display_img=False, | |
| output_bb_format='xyxy', | |
| goal_filtering=None, | |
| easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
| use_paddleocr=use_paddleocr | |
| ) | |
| text, ocr_bbox = ocr_bbox_rslt | |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
| image, | |
| yolo_model, | |
| BOX_TRESHOLD=box_threshold, | |
| output_coord_in_ratio=True, | |
| ocr_bbox=ocr_bbox, | |
| draw_bbox_config=draw_bbox_config, | |
| caption_model_processor=caption_model_processor, | |
| ocr_text=text, | |
| iou_threshold=iou_threshold, | |
| imgsz=imgsz, | |
| use_local_semantics=True, | |
| scale_img=False, | |
| batch_size=128 | |
| ) | |
| # Format detailed results with coordinates | |
| elements_detailed = [] | |
| for i, (content, coords) in enumerate(zip(parsed_content_list, label_coordinates.values())): | |
| # coords are in ratio format (0-1) from get_som_labeled_img | |
| element = { | |
| "id": i, | |
| "label": f"icon_{i}", | |
| "content": str(content), | |
| "coordinates": { | |
| "format": "normalized_bbox", # Values are between 0 and 1 | |
| "x_min": float(coords[0]) if len(coords) > 0 else 0, | |
| "y_min": float(coords[1]) if len(coords) > 1 else 0, | |
| "x_max": float(coords[2]) if len(coords) > 2 else 0, | |
| "y_max": float(coords[3]) if len(coords) > 3 else 0, | |
| "width": float(coords[2] - coords[0]) if len(coords) > 2 else 0, | |
| "height": float(coords[3] - coords[1]) if len(coords) > 3 else 0, | |
| "center_x": float((coords[0] + coords[2]) / 2) if len(coords) > 2 else 0, | |
| "center_y": float((coords[1] + coords[3]) / 2) if len(coords) > 3 else 0, | |
| "pixel_coordinates": { | |
| "x_min": int(coords[0] * original_size[0]) if len(coords) > 0 else 0, | |
| "y_min": int(coords[1] * original_size[1]) if len(coords) > 1 else 0, | |
| "x_max": int(coords[2] * original_size[0]) if len(coords) > 2 else 0, | |
| "y_max": int(coords[3] * original_size[1]) if len(coords) > 3 else 0, | |
| } | |
| } | |
| } | |
| elements_detailed.append(element) | |
| return { | |
| "status": "success", | |
| "image_size": { | |
| "width": original_size[0], | |
| "height": original_size[1] | |
| }, | |
| "elements": elements_detailed, | |
| "count": len(elements_detailed), | |
| "image_base64": dino_labled_img, | |
| "parameters": { | |
| "box_threshold": box_threshold, | |
| "iou_threshold": iou_threshold, | |
| "use_paddleocr": use_paddleocr, | |
| "imgsz": imgsz | |
| } | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}") | |
| async def parse_batch( | |
| files: List[UploadFile] = File(...), | |
| box_threshold: float = Form(0.05), | |
| iou_threshold: float = Form(0.1), | |
| use_paddleocr: bool = Form(True), | |
| imgsz: int = Form(640) | |
| ): | |
| """ | |
| Parse multiple images in batch. | |
| Returns: | |
| - results: List of parse results for each image | |
| - total_processed: Total number of images processed | |
| - errors: Any errors encountered | |
| """ | |
| results = [] | |
| errors = [] | |
| try: | |
| # Validate parameters | |
| if not 0.01 <= box_threshold <= 1.0: | |
| raise ValueError("box_threshold must be between 0.01 and 1.0") | |
| if not 0.01 <= iou_threshold <= 1.0: | |
| raise ValueError("iou_threshold must be between 0.01 and 1.0") | |
| if not (640 <= imgsz <= 1920 and imgsz % 32 == 0): | |
| raise ValueError("imgsz must be between 640-1920 and divisible by 32") | |
| # Load models once | |
| yolo_model, caption_model_processor = load_models() | |
| for idx, file in enumerate(files): | |
| try: | |
| contents = await file.read() | |
| image = Image.open(io.BytesIO(contents)) | |
| # Convert RGBA to RGB if necessary | |
| if image.mode == 'RGBA': | |
| image = image.convert('RGB') | |
| box_overlay_ratio = image.size[0] / 3200 | |
| draw_bbox_config = { | |
| 'text_scale': 0.8 * box_overlay_ratio, | |
| 'text_thickness': max(int(2 * box_overlay_ratio), 1), | |
| 'text_padding': max(int(3 * box_overlay_ratio), 1), | |
| 'thickness': max(int(3 * box_overlay_ratio), 1), | |
| } | |
| ocr_bbox_rslt, _ = check_ocr_box( | |
| image, | |
| display_img=False, | |
| output_bb_format='xyxy', | |
| goal_filtering=None, | |
| easyocr_args={'paragraph': False, 'text_threshold': 0.9}, | |
| use_paddleocr=use_paddleocr | |
| ) | |
| text, ocr_bbox = ocr_bbox_rslt | |
| dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img( | |
| image, | |
| yolo_model, | |
| BOX_TRESHOLD=box_threshold, | |
| output_coord_in_ratio=True, | |
| ocr_bbox=ocr_bbox, | |
| draw_bbox_config=draw_bbox_config, | |
| caption_model_processor=caption_model_processor, | |
| ocr_text=text, | |
| iou_threshold=iou_threshold, | |
| imgsz=imgsz, | |
| use_local_semantics=True, | |
| scale_img=False, | |
| batch_size=128 | |
| ) | |
| elements = [f"icon {i}: {str(v)}" for i, v in enumerate(parsed_content_list)] | |
| results.append({ | |
| "filename": file.filename, | |
| "status": "success", | |
| "elements": elements, | |
| "count": len(elements) | |
| }) | |
| except Exception as e: | |
| errors.append({ | |
| "filename": file.filename, | |
| "error": str(e) | |
| }) | |
| return { | |
| "status": "completed", | |
| "total_processed": len(results), | |
| "total_errors": len(errors), | |
| "results": results, | |
| "errors": errors if errors else None | |
| } | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=f"Batch processing error: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |