File size: 3,486 Bytes
dd85fb6
aaa448c
dd85fb6
 
 
 
 
aaa448c
dd85fb6
 
 
 
 
 
 
 
 
 
 
 
aaa448c
dd85fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa448c
dd85fb6
 
 
 
 
 
0a3fef3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd85fb6
 
 
 
0a3fef3
dd85fb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aaa448c
 
 
 
 
 
 
dd85fb6
0a3fef3
 
 
 
dd85fb6
aaa448c
dd85fb6
 
 
 
aaa448c
 
 
0a3fef3
 
aaa448c
 
 
dd85fb6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
from typing import Any, Dict, Optional, Tuple

import numpy as np

from .common import to_bgr, to_rgb
from .dl_adapters import get_adapter

try:
    import onnxruntime as ort  # type: ignore
except Exception:  # pragma: no cover
    ort = None  # type: ignore


MODEL_DIR = os.path.join(os.getcwd(), "models")

DL_MODELS = {
    "Edges (Canny)": ["hed.onnx", "dexined.onnx"],
    "Corners (Harris)": ["superpoint.onnx"],
    "Lines (Hough/LSD)": ["sold2.onnx", "hawp.onnx"],
    "Ellipses (Contours + fitEllipse)": ["ellipse_head.onnx"],
}


def _find_model(detector: str, choice_name: Optional[str]) -> Optional[str]:
    if choice_name:
        p = os.path.join(MODEL_DIR, choice_name)
        return p if os.path.isfile(p) else None
    for fname in DL_MODELS.get(detector, []):
        p = os.path.join(MODEL_DIR, fname)
        if os.path.isfile(p):
            return p
    return None


def _load_session(path: str):
    if ort is None:
        raise RuntimeError("onnxruntime not installed. `pip install onnxruntime`.")
    providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] if "darwin" in sys.platform else ["CPUExecutionProvider"]
    try:
        return ort.InferenceSession(path, providers=providers)
    except Exception as e:
        raise RuntimeError(f"Failed to load ONNX model '{path}': {e}")


def _extract_adapter_options(model_path: str, params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
    if not params:
        return {}
    name = os.path.basename(model_path).lower()
    if "dexined" in name:
        keys = [
            "dexined_threshold_mode",
            "dexined_threshold_sigma",
            "dexined_threshold_offset",
            "dexined_threshold_value",
            "dexined_use_marching_squares",
        ]
        return {k: params.get(k) for k in keys if k in params}
    return {}


def detect_dl(
    image: np.ndarray,
    detector: str,
    model_choice: Optional[str],
    params: Optional[Dict[str, Any]] = None,
) -> Tuple[np.ndarray, Dict[str, Any]]:
    bgr = to_bgr(image)
    rgb = to_rgb(bgr)
    meta: Dict[str, Any] = {"path": "dl"}

    model_path = _find_model(detector, model_choice)
    if model_path is None:
        meta["warning"] = (
            f"No ONNX model found for '{detector}'. Place a model in ./models."
            f" Expected one of: {DL_MODELS.get(detector, [])}"
        )
        return rgb, meta

    meta["model_path"] = model_path

    try:
        sess = _load_session(model_path)
    except Exception as e:
        meta["error"] = str(e)
        return rgb, meta

    # Dispatch to model-specific adapter
    adapter = get_adapter(model_path, detector)
    try:
        feed, ctx = adapter.preprocess(rgb, sess)
    except Exception as e:
        meta["error"] = f"Preprocess failed: {e}"
        return rgb, meta

    adapter_options = _extract_adapter_options(model_path, params)
    if adapter_options and isinstance(ctx.extra, dict):
        ctx.extra.update(adapter_options)

    try:
        outputs = sess.run(None, feed)
    except Exception as e:
        meta["error"] = f"ONNX inference failed: {e}"
        return rgb, meta

    try:
        overlay, post_meta = adapter.postprocess(outputs, rgb, ctx, detector)
        meta.update(post_meta)
        if adapter_options:
            meta["adapter_options"] = adapter_options
    except Exception as e:
        meta["error"] = f"Postprocess failed: {e}"
        return rgb, meta

    return overlay, meta