File size: 5,464 Bytes
d2885a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
from pathlib import Path

from src.entity.config_entity import MMPoseConfig
from src.utils.common import resolve_device
from src.utils.logger import get_logger


class MMPoseLoader:
    """
    Load and manage MMPose runtime models.

    This class wraps:
    - person detector
    - pose estimator
    - optional visualizer

    It keeps loading logic separate from inference logic.
    """

    def __init__(
        self, config: MMPoseConfig, log_dir: Path | None = None, log_level: str = "INFO"
    ) -> None:
        self.config = config
        self.logger = get_logger(
            self.__class__.__name__, log_dir=log_dir, level=log_level
        )

        self.device = resolve_device(self.config.device_preference)

        self._bbox_detector = None
        self._pose_estimator = None
        self._visualizer = None

    def _validate_files(self) -> None:
        """
        Make sure config and checkpoint files exist before loading.
        """
        detector_config = Path(self.config.detector.config_file)
        detector_checkpoint = Path(self.config.detector.checkpoint_file)
        pose_config = Path(self.config.pose_estimator.config_file)
        pose_checkpoint = Path(self.config.pose_estimator.checkpoint_file)

        if not detector_config.exists():
            raise FileNotFoundError(
                f"Person detector config file not found: {detector_config}"
            )

        if not detector_checkpoint.exists():
            raise FileNotFoundError(
                f"Person detector checkpoint not found: {detector_checkpoint}"
            )

        if not pose_config.exists():
            raise FileNotFoundError(
                f"Pose estimator config file not found: {pose_config}"
            )

        if not pose_checkpoint.exists():
            raise FileNotFoundError(
                f"Pose estimator checkpoint not found: {pose_checkpoint}"
            )

    def load(self):
        """
        Load MMPose models and visualizer.

        Important:
        Imports are done inside this method so the project does not fail
        immediately on import if mmpose/mmdet are not installed yet.
        """
        self._validate_files()

        try:
            from mmdet.apis import init_detector
            from mmpose.apis import init_model as init_pose_estimator
            from mmpose.registry import VISUALIZERS
            from mmpose.utils import register_all_modules, adapt_mmdet_pipeline
        except (ImportError, OSError) as exc:
            raise RuntimeError(
                "Unable to import MMPose/MMDetection. Ensure the packages are installed "
                "and the environment can access their dependencies."
            ) from exc

        register_all_modules()

        self.logger.info("Loading MMPose detector on device: %s", self.device)
        bbox_detector = init_detector(
            self.config.detector.config_file,
            str(self.config.detector.checkpoint_file),
            device=self.device,
        )
        bbox_detector.cfg = adapt_mmdet_pipeline(bbox_detector.cfg)

        self.logger.info("Loading MMPose pose estimator on device: %s", self.device)
        pose_estimator = init_pose_estimator(
            self.config.pose_estimator.config_file,
            str(self.config.pose_estimator.checkpoint_file),
            device=self.device,
            cfg_options=dict(
                model=dict(
                    test_cfg=dict(output_heatmaps=self.config.visualizer.draw_heatmap)
                )
            ),
        )

        # Apply visualizer settings from config
        pose_estimator.cfg.visualizer.radius = self.config.visualizer.radius
        pose_estimator.cfg.visualizer.alpha = self.config.visualizer.alpha
        pose_estimator.cfg.visualizer.line_width = self.config.visualizer.thickness

        visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
        visualizer.set_dataset_meta(
            pose_estimator.dataset_meta,
            skeleton_style=self.config.visualizer.skeleton_style,
        )

        self._bbox_detector = bbox_detector
        self._pose_estimator = pose_estimator
        self._visualizer = visualizer

        self.logger.info("MMPose models loaded successfully.")
        return self._bbox_detector, self._pose_estimator, self._visualizer

    @property
    def bbox_detector(self):
        if self._bbox_detector is None:
            raise RuntimeError("BBox detector is not loaded yet. Call load() first.")
        return self._bbox_detector

    @property
    def pose_estimator(self):
        if self._pose_estimator is None:
            raise RuntimeError("Pose estimator is not loaded yet. Call load() first.")
        return self._pose_estimator

    @property
    def visualizer(self):
        if self._visualizer is None:
            raise RuntimeError("Visualizer is not loaded yet. Call load() first.")
        return self._visualizer

    def get_runtime_bundle(self) -> dict:
        """
        Return a clean dictionary bundle for downstream pipeline code.
        """
        if (
            self._bbox_detector is None
            or self._pose_estimator is None
            or self._visualizer is None
        ):
            self.load()

        return {
            "bbox_detector_model": self._bbox_detector,
            "pose_estimator_model": self._pose_estimator,
            "visualizer": self._visualizer,
            "device": self.device,
        }