File size: 9,087 Bytes
dd85fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa448c
dd85fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import json
from typing import Any, Dict, Optional

import cv2
import numpy as np
from fastapi import APIRouter, File, Form, HTTPException, UploadFile, WebSocket, WebSocketDisconnect

from ...models.schemas import DetectionParams, DetectionRequest, DetectionResponse
from ...services.runtime_adapter import DetectionResult, run_detection
from ...utils.image_io import decode_base64_image, encode_png_base64


router = APIRouter(prefix="/v1/detect", tags=["detection"])

DETECTOR_KEYS: Dict[str, str] = {
    "edges": "Edges (Canny)",
    "corners": "Corners (Harris)",
    "lines": "Lines (Hough/LSD)",
    "ellipses": "Ellipses (Contours + fitEllipse)",
}

ALLOWED_MODES = {"classical", "dl", "both"}


def _detector_label(key: str) -> str:
    if key not in DETECTOR_KEYS:
        raise HTTPException(status_code=404, detail=f"Unknown detector '{key}'.")
    return DETECTOR_KEYS[key]


def _resolve_mode(mode: str, compare: bool) -> str:
    if compare:
        return "both"
    if mode not in ALLOWED_MODES:
        return "classical"
    return mode


def _choose_primary(mode: str, overlays: Dict[str, str]) -> Optional[str]:
    if mode == "dl" and "dl" in overlays:
        return "dl"
    if mode == "classical" and "classical" in overlays:
        return "classical"
    if mode == "both":
        if "classical" in overlays:
            return "classical"
        if "dl" in overlays:
            return "dl"
    return next(iter(overlays.keys()), None)


def _format_result(result: DetectionResult, mode: str) -> DetectionResponse:
    overlays_encoded: Dict[str, Optional[str]] = {}
    for path, image in result.overlays.items():
        if image is None:
            overlays_encoded[path] = None
            continue
        overlays_encoded[path] = encode_png_base64(image)

    primary = _choose_primary(mode, {k: v for k, v in overlays_encoded.items() if v})
    model_info = result.models.get(primary or "classical", result.models.get("classical", {}))

    return DetectionResponse(
        overlay=overlays_encoded.get(primary) if primary else None,
        overlays=overlays_encoded,
        features=result.features,
        timings=result.timings_ms,
        fps_estimate=result.fps_estimate,
        model=model_info,
        models=result.models,
    )


def _json_params(params: Optional[DetectionParams]) -> Optional[Dict[str, Any]]:
    if params is None:
        return None
    return params.dict(exclude_none=True)


@router.post("/edges", response_model=DetectionResponse)
async def detect_edges(payload: DetectionRequest):
    try:
        image = decode_base64_image(payload.image)
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}")
    runtime_mode = _resolve_mode(payload.mode, payload.compare)
    result = run_detection(
        image,
        _detector_label("edges"),
        params=_json_params(payload.params),
        mode=runtime_mode,
        dl_choice=payload.dl_model.strip() if payload.dl_model else None,
    )
    return _format_result(result, runtime_mode)


@router.post("/corners", response_model=DetectionResponse)
async def detect_corners(payload: DetectionRequest):
    try:
        image = decode_base64_image(payload.image)
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}")
    runtime_mode = _resolve_mode(payload.mode, payload.compare)
    result = run_detection(
        image,
        _detector_label("corners"),
        params=_json_params(payload.params),
        mode=runtime_mode,
        dl_choice=payload.dl_model.strip() if payload.dl_model else None,
    )
    return _format_result(result, runtime_mode)


@router.post("/lines", response_model=DetectionResponse)
async def detect_lines(payload: DetectionRequest):
    try:
        image = decode_base64_image(payload.image)
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}")
    runtime_mode = _resolve_mode(payload.mode, payload.compare)
    result = run_detection(
        image,
        _detector_label("lines"),
        params=_json_params(payload.params),
        mode=runtime_mode,
        dl_choice=payload.dl_model.strip() if payload.dl_model else None,
    )
    return _format_result(result, runtime_mode)


