Spaces:
Sleeping
Sleeping
File size: 2,849 Bytes
d2885a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 | """
This file loads the YOLO phone detection model for runtime inference.
It supports both your trained phone detector and the fallback pretrained YOLO model.
The class keeps model-loading logic separate from prediction logic.
This makes the project cleaner, easier to debug, and easier to track in MLOps pipelines.
"""
from pathlib import Path
from src.entity.config_entity import PhoneDetectorConfig
from src.utils.logger import get_logger
class PhoneModelLoader:
"""
Load and manage YOLO phone detection models.
Supported modes:
- trained custom phone detector
- fallback pretrained detector
"""
def __init__(
self,
config: PhoneDetectorConfig,
log_dir: Path | None = None,
log_level: str = "INFO",
) -> None:
self.config = config
self.logger = get_logger(
self.__class__.__name__, log_dir=log_dir, level=log_level
)
self._trained_model = None
self._fallback_model = None
def _import_yolo(self):
"""
Import YOLO only when needed.
"""
try:
from ultralytics import YOLO
return YOLO
except (ImportError, OSError) as exc:
raise RuntimeError(
"Unable to import Ultralytics YOLO. Ensure the package is installed "
"and the environment has permission to access Ultralytics settings."
) from exc
def load_trained_model(self):
"""
Load your trained phone detector model.
"""
model_path = Path(self.config.weights.default_weight_file)
if not model_path.exists():
raise FileNotFoundError(f"Trained phone detector not found: {model_path}")
YOLO = self._import_yolo()
self._trained_model = YOLO(str(model_path))
self.logger.info("Loaded trained phone detector from: %s", model_path)
return self._trained_model
def load_fallback_model(self):
"""
Load the fallback pretrained YOLO model.
"""
model_path = Path(self.config.weights.fallback_weight_file)
if not model_path.exists():
raise FileNotFoundError(f"Fallback phone detector not found: {model_path}")
YOLO = self._import_yolo()
self._fallback_model = YOLO(str(model_path))
self.logger.info("Loaded fallback phone detector from: %s", model_path)
return self._fallback_model
def get_model(self, use_trained: bool = True):
"""
Return the requested model and load it if needed.
"""
if use_trained:
if self._trained_model is None:
self.load_trained_model()
return self._trained_model
if self._fallback_model is None:
self.load_fallback_model()
return self._fallback_model
|