Spaces:
Sleeping
Sleeping
| import tqdm | |
| import cv2 | |
| import numpy as np | |
| import re | |
| import os | |
| from mediapipe.python.solutions import drawing_utils as mp_drawing | |
| import mediapipe as mp | |
| from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
| from PoseClassification.pose_classifier import PoseClassifier | |
| from PoseClassification.utils import EMADictSmoothing | |
| # from PoseClassification.utils import RepetitionCounter | |
| from PoseClassification.visualize import PoseClassificationVisualizer | |
| import argparse | |
| from PoseClassification.utils import show_image | |
| def main(): | |
| #Load arguments | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("video_path", help="string video path in") | |
| args = parser.parse_args() | |
| video_path_in = args.video_path | |
| direct_video=False | |
| if video_path_in=="live": | |
| video_path_in='data/live.mp4' | |
| direct_video=True | |
| video_path_out = re.sub(r'.mp4', r'_classified_video.mp4', video_path_in) | |
| results_classification_path_out = re.sub(r'.mp4', r'_classified_results.csv', video_path_in) | |
| # Instruction if direct flux video : not for now | |
| if direct_video : | |
| video_cap = cv2.VideoCapture(0) | |
| video_fps = 30 | |
| video_width = 1280 | |
| video_height = 720 | |
| class_name='tree' | |
| # Initialize tracker, classifier and current position. | |
| # Initialize tracker. | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| # Folder with pose class CSVs. That should be the same folder you used while | |
| # building classifier to output CSVs. | |
| pose_samples_folder = 'data/yoga_poses_csvs_out' | |
| # Initialize embedder. | |
| pose_embedder = FullBodyPoseEmbedding() | |
| # Initialize classifier. | |
| # Check that you are using the same parameters as during bootstrapping. | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder=pose_samples_folder, | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10) | |
| # Initialize list of results | |
| position_list=[] | |
| frame_list=[] | |
| # Initialize EMA smoothing. | |
| pose_classification_filter = EMADictSmoothing( | |
| window_size=10, | |
| alpha=0.2) | |
| # Initialize renderer. | |
| pose_classification_visualizer = PoseClassificationVisualizer( | |
| class_name=class_name, | |
| plot_x_max=1000, | |
| # Graphic looks nicer if it's the same as `top_n_by_mean_distance`. | |
| plot_y_max=10) | |
| # Open output video. | |
| out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
| # Initialize list of results | |
| frame_idx = 0 | |
| current_position = {"none":10.0} | |
| output_frame = None | |
| try: | |
| with tqdm.tqdm(position=0, leave=True) as pbar: | |
| while True: | |
| #on rajoute à chaque itération la valeur de current_position et de frame_idx | |
| position_list.append(current_position) | |
| frame_list.append(frame_idx) | |
| #on renvoie les deux valeurs au fur et à mesure | |
| with open(results_classification_path_out, 'a') as f: | |
| f.write(f'{frame_idx};{current_position}\n') | |
| success, input_frame = video_cap.read() | |
| if not success: | |
| print("Unable to read input video frame, breaking!") | |
| break | |
| # Run pose tracker | |
| input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=input_frame_rgb) | |
| pose_landmarks = result.pose_landmarks | |
| # Prepare the output frame | |
| output_frame = input_frame.copy() | |
| # Add a white banner on top | |
| banner_height = 180 | |
| output_frame[0:banner_height, :] = (255, 255, 255) # White color | |
| # Load the logo image | |
| logo = cv2.imread("src/logo_impredalam.jpg") | |
| logo_height, logo_width = logo.shape[:2] | |
| logo = cv2.resize( | |
| logo, (logo_width // 3, logo_height // 3) | |
| ) # Resize to 1/3 scale | |
| # Overlay the logo on the upper right corner | |
| output_frame[0 : logo.shape[0], output_frame.shape[1] - logo.shape[1] :] = ( | |
| logo | |
| ) | |
| if pose_landmarks is not None: | |
| mp_drawing.draw_landmarks( | |
| image=output_frame, | |
| landmark_list=pose_landmarks, | |
| connections=mp_pose.POSE_CONNECTIONS, | |
| ) | |
| # Get landmarks | |
| frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
| pose_landmarks = np.array( | |
| [ | |
| [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in pose_landmarks.landmark | |
| ], | |
| dtype=np.float32, | |
| ) | |
| assert pose_landmarks.shape == ( | |
| 33, | |
| 3, | |
| ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
| # Classify the pose on the current frame | |
| pose_classification = pose_classifier(pose_landmarks) | |
| # Smooth classification using EMA | |
| pose_classification_filtered = pose_classification_filter(pose_classification) | |
| current_position=pose_classification_filtered | |
| # Count repetitions | |
| # repetitions_count = repetition_counter(pose_classification_filtered) | |
| # Display repetitions count on the frame | |
| # cv2.putText( | |
| # output_frame, | |
| # f"Push-Ups: {repetitions_count}", | |
| # (10, 30), | |
| # cv2.FONT_HERSHEY_SIMPLEX, | |
| # 1, | |
| # (0, 0, 0), | |
| # 2, | |
| # cv2.LINE_AA, | |
| # ) | |
| # Display classified pose on the frame | |
| cv2.putText( | |
| output_frame, | |
| f"Pose: {current_position}", | |
| (10, 70), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1.2, # Smaller font size | |
| (0, 0, 0), | |
| 1, # Thinner line | |
| cv2.LINE_AA, | |
| ) | |
| else: | |
| # If no landmarks are detected, still display the last count | |
| # repetitions_count = repetition_counter.n_repeats | |
| # cv2.putText( | |
| # output_frame, | |
| # f"Push-Ups: {repetitions_count}", | |
| # (10, 30), | |
| # cv2.FONT_HERSHEY_SIMPLEX, | |
| # 1, | |
| # (0, 255, 0), | |
| # 2, | |
| # cv2.LINE_AA, | |
| # ) | |
| current_position={'None':10.0} | |
| cv2.putText( | |
| output_frame, | |
| f"Pose: {current_position}", | |
| (10, 70), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1.2, # Smaller font size | |
| (0, 0, 0), | |
| 1, # Thinner line | |
| cv2.LINE_AA, | |
| ) | |
| cv2.imshow("Yoga position classification", output_frame) | |
| key = cv2.waitKey(1) & 0xFF | |
| if key == ord("q"): | |
| break | |
| elif key == ord("r"): | |
| # repetition_counter.reset() | |
| print("Counter reset!") | |
| frame_idx += 1 | |
| pbar.update() | |
| finally: | |
| pose_tracker.close() | |
| video_cap.release() | |
| cv2.destroyAllWindows() | |
| # Instruction if recorded video with video_path_in | |
| else: | |
| assert type(video_path_in)==str, "Error in video path format, not a string. Abort." | |
| # Open video and get video parameters and check if video is OK | |
| video_cap = cv2.VideoCapture(video_path_in) | |
| video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| video_fps = video_cap.get(cv2.CAP_PROP_FPS) | |
| video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| assert type(video_n_frames)==float, 'Error in input video frames type. Abort.' | |
| assert video_n_frames>0.0, 'Error in input video frames number : no frame. Abort.' | |
| class_name='tree' | |
| # Initialize tracker, classifier and current position. | |
| # Initialize tracker. | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| # Folder with pose class CSVs. That should be the same folder you used while | |
| # building classifier to output CSVs. | |
| pose_samples_folder = 'data/yoga_poses_csvs_out' | |
| # Initialize embedder. | |
| pose_embedder = FullBodyPoseEmbedding() | |
| # Initialize classifier. | |
| # Check that you are using the same parameters as during bootstrapping. | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder=pose_samples_folder, | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10) | |
| # Initialize list of results | |
| position_list=[] | |
| frame_list=[] | |
| # Initialize EMA smoothing. | |
| pose_classification_filter = EMADictSmoothing( | |
| window_size=10, | |
| alpha=0.2) | |
| # Initialize renderer. | |
| pose_classification_visualizer = PoseClassificationVisualizer( | |
| class_name=class_name, | |
| plot_x_max=video_n_frames, | |
| # Graphic looks nicer if it's the same as `top_n_by_mean_distance`. | |
| plot_y_max=10) | |
| # Open output video. | |
| out_video = cv2.VideoWriter(video_path_out, cv2.VideoWriter_fourcc(*'mp4v'), video_fps, (video_width, video_height)) | |
| # Initialize list of results | |
| frame_idx = 0 | |
| current_position = {"none":10.0} | |
| output_frame = None | |
| with tqdm.tqdm(total=video_n_frames, position=0, leave=True) as pbar: | |
| while True: | |
| #on rajoute à chaque itération la valeur de current_position et de frame_idx | |
| position_list.append(current_position) | |
| frame_list.append(frame_idx) | |
| #on renvoie les deux valeurs au fur et à mesure | |
| with open(results_classification_path_out, 'a') as f: | |
| f.write(f'{frame_idx};{current_position}\n') | |
| # Get next frame of the video. | |
| success, input_frame = video_cap.read() | |
| if not success: | |
| print("unable to read input video frame, breaking!") | |
| break | |
| # Run pose tracker. | |
| input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=input_frame) | |
| pose_landmarks = result.pose_landmarks | |
| # Draw pose prediction. | |
| output_frame = input_frame.copy() | |
| if pose_landmarks is not None: | |
| mp_drawing.draw_landmarks( | |
| image=output_frame, | |
| landmark_list=pose_landmarks, | |
| connections=mp_pose.POSE_CONNECTIONS) | |
| if pose_landmarks is not None: | |
| # Get landmarks. | |
| frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
| pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in pose_landmarks.landmark], dtype=np.float32) | |
| assert pose_landmarks.shape == (33, 3), 'Unexpected landmarks shape: {}'.format(pose_landmarks.shape) | |
| # Classify the pose on the current frame. | |
| pose_classification = pose_classifier(pose_landmarks) | |
| # Smooth classification using EMA. | |
| pose_classification_filtered = pose_classification_filter(pose_classification) | |
| current_position=pose_classification_filtered | |
| # Count repetitions. | |
| # repetitions_count = repetition_counter(pose_classification_filtered) | |
| else: | |
| # No pose => no classification on current frame. | |
| pose_classification = None | |
| # Still add empty classification to the filter to maintaing correct | |
| # smoothing for future frames. | |
| pose_classification_filtered = pose_classification_filter(dict()) | |
| pose_classification_filtered = None | |
| current_position='None' | |
| # Don't update the counter presuming that person is 'frozen'. Just | |
| # take the latest repetitions count. | |
| # repetitions_count = repetition_counter.n_repeats | |
| # Draw classification plot and repetition counter. | |
| output_frame = pose_classification_visualizer( | |
| frame=output_frame, | |
| pose_classification=pose_classification, | |
| pose_classification_filtered=pose_classification_filtered, | |
| repetitions_count='0' | |
| ) | |
| # Save the output frame. | |
| out_video.write(cv2.cvtColor(np.array(output_frame), cv2.COLOR_RGB2BGR)) | |
| # Show intermediate frames of the video to track progress. | |
| if frame_idx % 50 == 0: | |
| show_image(output_frame) | |
| frame_idx += 1 | |
| pbar.update() | |
| # Close output video. | |
| out_video.release() | |
| # Release MediaPipe resources. | |
| pose_tracker.close() | |
| # Show the last frame of the video. | |
| if output_frame is not None: | |
| show_image(output_frame) | |
| video_cap.release() | |
| return current_position #string between ['Chair', 'Cobra', 'Dog', 'Goddess', 'Plank', 'Tree', 'Warrior', 'None' = nonfallen, 'Fall'] | |
| # mp_pose = mp.solutions.pose | |
| # pose_tracker = mp_pose.Pose() | |
| # pose_samples_folder = "data/yoga_poses_csvs_out" | |
| # class_name = "tree" | |
| # pose_embedder = FullBodyPoseEmbedding() | |
| # pose_classifier = PoseClassifier( | |
| # pose_samples_folder=pose_samples_folder, | |
| # pose_embedder=pose_embedder, | |
| # top_n_by_max_distance=30, | |
| # top_n_by_mean_distance=10, | |
| # ) | |
| # pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
| # repetition_counter = RepetitionCounter( | |
| # class_name=class_name, enter_threshold=6, exit_threshold=4 | |
| # ) | |
| # pose_classification_visualizer = PoseClassificationVisualizer( | |
| # class_name=class_name, plot_x_max=1000, plot_y_max=10 | |
| # ) | |
| # video_cap = cv2.VideoCapture(0) | |
| # video_fps = 30 | |
| # video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
| # video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
| # frame_idx = 0 | |
| # output_frame = None | |
| # try: | |
| # with tqdm.tqdm(position=0, leave=True) as pbar: | |
| # while True: | |
| # success, input_frame = video_cap.read() | |
| # if not success: | |
| # print("Unable to read input video frame, breaking!") | |
| # break | |
| # # Run pose tracker | |
| # input_frame_rgb = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) | |
| # result = pose_tracker.process(image=input_frame_rgb) | |
| # pose_landmarks = result.pose_landmarks | |
| # # Prepare the output frame | |
| # output_frame = input_frame.copy() | |
| # if pose_landmarks is not None: | |
| # mp_drawing.draw_landmarks( | |
| # image=output_frame, | |
| # landmark_list=pose_landmarks, | |
| # connections=mp_pose.POSE_CONNECTIONS, | |
| # ) | |
| # # Get landmarks | |
| # frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] | |
| # pose_landmarks = np.array( | |
| # [ | |
| # [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| # for lmk in pose_landmarks.landmark | |
| # ], | |
| # dtype=np.float32, | |
| # ) | |
| # assert pose_landmarks.shape == ( | |
| # 33, | |
| # 3, | |
| # ), "Unexpected landmarks shape: {}".format(pose_landmarks.shape) | |
| # # Classify the pose on the current frame | |
| # pose_classification = pose_classifier(pose_landmarks) | |
| # # Smooth classification using EMA | |
| # pose_classification_filtered = pose_classification_filter( | |
| # pose_classification | |
| # ) | |
| # # Count repetitions | |
| # # repetitions_count = repetition_counter(pose_classification_filtered) | |
| # # Display repetitions count on the frame | |
| # # cv2.putText( | |
| # # output_frame, | |
| # # f"Push-Ups: {repetitions_count}", | |
| # # (10, 30), | |
| # # cv2.FONT_HERSHEY_SIMPLEX, | |
| # # 1, | |
| # # (0, 255, 0), | |
| # # 2, | |
| # # cv2.LINE_AA, | |
| # # ) | |
| # # Display classified pose on the frame | |
| # cv2.putText( | |
| # output_frame, | |
| # f"Pose: {pose_classification}", | |
| # (10, 70), | |
| # cv2.FONT_HERSHEY_SIMPLEX, | |
| # 1, | |
| # (255, 0, 0), | |
| # 2, | |
| # cv2.LINE_AA, | |
| # ) | |
| # else: | |
| # # If no landmarks are detected, still display the last count | |
| # # repetitions_count = repetition_counter.n_repeats | |
| # # cv2.putText( | |
| # # output_frame, | |
| # # f"Push-Ups: {repetitions_count}", | |
| # # (10, 30), | |
| # # cv2.FONT_HERSHEY_SIMPLEX, | |
| # # 1, | |
| # # (0, 255, 0), | |
| # # 2, | |
| # # cv2.LINE_AA, | |
| # # ) | |
| # # If no landmarks are detected, still display the last classified pose | |
| # # Display classified pose on the frame | |
| # cv2.putText( | |
| # output_frame, | |
| # f"Pose: {pose_classification}", | |
| # (10, 70), | |
| # cv2.FONT_HERSHEY_SIMPLEX, | |
| # 1, | |
| # (255, 0, 0), | |
| # 2, | |
| # cv2.LINE_AA, | |
| # ) | |
| # cv2.imshow("Yoga pose classification", output_frame) | |
| # key = cv2.waitKey(1) & 0xFF | |
| # if key == ord("q"): | |
| # break | |
| # elif key == ord("r"): | |
| # # repetition_counter.reset() | |
| # print("Counter reset!") | |
| # frame_idx += 1 | |
| # pbar.update() | |
| # finally: | |
| # pose_tracker.close() | |
| # video_cap.release() | |
| # cv2.destroyAllWindows() | |
| if __name__ == "__main__": | |
| main() |