@router.post("/ellipses", response_model=DetectionResponse)
async def detect_ellipses(payload: DetectionRequest):
    try:
        image = decode_base64_image(payload.image)
    except Exception as exc:
        raise HTTPException(status_code=400, detail=f"Invalid image payload: {exc}")
    runtime_mode = _resolve_mode(payload.mode, payload.compare)
    result = run_detection(
        image,
        _detector_label("ellipses"),
        params=_json_params(payload.params),
        mode=runtime_mode,
        dl_choice=payload.dl_model.strip() if payload.dl_model else None,
    )
    return _format_result(result, runtime_mode)


async def _handle_upload(
    detector_key: str,
    file: UploadFile,
    params: Optional[str],
    mode: str,
    compare: bool,
    dl_model: Optional[str],
) -> DetectionResponse:
    content = await file.read()
    array = np.frombuffer(content, dtype=np.uint8)
    decoded = cv2.imdecode(array, cv2.IMREAD_COLOR)
    if decoded is None:
        raise HTTPException(status_code=400, detail="Unable to decode uploaded image.")
    image = cv2.cvtColor(decoded, cv2.COLOR_BGR2RGB)
    params_dict: Optional[Dict[str, Any]] = None
    if params:
        try:
            params_dict = json.loads(params)
            if not isinstance(params_dict, dict):
                raise ValueError("params JSON must decode to an object.")
        except ValueError as exc:
            raise HTTPException(status_code=400, detail=f"Invalid params: {exc}")
    runtime_mode = _resolve_mode(mode, compare)
    result = run_detection(
        image,
        _detector_label(detector_key),
        params=params_dict,
        mode=runtime_mode,
        dl_choice=dl_model.strip() if dl_model else None,
    )
    return _format_result(result, runtime_mode)


def _upload_endpoint(detector_key: str):
    async def endpoint(
        file: UploadFile = File(...),
        params: Optional[str] = Form(None),
        mode: str = Form("classical"),
        compare: bool = Form(False),
        dl_model: Optional[str] = Form(None),
    ):
        return await _handle_upload(detector_key, file, params, mode, compare, dl_model)

    return endpoint


router.add_api_route(
    "/edges/upload", _upload_endpoint("edges"), methods=["POST"], response_model=DetectionResponse
)
router.add_api_route(
    "/corners/upload", _upload_endpoint("corners"), methods=["POST"], response_model=DetectionResponse
)
router.add_api_route(
    "/lines/upload", _upload_endpoint("lines"), methods=["POST"], response_model=DetectionResponse
)
router.add_api_route(
    "/ellipses/upload", _upload_endpoint("ellipses"), methods=["POST"], response_model=DetectionResponse
)


@router.websocket("/stream")
async def detection_stream(websocket: WebSocket):
    await websocket.accept()
    await websocket.send_json({"ready": True})
    try:
        while True:
            message = await websocket.receive_text()
            try:
                payload = json.loads(message)
            except json.JSONDecodeError:
                await websocket.send_json({"error": "Invalid JSON payload."})
                continue

            detector_key = payload.get("detector")
            if detector_key not in DETECTOR_KEYS:
                await websocket.send_json({"error": "Unknown detector key."})
                continue

            image_b64 = payload.get("image")
            if not image_b64:
                await websocket.send_json({"error": "Missing 'image' field."})
                continue

            try:
                image = decode_base64_image(image_b64)
            except Exception as exc:
                await websocket.send_json({"error": f"Invalid image payload: {exc}"})
                continue

            params = payload.get("params")
            if params is not None and not isinstance(params, dict):
                await websocket.send_json({"error": "'params' must be an object."})
                continue

            mode = payload.get("mode", "classical")
            compare = bool(payload.get("compare", False))
            dl_model = payload.get("dl_model") or payload.get("model")

            runtime_mode = _resolve_mode(mode, compare)
            try:
                result = run_detection(
                    image,
                    _detector_label(detector_key),
                    params=params,
                    mode=runtime_mode,
                    dl_choice=dl_model.strip() if dl_model else None,
                )
            except Exception as exc:  # pragma: no cover
                await websocket.send_json({"error": str(exc)})
                continue

            await websocket.send_json(_format_result(result, runtime_mode).dict())
    except WebSocketDisconnect:
        return