Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| from mediapipe.python.solutions import drawing_utils as mp_drawing | |
| from PoseClassification.pose_embedding import FullBodyPoseEmbedding | |
| from PoseClassification.pose_classifier import PoseClassifier | |
| from PoseClassification.utils import EMADictSmoothing | |
| import time | |
| # Initialize components | |
| mp_pose = mp.solutions.pose | |
| pose_tracker = mp_pose.Pose() | |
| pose_embedder = FullBodyPoseEmbedding() | |
| pose_classifier = PoseClassifier( | |
| pose_samples_folder="data/yoga_poses_csvs_out", | |
| pose_embedder=pose_embedder, | |
| top_n_by_max_distance=30, | |
| top_n_by_mean_distance=10, | |
| ) | |
| pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) | |
| class_names = ["chair", "cobra", "dog", "goddess", "plank", "tree", "warrior", "none"] | |
| position_threshold = 8.0 | |
| def check_major_current_position(positions_detected: dict, threshold_position) -> str: | |
| if max(positions_detected.values()) < float(threshold_position): | |
| return "none" | |
| return max(positions_detected, key=positions_detected.get) | |
| def process_frame(frame): | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| result = pose_tracker.process(image=frame_rgb) | |
| pose_landmarks = result.pose_landmarks | |
| if pose_landmarks is not None: | |
| frame_height, frame_width = frame.shape[0], frame.shape[1] | |
| pose_landmarks = np.array( | |
| [ | |
| [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] | |
| for lmk in pose_landmarks.landmark | |
| ], | |
| dtype=np.float32, | |
| ) | |
| pose_classification = pose_classifier(pose_landmarks) | |
| pose_classification_filtered = pose_classification_filter(pose_classification) | |
| current_position = pose_classification_filtered | |
| else: | |
| current_position = {"none": 10.0} | |
| current_position_major = check_major_current_position( | |
| current_position, position_threshold | |
| ) | |
| return current_position_major, frame | |
| def yoga_position_from_stream(): | |
| current_position = "none" | |
| position_timer = 0 | |
| last_update_time = 0 | |
| recording = False | |
| recorded_frames = [] | |
| start_time = 0 | |
| frame_count = 0 | |
| def classify_pose(frame): | |
| nonlocal current_position, position_timer, last_update_time, recording, recorded_frames, start_time, frame_count | |
| if frame is None: | |
| return ( | |
| None, | |
| None, | |
| current_position, | |
| f"Duration: {int(position_timer)} seconds", | |
| ) | |
| new_position, processed_frame = process_frame(frame) | |
| if new_position != current_position: | |
| current_position = new_position | |
| position_timer = 0 | |
| last_update_time = cv2.getTickCount() / cv2.getTickFrequency() | |
| else: | |
| current_time = cv2.getTickCount() / cv2.getTickFrequency() | |
| position_timer += current_time - last_update_time | |
| last_update_time = current_time | |
| mp_drawing.draw_landmarks( | |
| image=processed_frame, | |
| landmark_list=pose_tracker.process( | |
| cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) | |
| ).pose_landmarks, | |
| connections=mp_pose.POSE_CONNECTIONS, | |
| ) | |
| cv2.putText( | |
| processed_frame, | |
| f"Pose: {current_position}", | |
| (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| cv2.putText( | |
| processed_frame, | |
| f"Duration: {int(position_timer)} seconds", | |
| (10, 70), | |
| cv2.FONT_HERSHEY_SIMPLEX, | |
| 1, | |
| (0, 255, 0), | |
| 2, | |
| ) | |
| if recording: | |
| recorded_frames.append(processed_frame) | |
| frame_count += 1 | |
| if frame_count == 1: | |
| start_time = time.time() | |
| return ( | |
| frame, | |
| processed_frame, | |
| current_position, | |
| f"Duration: {int(position_timer)} seconds", | |
| ) | |
| def toggle_debug(debug_mode): | |
| return [ | |
| gr.update(visible=debug_mode), | |
| gr.update(visible=not debug_mode), | |
| gr.update(visible=debug_mode), | |
| ] | |
| def start_recording(): | |
| nonlocal recording, recorded_frames, start_time, frame_count | |
| recording = True | |
| recorded_frames = [] | |
| start_time = 0 | |
| frame_count = 0 | |
| return "Recording started" | |
| def stop_recording(): | |
| nonlocal recording | |
| recording = False | |
| return "Recording stopped" | |
| def save_video(): | |
| nonlocal recorded_frames, start_time, frame_count | |
| if not recorded_frames: | |
| return None, "No recorded frames available" | |
| output_path = "recorded_yoga_session.mp4" | |
| height, width, _ = recorded_frames[0].shape | |
| # Calculate the actual frame rate | |
| elapsed_time = time.time() - start_time | |
| fps = frame_count / elapsed_time if elapsed_time > 0 else 30.0 | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) | |
| for frame in recorded_frames: | |
| # Convert frame to BGR color space before writing | |
| frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) | |
| out.write(frame_bgr) | |
| out.release() | |
| return output_path, f"Video saved successfully at {fps:.2f} FPS" | |
| with gr.Column() as yoga_stream: | |
| gr.Markdown("# Yoga Position Classifier", elem_classes=["custom-title"]) | |
| gr.Markdown( | |
| "Stream live yoga sessions and get real-time pose classification.", | |
| elem_classes=["custom-subtitle"], | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| video_feed = gr.Webcam(streaming=True, elem_classes=["custom-webcam"]) | |
| with gr.Column(scale=2): | |
| pose_output = gr.Textbox( | |
| label="Current Pose", elem_classes=["custom-textbox"] | |
| ) | |
| timer_output = gr.Textbox( | |
| label="Pose Duration", elem_classes=["custom-textbox"] | |
| ) | |
| debug_toggle = gr.Checkbox( | |
| label="Debug Mode", value=False, elem_classes=["custom-checkbox"] | |
| ) | |
| with gr.Column(visible=False) as debug_view: | |
| classified_video = gr.Image( | |
| label="Classified Video Feed", elem_classes=["custom-image"] | |
| ) | |
| with gr.Row(): | |
| start_button = gr.Button( | |
| "Start Recording", elem_classes=["custom-button"] | |
| ) | |
| stop_button = gr.Button( | |
| "Stop Recording", elem_classes=["custom-button"] | |
| ) | |
| save_button = gr.Button("Save Recording", elem_classes=["custom-button"]) | |
| recording_status = gr.Textbox( | |
| label="Recording Status", elem_classes=["custom-textbox"] | |
| ) | |
| recorded_video = gr.Video( | |
| label="Recorded Video", elem_classes=["custom-video"] | |
| ) | |
| download_button = gr.Button( | |
| "Download Recorded Video", elem_classes=["custom-button"] | |
| ) | |
| debug_toggle.change( | |
| toggle_debug, | |
| inputs=[debug_toggle], | |
| outputs=[debug_view, video_feed, classified_video], | |
| ) | |
| video_feed.stream( | |
| classify_pose, | |
| inputs=[video_feed], | |
| outputs=[video_feed, classified_video, pose_output, timer_output], | |
| show_progress=False, | |
| ) | |
| start_button.click(start_recording, outputs=[recording_status]) | |
| stop_button.click(stop_recording, outputs=[recording_status]) | |
| save_button.click(save_video, outputs=[recorded_video, recording_status]) | |
| download_button.click(lambda: "recorded_yoga_session.mp4", outputs=[gr.File()]) | |
| return yoga_stream | |
| if __name__ == "__main__": | |
| with gr.Blocks( | |
| css=""" | |
| .custom-title { font-size: 36px; font-weight: bold; margin-bottom: 10px; } | |
| .custom-subtitle { font-size: 18px; margin-bottom: 20px; } | |
| .custom-webcam { height: 480px; } | |
| .custom-textbox input { font-size: 24px; } | |
| .custom-checkbox label { font-size: 18px; } | |
| .custom-button { font-size: 18px; } | |
| .custom-image img { max-height: 400px; } | |
| .custom-video video { max-height: 400px; } | |
| """ | |
| ) as demo: | |
| yoga_position_from_stream() | |
| demo.launch() |