CU1-X / api /endpoints.py
AI-DrivenTesting's picture
init
77da9e2
raw
history blame
8.11 kB
"""
API Endpoints - Thin HTTP Layer
This module provides FastAPI endpoints with NO business logic.
All detection logic is delegated to the detection module.
Architecture:
- Validates HTTP requests
- Delegates to detection.service for business logic
- Returns standardized responses via detection.response_builder
"""
import os
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from PIL import Image
import io
import torch
from typing import Optional
# Import detection services
from detection.service_factory import get_detection_service
from detection import ocr_handler, response_builder
# Create FastAPI app
app = FastAPI(
title="CU-1 UI Element Detector API",
description="Detect and classify UI elements in screenshots using RF-DETR + CLIP + OCR + BLIP",
version="1.0.0"
)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
"""API root endpoint with documentation"""
return {
"name": "CU-1 UI Element Detector API",
"version": "1.0.0",
"architecture": "RF-DETR (Detection) + CLIP (Classification) + OCR + BLIP",
"endpoints": {
"/detect": "POST - Detect UI elements in an image",
"/health": "GET - Health check",
"/docs": "GET - Interactive API documentation"
},
"example": {
"curl": """curl -X POST "http://localhost:8000/detect" \\
-F "image=@screenshot.png" \\
-F "confidence_threshold=0.35" \\
-F "enable_clip=true" \\
-F "enable_ocr=true" """
}
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"cuda_available": torch.cuda.is_available(),
"device": "cuda" if torch.cuda.is_available() else "cpu"
}
@app.post("/detect")
async def detect_ui_elements(
image: UploadFile = File(..., description="Image file to process"),
confidence_threshold: float = Form(0.35, description="Detection confidence threshold (0.1-0.9)"),
line_thickness: int = Form(2, description="Bounding box thickness for annotated image (1-6)"),
enable_clip: bool = Form(False, description="Enable CLIP classification"),
enable_ocr: bool = Form(True, description="Enable OCR text extraction"),
enable_blip: bool = Form(False, description="Enable BLIP visual description for icons"),
blip_scope: str = Form("icons", description="BLIP scope: icons | all"),
ocr_only: bool = Form(False, description="Run OCR across the full image and return OCR results only"),
preprocess: bool = Form(False, description="Enable image preprocessing for cross-device consistency (Samsung, Pixel, Oppo, etc.)"),
preprocess_mode: str = Form("rfdetr", description="Preprocessing mode: rfdetr (optimized for RF-DETR) | generic (for CLIP/OCR)"),
preprocess_preset: str = Form("standard", description="Preprocessing preset (depends on mode)")
):
"""
Detect UI elements in an uploaded image
**Parameters:**
- `image`: Image file (PNG, JPG, JPEG, WebP)
- `confidence_threshold`: Detection sensitivity (0.1-0.9, default: 0.35)
- `line_thickness`: Bounding box line thickness (1-6, default: 2)
- `enable_clip`: Classify element types using CLIP (default: false)
- `enable_ocr`: Extract text content using OCR (default: true)
- `enable_blip`: Generate visual descriptions using BLIP (default: false)
- `blip_scope`: BLIP scope - "icons" (image/button only) or "all" (default: icons)
- `ocr_only`: Skip detection/classification, run OCR only (default: false)
- `preprocess`: Enable image preprocessing for cross-device consistency (default: false)
- `preprocess_mode`: Preprocessing mode - "rfdetr" (optimized for RF-DETR, preserves ImageNet norm) | "generic" (for CLIP/OCR) (default: rfdetr)
- `preprocess_preset`: Preprocessing preset (depends on mode, default: standard)
**Returns:**
```json
{
"success": true,
"detections": [
{
"box": {"x1": 50, "y1": 100, "x2": 200, "y2": 150},
"confidence": 0.79,
"class_name": "button",
"text": "Submit"
}
],
"total_detections": 1,
"image_size": {"width": 1080, "height": 1920},
"parameters": {...},
"type_distribution": {"button": 5, "text": 12}
}
```
"""
try:
# Validate confidence threshold
if not 0.1 <= confidence_threshold <= 0.9:
raise HTTPException(
status_code=400,
detail="confidence_threshold must be between 0.1 and 0.9"
)
if not 1 <= line_thickness <= 6:
raise HTTPException(
status_code=400,
detail="line_thickness must be between 1 and 6"
)
# Read and validate image
try:
image_bytes = await image.read()
pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Invalid image file: {str(e)}"
)
# Validate OCR-only mode: CLIP and BLIP are incompatible with OCR-only
if ocr_only and (enable_clip or enable_blip):
raise HTTPException(
status_code=400,
detail="When ocr_only=true, enable_clip and enable_blip must be false"
)
# OCR-only path: Bypass detection service
if ocr_only:
detections = ocr_handler.process_ocr_only(pil_image)
annotated = ocr_handler.annotate_ocr_detections(
pil_image,
detections,
thickness=line_thickness,
return_format="numpy"
)
return response_builder.build_ocr_only_response(
detections=detections,
image_width=pil_image.width,
image_height=pil_image.height,
annotated_image=annotated,
confidence_threshold=confidence_threshold,
line_thickness=line_thickness
)
# Standard detection path: Use detection service
service = get_detection_service()
# Run analysis (pass parameters directly to avoid race conditions)
analysis = service.analyze(
pil_image,
confidence_threshold=confidence_threshold,
extract_text=enable_ocr,
use_clip=enable_clip,
use_blip=enable_blip,
merge_global_ocr=True,
blip_scope=(blip_scope if blip_scope in {"icons", "all"} else "icons"),
preprocess=preprocess,
preprocess_mode=preprocess_mode,
preprocess_preset=preprocess_preset
)
# Generate annotated image
annotated = service.get_prediction_image(
pil_image,
confidence_threshold=confidence_threshold,
extract_content=True,
thickness=line_thickness,
return_format="numpy",
analysis=analysis
)
# Build response
return response_builder.build_detection_response(
analysis=analysis,
image=pil_image,
annotated_image=annotated,
confidence_threshold=confidence_threshold,
line_thickness=line_thickness,
enable_clip=enable_clip,
enable_ocr=enable_ocr,
enable_blip=enable_blip,
blip_scope=blip_scope,
ocr_only=False,
include_annotated_image=True
)
except HTTPException:
raise
except Exception as e:
import traceback
error_msg = f"Error during detection: {str(e)}"
print(f"{error_msg}\n{traceback.format_exc()}")
raise HTTPException(status_code=500, detail=error_msg)