import os import sys from typing import Any, Dict, Optional, Tuple import numpy as np from .common import to_bgr, to_rgb from .dl_adapters import get_adapter try: import onnxruntime as ort # type: ignore except Exception: # pragma: no cover ort = None # type: ignore MODEL_DIR = os.path.join(os.getcwd(), "models") DL_MODELS = { "Edges (Canny)": ["hed.onnx", "dexined.onnx"], "Corners (Harris)": ["superpoint.onnx"], "Lines (Hough/LSD)": ["sold2.onnx", "hawp.onnx"], "Ellipses (Contours + fitEllipse)": ["ellipse_head.onnx"], } def _find_model(detector: str, choice_name: Optional[str]) -> Optional[str]: if choice_name: p = os.path.join(MODEL_DIR, choice_name) return p if os.path.isfile(p) else None for fname in DL_MODELS.get(detector, []): p = os.path.join(MODEL_DIR, fname) if os.path.isfile(p): return p return None def _load_session(path: str): if ort is None: raise RuntimeError("onnxruntime not installed. `pip install onnxruntime`.") providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] if "darwin" in sys.platform else ["CPUExecutionProvider"] try: return ort.InferenceSession(path, providers=providers) except Exception as e: raise RuntimeError(f"Failed to load ONNX model '{path}': {e}") def _extract_adapter_options(model_path: str, params: Optional[Dict[str, Any]]) -> Dict[str, Any]: if not params: return {} name = os.path.basename(model_path).lower() if "dexined" in name: keys = [ "dexined_threshold_mode", "dexined_threshold_sigma", "dexined_threshold_offset", "dexined_threshold_value", "dexined_use_marching_squares", ] return {k: params.get(k) for k in keys if k in params} return {} def detect_dl( image: np.ndarray, detector: str, model_choice: Optional[str], params: Optional[Dict[str, Any]] = None, ) -> Tuple[np.ndarray, Dict[str, Any]]: bgr = to_bgr(image) rgb = to_rgb(bgr) meta: Dict[str, Any] = {"path": "dl"} model_path = _find_model(detector, model_choice) if model_path is None: meta["warning"] = ( f"No ONNX model found for '{detector}'. Place a model in ./models." f" Expected one of: {DL_MODELS.get(detector, [])}" ) return rgb, meta meta["model_path"] = model_path try: sess = _load_session(model_path) except Exception as e: meta["error"] = str(e) return rgb, meta # Dispatch to model-specific adapter adapter = get_adapter(model_path, detector) try: feed, ctx = adapter.preprocess(rgb, sess) except Exception as e: meta["error"] = f"Preprocess failed: {e}" return rgb, meta adapter_options = _extract_adapter_options(model_path, params) if adapter_options and isinstance(ctx.extra, dict): ctx.extra.update(adapter_options) try: outputs = sess.run(None, feed) except Exception as e: meta["error"] = f"ONNX inference failed: {e}" return rgb, meta try: overlay, post_meta = adapter.postprocess(outputs, rgb, ctx, detector) meta.update(post_meta) if adapter_options: meta["adapter_options"] = adapter_options except Exception as e: meta["error"] = f"Postprocess failed: {e}" return rgb, meta return overlay, meta