Spaces:
Sleeping
Sleeping
| import json | |
| from typing import Any, Dict, Optional | |
| import cv2 | |
| import numpy as np | |
| from fastapi import APIRouter, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect | |
| from ...models.schemas import DetectionParams, DetectionRequest, DetectionResponse | |
| from ...services.runtime_adapter import DetectionResult, run_detection | |
| from ...utils.image_io import decode_base64_image, encode_png_base64 | |
| router = APIRouter(prefix="/v1/detect", tags=["detection"]) | |
| DETECTOR_KEYS: Dict[str, str] = { | |
| "edges": "Edges (Canny)", | |
| "corners": "Corners (Harris)", | |
| "lines": "Lines (Hough/LSD)", | |
| "ellipses": "Ellipses (Contours + fitEllipse)", | |
| } | |
| ALLOWED_MODES = {"classical", "dl", "both"} | |
| def _detector_label(key: str) -> str: | |
| if key not in DETECTOR_KEYS: | |
| raise HTTPException(status_code=404, detail=f"Unknown detector '{key}'.") | |
| return DETECTOR_KEYS[key] | |
| def _resolve_mode(mode: str, compare: bool) -> str: | |
| if compare: | |
| return "both" | |
| if mode not in ALLOWED_MODES: | |
| return "classical" | |
| return mode | |
| def _choose_primary(mode: str, overlays: Dict[str, str]) -> Optional[str]: | |
| if mode == "dl" and "dl" in overlays: | |
| return "dl" | |
| if mode == "classical" and "classical" in overlays: | |
| return "classical" | |
| if mode == "both": | |
| if "classical" in overlays: | |
| return "classical" | |
| if "dl" in overlays: | |
| return "dl" | |
| return next(iter(overlays.keys()), None) | |
| def _format_result(result: DetectionResult, mode: str) -> DetectionResponse: | |
| overlays_encoded: Dict[str, Optional[str]] = {} | |
| for path, image in result.overlays.items(): | |
| if image is None: | |
| overlays_encoded[path] = None | |
| continue | |
| overlays_encoded[path] = encode_png_base64(image) | |
| primary = _choose_primary(mode, {k: v for k, v in overlays_encoded.items() if v}) | |
| model_info = result.models.get(primary or "classical", result.models.get("classical", {})) | |
| return DetectionResponse( | |
| overlay=overlays_encoded.get(primary) if primary else None, | |
| overlays=overlays_encoded, | |
| features=result.features, | |
| timings=result.timings_ms, | |
| fps_estimate=result.fps_estimate, | |
| model=model_info, | |
| models=result.models, | |
| ) | |
| def _json_params(params: Optional[DetectionParams]) -> Optional[Dict[str, Any]]: | |
| if params is None: | |
| return None | |
| return params.dict(exclude_none=True) | |
| async def detect_edges(payload: DetectionRequest): | |
| try: | |
| image = decode_base64_image(payload.image) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") | |
| runtime_mode = _resolve_mode(payload.mode, payload.compare) | |
| result = run_detection( | |
| image, | |
| _detector_label("edges"), | |
| params=_json_params(payload.params), | |
| mode=runtime_mode, | |
| dl_choice=payload.dl_model.strip() if payload.dl_model else None, | |
| ) | |
| return _format_result(result, runtime_mode) | |
| async def detect_corners(payload: DetectionRequest): | |
| try: | |
| image = decode_base64_image(payload.image) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") | |
| runtime_mode = _resolve_mode(payload.mode, payload.compare) | |
| result = run_detection( | |
| image, | |
| _detector_label("corners"), | |
| params=_json_params(payload.params), | |
| mode=runtime_mode, | |
| dl_choice=payload.dl_model.strip() if payload.dl_model else None, | |
| ) | |
| return _format_result(result, runtime_mode) | |
| async def detect_lines(payload: DetectionRequest): | |
| try: | |
| image = decode_base64_image(payload.image) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") | |
| runtime_mode = _resolve_mode(payload.mode, payload.compare) | |
| result = run_detection( | |
| image, | |
| _detector_label("lines"), | |
| params=_json_params(payload.params), | |
| mode=runtime_mode, | |
| dl_choice=payload.dl_model.strip() if payload.dl_model else None, | |
| ) | |
| return _format_result(result, runtime_mode) | |
| async def detect_ellipses(payload: DetectionRequest): | |
| try: | |
| image = decode_base64_image(payload.image) | |
| except Exception as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}") | |
| runtime_mode = _resolve_mode(payload.mode, payload.compare) | |
| result = run_detection( | |
| image, | |
| _detector_label("ellipses"), | |
| params=_json_params(payload.params), | |
| mode=runtime_mode, | |
| dl_choice=payload.dl_model.strip() if payload.dl_model else None, | |
| ) | |
| return _format_result(result, runtime_mode) | |
| async def _handle_upload( | |
| detector_key: str, | |
| file: UploadFile, | |
| params: Optional[str], | |
| mode: str, | |
| compare: bool, | |
| dl_model: Optional[str], | |
| ) -> DetectionResponse: | |
| content = await file.read() | |
| array = np.frombuffer(content, dtype=np.uint8) | |
| decoded = cv2.imdecode(array, cv2.IMREAD_COLOR) | |
| if decoded is None: | |
| raise HTTPException(status_code=400, detail="Unable to decode uploaded image.") | |
| image = cv2.cvtColor(decoded, cv2.COLOR_BGR2RGB) | |
| params_dict: Optional[Dict[str, Any]] = None | |
| if params: | |
| try: | |
| params_dict = json.loads(params) | |
| if not isinstance(params_dict, dict): | |
| raise ValueError("params JSON must decode to an object.") | |
| except ValueError as exc: | |
| raise HTTPException(status_code=400, detail=f"Invalid params: {exc}") | |
| runtime_mode = _resolve_mode(mode, compare) | |
| result = run_detection( | |
| image, | |
| _detector_label(detector_key), | |
| params=params_dict, | |
| mode=runtime_mode, | |
| dl_choice=dl_model.strip() if dl_model else None, | |
| ) | |
| return _format_result(result, runtime_mode) | |
| def _upload_endpoint(detector_key: str): | |
| async def endpoint( | |
| file: UploadFile = File(...), | |
| params: Optional[str] = Form(None), | |
| mode: str = Form("classical"), | |
| compare: bool = Form(False), | |
| dl_model: Optional[str] = Form(None), | |
| ): | |
| return await _handle_upload(detector_key, file, params, mode, compare, dl_model) | |
| return endpoint | |
| router.add_api_route( | |
| "/edges/upload", _upload_endpoint("edges"), methods=["POST"], response_model=DetectionResponse | |
| ) | |
| router.add_api_route( | |
| "/corners/upload", _upload_endpoint("corners"), methods=["POST"], response_model=DetectionResponse | |
| ) | |
| router.add_api_route( | |
| "/lines/upload", _upload_endpoint("lines"), methods=["POST"], response_model=DetectionResponse | |
| ) | |
| router.add_api_route( | |
| "/ellipses/upload", _upload_endpoint("ellipses"), methods=["POST"], response_model=DetectionResponse | |
| ) | |
| async def detection_stream(websocket: WebSocket): | |
| await websocket.accept() | |
| await websocket.send_json({"ready": True}) | |
| try: | |
| while True: | |
| message = await websocket.receive_text() | |
| try: | |
| payload = json.loads(message) | |
| except json.JSONDecodeError: | |
| await websocket.send_json({"error": "Invalid JSON payload."}) | |
| continue | |
| detector_key = payload.get("detector") | |
| if detector_key not in DETECTOR_KEYS: | |
| await websocket.send_json({"error": "Unknown detector key."}) | |
| continue | |
| image_b64 = payload.get("image") | |
| if not image_b64: | |
| await websocket.send_json({"error": "Missing 'image' field."}) | |
| continue | |
| try: | |
| image = decode_base64_image(image_b64) | |
| except Exception as exc: | |
| await websocket.send_json({"error": f"Invalid image payload: {exc}"}) | |
| continue | |
| params = payload.get("params") | |
| if params is not None and not isinstance(params, dict): | |
| await websocket.send_json({"error": "'params' must be an object."}) | |
| continue | |
| mode = payload.get("mode", "classical") | |
| compare = bool(payload.get("compare", False)) | |
| dl_model = payload.get("dl_model") or payload.get("model") | |
| runtime_mode = _resolve_mode(mode, compare) | |
| try: | |
| result = run_detection( | |
| image, | |
| _detector_label(detector_key), | |
| params=params, | |
| mode=runtime_mode, | |
| dl_choice=dl_model.strip() if dl_model else None, | |
| ) | |
| except Exception as exc: # pragma: no cover | |
| await websocket.send_json({"error": str(exc)}) | |
| continue | |
| await websocket.send_json(_format_result(result, runtime_mode).dict()) | |
| except WebSocketDisconnect: | |
| return | |