""" API Endpoints - Thin HTTP Layer This module provides FastAPI endpoints with NO business logic. All detection logic is delegated to the detection module. Architecture: - Validates HTTP requests - Delegates to detection.service for business logic - Returns standardized responses via detection.response_builder """ import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import torch from typing import Optional # Import detection services from detection.service_factory import get_detection_service from detection import ocr_handler, response_builder # Create FastAPI app app = FastAPI( title="CU-1 UI Element Detector API", description="Detect and classify UI elements in screenshots using RF-DETR + CLIP + OCR + BLIP", version="1.0.0" ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """API root endpoint with documentation""" return { "name": "CU-1 UI Element Detector API", "version": "1.0.0", "architecture": "RF-DETR (Detection) + CLIP (Classification) + OCR + BLIP", "endpoints": { "/detect": "POST - Detect UI elements in an image", "/health": "GET - Health check", "/docs": "GET - Interactive API documentation" }, "example": { "curl": """curl -X POST "http://localhost:8000/detect" \\ -F "image=@screenshot.png" \\ -F "confidence_threshold=0.35" \\ -F "enable_clip=true" \\ -F "enable_ocr=true" """ } } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "cuda_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.post("/detect") async def detect_ui_elements( image: UploadFile = File(..., description="Image file to process"), confidence_threshold: float = Form(0.35, description="Detection confidence threshold (0.1-0.9)"), line_thickness: int = Form(2, description="Bounding box thickness for annotated image (1-6)"), enable_clip: bool = Form(False, description="Enable CLIP classification"), enable_ocr: bool = Form(True, description="Enable OCR text extraction"), enable_blip: bool = Form(False, description="Enable BLIP visual description for icons"), blip_scope: str = Form("icons", description="BLIP scope: icons | all"), ocr_only: bool = Form(False, description="Run OCR across the full image and return OCR results only"), preprocess: bool = Form(False, description="Enable image preprocessing for cross-device consistency (Samsung, Pixel, Oppo, etc.)"), preprocess_mode: str = Form("rfdetr", description="Preprocessing mode: rfdetr (optimized for RF-DETR) | generic (for CLIP/OCR)"), preprocess_preset: str = Form("standard", description="Preprocessing preset (depends on mode)") ): """ Detect UI elements in an uploaded image **Parameters:** - `image`: Image file (PNG, JPG, JPEG, WebP) - `confidence_threshold`: Detection sensitivity (0.1-0.9, default: 0.35) - `line_thickness`: Bounding box line thickness (1-6, default: 2) - `enable_clip`: Classify element types using CLIP (default: false) - `enable_ocr`: Extract text content using OCR (default: true) - `enable_blip`: Generate visual descriptions using BLIP (default: false) - `blip_scope`: BLIP scope - "icons" (image/button only) or "all" (default: icons) - `ocr_only`: Skip detection/classification, run OCR only (default: false) - `preprocess`: Enable image preprocessing for cross-device consistency (default: false) - `preprocess_mode`: Preprocessing mode - "rfdetr" (optimized for RF-DETR, preserves ImageNet norm) | "generic" (for CLIP/OCR) (default: rfdetr) - `preprocess_preset`: Preprocessing preset (depends on mode, default: standard) **Returns:** ```json { "success": true, "detections": [ { "box": {"x1": 50, "y1": 100, "x2": 200, "y2": 150}, "confidence": 0.79, "class_name": "button", "text": "Submit" } ], "total_detections": 1, "image_size": {"width": 1080, "height": 1920}, "parameters": {...}, "type_distribution": {"button": 5, "text": 12} } ``` """ try: # Validate confidence threshold if not 0.1 <= confidence_threshold <= 0.9: raise HTTPException( status_code=400, detail="confidence_threshold must be between 0.1 and 0.9" ) if not 1 <= line_thickness <= 6: raise HTTPException( status_code=400, detail="line_thickness must be between 1 and 6" ) # Read and validate image try: image_bytes = await image.read() pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid image file: {str(e)}" ) # Validate OCR-only mode: CLIP and BLIP are incompatible with OCR-only if ocr_only and (enable_clip or enable_blip): raise HTTPException( status_code=400, detail="When ocr_only=true, enable_clip and enable_blip must be false" ) # OCR-only path: Bypass detection service if ocr_only: detections = ocr_handler.process_ocr_only(pil_image) annotated = ocr_handler.annotate_ocr_detections( pil_image, detections, thickness=line_thickness, return_format="numpy" ) return response_builder.build_ocr_only_response( detections=detections, image_width=pil_image.width, image_height=pil_image.height, annotated_image=annotated, confidence_threshold=confidence_threshold, line_thickness=line_thickness ) # Standard detection path: Use detection service service = get_detection_service() # Run analysis (pass parameters directly to avoid race conditions) analysis = service.analyze( pil_image, confidence_threshold=confidence_threshold, extract_text=enable_ocr, use_clip=enable_clip, use_blip=enable_blip, merge_global_ocr=True, blip_scope=(blip_scope if blip_scope in {"icons", "all"} else "icons"), preprocess=preprocess, preprocess_mode=preprocess_mode, preprocess_preset=preprocess_preset ) # Generate annotated image annotated = service.get_prediction_image( pil_image, confidence_threshold=confidence_threshold, extract_content=True, thickness=line_thickness, return_format="numpy", analysis=analysis ) # Build response return response_builder.build_detection_response( analysis=analysis, image=pil_image, annotated_image=annotated, confidence_threshold=confidence_threshold, line_thickness=line_thickness, enable_clip=enable_clip, enable_ocr=enable_ocr, enable_blip=enable_blip, blip_scope=blip_scope, ocr_only=False, include_annotated_image=True ) except HTTPException: raise except Exception as e: import traceback error_msg = f"Error during detection: {str(e)}" print(f"{error_msg}\n{traceback.format_exc()}") raise HTTPException(status_code=500, detail=error_msg)