import gradio as gr import cv2 import numpy as np import mediapipe as mp from mediapipe.python.solutions import drawing_utils as mp_drawing from PoseClassification.pose_embedding import FullBodyPoseEmbedding from PoseClassification.pose_classifier import PoseClassifier from PoseClassification.utils import EMADictSmoothing import time # Initialize components mp_pose = mp.solutions.pose pose_tracker = mp_pose.Pose() pose_embedder = FullBodyPoseEmbedding() pose_classifier = PoseClassifier( pose_samples_folder="data/yoga_poses_csvs_out", pose_embedder=pose_embedder, top_n_by_max_distance=30, top_n_by_mean_distance=10, ) pose_classification_filter = EMADictSmoothing(window_size=10, alpha=0.2) class_names = ["chair", "cobra", "dog", "goddess", "plank", "tree", "warrior", "none"] position_threshold = 8.0 def check_major_current_position(positions_detected: dict, threshold_position) -> str: if max(positions_detected.values()) < float(threshold_position): return "none" return max(positions_detected, key=positions_detected.get) def process_frame(frame): frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) result = pose_tracker.process(image=frame_rgb) pose_landmarks = result.pose_landmarks if pose_landmarks is not None: frame_height, frame_width = frame.shape[0], frame.shape[1] pose_landmarks = np.array( [ [lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] for lmk in pose_landmarks.landmark ], dtype=np.float32, ) pose_classification = pose_classifier(pose_landmarks) pose_classification_filtered = pose_classification_filter(pose_classification) current_position = pose_classification_filtered else: current_position = {"none": 10.0} current_position_major = check_major_current_position( current_position, position_threshold ) return current_position_major, frame def yoga_position_from_stream(): current_position = "none" position_timer = 0 last_update_time = 0 recording = False recorded_frames = [] start_time = 0 frame_count = 0 def classify_pose(frame): nonlocal current_position, position_timer, last_update_time, recording, recorded_frames, start_time, frame_count if frame is None: return ( None, None, current_position, f"Duration: {int(position_timer)} seconds", ) new_position, processed_frame = process_frame(frame) if new_position != current_position: current_position = new_position position_timer = 0 last_update_time = cv2.getTickCount() / cv2.getTickFrequency() else: current_time = cv2.getTickCount() / cv2.getTickFrequency() position_timer += current_time - last_update_time last_update_time = current_time mp_drawing.draw_landmarks( image=processed_frame, landmark_list=pose_tracker.process( cv2.cvtColor(processed_frame, cv2.COLOR_BGR2RGB) ).pose_landmarks, connections=mp_pose.POSE_CONNECTIONS, ) cv2.putText( processed_frame, f"Pose: {current_position}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, ) cv2.putText( processed_frame, f"Duration: {int(position_timer)} seconds", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, ) if recording: recorded_frames.append(processed_frame) frame_count += 1 if frame_count == 1: start_time = time.time() return ( frame, processed_frame, current_position, f"Duration: {int(position_timer)} seconds", ) def toggle_debug(debug_mode): return [ gr.update(visible=debug_mode), gr.update(visible=not debug_mode), gr.update(visible=debug_mode), ] def start_recording(): nonlocal recording, recorded_frames, start_time, frame_count recording = True recorded_frames = [] start_time = 0 frame_count = 0 return "Recording started" def stop_recording(): nonlocal recording recording = False return "Recording stopped" def save_video(): nonlocal recorded_frames, start_time, frame_count if not recorded_frames: return None, "No recorded frames available" output_path = "recorded_yoga_session.mp4" height, width, _ = recorded_frames[0].shape # Calculate the actual frame rate elapsed_time = time.time() - start_time fps = frame_count / elapsed_time if elapsed_time > 0 else 30.0 fourcc = cv2.VideoWriter_fourcc(*"mp4v") out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) for frame in recorded_frames: # Convert frame to BGR color space before writing frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) out.write(frame_bgr) out.release() return output_path, f"Video saved successfully at {fps:.2f} FPS" with gr.Column() as yoga_stream: gr.Markdown("# Yoga Position Classifier", elem_classes=["custom-title"]) gr.Markdown( "Stream live yoga sessions and get real-time pose classification.", elem_classes=["custom-subtitle"], ) with gr.Row(): with gr.Column(scale=3): video_feed = gr.Webcam(streaming=True, elem_classes=["custom-webcam"]) with gr.Column(scale=2): pose_output = gr.Textbox( label="Current Pose", elem_classes=["custom-textbox"] ) timer_output = gr.Textbox( label="Pose Duration", elem_classes=["custom-textbox"] ) debug_toggle = gr.Checkbox( label="Debug Mode", value=False, elem_classes=["custom-checkbox"] ) with gr.Column(visible=False) as debug_view: classified_video = gr.Image( label="Classified Video Feed", elem_classes=["custom-image"] ) with gr.Row(): start_button = gr.Button( "Start Recording", elem_classes=["custom-button"] ) stop_button = gr.Button( "Stop Recording", elem_classes=["custom-button"] ) save_button = gr.Button("Save Recording", elem_classes=["custom-button"]) recording_status = gr.Textbox( label="Recording Status", elem_classes=["custom-textbox"] ) recorded_video = gr.Video( label="Recorded Video", elem_classes=["custom-video"] ) download_button = gr.Button( "Download Recorded Video", elem_classes=["custom-button"] ) debug_toggle.change( toggle_debug, inputs=[debug_toggle], outputs=[debug_view, video_feed, classified_video], ) video_feed.stream( classify_pose, inputs=[video_feed], outputs=[video_feed, classified_video, pose_output, timer_output], show_progress=False, ) start_button.click(start_recording, outputs=[recording_status]) stop_button.click(stop_recording, outputs=[recording_status]) save_button.click(save_video, outputs=[recorded_video, recording_status]) download_button.click(lambda: "recorded_yoga_session.mp4", outputs=[gr.File()]) return yoga_stream if __name__ == "__main__": with gr.Blocks( css=""" .custom-title { font-size: 36px; font-weight: bold; margin-bottom: 10px; } .custom-subtitle { font-size: 18px; margin-bottom: 20px; } .custom-webcam { height: 480px; } .custom-textbox input { font-size: 24px; } .custom-checkbox label { font-size: 18px; } .custom-button { font-size: 18px; } .custom-image img { max-height: 400px; } .custom-video video { max-height: 400px; } """ ) as demo: yoga_position_from_stream() demo.launch()