Spaces:
Running
Running
File size: 5,464 Bytes
d2885a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | from pathlib import Path
from src.entity.config_entity import MMPoseConfig
from src.utils.common import resolve_device
from src.utils.logger import get_logger
class MMPoseLoader:
"""
Load and manage MMPose runtime models.
This class wraps:
- person detector
- pose estimator
- optional visualizer
It keeps loading logic separate from inference logic.
"""
def __init__(
self, config: MMPoseConfig, log_dir: Path | None = None, log_level: str = "INFO"
) -> None:
self.config = config
self.logger = get_logger(
self.__class__.__name__, log_dir=log_dir, level=log_level
)
self.device = resolve_device(self.config.device_preference)
self._bbox_detector = None
self._pose_estimator = None
self._visualizer = None
def _validate_files(self) -> None:
"""
Make sure config and checkpoint files exist before loading.
"""
detector_config = Path(self.config.detector.config_file)
detector_checkpoint = Path(self.config.detector.checkpoint_file)
pose_config = Path(self.config.pose_estimator.config_file)
pose_checkpoint = Path(self.config.pose_estimator.checkpoint_file)
if not detector_config.exists():
raise FileNotFoundError(
f"Person detector config file not found: {detector_config}"
)
if not detector_checkpoint.exists():
raise FileNotFoundError(
f"Person detector checkpoint not found: {detector_checkpoint}"
)
if not pose_config.exists():
raise FileNotFoundError(
f"Pose estimator config file not found: {pose_config}"
)
if not pose_checkpoint.exists():
raise FileNotFoundError(
f"Pose estimator checkpoint not found: {pose_checkpoint}"
)
def load(self):
"""
Load MMPose models and visualizer.
Important:
Imports are done inside this method so the project does not fail
immediately on import if mmpose/mmdet are not installed yet.
"""
self._validate_files()
try:
from mmdet.apis import init_detector
from mmpose.apis import init_model as init_pose_estimator
from mmpose.registry import VISUALIZERS
from mmpose.utils import register_all_modules, adapt_mmdet_pipeline
except (ImportError, OSError) as exc:
raise RuntimeError(
"Unable to import MMPose/MMDetection. Ensure the packages are installed "
"and the environment can access their dependencies."
) from exc
register_all_modules()
self.logger.info("Loading MMPose detector on device: %s", self.device)
bbox_detector = init_detector(
self.config.detector.config_file,
str(self.config.detector.checkpoint_file),
device=self.device,
)
bbox_detector.cfg = adapt_mmdet_pipeline(bbox_detector.cfg)
self.logger.info("Loading MMPose pose estimator on device: %s", self.device)
pose_estimator = init_pose_estimator(
self.config.pose_estimator.config_file,
str(self.config.pose_estimator.checkpoint_file),
device=self.device,
cfg_options=dict(
model=dict(
test_cfg=dict(output_heatmaps=self.config.visualizer.draw_heatmap)
)
),
)
# Apply visualizer settings from config
pose_estimator.cfg.visualizer.radius = self.config.visualizer.radius
pose_estimator.cfg.visualizer.alpha = self.config.visualizer.alpha
pose_estimator.cfg.visualizer.line_width = self.config.visualizer.thickness
visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
visualizer.set_dataset_meta(
pose_estimator.dataset_meta,
skeleton_style=self.config.visualizer.skeleton_style,
)
self._bbox_detector = bbox_detector
self._pose_estimator = pose_estimator
self._visualizer = visualizer
self.logger.info("MMPose models loaded successfully.")
return self._bbox_detector, self._pose_estimator, self._visualizer
@property
def bbox_detector(self):
if self._bbox_detector is None:
raise RuntimeError("BBox detector is not loaded yet. Call load() first.")
return self._bbox_detector
@property
def pose_estimator(self):
if self._pose_estimator is None:
raise RuntimeError("Pose estimator is not loaded yet. Call load() first.")
return self._pose_estimator
@property
def visualizer(self):
if self._visualizer is None:
raise RuntimeError("Visualizer is not loaded yet. Call load() first.")
return self._visualizer
def get_runtime_bundle(self) -> dict:
"""
Return a clean dictionary bundle for downstream pipeline code.
"""
if (
self._bbox_detector is None
or self._pose_estimator is None
or self._visualizer is None
):
self.load()
return {
"bbox_detector_model": self._bbox_detector,
"pose_estimator_model": self._pose_estimator,
"visualizer": self._visualizer,
"device": self.device,
}
|