import tqdm import cv2 import numpy as np from mediapipe.python.solutions import drawing_utils as mp_drawing import mediapipe as mp from PoseClassification.pose_embedding import FullBodyPoseEmbedding from PoseClassification.pose_classifier import PoseClassifier from PoseClassification.utils import EMADictSmoothing from PoseClassification.utils import RepetitionCounter from PoseClassification.visualize import PoseClassificationVisualizer mp_pose = mp.solutions.pose pose_tracker = mp_pose.Pose() pose_samples_folder = "data/fitness_poses_csvs_out" class_name = "pushups_down" pose_embedder = FullBodyPoseEmbedding() pose_classifier = PoseClassifier( pose_samples_folder=pose_samples_folder, pose_embedder=pose_embedder, top_n_by_max_distance=30, top_n_by_mean_distance=10, ) pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) repetition_counter = RepetitionCounter( class_name=class_name, enter_threshold=6, exit_threshold=4 ) pose_classification_visualizer = PoseClassificationVisualizer( class_name=class_name, plot_x_max=1000, plot_y_max=10 ) video_cap = cv2.VideoCapture(0) video_fps = 30 video_width = 1280 video_height = 720 video_cap.set(cv2.CAP_PROP_FRAME_WIDTH, video_width) video_cap.set(cv2.CAP_PROP_FRAME_HEIGHT, video_height) frame_idx = 0 output_frame = None try: with tqdm.tqdm(position=0, leave=True) as pbar: while True: success, input_frame = video_cap.read() if not success: print("Unable to read input video frame, breaking!") break # Run pose tracker input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) result = pose_tracker.process(image=input_frame_rgb) pose_landmarks = result.pose_landmarks # Prepare the output frame output_frame = input_frame.copy() # Add a white banner on top banner_height = 180 output_frame[0:banner_height, :] = (255, 255, 255) # White color # Load the logo image logo = cv2.imread("src/logo_impredalam.jpg") logo_height, logo_width = logo.shape[:2] logo = cv2.resize( logo, (logo_width // 3, logo_height // 3) ) # Resize to 1/3 scale # Overlay the logo on the upper right corner output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = ( logo ) if pose_landmarks is not None: mp_drawing.draw_landmarks( image=output_frame, landmark_list=pose_landmarks, connections=mp_pose.POSE_CONNECTIONS, ) # Get landmarks frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] pose_landmarks = np.array( [ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] for lmk in pose_landmarks.landmark ], dtype=np.float32, ) assert pose_landmarks.shape == ( 33, 3, ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) # Classify the pose on the current frame pose_classification = pose_classifier(pose_landmarks) # Smooth classification using EMA pose_classification_filtered = pose_classification_filter( pose_classification ) # Count repetitions repetitions_count = repetition_counter(pose_classification_filtered) # Display repetitions count on the frame cv2.putText( output_frame, f"Push-Ups: {repetitions_count}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 0), 2, cv2.LINE_AA, ) # Display classified pose on the frame cv2.putText( output_frame, f"Pose: {pose_classification}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.2, # Smaller font size (0, 0, 0), 1, # Thinner line cv2.LINE_AA, ) else: # If no landmarks are detected, still display the last count repetitions_count = repetition_counter.n_repeats cv2.putText( output_frame, f"Push-Ups: {repetitions_count}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA, ) cv2.imshow("Push-Up Counter", output_frame) key = cv2.waitKey(1) & 0xFF if key == ord("q"): break elif key == ord("r"): repetition_counter.reset() print("Counter reset!") frame_idx += 1 pbar.update() finally: pose_tracker.close() video_cap.release() cv2.destroyAllWindows()