import json from typing import Any, Dict, Optional import cv2 import numpy as np from fastapi import APIRouter, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect from ...models.schemas import DetectionParams, DetectionRequest, DetectionResponse from ...services.runtime_adapter import DetectionResult, run_detection from ...utils.image_io import decode_base64_image, encode_png_base64 router = APIRouter(prefix="/v1/detect", tags=["detection"]) DETECTOR_KEYS: Dict[str, str] = { "edges": "Edges (Canny)", "corners": "Corners (Harris)", "lines": "Lines (Hough/LSD)", "ellipses": "Ellipses (Contours + fitEllipse)", } ALLOWED_MODES = {"classical", "dl", "both"} def _detector_label(key: str) -> str: if key not in DETECTOR_KEYS: raise HTTPException(status_code=404, detail=f"Unknown detector '{key}'.") return DETECTOR_KEYS[key] def _resolve_mode(mode: str, compare: bool) -> str: if compare: return "both" if mode not in ALLOWED_MODES: return "classical" return mode def _choose_primary(mode: str, overlays: Dict[str, str]) -> Optional[str]: if mode == "dl" and "dl" in overlays: return "dl" if mode == "classical" and "classical" in overlays: return "classical" if mode == "both": if "classical" in overlays: return "classical" if "dl" in overlays: return "dl" return next(iter(overlays.keys()), None) def _format_result(result: DetectionResult, mode: str) -> DetectionResponse: overlays_encoded: Dict[str, Optional[str]] = {} for path, image in result.overlays.items(): if image is None: overlays_encoded[path] = None continue overlays_encoded[path] = encode_png_base64(image) primary = _choose_primary(mode, {k: v for k, v in overlays_encoded.items() if v}) model_info = result.models.get(primary or "classical", result.models.get("classical", {})) return DetectionResponse( overlay=overlays_encoded.get(primary) if primary else None, overlays=overlays_encoded, features=result.features, timings=result.timings_ms, fps_estimate=result.fps_estimate, model=model_info, models=result.models, ) def _json_params(params: Optional[DetectionParams]) -> Optional[Dict[str, Any]]: if params is None: return None return params.dict(exclude_none=True) @router.post("/edges", response_model=DetectionResponse) async def detect_edges(payload: DetectionRequest): try: image = decode_base64_image(payload.image) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") runtime_mode = _resolve_mode(payload.mode, payload.compare) result = run_detection( image, _detector_label("edges"), params=_json_params(payload.params), mode=runtime_mode, dl_choice=payload.dl_model.strip() if payload.dl_model else None, ) return _format_result(result, runtime_mode) @router.post("/corners", response_model=DetectionResponse) async def detect_corners(payload: DetectionRequest): try: image = decode_base64_image(payload.image) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") runtime_mode = _resolve_mode(payload.mode, payload.compare) result = run_detection( image, _detector_label("corners"), params=_json_params(payload.params), mode=runtime_mode, dl_choice=payload.dl_model.strip() if payload.dl_model else None, ) return _format_result(result, runtime_mode) @router.post("/lines", response_model=DetectionResponse) async def detect_lines(payload: DetectionRequest): try: image = decode_base64_image(payload.image) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") runtime_mode = _resolve_mode(payload.mode, payload.compare) result = run_detection( image, _detector_label("lines"), params=_json_params(payload.params), mode=runtime_mode, dl_choice=payload.dl_model.strip() if payload.dl_model else None, ) return _format_result(result, runtime_mode) @router.post("/ellipses", response_model=DetectionResponse) async def detect_ellipses(payload: DetectionRequest): try: image = decode_base64_image(payload.image) except Exception as exc: raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") runtime_mode = _resolve_mode(payload.mode, payload.compare) result = run_detection( image, _detector_label("ellipses"), params=_json_params(payload.params), mode=runtime_mode, dl_choice=payload.dl_model.strip() if payload.dl_model else None, ) return _format_result(result, runtime_mode) async def _handle_upload( detector_key: str, file: UploadFile, params: Optional[str], mode: str, compare: bool, dl_model: Optional[str], ) -> DetectionResponse: content = await file.read() array = np.frombuffer(content, dtype=np.uint8) decoded = cv2.imdecode(array, cv2.IMREAD_COLOR) if decoded is None: raise HTTPException(status_code=400, detail="Unable to decode uploaded image.") image = cv2.cvtColor(decoded, cv2.COLOR_BGR2RGB) params_dict: Optional[Dict[str, Any]] = None if params: try: params_dict = json.loads(params) if not isinstance(params_dict, dict): raise ValueError("params JSON must decode to an object.") except ValueError as exc: raise HTTPException(status_code=400, detail=f"Invalid params: {exc}") runtime_mode = _resolve_mode(mode, compare) result = run_detection( image, _detector_label(detector_key), params=params_dict, mode=runtime_mode, dl_choice=dl_model.strip() if dl_model else None, ) return _format_result(result, runtime_mode) def _upload_endpoint(detector_key: str): async def endpoint( file: UploadFile = File(...), params: Optional[str] = Form(None), mode: str = Form("classical"), compare: bool = Form(False), dl_model: Optional[str] = Form(None), ): return await _handle_upload(detector_key, file, params, mode, compare, dl_model) return endpoint router.add_api_route( "/edges/upload", _upload_endpoint("edges"), methods=["POST"], response_model=DetectionResponse ) router.add_api_route( "/corners/upload", _upload_endpoint("corners"), methods=["POST"], response_model=DetectionResponse ) router.add_api_route( "/lines/upload", _upload_endpoint("lines"), methods=["POST"], response_model=DetectionResponse ) router.add_api_route( "/ellipses/upload", _upload_endpoint("ellipses"), methods=["POST"], response_model=DetectionResponse ) @router.websocket("/stream") async def detection_stream(websocket: WebSocket): await websocket.accept() await websocket.send_json({"ready": True}) try: while True: message = await websocket.receive_text() try: payload = json.loads(message) except json.JSONDecodeError: await websocket.send_json({"error": "Invalid JSON payload."}) continue detector_key = payload.get("detector") if detector_key not in DETECTOR_KEYS: await websocket.send_json({"error": "Unknown detector key."}) continue image_b64 = payload.get("image") if not image_b64: await websocket.send_json({"error": "Missing 'image' field."}) continue try: image = decode_base64_image(image_b64) except Exception as exc: await websocket.send_json({"error": f"Invalid image payload: {exc}"}) continue params = payload.get("params") if params is not None and not isinstance(params, dict): await websocket.send_json({"error": "'params' must be an object."}) continue mode = payload.get("mode", "classical") compare = bool(payload.get("compare", False)) dl_model = payload.get("dl_model") or payload.get("model") runtime_mode = _resolve_mode(mode, compare) try: result = run_detection( image, _detector_label(detector_key), params=params, mode=runtime_mode, dl_choice=dl_model.strip() if dl_model else None, ) except Exception as exc: # pragma: no cover await websocket.send_json({"error": str(exc)}) continue await websocket.send_json(_format_result(result, runtime_mode).dict()) except WebSocketDisconnect: return