Spaces:
Sleeping
Sleeping
| import tqdm | |
| import cv2 | |
| import numpy as np | |
| from mediapipe.python.solutions import drawing_utils as mp_drawing | |
| import mediapipe as mp | |
| from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
| from PoseClassification.pose_classifier import PoseClassifier | |
| from PoseClassification.utils import EMADictSmoothing | |
| from PoseClassification.utils import RepetitionCounter | |
| from PoseClassification.visualize import PoseClassificationVisualizer | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| pose_samples_folder = "data/fitness_poses_csvs_out" | |
| class_name = "pushups_down" | |
| pose_embedder = FullBodyPoseEmbedding() | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder=pose_samples_folder, | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10, | |
| ) | |
| pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
| repetition_counter = RepetitionCounter( | |
| class_name=class_name, enter_threshold=6, exit_threshold=4 | |
| ) | |
| pose_classification_visualizer = PoseClassificationVisualizer( | |
| class_name=class_name, plot_x_max=1000, plot_y_max=10 | |
| ) | |
| video_cap = cv2.VideoCapture(0) | |
| video_fps = 30 | |
| video_width = 1280 | |
| video_height = 720 | |
| video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width) | |
| video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height) | |
| frame_idx = 0 | |
| output_frame = None | |
| try: | |
| with tqdm.tqdm(position=0, leave=True) as pbar: | |
| while True: | |
| success, input_frame = video_cap.read() | |
| if not success: | |
| print("Unable to read input video frame, breaking!") | |
| break | |
| # Run pose tracker | |
| input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=input_frame_rgb) | |
| pose_landmarks = result.pose_landmarks | |
| # Prepare the output frame | |
| output_frame = input_frame.copy() | |
| # Add a white banner on top | |
| banner_height = 180 | |
| output_frame[0:banner_height, :] = (255, 255, 255) # White color | |
| # Load the logo image | |
| logo = cv2.imread("src/logo_impredalam.jpg") | |
| logo_height, logo_width = logo.shape[:2] | |
| logo = cv2.resize( | |
| logo, (logo_width // 3, logo_height // 3) | |
| ) # Resize to 1/3 scale | |
| # Overlay the logo on the upper right corner | |
| output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = ( | |
| logo | |
| ) | |
| if pose_landmarks is not None: | |
| mp_drawing.draw_landmarks( | |
| image=output_frame, | |
| landmark_list=pose_landmarks, | |
| connections=mp_pose.POSE_CONNECTIONS, | |
| ) | |
| # Get landmarks | |
| frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
| pose_landmarks = np.array( | |
| [ | |
| [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in pose_landmarks.landmark | |
| ], | |
| dtype=np.float32, | |
| ) | |
| assert pose_landmarks.shape == ( | |
| 33, | |
| 3, | |
| ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
| # Classify the pose on the current frame | |
| pose_classification = pose_classifier(pose_landmarks) | |
| # Smooth classification using EMA | |
| pose_classification_filtered = pose_classification_filter( | |
| pose_classification | |
| ) | |
| # Count repetitions | |
| repetitions_count = repetition_counter(pose_classification_filtered) | |
| # Display repetitions count on the frame | |
| cv2.putText( | |
| output_frame, | |
| f"Push-Ups: {repetitions_count}", | |
| (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 0, 0), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| # Display classified pose on the frame | |
| cv2.putText( | |
| output_frame, | |
| f"Pose: {pose_classification}", | |
| (10, 70), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1.2, # Smaller font size | |
| (0, 0, 0), | |
| 1, # Thinner line | |
| cv2.LINE_AA, | |
| ) | |
| else: | |
| # If no landmarks are detected, still display the last count | |
| repetitions_count = repetition_counter.n_repeats | |
| cv2.putText( | |
| output_frame, | |
| f"Push-Ups: {repetitions_count}", | |
| (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| cv2.LINE_AA, | |
| ) | |
| cv2.imshow("Push-Up Counter", output_frame) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord("q"): | |
| break | |
| elif key == ord("r"): | |
| repetition_counter.reset() | |
| print("Counter reset!") | |
| frame_idx += 1 | |
| pbar.update() | |
| finally: | |
| pose_tracker.close() | |
| video_cap.release() | |
| cv2.destroyAllWindows() | |