Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| from typing import Any, Dict, Optional, Tuple | |
| import numpy as np | |
| from .common import to_bgr, to_rgb | |
| from .dl_adapters import get_adapter | |
| try: | |
| import onnxruntime as ort # type: ignore | |
| except Exception: # pragma: no cover | |
| ort = None # type: ignore | |
| MODEL_DIR = os.path.join(os.getcwd(), "models") | |
| DL_MODELS = { | |
| "Edges (Canny)": ["hed.onnx", "dexined.onnx"], | |
| "Corners (Harris)": ["superpoint.onnx"], | |
| "Lines (Hough/LSD)": ["sold2.onnx", "hawp.onnx"], | |
| "Ellipses (Contours + fitEllipse)": ["ellipse_head.onnx"], | |
| } | |
| def _find_model(detector: str, choice_name: Optional[str]) -> Optional[str]: | |
| if choice_name: | |
| p = os.path.join(MODEL_DIR, choice_name) | |
| return p if os.path.isfile(p) else None | |
| for fname in DL_MODELS.get(detector, []): | |
| p = os.path.join(MODEL_DIR, fname) | |
| if os.path.isfile(p): | |
| return p | |
| return None | |
| def _load_session(path: str): | |
| if ort is None: | |
| raise RuntimeError("onnxruntime not installed. `pip install onnxruntime`.") | |
| providers = ["CoreMLExecutionProvider", "CPUExecutionProvider"] if "darwin" in sys.platform else ["CPUExecutionProvider"] | |
| try: | |
| return ort.InferenceSession(path, providers=providers) | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load ONNX model '{path}': {e}") | |
| def _extract_adapter_options(model_path: str, params: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| if not params: | |
| return {} | |
| name = os.path.basename(model_path).lower() | |
| if "dexined" in name: | |
| keys = [ | |
| "dexined_threshold_mode", | |
| "dexined_threshold_sigma", | |
| "dexined_threshold_offset", | |
| "dexined_threshold_value", | |
| "dexined_use_marching_squares", | |
| ] | |
| return {k: params.get(k) for k in keys if k in params} | |
| return {} | |
| def detect_dl( | |
| image: np.ndarray, | |
| detector: str, | |
| model_choice: Optional[str], | |
| params: Optional[Dict[str, Any]] = None, | |
| ) -> Tuple[np.ndarray, Dict[str, Any]]: | |
| bgr = to_bgr(image) | |
| rgb = to_rgb(bgr) | |
| meta: Dict[str, Any] = {"path": "dl"} | |
| model_path = _find_model(detector, model_choice) | |
| if model_path is None: | |
| meta["warning"] = ( | |
| f"No ONNX model found for '{detector}'. Place a model in ./models." | |
| f" Expected one of: {DL_MODELS.get(detector, [])}" | |
| ) | |
| return rgb, meta | |
| meta["model_path"] = model_path | |
| try: | |
| sess = _load_session(model_path) | |
| except Exception as e: | |
| meta["error"] = str(e) | |
| return rgb, meta | |
| # Dispatch to model-specific adapter | |
| adapter = get_adapter(model_path, detector) | |
| try: | |
| feed, ctx = adapter.preprocess(rgb, sess) | |
| except Exception as e: | |
| meta["error"] = f"Preprocess failed: {e}" | |
| return rgb, meta | |
| adapter_options = _extract_adapter_options(model_path, params) | |
| if adapter_options and isinstance(ctx.extra, dict): | |
| ctx.extra.update(adapter_options) | |
| try: | |
| outputs = sess.run(None, feed) | |
| except Exception as e: | |
| meta["error"] = f"ONNX inference failed: {e}" | |
| return rgb, meta | |
| try: | |
| overlay, post_meta = adapter.postprocess(outputs, rgb, ctx, detector) | |
| meta.update(post_meta) | |
| if adapter_options: | |
| meta["adapter_options"] = adapter_options | |
| except Exception as e: | |
| meta["error"] = f"Postprocess failed: {e}" | |
| return rgb, meta | |
| return overlay, meta | |