|
|
|
|
|
""" |
|
|
Real-time pose classifier |
|
|
Uses MediaPipe to capture camera input, perform pose recognition and classification, and display results on screen |
|
|
|
|
|
Features: |
|
|
1. Use MediaPipe to obtain real-time pose data from camera |
|
|
2. Extract joint coordinates and preprocess them |
|
|
3. Use trained machine learning models for pose classification |
|
|
4. Display classification results and keypoints in real-time on video screen |
|
|
|
|
|
Dependencies: |
|
|
pip install opencv-python mediapipe numpy scikit-learn |
|
|
|
|
|
Usage: |
|
|
python realtime_pose_classifier.py [--model MODEL_PATH] [--camera CAMERA_ID] |
|
|
""" |
|
|
|
|
|
import cv2 |
|
|
import mediapipe as mp |
|
|
import numpy as np |
|
|
import json |
|
|
import joblib |
|
|
import argparse |
|
|
import time |
|
|
from pathlib import Path |
|
|
import traceback |
|
|
|
|
|
|
|
|
class RealtimePoseClassifier: |
|
|
def __init__(self, model_path=None, camera_id=0): |
|
|
""" |
|
|
Initialize real-time pose classifier |
|
|
|
|
|
Args: |
|
|
model_path (str): Model file path, auto-detect if None |
|
|
camera_id (int): Camera ID, default 0 |
|
|
""" |
|
|
self.camera_id = camera_id |
|
|
|
|
|
|
|
|
self.mp_pose = mp.solutions.pose |
|
|
self.mp_drawing = mp.solutions.drawing_utils |
|
|
self.mp_drawing_styles = mp.solutions.drawing_styles |
|
|
|
|
|
|
|
|
self.pose = self.mp_pose.Pose( |
|
|
static_image_mode=False, |
|
|
model_complexity=1, |
|
|
enable_segmentation=False, |
|
|
min_detection_confidence=0.7, |
|
|
min_tracking_confidence=0.5 |
|
|
) |
|
|
|
|
|
|
|
|
self.landmark_names = [ |
|
|
'nose', 'left_eye_inner', 'left_eye', 'left_eye_outer', |
|
|
'right_eye_inner', 'right_eye', 'right_eye_outer', |
|
|
'left_ear', 'right_ear', 'mouth_left', 'mouth_right', |
|
|
'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', |
|
|
'left_wrist', 'right_wrist', 'left_pinky', 'right_pinky', |
|
|
'left_index', 'right_index', 'left_thumb', 'right_thumb', |
|
|
'left_hip', 'right_hip', 'left_knee', 'right_knee', |
|
|
'left_ankle', 'right_ankle', 'left_heel', 'right_heel', |
|
|
'left_foot_index', 'right_foot_index' |
|
|
] |
|
|
|
|
|
|
|
|
self.model = None |
|
|
self.scaler = None |
|
|
self.label_encoder = None |
|
|
self.target_joints = None |
|
|
self.model_info = None |
|
|
|
|
|
self.load_model(model_path) |
|
|
|
|
|
|
|
|
self.prediction_history = [] |
|
|
self.history_size = 5 |
|
|
|
|
|
|
|
|
self.fps_counter = 0 |
|
|
self.fps_start_time = time.time() |
|
|
self.current_fps = 0 |
|
|
|
|
|
|
|
|
self.mediapipe_time_total = 0.0 |
|
|
self.mediapipe_time_count = 0 |
|
|
self.feature_pred_time_total = 0.0 |
|
|
self.feature_pred_time_count = 0 |
|
|
|
|
|
|
|
|
self.show_landmarks = True |
|
|
self.show_connections = True |
|
|
|
|
|
def load_model(self, model_path=None): |
|
|
"""Load trained model""" |
|
|
if model_path is None: |
|
|
|
|
|
possible_models = [ |
|
|
'pose_classifier_random_forest.pkl', |
|
|
'pose_classifier_logistic.pkl', |
|
|
'pose_classifier_distilled_rf.pkl' |
|
|
] |
|
|
|
|
|
for model_file in possible_models: |
|
|
if Path(model_file).exists(): |
|
|
model_path = model_file |
|
|
break |
|
|
|
|
|
if model_path is None: |
|
|
raise FileNotFoundError("No available model file found, please specify model path") |
|
|
|
|
|
try: |
|
|
print(f"Loading model: {model_path}") |
|
|
model_data = joblib.load(model_path) |
|
|
|
|
|
self.model = model_data['model'] |
|
|
self.scaler = model_data['scaler'] |
|
|
self.label_encoder = model_data['label_encoder'] |
|
|
self.target_joints = model_data['target_joints'] |
|
|
|
|
|
|
|
|
labels_path = model_path.replace('.pkl', '_labels.json') |
|
|
if Path(labels_path).exists(): |
|
|
with open(labels_path, 'r') as f: |
|
|
self.model_info = json.load(f) |
|
|
print(f"Loaded label information: {labels_path}") |
|
|
|
|
|
print("Model loaded successfully!") |
|
|
print(f"Target joints: {self.target_joints}") |
|
|
print(f"Classification classes: {self.label_encoder.classes_}") |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Model loading failed: {e}") |
|
|
|
|
|
def extract_pose_features(self, landmarks): |
|
|
""" |
|
|
Extract pose features from MediaPipe landmarks (vectorized optimized version) |
|
|
""" |
|
|
if landmarks is None: |
|
|
return None |
|
|
|
|
|
|
|
|
coords = np.array([[lm.x, lm.y, lm.z] for lm in landmarks.landmark], dtype=np.float32) |
|
|
|
|
|
|
|
|
try: |
|
|
head_idx = self.landmark_names.index('nose') |
|
|
head_pos = coords[head_idx] |
|
|
except ValueError: |
|
|
return None |
|
|
|
|
|
|
|
|
joint_indices = [self.landmark_names.index(j) if j in self.landmark_names else -1 for j in self.target_joints] |
|
|
|
|
|
|
|
|
joint_coords = np.array([ |
|
|
coords[idx] if idx >= 0 else np.zeros(3, dtype=np.float32) |
|
|
for idx in joint_indices |
|
|
], dtype=np.float32) |
|
|
|
|
|
|
|
|
relative_coords = (joint_coords - head_pos) * 100 |
|
|
|
|
|
|
|
|
features = np.round(relative_coords, 2).flatten() |
|
|
|
|
|
return features |
|
|
|
|
|
def predict_pose(self, features): |
|
|
""" |
|
|
Use machine learning model to predict pose |
|
|
|
|
|
Args: |
|
|
features: Feature vector |
|
|
|
|
|
Returns: |
|
|
dict: Prediction result containing label, confidence, etc. |
|
|
""" |
|
|
if features is None or self.model is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
features_scaled = self.scaler.transform(features.reshape(1, -1)) |
|
|
|
|
|
|
|
|
prediction = self.model.predict(features_scaled)[0] |
|
|
predicted_label = self.label_encoder.inverse_transform([prediction])[0] |
|
|
|
|
|
|
|
|
confidence = 0.0 |
|
|
probabilities = None |
|
|
if hasattr(self.model, 'predict_proba'): |
|
|
probs = self.model.predict_proba(features_scaled)[0] |
|
|
confidence = float(np.max(probs)) |
|
|
probabilities = dict(zip(self.label_encoder.classes_, probs)) |
|
|
|
|
|
return { |
|
|
'predicted_label': predicted_label, |
|
|
'confidence': confidence, |
|
|
'probabilities': probabilities |
|
|
} |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Prediction error: {e}") |
|
|
return None |
|
|
|
|
|
def smooth_predictions(self, current_prediction): |
|
|
""" |
|
|
Smooth prediction results |
|
|
|
|
|
Args: |
|
|
current_prediction: Current prediction result |
|
|
|
|
|
Returns: |
|
|
dict: Smoothed prediction result |
|
|
""" |
|
|
if current_prediction is None: |
|
|
return None |
|
|
|
|
|
|
|
|
self.prediction_history.append(current_prediction) |
|
|
if len(self.prediction_history) > self.history_size: |
|
|
self.prediction_history.pop(0) |
|
|
|
|
|
|
|
|
if len(self.prediction_history) < 3: |
|
|
return current_prediction |
|
|
|
|
|
|
|
|
recent_labels = [pred['predicted_label'] for pred in self.prediction_history] |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
label_counts = Counter(recent_labels) |
|
|
most_common_label = label_counts.most_common(1)[0][0] |
|
|
|
|
|
|
|
|
avg_confidence = np.mean([ |
|
|
pred['confidence'] for pred in self.prediction_history |
|
|
if pred['predicted_label'] == most_common_label |
|
|
]) |
|
|
|
|
|
return { |
|
|
'predicted_label': most_common_label, |
|
|
'confidence': avg_confidence, |
|
|
'stability': label_counts[most_common_label] / len(recent_labels) |
|
|
} |
|
|
|
|
|
def draw_pose_info(self, image, landmarks, prediction_result): |
|
|
""" |
|
|
Draw pose information on image |
|
|
|
|
|
Args: |
|
|
image: OpenCV image |
|
|
landmarks: MediaPipe landmarks |
|
|
prediction_result: Prediction result |
|
|
""" |
|
|
height, width = image.shape[:2] |
|
|
|
|
|
|
|
|
if landmarks and self.show_connections: |
|
|
self.mp_drawing.draw_landmarks( |
|
|
image, |
|
|
landmarks, |
|
|
self.mp_pose.POSE_CONNECTIONS, |
|
|
landmark_drawing_spec=self.mp_drawing_styles.get_default_pose_landmarks_style() |
|
|
) |
|
|
|
|
|
|
|
|
if landmarks and self.show_landmarks: |
|
|
for i, landmark in enumerate(landmarks.landmark): |
|
|
if self.landmark_names[i] in self.target_joints: |
|
|
x = int(landmark.x * width) |
|
|
y = int(landmark.y * height) |
|
|
cv2.circle(image, (x, y), 8, (0, 255, 0), -1) |
|
|
cv2.putText(image, self.landmark_names[i], (x + 10, y - 10), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1) |
|
|
|
|
|
|
|
|
if prediction_result: |
|
|
label = prediction_result['predicted_label'] |
|
|
confidence = prediction_result.get('confidence', 0.0) |
|
|
stability = prediction_result.get('stability', 1.0) |
|
|
|
|
|
|
|
|
if confidence > 0.8: |
|
|
color = (0, 255, 0) |
|
|
elif confidence > 0.6: |
|
|
color = (0, 255, 255) |
|
|
else: |
|
|
color = (0, 0, 255) |
|
|
|
|
|
|
|
|
cv2.rectangle(image, (10, 10), (400, 120), (0, 0, 0), -1) |
|
|
cv2.rectangle(image, (10, 10), (400, 120), color, 2) |
|
|
|
|
|
|
|
|
cv2.putText(image, f"Pose: {label}", (20, 40), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 1.0, color, 2) |
|
|
|
|
|
|
|
|
cv2.putText(image, f"Confidence: {confidence:.2f}", (20, 70), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
cv2.putText(image, f"Stability: {stability:.2f}", (20, 95), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
cv2.putText(image, f"FPS: {self.current_fps:.1f}", (width - 150, 30), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) |
|
|
|
|
|
|
|
|
instructions = [ |
|
|
"Controls:", |
|
|
"Q - Quit", |
|
|
"L - Toggle Landmarks", |
|
|
"C - Toggle Connections", |
|
|
"R - Reset History" |
|
|
] |
|
|
|
|
|
for i, instruction in enumerate(instructions): |
|
|
cv2.putText(image, instruction, (width - 200, height - 120 + i * 25), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (200, 200, 200), 1) |
|
|
|
|
|
|
|
|
mp_avg = self.mediapipe_time_total / self.mediapipe_time_count if self.mediapipe_time_count else 0.0 |
|
|
fp_avg = self.feature_pred_time_total / self.feature_pred_time_count if self.feature_pred_time_count else 0.0 |
|
|
cv2.putText(image, f"MP avg: {mp_avg*1000:.1f}ms", (width - 150, 55), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
|
|
cv2.putText(image, f"FP avg: {fp_avg*1000:.1f}ms", (width - 150, 75), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
|
|
|
|
|
total_frames = max(self.mediapipe_time_count, 1) |
|
|
avg_fps = total_frames / max(self.mediapipe_time_total + self.feature_pred_time_total, 1e-6) |
|
|
cv2.putText(image, f"Avg FPS: {avg_fps:.1f}", (width - 150, 95), |
|
|
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 1) |
|
|
|
|
|
def update_fps(self): |
|
|
"""Update FPS calculation""" |
|
|
self.fps_counter += 1 |
|
|
if self.fps_counter >= 30: |
|
|
current_time = time.time() |
|
|
self.current_fps = 30 / (current_time - self.fps_start_time) |
|
|
self.fps_start_time = current_time |
|
|
self.fps_counter = 0 |
|
|
|
|
|
def run(self): |
|
|
"""Run real-time pose classification""" |
|
|
print("Starting real-time pose classifier...") |
|
|
print("Press 'Q' to quit, 'L' to toggle landmark display, 'C' to toggle skeleton connections, 'R' to reset history") |
|
|
|
|
|
|
|
|
cap = cv2.VideoCapture(self.camera_id) |
|
|
if not cap.isOpened(): |
|
|
raise RuntimeError(f"Cannot open camera {self.camera_id}") |
|
|
|
|
|
|
|
|
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280) |
|
|
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720) |
|
|
cap.set(cv2.CAP_PROP_FPS, 30) |
|
|
|
|
|
try: |
|
|
while True: |
|
|
success, frame = cap.read() |
|
|
if not success: |
|
|
print("Cannot read camera frame") |
|
|
break |
|
|
|
|
|
|
|
|
frame = cv2.flip(frame, 1) |
|
|
|
|
|
|
|
|
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
|
|
|
mp_start = time.time() |
|
|
results = self.pose.process(rgb_frame) |
|
|
mp_end = time.time() |
|
|
self.mediapipe_time_total += (mp_end - mp_start) |
|
|
self.mediapipe_time_count += 1 |
|
|
|
|
|
|
|
|
fp_start = time.time() |
|
|
prediction_result = None |
|
|
if results.pose_landmarks: |
|
|
features = self.extract_pose_features(results.pose_landmarks) |
|
|
if features is not None: |
|
|
raw_prediction = self.predict_pose(features) |
|
|
prediction_result = self.smooth_predictions(raw_prediction) |
|
|
fp_end = time.time() |
|
|
self.feature_pred_time_total += (fp_end - fp_start) |
|
|
self.feature_pred_time_count += 1 |
|
|
|
|
|
|
|
|
self.draw_pose_info(frame, results.pose_landmarks, prediction_result) |
|
|
|
|
|
|
|
|
self.update_fps() |
|
|
|
|
|
|
|
|
cv2.imshow('Real-time Pose Classification', frame) |
|
|
|
|
|
|
|
|
key = cv2.waitKey(1) & 0xFF |
|
|
if key == ord('q') or key == ord('Q'): |
|
|
break |
|
|
elif key == ord('l') or key == ord('L'): |
|
|
self.show_landmarks = not self.show_landmarks |
|
|
print(f"Landmark display: {'On' if self.show_landmarks else 'Off'}") |
|
|
elif key == ord('c') or key == ord('C'): |
|
|
self.show_connections = not self.show_connections |
|
|
print(f"Skeleton connection display: {'On' if self.show_connections else 'Off'}") |
|
|
elif key == ord('r') or key == ord('R'): |
|
|
self.prediction_history.clear() |
|
|
print("Prediction history reset") |
|
|
|
|
|
except KeyboardInterrupt: |
|
|
print("\nUser interrupted program") |
|
|
except Exception as e: |
|
|
print(f"Runtime error: {e}") |
|
|
traceback.print_exc() |
|
|
finally: |
|
|
cap.release() |
|
|
cv2.destroyAllWindows() |
|
|
print("Program exited") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function""" |
|
|
parser = argparse.ArgumentParser(description='Real-time pose classifier') |
|
|
parser.add_argument('--model', '-m', type=str, default=None, |
|
|
help='Model file path (auto-detect by default)') |
|
|
parser.add_argument('--camera', '-c', type=int, default=0, |
|
|
help='Camera ID (default 0)') |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
|
classifier = RealtimePoseClassifier( |
|
|
model_path=args.model, |
|
|
camera_id=args.camera |
|
|
) |
|
|
classifier.run() |
|
|
except Exception as e: |
|
|
print(f"Program startup failed: {e}") |
|
|
return 1 |
|
|
|
|
|
return 0 |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
exit(main()) |
|
|
|