""" API Endpoints - Thin HTTP Layer This module provides FastAPI endpoints with NO business logic. All detection logic is delegated to the detection module. Architecture: - Validates HTTP requests - Delegates to detection.service for business logic - Returns standardized responses via detection.response_builder """ import os os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' from fastapi import FastAPI, File, UploadFile, Form, HTTPException from fastapi.middleware.cors import CORSMiddleware from PIL import Image import io import torch from typing import Optional # Import detection services from detection.service_factory import get_detection_service from detection import ocr_handler, response_builder # Create FastAPI app app = FastAPI( title="CU-1 UI Element Detector API", description="Detect and classify UI elements in screenshots using RF-DETR + CLIP + OCR + BLIP", version="1.0.0" ) # Enable CORS app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/") async def root(): """API root endpoint with documentation""" return { "name": "CU-1 UI Element Detector API", "version": "1.0.0", "architecture": "RF-DETR (Detection) + CLIP (Classification) + OCR + BLIP", "endpoints": { "/detect": "POST - Detect UI elements in an image", "/health": "GET - Health check", "/warmup": "POST - Preload models to avoid timeout on first request", "/docs": "GET - Interactive API documentation" }, "example": { "curl": """curl -X POST "http://localhost:8000/detect" \\ -F "image=@screenshot.png" \\ -F "confidence_threshold=0.35" \\ -F "enable_clip=true" \\ -F "enable_ocr=true" """ } } @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "cuda_available": torch.cuda.is_available(), "device": "cuda" if torch.cuda.is_available() else "cpu" } @app.post("/warmup") async def warmup_models(): """ Warmup endpoint to preload models before first detection request. This helps avoid timeout on the first run. """ try: service = get_detection_service() # Force loading of all models by accessing them # RF-DETR is already loaded in __init__ service._load_ocr() # Load OCR if enabled service._load_clip() # Load CLIP if enabled service._load_blip() # Load BLIP if enabled return { "status": "success", "message": "Models warmed up successfully", "models_loaded": { "rfdetr": service.model is not None, "ocr": service.ocr_reader is not None if service.enable_ocr else None, "clip": service.clip_processor is not None if service.enable_clip else None, "blip": service.blip_model is not None if service.enable_blip else None } } except Exception as e: import traceback error_msg = f"Error during warmup: {str(e)}" print(f"{error_msg}\n{traceback.format_exc()}") return { "status": "error", "message": error_msg } @app.post("/detect") async def detect_ui_elements( image: UploadFile = File(..., description="Image file to process"), confidence_threshold: float = Form(0.35, description="Detection confidence threshold (0.1-0.9)"), line_thickness: int = Form(2, description="Bounding box thickness for annotated image (1-6)"), enable_clip: bool = Form(False, description="Enable CLIP classification"), enable_ocr: bool = Form(True, description="Enable OCR text extraction"), enable_blip: bool = Form(False, description="Enable BLIP visual description for icons"), blip_scope: str = Form("icons", description="BLIP scope: icons | all"), ocr_only: bool = Form(False, description="Run OCR across the full image and return OCR results only"), preprocess: bool = Form(False, description="Enable image preprocessing for cross-device consistency (Samsung, Pixel, Oppo, etc.)"), preprocess_mode: str = Form("rfdetr", description="Preprocessing mode: rfdetr (optimized for RF-DETR) | generic (for CLIP/OCR)"), preprocess_preset: str = Form("standard", description="Preprocessing preset (depends on mode)") ): """ Detect UI elements in an uploaded image **Parameters:** - `image`: Image file (PNG, JPG, JPEG, WebP) - `confidence_threshold`: Detection sensitivity (0.1-0.9, default: 0.35) - `line_thickness`: Bounding box line thickness (1-6, default: 2) - `enable_clip`: Classify element types using CLIP (default: false) - `enable_ocr`: Extract text content using OCR (default: true) - `enable_blip`: Generate visual descriptions using BLIP (default: false) - `blip_scope`: BLIP scope - "icons" (image/button only) or "all" (default: icons) - `ocr_only`: Skip detection/classification, run OCR only (default: false) - `preprocess`: Enable image preprocessing for cross-device consistency (default: false) - `preprocess_mode`: Preprocessing mode - "rfdetr" (optimized for RF-DETR, preserves ImageNet norm) | "generic" (for CLIP/OCR) (default: rfdetr) - `preprocess_preset`: Preprocessing preset (depends on mode, default: standard) **Returns:** ```json { "success": true, "detections": [ { "box": {"x1": 50, "y1": 100, "x2": 200, "y2": 150}, "confidence": 0.79, "class_name": "button", "text": "Submit" } ], "total_detections": 1, "image_size": {"width": 1080, "height": 1920}, "parameters": {...}, "type_distribution": {"button": 5, "text": 12} } ``` """ try: # Validate confidence threshold if not 0.1 <= confidence_threshold <= 0.9: raise HTTPException( status_code=400, detail="confidence_threshold must be between 0.1 and 0.9" ) if not 1 <= line_thickness <= 6: raise HTTPException( status_code=400, detail="line_thickness must be between 1 and 6" ) # Read and validate image try: image_bytes = await image.read() pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception as e: raise HTTPException( status_code=400, detail=f"Invalid image file: {str(e)}" ) # Validate OCR-only mode: CLIP and BLIP are incompatible with OCR-only if ocr_only and (enable_clip or enable_blip): raise HTTPException( status_code=400, detail="When ocr_only=true, enable_clip and enable_blip must be false" ) # OCR-only path: Bypass detection service if ocr_only: detections = ocr_handler.process_ocr_only(pil_image) annotated = ocr_handler.annotate_ocr_detections( pil_image, detections, thickness=line_thickness, return_format="numpy" ) # Build analysis structure for simplified response analysis = { "detections": detections, "image_size": {"width": pil_image.width, "height": pil_image.height} } return response_builder.build_simplified_response( analysis=analysis, image=pil_image, annotated_image=annotated, confidence_threshold=confidence_threshold, line_thickness=line_thickness, enable_clip=False, enable_ocr=True, enable_blip=False, blip_scope=None, ocr_only=True ) # Standard detection path: Use detection service import time start_time = time.time() print(f"[API] Starting detection - Image size: {pil_image.size}, CLIP: {enable_clip}, OCR: {enable_ocr}, BLIP: {enable_blip}") service = get_detection_service() # Run analysis (pass parameters directly to avoid race conditions) print(f"[API] Calling service.analyze()...") analysis_start = time.time() analysis = service.analyze( pil_image, confidence_threshold=confidence_threshold, extract_text=enable_ocr, use_clip=enable_clip, use_blip=enable_blip, merge_global_ocr=True, blip_scope=(blip_scope if blip_scope in {"icons", "all"} else "icons"), preprocess=preprocess, preprocess_mode=preprocess_mode, preprocess_preset=preprocess_preset ) analysis_time = time.time() - analysis_start print(f"[API] service.analyze() completed in {analysis_time:.2f}s - Found {len(analysis.get('detections', []))} detections") # Generate annotated image print(f"[API] Generating annotated image...") annotated_start = time.time() annotated = service.get_prediction_image( pil_image, confidence_threshold=confidence_threshold, extract_content=True, thickness=line_thickness, return_format="numpy", analysis=analysis ) annotated_time = time.time() - annotated_start print(f"[API] Annotated image generated in {annotated_time:.2f}s") total_time = time.time() - start_time print(f"[API] Total detection time: {total_time:.2f}s") # Build response using simplified format return response_builder.build_simplified_response( analysis=analysis, image=pil_image, annotated_image=annotated, confidence_threshold=confidence_threshold, line_thickness=line_thickness, enable_clip=enable_clip, enable_ocr=enable_ocr, enable_blip=enable_blip, blip_scope=blip_scope, ocr_only=False ) except HTTPException: raise except Exception as e: import traceback error_msg = f"Error during detection: {str(e)}" print(f"{error_msg}\n{traceback.format_exc()}") raise HTTPException(status_code=500, detail=error_msg)