AI-Coach / src /core /pose.py
anhlehong
feat/enhance
1c58706
import cv2
import mediapipe as mp
import numpy as np
import logging
import os
from mediapipe.tasks import python
from mediapipe.tasks.python import vision
logger = logging.getLogger(__name__)
# Model paths
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
MODEL_FILES = {
"lite": "pose_landmarker_lite.task",
"full": "pose_landmarker_full.task",
"heavy": "pose_landmarker.task"
}
HAND_MODEL_PATH = os.path.join(BASE_DIR, "data", "models", "hand_landmarker.task")
class PoseEstimator:
def __init__(self, static_image_mode=False, model_type="full", resize_width=None):
"""
model_type: "lite", "full", or "heavy"
resize_width: if set, frames will be resized to this width before processing
"""
model_name = MODEL_FILES.get(model_type, "pose_landmarker_full.task")
pose_model_path = os.path.join(BASE_DIR, "data", "models", model_name)
# Initializing Pose Landmarker
pose_base_options = python.BaseOptions(model_asset_path=pose_model_path)
pose_options = vision.PoseLandmarkerOptions(
base_options=pose_base_options,
running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO,
output_segmentation_masks=False
)
self.pose_landmarker = vision.PoseLandmarker.create_from_options(pose_options)
# Initializing Hand Landmarker
hand_base_options = python.BaseOptions(model_asset_path=HAND_MODEL_PATH)
hand_options = vision.HandLandmarkerOptions(
base_options=hand_base_options,
running_mode=vision.RunningMode.IMAGE if static_image_mode else vision.RunningMode.VIDEO,
num_hands=2
)
self.hand_landmarker = vision.HandLandmarker.create_from_options(hand_options)
self.timestamp = 0
self.static_image_mode = static_image_mode
self.resize_width = resize_width
def process_frame(self, frame):
"""
Xử lý một frame hình duy nhất.
"""
# Resize if requested
if self.resize_width and frame.shape[1] > self.resize_width:
aspect_ratio = frame.shape[0] / frame.shape[1]
target_height = int(self.resize_width * aspect_ratio)
frame = cv2.resize(frame, (self.resize_width, target_height))
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=image_rgb)
if self.static_image_mode:
pose_result = self.pose_landmarker.detect(mp_image)
hand_result = self.hand_landmarker.detect(mp_image)
else:
pose_result = self.pose_landmarker.detect_for_video(mp_image, self.timestamp)
hand_result = self.hand_landmarker.detect_for_video(mp_image, self.timestamp)
self.timestamp += 33 # Increment timestamp (approx 30fps)
return {"pose": pose_result, "hands": hand_result}
def extract_landmarks(self, results):
"""
Extracts pose and hand landmarks into a structured dictionary.
Maps new Tasks API output to the legacy-compatible format.
"""
data = {
"pose": None,
"left_hand": None,
"right_hand": None
}
# Extract Pose
pose_res = results["pose"]
if pose_res.pose_landmarks:
# We take the first person detected
data["pose"] = [[lm.x, lm.y, lm.z, lm.visibility] for lm in pose_res.pose_landmarks[0]]
# Extract Hands
hand_res = results["hands"]
if hand_res.hand_landmarks:
for idx, hand_lms in enumerate(hand_res.hand_landmarks):
label = hand_res.handedness[idx][0].category_name # "Left" or "Right"
lms = [[lm.x, lm.y, lm.z] for lm in hand_lms]
if label == "Left":
data["left_hand"] = lms
else:
data["right_hand"] = lms
return data
def draw_landmarks(self, frame, results):
"""
Custom drawing since mp.solutions.drawing_utils is missing.
"""
annotated_frame = frame.copy()
h, w, _ = frame.shape
# Draw Pose connections
res = self.extract_landmarks(results)
pose = res["pose"]
if pose:
# Simple pose connection drawing (subset of joints)
connections = [
(11, 13), (13, 15), (12, 14), (14, 16), # Arms
(11, 12), (23, 24), (11, 23), (12, 24), # Torso
(23, 25), (25, 27), (24, 26), (26, 28) # Legs
]
for start_idx, end_idx in connections:
p1 = (int(pose[start_idx][0] * w), int(pose[start_idx][1] * h))
p2 = (int(pose[end_idx][0] * w), int(pose[end_idx][1] * h))
cv2.line(annotated_frame, p1, p2, (0, 255, 0), 2)
return annotated_frame
def close(self):
self.pose_landmarker.close()
self.hand_landmarker.close()