Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from dataclasses import dataclass, field | |
| from typing import Any, Dict, Literal, Optional | |
| import cv2 | |
| import numpy as np | |
| from ..inference.classical import detect_classical | |
| from ..inference.dl import DL_MODELS, detect_dl | |
| DEFAULT_PARAMS: Dict[str, Any] = { | |
| "canny_low": 50, | |
| "canny_high": 150, | |
| "harris_k": 0.05, | |
| "harris_block": 2, | |
| "harris_ksize": 3, | |
| "hough_thresh": 50, | |
| "hough_min_len": 30, | |
| "hough_max_gap": 5, | |
| "ellipse_min_area": 300, | |
| "max_ellipses": 5, | |
| "line_detector": "hough", | |
| "dexined_threshold_mode": "adaptive", | |
| "dexined_threshold_sigma": 1.0, | |
| "dexined_threshold_offset": 0.0, | |
| "dexined_threshold_value": 0.3, | |
| "dexined_use_marching_squares": False, | |
| } | |
| def _to_bool(value: Any) -> bool: | |
| if isinstance(value, bool): | |
| return value | |
| if isinstance(value, str): | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| return bool(value) | |
| PARAM_TYPES: Dict[str, Any] = { | |
| "canny_low": int, | |
| "canny_high": int, | |
| "harris_k": float, | |
| "harris_block": int, | |
| "harris_ksize": int, | |
| "hough_thresh": int, | |
| "hough_min_len": int, | |
| "hough_max_gap": int, | |
| "ellipse_min_area": int, | |
| "max_ellipses": int, | |
| "line_detector": lambda x: str(x).lower(), | |
| "dexined_threshold_mode": lambda x: str(x).lower(), | |
| "dexined_threshold_sigma": float, | |
| "dexined_threshold_offset": float, | |
| "dexined_threshold_value": float, | |
| "dexined_use_marching_squares": _to_bool, | |
| } | |
| CLASSICAL_MODEL_INFO = {"name": "opencv-classical", "version": cv2.__version__} | |
| try: | |
| import onnxruntime as ort # type: ignore | |
| except Exception: # pragma: no cover | |
| ort = None # type: ignore | |
| DL_MODEL_INFO = { | |
| "name": "onnxruntime" if ort is not None else "onnxruntime-missing", | |
| "version": getattr(ort, "__version__", "unknown"), | |
| } | |
| def merge_params(params: Optional[Dict[str, Any]]) -> Dict[str, Any]: | |
| merged = DEFAULT_PARAMS.copy() | |
| if params: | |
| for key, value in params.items(): | |
| if value is None or key not in DEFAULT_PARAMS: | |
| continue | |
| caster = PARAM_TYPES.get(key, lambda x: x) | |
| try: | |
| merged[key] = caster(value) | |
| except (TypeError, ValueError): | |
| continue | |
| return merged | |
| class DetectionResult: | |
| overlays: Dict[str, np.ndarray] = field(default_factory=dict) | |
| features: Dict[str, Dict[str, Any]] = field(default_factory=dict) | |
| timings_ms: Dict[str, float] = field(default_factory=dict) | |
| fps_estimate: Optional[float] = None | |
| models: Dict[str, Dict[str, Any]] = field(default_factory=dict) | |
| def run_detection( | |
| image: np.ndarray, | |
| detector: str, | |
| params: Optional[Dict[str, Any]] = None, | |
| mode: Literal["classical", "dl", "both"] = "classical", | |
| dl_choice: Optional[str] = None, | |
| ) -> DetectionResult: | |
| merged = merge_params(params) | |
| overlays: Dict[str, np.ndarray] = {} | |
| features: Dict[str, Dict[str, Any]] = {} | |
| timings: Dict[str, float] = {} | |
| models: Dict[str, Dict[str, Any]] = {} | |
| execute_classical = mode in ("classical", "both") | |
| execute_dl = mode in ("dl", "both") | |
| total_ms = 0.0 | |
| if execute_classical: | |
| t0 = time.perf_counter() | |
| classical_img, classical_meta = detect_classical( | |
| image, | |
| detector, | |
| merged["canny_low"], | |
| merged["canny_high"], | |
| merged["harris_k"], | |
| merged["harris_block"], | |
| merged["harris_ksize"], | |
| merged["hough_thresh"], | |
| merged["hough_min_len"], | |
| merged["hough_max_gap"], | |
| merged["ellipse_min_area"], | |
| merged["max_ellipses"], | |
| merged["line_detector"], | |
| ) | |
| t_ms = (time.perf_counter() - t0) * 1000.0 | |
| overlays["classical"] = classical_img | |
| features["classical"] = classical_meta | |
| timings["classical"] = round(t_ms, 2) | |
| models["classical"] = CLASSICAL_MODEL_INFO | |
| total_ms += t_ms | |
| if execute_dl: | |
| t0 = time.perf_counter() | |
| dl_img, dl_meta = detect_dl(image, detector, dl_choice, params=merged) | |
| t_ms = (time.perf_counter() - t0) * 1000.0 | |
| overlays["dl"] = dl_img | |
| features["dl"] = dl_meta | |
| timings["dl"] = round(t_ms, 2) | |
| model_name = ( | |
| os.path.basename(dl_meta["model_path"]) if "model_path" in dl_meta else DL_MODEL_INFO["name"] | |
| ) | |
| models["dl"] = {"name": model_name, "version": DL_MODEL_INFO["version"]} | |
| total_ms += t_ms | |
| timings["total"] = round(total_ms, 2) | |
| fps = round(1000.0 / total_ms, 2) if total_ms > 0 else None | |
| return DetectionResult( | |
| overlays=overlays, | |
| features=features, | |
| timings_ms=timings, | |
| fps_estimate=fps, | |
| models=models, | |
| ) | |
| __all__ = [ | |
| "DetectionResult", | |
| "DEFAULT_PARAMS", | |
| "DL_MODELS", | |
| "merge_params", | |
| "run_detection", | |
| ] | |