Pose-Detection-App / pose_detector.py
vertalius's picture
Update pose_detector.py
4155984 verified
import mediapipe as mp
import numpy as np
import cv2
from typing import List, Tuple, Optional
class PoseDetector:
def __init__(self):
self.mp_pose = mp.solutions.pose
self.mp_drawing = mp.solutions.drawing_utils
self.mp_drawing_styles = mp.solutions.drawing_styles
# Define pose connections for smooth animation
self.pose_connections = [
# Spine Chain
(self.mp_pose.PoseLandmark.NOSE.value, self.mp_pose.PoseLandmark.LEFT_SHOULDER.value),
(self.mp_pose.PoseLandmark.NOSE.value, self.mp_pose.PoseLandmark.RIGHT_SHOULDER.value),
(self.mp_pose.PoseLandmark.LEFT_SHOULDER.value, self.mp_pose.PoseLandmark.RIGHT_SHOULDER.value),
(self.mp_pose.PoseLandmark.LEFT_SHOULDER.value, self.mp_pose.PoseLandmark.LEFT_HIP.value),
(self.mp_pose.PoseLandmark.RIGHT_SHOULDER.value, self.mp_pose.PoseLandmark.RIGHT_HIP.value),
(self.mp_pose.PoseLandmark.LEFT_HIP.value, self.mp_pose.PoseLandmark.RIGHT_HIP.value),
# Left Arm Chain
(self.mp_pose.PoseLandmark.LEFT_SHOULDER.value, self.mp_pose.PoseLandmark.LEFT_ELBOW.value),
(self.mp_pose.PoseLandmark.LEFT_ELBOW.value, self.mp_pose.PoseLandmark.LEFT_WRIST.value),
(self.mp_pose.PoseLandmark.LEFT_WRIST.value, self.mp_pose.PoseLandmark.LEFT_THUMB.value),
# Right Arm Chain
(self.mp_pose.PoseLandmark.RIGHT_SHOULDER.value, self.mp_pose.PoseLandmark.RIGHT_ELBOW.value),
(self.mp_pose.PoseLandmark.RIGHT_ELBOW.value, self.mp_pose.PoseLandmark.RIGHT_WRIST.value),
(self.mp_pose.PoseLandmark.RIGHT_WRIST.value, self.mp_pose.PoseLandmark.RIGHT_THUMB.value),
# Left Leg Chain
(self.mp_pose.PoseLandmark.LEFT_HIP.value, self.mp_pose.PoseLandmark.LEFT_KNEE.value),
(self.mp_pose.PoseLandmark.LEFT_KNEE.value, self.mp_pose.PoseLandmark.LEFT_ANKLE.value),
(self.mp_pose.PoseLandmark.LEFT_ANKLE.value, self.mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value),
# Right Leg Chain
(self.mp_pose.PoseLandmark.RIGHT_HIP.value, self.mp_pose.PoseLandmark.RIGHT_KNEE.value),
(self.mp_pose.PoseLandmark.RIGHT_KNEE.value, self.mp_pose.PoseLandmark.RIGHT_ANKLE.value),
(self.mp_pose.PoseLandmark.RIGHT_ANKLE.value, self.mp_pose.PoseLandmark.RIGHT_FOOT_INDEX.value),
]
# Drawing specifications
self.landmark_drawing_spec = self.mp_drawing.DrawingSpec(
color=(0, 255, 0), # Green color
thickness=2,
circle_radius=2
)
self.connection_drawing_spec = self.mp_drawing.DrawingSpec(
color=(255, 255, 0), # Yellow color
thickness=2
)
def detect(self, image, manual_corrections=None) -> Tuple[Optional[np.ndarray], np.ndarray]:
"""
Detect pose in the given image
Args:
image: Input image
manual_corrections: Dictionary of landmark indices and their corrected positions
Returns: (landmarks, annotated_image)
"""
# (Основная логика не изменена)
pass
def draw_corrected_pose(self, image: np.ndarray, corrected_joints: dict) -> np.ndarray:
"""Draw pose with manually corrected joint positions"""
annotated_image = image.copy()
h, w = image.shape[:2]
# Draw connections
for start_name, end_name in self.pose_connections:
if start_name in corrected_joints and end_name in corrected_joints:
start_pos = corrected_joints[start_name]['position']
end_pos = corrected_joints[end_name]['position']
start_px = (int(start_pos[0] * w), int(start_pos[1] * h))
end_px = (int(end_pos[0] * w), int(end_pos[1] * h))
cv2.line(annotated_image, start_px, end_px, (0, 255, 0), 3)
# Draw joints (уменьшенный радиус с 5 до 3)
for joint_name, joint_data in corrected_joints.items():
pos = joint_data['position']
px_pos = (int(pos[0] * w), int(pos[1] * h))
cv2.circle(annotated_image, px_pos, 3, (0, 255, 255), -1)
cv2.putText(annotated_image, joint_name,
(px_pos[0], px_pos[1] - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
return annotated_image
def _draw_smooth_connections(self, image: np.ndarray, landmarks, connections: List[Tuple[int, int]]):
"""Draw smooth connections between landmarks in Maya-like style"""
h, w = image.shape[:2]
bone_names = {
0: "Head",
11: "Neck",
12: "Spine2",
23: "Hips",
24: "Spine",
13: "LeftArm",
14: "RightArm",
15: "LeftForeArm",
16: "RightForeArm",
25: "LeftLeg",
26: "RightLeg",
27: "LeftFoot",
28: "RightFoot",
31: "LeftToeBase",
32: "RightToeBase"
}
joint_color = (0, 255, 255) # Cyan
bone_color = (0, 255, 0) # Green
text_color = (255, 255, 255) # White
for connection in connections:
start_idx, end_idx = connection
start_point = landmarks.landmark[start_idx]
end_point = landmarks.landmark[end_idx]
start_pos = (int(start_point.x * w), int(start_point.y * h))
end_pos = (int(end_point.x * w), int(end_point.y * h))
cv2.line(
image,
start_pos,
end_pos,
bone_color,
3,
cv2.LINE_AA
)
# Рисуем суставы с уменьшенным радиусом (с 5 до 3)
for pos, idx in [(start_pos, start_idx), (end_pos, end_idx)]:
cv2.circle(
image,
pos,
3,
joint_color,
-1,
cv2.LINE_AA
)
if idx in bone_names:
text_pos = (pos[0], pos[1] - 10)
cv2.putText(
image,
bone_names[idx],
text_pos,
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
text_color,
1,
cv2.LINE_AA
)
def detect_video_frame(self, frame):
"""
Detect pose in video frame with optimized parameters for video
"""
with self.mp_pose.Pose(
static_image_mode=False,
model_complexity=2,
smooth_landmarks=True,
min_detection_confidence=0.5, # Повышенный порог для большей точности
min_tracking_confidence=0.5, # Повышенный порог для большей точности
enable_segmentation=False
) as pose:
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
annotated_frame = frame.copy()
if results.pose_landmarks:
self.landmark_drawing_spec.thickness = 6
self.connection_drawing_spec.thickness = 5
self.landmark_drawing_spec.circle_radius = 2 # Используем меньший радиус для видео
if hasattr(self, 'previous_landmarks'):
smoothing_factor = 0.8
for i, landmark in enumerate(results.pose_landmarks.landmark):
if self.previous_landmarks is not None:
if landmark.visibility < 0.7 or self.previous_landmarks[i].visibility > 0.8:
landmark.x = smoothing_factor * self.previous_landmarks[i].x + (1 - smoothing_factor) * landmark.x
landmark.y = smoothing_factor * self.previous_landmarks[i].y + (1 - smoothing_factor) * landmark.y
landmark.z = smoothing_factor * self.previous_landmarks[i].z + (1 - smoothing_factor) * landmark.z
landmark.visibility = max(landmark.visibility, self.previous_landmarks[i].visibility * 0.9)
self.previous_landmarks = results.pose_landmarks.landmark
self._draw_smooth_connections(
annotated_frame,
results.pose_landmarks,
self.pose_connections
)
self.mp_drawing.draw_landmarks(
image=annotated_frame,
landmark_list=results.pose_landmarks,
connections=self.mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=self.landmark_drawing_spec,
connection_drawing_spec=self.connection_drawing_spec
)
landmarks = np.array([[lm.x, lm.y, lm.z] for lm in results.pose_landmarks.landmark])
return landmarks, annotated_frame
elif hasattr(self, 'previous_landmarks') and self.previous_landmarks is not None:
results.pose_landmarks = type('obj', (object,), {'landmark': self.previous_landmarks})
self._draw_smooth_connections(
annotated_frame,
results.pose_landmarks,
self.pose_connections
)
landmarks = np.array([[lm.x, lm.y, lm.z] for lm in self.previous_landmarks])
return landmarks, annotated_frame
return None, annotated_frame