YogaPoseClassify / realtime_pose_classifier.py
pegasama's picture
train and test python script
b26156a verified
#!/usr/bin/env python3
"""
Real-time pose classifier
Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen
Features:
1. Use MediaPipe to obtain real-time pose data from camera
2. Extract joint coordinates and preprocess them
3. Use trained machine learning models for pose classification
4. Display classification results and keypoints in real-time on video screen
Dependencies:
pip install opencv-python mediapipe numpy scikit-learn
Usage:
python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID]
"""
import cv2
import mediapipe as mp
import numpy as np
import json
import joblib
import argparse
import time
from pathlib import Path
import traceback
class RealtimePoseClassifier:
def __init__(self, model_path=None, camera_id=0):
"""
Initialize real-time pose classifier
Args:
model_path (str): Model file path, auto-detect if None
camera_id (int): Camera ID, default 0
"""
self.camera_id = camera_id
# Initialize MediaPipe
self.mp_pose = mp.solutions.pose
self.mp_drawing = mp.solutions.drawing_utils
self.mp_drawing_styles = mp.solutions.drawing_styles
# Configure pose detector
self.pose = self.mp_pose.Pose(
static_image_mode=False,
model_complexity=1, # Use lower complexity for real-time applications
enable_segmentation=False,
min_detection_confidence=0.7,
min_tracking_confidence=0.5
)
# MediaPipe landmark name mapping
self.landmark_names = [
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer',
'right_eye_inner', 'right_eye', 'right_eye_outer',
'left_ear', 'right_ear', 'mouth_left', 'mouth_right',
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky',
'left_index', 'right_index', 'left_thumb', 'right_thumb',
'left_hip', 'right_hip', 'left_knee', 'right_knee',
'left_ankle', 'right_ankle', 'left_heel', 'right_heel',
'left_foot_index', 'right_foot_index'
]
# Load model
self.model = None
self.scaler = None
self.label_encoder = None
self.target_joints = None
self.model_info = None
self.load_model(model_path)
# Prediction result cache
self.prediction_history = []
self.history_size = 5 # Keep recent 5 predictions for smoothing
# Performance statistics
self.fps_counter = 0
self.fps_start_time = time.time()
self.current_fps = 0
# Added: Time statistics
self.mediapipe_time_total = 0.0
self.mediapipe_time_count = 0
self.feature_pred_time_total = 0.0
self.feature_pred_time_count = 0
# Display settings
self.show_landmarks = True
self.show_connections = True
def load_model(self, model_path=None):
"""Load trained model"""
if model_path is None:
# Auto-detect available model files
possible_models = [
'pose_classifier_random_forest.pkl',
'pose_classifier_logistic.pkl',
'pose_classifier_distilled_rf.pkl'
]
for model_file in possible_models:
if Path(model_file).exists():
model_path = model_file
break
if model_path is None:
raise FileNotFoundError("No available model file found, please specify model path")
try:
print(f"Loading model: {model_path}")
model_data = joblib.load(model_path)
self.model = model_data['model']
self.scaler = model_data['scaler']
self.label_encoder = model_data['label_encoder']
self.target_joints = model_data['target_joints']
# Try to load corresponding labels file
labels_path = model_path.replace('.pkl', '_labels.json')
if Path(labels_path).exists():
with open(labels_path, 'r') as f:
self.model_info = json.load(f)
print(f"Loaded label information: {labels_path}")
print("Model loaded successfully!")
print(f"Target joints: {self.target_joints}")
print(f"Classification classes: {self.label_encoder.classes_}")
except Exception as e:
raise RuntimeError(f"Model loading failed: {e}")
def extract_pose_features(self, landmarks):
"""
Extract pose features from MediaPipe landmarks (vectorized optimized version)
"""
if landmarks is None:
return None
# Get all joint coordinates as NumPy array
coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32)
# Get head position (nose as reference point)
try:
head_idx = self.landmark_names.index('nose')
head_pos = coords[head_idx]
except ValueError:
return None
# Build target joint indices list
joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints]
# Extract target joint coordinates (fill with 0 if not exist)
joint_coords = np.array([
coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32)
for idx in joint_indices
], dtype=np.float32)
# Calculate relative position to head and scale
relative_coords = (joint_coords - head_pos) * 100 # Keep consistent with training processing
# Keep two decimal places
features = np.round(relative_coords, 2).flatten()
return features
def predict_pose(self, features):
"""
Use machine learning model to predict pose
Args:
features: Feature vector
Returns:
dict: Prediction result containing label, confidence, etc.
"""
if features is None or self.model is None:
return None
try:
# Standardize features
features_scaled = self.scaler.transform(features.reshape(1, -1))
# Predict
prediction = self.model.predict(features_scaled)[0]
predicted_label = self.label_encoder.inverse_transform([prediction])[0]
# Get confidence (if model supports probability prediction)
confidence = 0.0
probabilities = None
if hasattr(self.model, 'predict_proba'):
probs = self.model.predict_proba(features_scaled)[0]
confidence = float(np.max(probs))
probabilities = dict(zip(self.label_encoder.classes_, probs))
return {
'predicted_label': predicted_label,
'confidence': confidence,
'probabilities': probabilities
}
except Exception as e:
print(f"Prediction error: {e}")
return None
def smooth_predictions(self, current_prediction):
"""
Smooth prediction results
Args:
current_prediction: Current prediction result
Returns:
dict: Smoothed prediction result
"""
if current_prediction is None:
return None
# Add to history
self.prediction_history.append(current_prediction)
if len(self.prediction_history) > self.history_size:
self.prediction_history.pop(0)
# If history is insufficient, return current prediction directly
if len(self.prediction_history) < 3:
return current_prediction
# Count recent prediction labels
recent_labels = [pred['predicted_label'] for pred in self.prediction_history]
# Use mode as final prediction
from collections import Counter
label_counts = Counter(recent_labels)
most_common_label = label_counts.most_common(1)[0][0]
# Calculate average confidence for this label
avg_confidence = np.mean([
pred['confidence'] for pred in self.prediction_history
if pred['predicted_label'] == most_common_label
])
return {
'predicted_label': most_common_label,
'confidence': avg_confidence,
'stability': label_counts[most_common_label] / len(recent_labels)
}
def draw_pose_info(self, image, landmarks, prediction_result):
"""
Draw pose information on image
Args:
image: OpenCV image
landmarks: MediaPipe landmarks
prediction_result: Prediction result
"""
height, width = image.shape[:2]
# Draw pose skeleton
if landmarks and self.show_connections:
self.mp_drawing.draw_landmarks(
image,
landmarks,
self.mp_pose.POSE_CONNECTIONS,
landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style()
)
# Draw keypoints
if landmarks and self.show_landmarks:
for i, landmark in enumerate(landmarks.landmark):
if self.landmark_names[i] in self.target_joints:
x = int(landmark.x * width)
y = int(landmark.y * height)
cv2.circle(image, (x, y), 8, (0, 255, 0), -1)
cv2.putText(image, self.landmark_names[i], (x + 10, y - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1)
# Display prediction results
if prediction_result:
label = prediction_result['predicted_label']
confidence = prediction_result.get('confidence', 0.0)
stability = prediction_result.get('stability', 1.0)
# Set color based on confidence
if confidence > 0.8:
color = (0, 255, 0) # Green - high confidence
elif confidence > 0.6:
color = (0, 255, 255) # Yellow - medium confidence
else:
color = (0, 0, 255) # Red - low confidence
# Draw prediction result background box
cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1)
cv2.rectangle(image, (10, 10), (400, 120), color, 2)
# Display prediction label
cv2.putText(image, f"Pose: {label}", (20, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2)
# Display confidence
cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Display stability
cv2.putText(image, f"Stability: {stability:.2f}", (20, 95),
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
# Display FPS
cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
# Display control instructions
instructions = [
"Controls:",
"Q - Quit",
"L - Toggle Landmarks",
"C - Toggle Connections",
"R - Reset History"
]
for i, instruction in enumerate(instructions):
cv2.putText(image, instruction, (width - 200, height - 120 + i * 25),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1)
# Added: Display timing statistics
mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0
fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0
cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
# Display average frame rate
total_frames = max(self.mediapipe_time_count, 1)
avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6)
cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1)
def update_fps(self):
"""Update FPS calculation"""
self.fps_counter += 1
if self.fps_counter >= 30: # Update FPS every 30 frames
current_time = time.time()
self.current_fps = 30 / (current_time - self.fps_start_time)
self.fps_start_time = current_time
self.fps_counter = 0
def run(self):
"""Run real-time pose classification"""
print("Starting real-time pose classifier...")
print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history")
# Initialize camera
cap = cv2.VideoCapture(self.camera_id)
if not cap.isOpened():
raise RuntimeError(f"Cannot open camera {self.camera_id}")
# Set camera parameters
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
cap.set(cv2.CAP_PROP_FPS, 30)
try:
while True:
success, frame = cap.read()
if not success:
print("Cannot read camera frame")
break
# Flip image horizontally (mirror effect)
frame = cv2.flip(frame, 1)
# Convert color space
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Time MediaPipe pose detection
mp_start = time.time()
results = self.pose.process(rgb_frame)
mp_end = time.time()
self.mediapipe_time_total += (mp_end - mp_start)
self.mediapipe_time_count += 1
# Extract features and predict
fp_start = time.time()
prediction_result = None
if results.pose_landmarks:
features = self.extract_pose_features(results.pose_landmarks)
if features is not None:
raw_prediction = self.predict_pose(features)
prediction_result = self.smooth_predictions(raw_prediction)
fp_end = time.time()
self.feature_pred_time_total += (fp_end - fp_start)
self.feature_pred_time_count += 1
# Draw results
self.draw_pose_info(frame, results.pose_landmarks, prediction_result)
# Update FPS
self.update_fps()
# Display image
cv2.imshow('Real-time Pose Classification', frame)
# Handle key presses
key = cv2.waitKey(1) & 0xFF
if key == ord('q') or key == ord('Q'):
break
elif key == ord('l') or key == ord('L'):
self.show_landmarks = not self.show_landmarks
print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}")
elif key == ord('c') or key == ord('C'):
self.show_connections = not self.show_connections
print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}")
elif key == ord('r') or key == ord('R'):
self.prediction_history.clear()
print("Prediction history reset")
except KeyboardInterrupt:
print("\nUser interrupted program")
except Exception as e:
print(f"Runtime error: {e}")
traceback.print_exc()
finally:
cap.release()
cv2.destroyAllWindows()
print("Program exited")
def main():
"""Main function"""
parser = argparse.ArgumentParser(description='Real-time pose classifier')
parser.add_argument('--model', '-m', type=str, default=None,
help='Model file path (auto-detect by default)')
parser.add_argument('--camera', '-c', type=int, default=0,
help='Camera ID (default 0)')
args = parser.parse_args()
try:
classifier = RealtimePoseClassifier(
model_path=args.model,
camera_id=args.camera
)
classifier.run()
except Exception as e:
print(f"Program startup failed: {e}")
return 1
return 0
if __name__ == "__main__":
exit(main())