File size: 5,090 Bytes
c75ced2
 
 
 
 
 
 
 
 
 
 
 
 
1c58706
 
 
 
 
c75ced2
 
 
1c58706
 
 
 
 
 
 
 
c75ced2
1c58706
c75ced2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c58706
c75ced2
 
 
 
 
1c58706
 
 
 
 
 
c75ced2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import cv2
import mediapipe as mp
import numpy as np
import logging
import os

from mediapipe.tasks import python
from mediapipe.tasks.python import vision

logger = logging.getLogger(__name__)

# Model paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODEL_FILES = {
    "lite": "pose_landmarker_lite.task",
    "full": "pose_landmarker_full.task",
    "heavy": "pose_landmarker.task"
}
HAND_MODEL_PATH = os.path.join(BASE_DIR, "data", "models", "hand_landmarker.task")

class PoseEstimator:
    def __init__(self, static_image_mode=False, model_type="full", resize_width=None):
        """
        model_type: "lite", "full", or "heavy"
        resize_width: if set, frames will be resized to this width before processing
        """
        model_name = MODEL_FILES.get(model_type, "pose_landmarker_full.task")
        pose_model_path = os.path.join(BASE_DIR, "data", "models", model_name)
        
        # Initializing Pose Landmarker
        pose_base_options = python.BaseOptions(model_asset_path=pose_model_path)
        pose_options = vision.PoseLandmarkerOptions(
            base_options=pose_base_options,
            running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO,
            output_segmentation_masks=False
        )
        self.pose_landmarker = vision.PoseLandmarker.create_from_options(pose_options)

        # Initializing Hand Landmarker
        hand_base_options = python.BaseOptions(model_asset_path=HAND_MODEL_PATH)
        hand_options = vision.HandLandmarkerOptions(
            base_options=hand_base_options,
            running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO,
            num_hands=2
        )
        self.hand_landmarker = vision.HandLandmarker.create_from_options(hand_options)
        
        self.timestamp = 0
        self.static_image_mode = static_image_mode
        self.resize_width = resize_width

    def process_frame(self, frame):
        """
        Xử lý một frame hình duy nhất. 
        """
        # Resize if requested
        if self.resize_width and frame.shape[1] > self.resize_width:
            aspect_ratio = frame.shape[0] / frame.shape[1]
            target_height = int(self.resize_width * aspect_ratio)
            frame = cv2.resize(frame, (self.resize_width, target_height))

        image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb)
        
        if self.static_image_mode:
            pose_result = self.pose_landmarker.detect(mp_image)
            hand_result = self.hand_landmarker.detect(mp_image)
        else:
            pose_result = self.pose_landmarker.detect_for_video(mp_image, self.timestamp)
            hand_result = self.hand_landmarker.detect_for_video(mp_image, self.timestamp)
            self.timestamp += 33  # Increment timestamp (approx 30fps)

        return {"pose": pose_result, "hands": hand_result}

    def extract_landmarks(self, results):
        """
        Extracts pose and hand landmarks into a structured dictionary.
        Maps new Tasks API output to the legacy-compatible format.
        """
        data = {
            "pose": None,
            "left_hand": None,
            "right_hand": None
        }

        # Extract Pose
        pose_res = results["pose"]
        if pose_res.pose_landmarks:
            # We take the first person detected
            data["pose"] = [[lm.x, lm.y, lm.z, lm.visibility] for lm in pose_res.pose_landmarks[0]]
        
        # Extract Hands
        hand_res = results["hands"]
        if hand_res.hand_landmarks:
            for idx, hand_lms in enumerate(hand_res.hand_landmarks):
                label = hand_res.handedness[idx][0].category_name # "Left" or "Right"
                lms = [[lm.x, lm.y, lm.z] for lm in hand_lms]
                if label == "Left":
                    data["left_hand"] = lms
                else:
                    data["right_hand"] = lms

        return data

    def draw_landmarks(self, frame, results):
        """
        Custom drawing since mp.solutions.drawing_utils is missing.
        """
        annotated_frame = frame.copy()
        h, w, _ = frame.shape
        
        # Draw Pose connections
        res = self.extract_landmarks(results)
        pose = res["pose"]
        if pose:
            # Simple pose connection drawing (subset of joints)
            connections = [
                (11, 13), (13, 15), (12, 14), (14, 16), # Arms
                (11, 12), (23, 24), (11, 23), (12, 24), # Torso
                (23, 25), (25, 27), (24, 26), (26, 28)  # Legs
            ]
            for start_idx, end_idx in connections:
                p1 = (int(pose[start_idx][0] * w), int(pose[start_idx][1] * h))
                p2 = (int(pose[end_idx][0] * w), int(pose[end_idx][1] * h))
                cv2.line(annotated_frame, p1, p2, (0, 255, 0), 2)

        return annotated_frame

    def close(self):
        self.pose_landmarker.close()
        self.hand_landmarker.close()