Spaces:
Sleeping
Sleeping
| """ | |
| API Endpoints - Thin HTTP Layer | |
| This module provides FastAPI endpoints with NO business logic. | |
| All detection logic is delegated to the detection module. | |
| Architecture: | |
| - Validates HTTP requests | |
| - Delegates to detection.service for business logic | |
| - Returns standardized responses via detection.response_builder | |
| """ | |
| import os | |
| os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' | |
| from fastapi import FastAPI, File, UploadFile, Form, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from PIL import Image | |
| import io | |
| import torch | |
| from typing import Optional | |
| # Import detection services | |
| from detection.service_factory import get_detection_service | |
| from detection import ocr_handler, response_builder | |
| # Create FastAPI app | |
| app = FastAPI( | |
| title="CU-1 UI Element Detector API", | |
| description="Detect and classify UI elements in screenshots using RF-DETR + CLIP + OCR + BLIP", | |
| version="1.0.0" | |
| ) | |
| # Enable CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def root(): | |
| """API root endpoint with documentation""" | |
| return { | |
| "name": "CU-1 UI Element Detector API", | |
| "version": "1.0.0", | |
| "architecture": "RF-DETR (Detection) + CLIP (Classification) + OCR + BLIP", | |
| "endpoints": { | |
| "/detect": "POST - Detect UI elements in an image", | |
| "/health": "GET - Health check", | |
| "/docs": "GET - Interactive API documentation" | |
| }, | |
| "example": { | |
| "curl": """curl -X POST "http://localhost:8000/detect" \\ | |
| -F "image=@screenshot.png" \\ | |
| -F "confidence_threshold=0.35" \\ | |
| -F "enable_clip=true" \\ | |
| -F "enable_ocr=true" """ | |
| } | |
| } | |
| async def health_check(): | |
| """Health check endpoint""" | |
| return { | |
| "status": "healthy", | |
| "cuda_available": torch.cuda.is_available(), | |
| "device": "cuda" if torch.cuda.is_available() else "cpu" | |
| } | |
| async def detect_ui_elements( | |
| image: UploadFile = File(..., description="Image file to process"), | |
| confidence_threshold: float = Form(0.35, description="Detection confidence threshold (0.1-0.9)"), | |
| line_thickness: int = Form(2, description="Bounding box thickness for annotated image (1-6)"), | |
| enable_clip: bool = Form(False, description="Enable CLIP classification"), | |
| enable_ocr: bool = Form(True, description="Enable OCR text extraction"), | |
| enable_blip: bool = Form(False, description="Enable BLIP visual description for icons"), | |
| blip_scope: str = Form("icons", description="BLIP scope: icons | all"), | |
| ocr_only: bool = Form(False, description="Run OCR across the full image and return OCR results only"), | |
| preprocess: bool = Form(False, description="Enable image preprocessing for cross-device consistency (Samsung, Pixel, Oppo, etc.)"), | |
| preprocess_mode: str = Form("rfdetr", description="Preprocessing mode: rfdetr (optimized for RF-DETR) | generic (for CLIP/OCR)"), | |
| preprocess_preset: str = Form("standard", description="Preprocessing preset (depends on mode)") | |
| ): | |
| """ | |
| Detect UI elements in an uploaded image | |
| **Parameters:** | |
| - `image`: Image file (PNG, JPG, JPEG, WebP) | |
| - `confidence_threshold`: Detection sensitivity (0.1-0.9, default: 0.35) | |
| - `line_thickness`: Bounding box line thickness (1-6, default: 2) | |
| - `enable_clip`: Classify element types using CLIP (default: false) | |
| - `enable_ocr`: Extract text content using OCR (default: true) | |
| - `enable_blip`: Generate visual descriptions using BLIP (default: false) | |
| - `blip_scope`: BLIP scope - "icons" (image/button only) or "all" (default: icons) | |
| - `ocr_only`: Skip detection/classification, run OCR only (default: false) | |
| - `preprocess`: Enable image preprocessing for cross-device consistency (default: false) | |
| - `preprocess_mode`: Preprocessing mode - "rfdetr" (optimized for RF-DETR, preserves ImageNet norm) | "generic" (for CLIP/OCR) (default: rfdetr) | |
| - `preprocess_preset`: Preprocessing preset (depends on mode, default: standard) | |
| **Returns:** | |
| ```json | |
| { | |
| "success": true, | |
| "detections": [ | |
| { | |
| "box": {"x1": 50, "y1": 100, "x2": 200, "y2": 150}, | |
| "confidence": 0.79, | |
| "class_name": "button", | |
| "text": "Submit" | |
| } | |
| ], | |
| "total_detections": 1, | |
| "image_size": {"width": 1080, "height": 1920}, | |
| "parameters": {...}, | |
| "type_distribution": {"button": 5, "text": 12} | |
| } | |
| ``` | |
| """ | |
| try: | |
| # Validate confidence threshold | |
| if not 0.1 <= confidence_threshold <= 0.9: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="confidence_threshold must be between 0.1 and 0.9" | |
| ) | |
| if not 1 <= line_thickness <= 6: | |
| raise HTTPException( | |
| status_code=400, | |
| detail="line_thickness must be between 1 and 6" | |
| ) | |
| # Read and validate image | |
| try: | |
| image_bytes = await image.read() | |
| pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid image file: {str(e)}" | |
| ) | |
| # Validate OCR-only mode: CLIP and BLIP are incompatible with OCR-only | |
| if ocr_only and (enable_clip or enable_blip): | |
| raise HTTPException( | |
| status_code=400, | |
| detail="When ocr_only=true, enable_clip and enable_blip must be false" | |
| ) | |
| # OCR-only path: Bypass detection service | |
| if ocr_only: | |
| detections = ocr_handler.process_ocr_only(pil_image) | |
| annotated = ocr_handler.annotate_ocr_detections( | |
| pil_image, | |
| detections, | |
| thickness=line_thickness, | |
| return_format="numpy" | |
| ) | |
| return response_builder.build_ocr_only_response( | |
| detections=detections, | |
| image_width=pil_image.width, | |
| image_height=pil_image.height, | |
| annotated_image=annotated, | |
| confidence_threshold=confidence_threshold, | |
| line_thickness=line_thickness | |
| ) | |
| # Standard detection path: Use detection service | |
| service = get_detection_service() | |
| # Run analysis (pass parameters directly to avoid race conditions) | |
| analysis = service.analyze( | |
| pil_image, | |
| confidence_threshold=confidence_threshold, | |
| extract_text=enable_ocr, | |
| use_clip=enable_clip, | |
| use_blip=enable_blip, | |
| merge_global_ocr=True, | |
| blip_scope=(blip_scope if blip_scope in {"icons", "all"} else "icons"), | |
| preprocess=preprocess, | |
| preprocess_mode=preprocess_mode, | |
| preprocess_preset=preprocess_preset | |
| ) | |
| # Generate annotated image | |
| annotated = service.get_prediction_image( | |
| pil_image, | |
| confidence_threshold=confidence_threshold, | |
| extract_content=True, | |
| thickness=line_thickness, | |
| return_format="numpy", | |
| analysis=analysis | |
| ) | |
| # Build response | |
| return response_builder.build_detection_response( | |
| analysis=analysis, | |
| image=pil_image, | |
| annotated_image=annotated, | |
| confidence_threshold=confidence_threshold, | |
| line_thickness=line_thickness, | |
| enable_clip=enable_clip, | |
| enable_ocr=enable_ocr, | |
| enable_blip=enable_blip, | |
| blip_scope=blip_scope, | |
| ocr_only=False, | |
| include_annotated_image=True | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during detection: {str(e)}" | |
| print(f"{error_msg}\n{traceback.format_exc()}") | |
| raise HTTPException(status_code=500, detail=error_msg) | |