import os from functools import lru_cache from typing import Callable, Dict, Optional from models.detectors.base import ObjectDetector from models.detectors.detr import DetrDetector from models.detectors.grounding_dino import GroundingDinoDetector from models.detectors.yolov11 import Yolo11Detector from models.detectors.yolov8_visdrone import YoloV8VisDroneDetector DEFAULT_DETECTOR = "yolo11" _REGISTRY: Dict[str, Callable[[], ObjectDetector]] = { "yolo11": Yolo11Detector, "detr_resnet50": DetrDetector, "grounding_dino": GroundingDinoDetector, "yolov8_visdrone": YoloV8VisDroneDetector, } def _create_detector(name: str, **kwargs) -> ObjectDetector: try: factory = _REGISTRY[name] except KeyError as exc: available = ", ".join(sorted(_REGISTRY)) raise ValueError(f"Unknown detector '{name}'. Available: {available}") from exc return factory(**kwargs) @lru_cache(maxsize=None) def _get_cached_detector(name: str) -> ObjectDetector: return _create_detector(name) def load_detector(name: Optional[str] = None) -> ObjectDetector: """Return a cached detector instance selected via arg or OBJECT_DETECTOR env.""" detector_name = name or os.getenv("OBJECT_DETECTOR", DEFAULT_DETECTOR) return _get_cached_detector(detector_name) def prefetch_weights(name: str) -> None: """Pre-download model weights (call before parallel GPU init).""" factory = _REGISTRY.get(name) if factory and hasattr(factory, "ensure_weights"): factory.ensure_weights() def load_detector_on_device(name: str, device: str) -> ObjectDetector: """Create a new detector instance on the specified device (no caching).""" return _create_detector(name, device=device) # Backwards compatibility for existing callers. def load_model(): return load_detector()