Spaces:
Sleeping
Sleeping
| """ | |
| YOLOv11-Pose ๋ํผ ํด๋์ค | |
| ์ค์๊ฐ pose estimation์ ์ํ YOLOv11-Pose ๋ชจ๋ธ ๋ํผ์ ๋๋ค. | |
| """ | |
| import logging | |
| from typing import Optional | |
| import numpy as np | |
| import torch | |
| from ultralytics import YOLO | |
| class PoseEstimator: | |
| """YOLOv11-Pose ๊ธฐ๋ฐ ํฌ์ฆ ์ถ์ ๊ธฐ""" | |
| def __init__( | |
| self, | |
| model_path: str = "yolo11m-pose.pt", | |
| conf_threshold: float = 0.5, | |
| imgsz: int = 640, | |
| device: str = "cuda:0", | |
| logger: Optional[logging.Logger] = None | |
| ): | |
| """ | |
| Args: | |
| model_path: YOLOv11-Pose ๋ชจ๋ธ ๊ฒฝ๋ก | |
| conf_threshold: ๊ฐ์ง ์ ๋ขฐ๋ ์๊ณ๊ฐ | |
| imgsz: ์ ๋ ฅ ์ด๋ฏธ์ง ํฌ๊ธฐ | |
| device: ๋๋ฐ์ด์ค (cuda:0, cpu ๋ฑ) | |
| logger: ๋ก๊ฑฐ ์ธ์คํด์ค | |
| """ | |
| self.device = torch.device(device if torch.cuda.is_available() else "cpu") | |
| self.conf_threshold = conf_threshold | |
| self.imgsz = imgsz | |
| self.logger = logger or logging.getLogger(__name__) | |
| # ๋ชจ๋ธ ๋ก๋ | |
| self.logger.info(f"[Stage 1] YOLOv11-Pose ๋ก๋ ์ค: {model_path}") | |
| self.model = YOLO(model_path) | |
| self.model.to(self.device) | |
| self.logger.info(f" - Confidence threshold: {conf_threshold}") | |
| self.logger.info(f" - Image size: {imgsz}") | |
| self.logger.info(f" - Device: {self.device}") | |
| def extract(self, frame: np.ndarray, debug: bool = False) -> Optional[np.ndarray]: | |
| """ | |
| ํ๋ ์์์ pose keypoints ์ถ์ถ | |
| Args: | |
| frame: OpenCV ์ด๋ฏธ์ง (H, W, 3) | |
| debug: ๋๋ฒ๊ทธ ๋ก๊ทธ ์ถ๋ ฅ ์ฌ๋ถ | |
| Returns: | |
| keypoints: (17, 3) numpy array ๋๋ None (์ฌ๋์ด ๊ฐ์ง๋์ง ์์ ๊ฒฝ์ฐ) | |
| ๊ฐ keypoint๋ (x, y, confidence) ํํ | |
| """ | |
| results = self.model.predict( | |
| frame, | |
| imgsz=self.imgsz, | |
| conf=self.conf_threshold, | |
| verbose=False | |
| ) | |
| if results and len(results) > 0 and results[0].keypoints is not None: | |
| keypoints_data = results[0].keypoints.data.cpu().numpy() | |
| if len(keypoints_data) > 0: | |
| # ๊ฐ์ฅ ์ ๋ขฐ๋ ๋์ ์ฌ๋ ์ ํ | |
| if results[0].boxes is not None: | |
| confidences = results[0].boxes.conf.cpu().numpy() | |
| best_idx = np.argmax(confidences) | |
| keypoints = keypoints_data[best_idx] # (17, 3) | |
| else: | |
| keypoints = keypoints_data[0] | |
| if debug: | |
| avg_conf = keypoints[:, 2].mean() | |
| self.logger.debug(f" Pose detected: avg_conf={avg_conf:.3f}") | |
| return keypoints | |
| if debug: | |
| self.logger.debug(" No pose detected") | |
| return None | |
| def extract_batch( | |
| self, frames: list[np.ndarray] | np.ndarray, debug: bool = False | |
| ) -> list[Optional[np.ndarray]]: | |
| """ | |
| ์ฌ๋ฌ ํ๋ ์์์ ๋ฐฐ์น๋ก pose keypoints ์ถ์ถ (GPU ํ์ฉ ๊ทน๋ํ) | |
| Args: | |
| frames: OpenCV ์ด๋ฏธ์ง ๋ฆฌ์คํธ [(H, W, 3), ...] ๋๋ numpy ๋ฐฐ์ด (N, H, W, C) | |
| debug: ๋๋ฒ๊ทธ ๋ก๊ทธ ์ถ๋ ฅ ์ฌ๋ถ | |
| Returns: | |
| keypoints_list: [(17, 3) numpy array or None, ...] ๊ฐ ํ๋ ์๋ณ keypoints | |
| """ | |
| # ๋น ์ ๋ ฅ ์ฒดํฌ (๋ฆฌ์คํธ์ numpy ๋ฐฐ์ด ๋ชจ๋ ์ง์) | |
| if isinstance(frames, np.ndarray): | |
| if frames.size == 0: | |
| return [] | |
| # numpy ๋ฐฐ์ด์ ๋ฆฌ์คํธ๋ก ๋ณํ | |
| frames = list(frames) | |
| elif not frames: | |
| return [] | |
| # YOLO ๋ฐฐ์น ์ถ๋ก | |
| results = self.model.predict( | |
| frames, | |
| imgsz=self.imgsz, | |
| conf=self.conf_threshold, | |
| verbose=False | |
| ) | |
| keypoints_list = [] | |
| for i, result in enumerate(results): | |
| if result.keypoints is not None: | |
| keypoints_data = result.keypoints.data.cpu().numpy() | |
| if len(keypoints_data) > 0: | |
| # ๊ฐ์ฅ ์ ๋ขฐ๋ ๋์ ์ฌ๋ ์ ํ | |
| if result.boxes is not None: | |
| confidences = result.boxes.conf.cpu().numpy() | |
| best_idx = np.argmax(confidences) | |
| keypoints = keypoints_data[best_idx] # (17, 3) | |
| else: | |
| keypoints = keypoints_data[0] | |
| if debug: | |
| avg_conf = keypoints[:, 2].mean() | |
| self.logger.debug( | |
| f" Batch[{i}] Pose detected: avg_conf={avg_conf:.3f}" | |
| ) | |
| keypoints_list.append(keypoints) | |
| continue | |
| if debug: | |
| self.logger.debug(f" Batch[{i}] No pose detected") | |
| keypoints_list.append(None) | |
| return keypoints_list | |
| def get_empty_keypoints(self) -> np.ndarray: | |
| """๋น keypoints ๋ฐฐ์ด ๋ฐํ (์ฌ๋์ด ๊ฐ์ง๋์ง ์์ ๊ฒฝ์ฐ ์ฌ์ฉ)""" | |
| return np.zeros((17, 3), dtype=np.float32) | |