VitalyVorobyev's picture
dexined configurable
0a3fef3
import os
import sys
from typing import Any, Dict, Optional, Tuple
import numpy as np
from .common import to_bgr, to_rgb
from .dl_adapters import get_adapter
try:
import onnxruntime as ort # type: ignore
except Exception: # pragma: no cover
ort = None # type: ignore
MODEL_DIR = os.path.join(os.getcwd(), "models")
DL_MODELS = {
"Edges (Canny)": ["hed.onnx", "dexined.onnx"],
"Corners (Harris)": ["superpoint.onnx"],
"Lines (Hough/LSD)": ["sold2.onnx", "hawp.onnx"],
"Ellipses (Contours + fitEllipse)": ["ellipse_head.onnx"],
}
def _find_model(detector: str, choice_name: Optional[str]) -> Optional[str]:
if choice_name:
p = os.path.join(MODEL_DIR, choice_name)
return p if os.path.isfile(p) else None
for fname in DL_MODELS.get(detector, []):
p = os.path.join(MODEL_DIR, fname)
if os.path.isfile(p):
return p
return None
def _load_session(path: str):
if ort is None:
raise RuntimeError("onnxruntime not installed. `pip install onnxruntime`.")
providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] if "darwin" in sys.platform else ["CPUExecutionProvider"]
try:
return ort.InferenceSession(path, providers=providers)
except Exception as e:
raise RuntimeError(f"Failed to load ONNX model '{path}': {e}")
def _extract_adapter_options(model_path: str, params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if not params:
return {}
name = os.path.basename(model_path).lower()
if "dexined" in name:
keys = [
"dexined_threshold_mode",
"dexined_threshold_sigma",
"dexined_threshold_offset",
"dexined_threshold_value",
"dexined_use_marching_squares",
]
return {k: params.get(k) for k in keys if k in params}
return {}
def detect_dl(
image: np.ndarray,
detector: str,
model_choice: Optional[str],
params: Optional[Dict[str, Any]] = None,
) -> Tuple[np.ndarray, Dict[str, Any]]:
bgr = to_bgr(image)
rgb = to_rgb(bgr)
meta: Dict[str, Any] = {"path": "dl"}
model_path = _find_model(detector, model_choice)
if model_path is None:
meta["warning"] = (
f"No ONNX model found for '{detector}'. Place a model in ./models."
f" Expected one of: {DL_MODELS.get(detector, [])}"
)
return rgb, meta
meta["model_path"] = model_path
try:
sess = _load_session(model_path)
except Exception as e:
meta["error"] = str(e)
return rgb, meta
# Dispatch to model-specific adapter
adapter = get_adapter(model_path, detector)
try:
feed, ctx = adapter.preprocess(rgb, sess)
except Exception as e:
meta["error"] = f"Preprocess failed: {e}"
return rgb, meta
adapter_options = _extract_adapter_options(model_path, params)
if adapter_options and isinstance(ctx.extra, dict):
ctx.extra.update(adapter_options)
try:
outputs = sess.run(None, feed)
except Exception as e:
meta["error"] = f"ONNX inference failed: {e}"
return rgb, meta
try:
overlay, post_meta = adapter.postprocess(outputs, rgb, ctx, detector)
meta.update(post_meta)
if adapter_options:
meta["adapter_options"] = adapter_options
except Exception as e:
meta["error"] = f"Postprocess failed: {e}"
return rgb, meta
return overlay, meta