| import cv2 |
| import mediapipe as mp |
| import numpy as np |
| import logging |
| import os |
|
|
| from mediapipe.tasks import python |
| from mediapipe.tasks.python import vision |
|
|
| logger = logging.getLogger(__name__) |
|
|
| |
| BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) |
| MODEL_FILES = { |
| "lite": "pose_landmarker_lite.task", |
| "full": "pose_landmarker_full.task", |
| "heavy": "pose_landmarker.task" |
| } |
| HAND_MODEL_PATH = os.path.join(BASE_DIR, "data", "models", "hand_landmarker.task") |
|
|
| class PoseEstimator: |
| def __init__(self, static_image_mode=False, model_type="full", resize_width=None): |
| """ |
| model_type: "lite", "full", or "heavy" |
| resize_width: if set, frames will be resized to this width before processing |
| """ |
| model_name = MODEL_FILES.get(model_type, "pose_landmarker_full.task") |
| pose_model_path = os.path.join(BASE_DIR, "data", "models", model_name) |
| |
| |
| pose_base_options = python.BaseOptions(model_asset_path=pose_model_path) |
| pose_options = vision.PoseLandmarkerOptions( |
| base_options=pose_base_options, |
| running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO, |
| output_segmentation_masks=False |
| ) |
| self.pose_landmarker = vision.PoseLandmarker.create_from_options(pose_options) |
|
|
| |
| hand_base_options = python.BaseOptions(model_asset_path=HAND_MODEL_PATH) |
| hand_options = vision.HandLandmarkerOptions( |
| base_options=hand_base_options, |
| running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO, |
| num_hands=2 |
| ) |
| self.hand_landmarker = vision.HandLandmarker.create_from_options(hand_options) |
| |
| self.timestamp = 0 |
| self.static_image_mode = static_image_mode |
| self.resize_width = resize_width |
|
|
| def process_frame(self, frame): |
| """ |
| Xử lý một frame hình duy nhất. |
| """ |
| |
| if self.resize_width and frame.shape[1] > self.resize_width: |
| aspect_ratio = frame.shape[0] / frame.shape[1] |
| target_height = int(self.resize_width * aspect_ratio) |
| frame = cv2.resize(frame, (self.resize_width, target_height)) |
|
|
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
| mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb) |
| |
| if self.static_image_mode: |
| pose_result = self.pose_landmarker.detect(mp_image) |
| hand_result = self.hand_landmarker.detect(mp_image) |
| else: |
| pose_result = self.pose_landmarker.detect_for_video(mp_image, self.timestamp) |
| hand_result = self.hand_landmarker.detect_for_video(mp_image, self.timestamp) |
| self.timestamp += 33 |
|
|
| return {"pose": pose_result, "hands": hand_result} |
|
|
| def extract_landmarks(self, results): |
| """ |
| Extracts pose and hand landmarks into a structured dictionary. |
| Maps new Tasks API output to the legacy-compatible format. |
| """ |
| data = { |
| "pose": None, |
| "left_hand": None, |
| "right_hand": None |
| } |
|
|
| |
| pose_res = results["pose"] |
| if pose_res.pose_landmarks: |
| |
| data["pose"] = [[lm.x, lm.y, lm.z, lm.visibility] for lm in pose_res.pose_landmarks[0]] |
| |
| |
| hand_res = results["hands"] |
| if hand_res.hand_landmarks: |
| for idx, hand_lms in enumerate(hand_res.hand_landmarks): |
| label = hand_res.handedness[idx][0].category_name |
| lms = [[lm.x, lm.y, lm.z] for lm in hand_lms] |
| if label == "Left": |
| data["left_hand"] = lms |
| else: |
| data["right_hand"] = lms |
|
|
| return data |
|
|
| def draw_landmarks(self, frame, results): |
| """ |
| Custom drawing since mp.solutions.drawing_utils is missing. |
| """ |
| annotated_frame = frame.copy() |
| h, w, _ = frame.shape |
| |
| |
| res = self.extract_landmarks(results) |
| pose = res["pose"] |
| if pose: |
| |
| connections = [ |
| (11, 13), (13, 15), (12, 14), (14, 16), |
| (11, 12), (23, 24), (11, 23), (12, 24), |
| (23, 25), (25, 27), (24, 26), (26, 28) |
| ] |
| for start_idx, end_idx in connections: |
| p1 = (int(pose[start_idx][0] * w), int(pose[start_idx][1] * h)) |
| p2 = (int(pose[end_idx][0] * w), int(pose[end_idx][1] * h)) |
| cv2.line(annotated_frame, p1, p2, (0, 255, 0), 2) |
|
|
| return annotated_frame |
|
|
| def close(self): |
| self.pose_landmarker.close() |
| self.hand_landmarker.close() |
|
|