Spaces:
Running
Running
| import os | |
| import json | |
| import math | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Dict, Tuple | |
| import cv2 | |
| import numpy as np | |
| import mediapipe as mp | |
| import gradio as gr | |
| # ------------------------------ | |
| # Configuration | |
| # ------------------------------ | |
| REFERENCE_POSES_FILE = "reference_poses.json" | |
| DEFAULT_TOLERANCE = 15.0 | |
| mp_pose = mp.solutions.pose | |
| LANDMARK = mp_pose.PoseLandmark | |
| JOINT_TRIPLETS = { | |
| "left_elbow": (LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_ELBOW, LANDMARK.LEFT_WRIST), | |
| "right_elbow": (LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_ELBOW, LANDMARK.RIGHT_WRIST), | |
| "left_shoulder": (LANDMARK.LEFT_HIP, LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_ELBOW), | |
| "right_shoulder": (LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_ELBOW), | |
| "left_knee": (LANDMARK.LEFT_HIP, LANDMARK.LEFT_KNEE, LANDMARK.LEFT_ANKLE), | |
| "right_knee": (LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_KNEE, LANDMARK.RIGHT_ANKLE), | |
| "left_hip": (LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_HIP, LANDMARK.LEFT_KNEE), | |
| "right_hip": (LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_KNEE), | |
| } | |
| # ------------------------------ | |
| # Utility functions | |
| # ------------------------------ | |
| def load_reference_poses(path: str = REFERENCE_POSES_FILE) -> Dict: | |
| if not os.path.exists(path): | |
| default = { | |
| "Warrior II": { | |
| "left_elbow": 170, "right_elbow": 170, | |
| "left_shoulder": 90, "right_shoulder": 90, | |
| "left_knee": 90, "right_knee": 170, | |
| "left_hip": 170, "right_hip": 170 | |
| }, | |
| "Tree": { | |
| "left_elbow": 170, "right_elbow": 170, | |
| "left_shoulder": 120, "right_shoulder": 120, | |
| "left_knee": 170, "right_knee": 40, | |
| "left_hip": 170, "right_hip": 40 | |
| }, | |
| "Downward Dog": { | |
| "left_elbow": 170, "right_elbow": 170, | |
| "left_shoulder": 70, "right_shoulder": 70, | |
| "left_knee": 170, "right_knee": 170, | |
| "left_hip": 160, "right_hip": 160 | |
| } | |
| } | |
| with open(path, "w") as f: | |
| json.dump(default, f, indent=2) | |
| return default | |
| with open(path, "r") as f: | |
| return json.load(f) | |
| def vector(a, b): | |
| return np.array([b[0] - a[0], b[1] - a[1]]) | |
| def angle_between_points(a, b, c): | |
| v1 = vector(b, a) | |
| v2 = vector(b, c) | |
| dot = v1.dot(v2) | |
| norm = (np.linalg.norm(v1) * np.linalg.norm(v2)) + 1e-8 | |
| cosang = np.clip(dot / norm, -1.0, 1.0) | |
| return math.degrees(math.acos(cosang)) | |
| def landmarks_to_xy(landmark_list, width, height): | |
| coords = {} | |
| for idx, lm in enumerate(landmark_list.landmark): | |
| coords[idx] = (lm.x * width, lm.y * height, lm.visibility) | |
| return coords | |
| def compute_joint_angles(landmarks_xy: Dict[int, Tuple[float, float, float]]) -> Dict[str, float]: | |
| angles = {} | |
| for name, (p_idx, j_idx, c_idx) in JOINT_TRIPLETS.items(): | |
| try: | |
| pa, jb, ca = landmarks_xy[p_idx], landmarks_xy[j_idx], landmarks_xy[c_idx] | |
| if pa[2] < 0.3 or jb[2] < 0.3 or ca[2] < 0.3: | |
| angles[name] = None | |
| else: | |
| angles[name] = angle_between_points((pa[0], pa[1]), (jb[0], jb[1]), (ca[0], ca[1])) | |
| except KeyError: | |
| angles[name] = None | |
| return angles | |
| def compare_angles(detected, reference, tolerance=DEFAULT_TOLERANCE): | |
| per_joint_score, per_joint_diff = {}, {} | |
| for joint, ref_ang in reference.items(): | |
| det_ang = detected.get(joint) | |
| if det_ang is None: | |
| per_joint_score[joint] = None | |
| per_joint_diff[joint] = None | |
| else: | |
| diff = det_ang - ref_ang | |
| per_joint_diff[joint] = diff | |
| score = max(0.0, 100.0 * (1 - (abs(diff) / (2 * tolerance)))) | |
| per_joint_score[joint] = float(np.clip(score, 0.0, 100.0)) | |
| valid = [v for v in per_joint_score.values() if v is not None] | |
| overall = float(np.mean(valid)) if valid else 0.0 | |
| return overall, per_joint_score, per_joint_diff | |
| # ------------------------------ | |
| # Video processing | |
| # ------------------------------ | |
| def process_video(input_path: str, pose_name: str, tolerance: float = DEFAULT_TOLERANCE): | |
| ref_poses = load_reference_poses() | |
| if pose_name not in ref_poses: | |
| return None, f"Pose '{pose_name}' not found." | |
| reference = ref_poses[pose_name] | |
| cap = cv2.VideoCapture(input_path) | |
| if not cap.isOpened(): | |
| return None, "Failed to open uploaded video." | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 20.0 | |
| width, height = int(cap.get(3)), int(cap.get(4)) | |
| fourcc = cv2.VideoWriter_fourcc(*"mp4v") | |
| out_path = os.path.join(tempfile.gettempdir(), f"annotated_{Path(input_path).stem}.mp4") | |
| out = cv2.VideoWriter(out_path, fourcc, fps, (width, height)) | |
| pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5) | |
| aggregate_scores = [] | |
| while True: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = pose.process(image_rgb) | |
| annotated = frame.copy() | |
| if results.pose_landmarks: | |
| lm_xy = landmarks_to_xy(results.pose_landmarks, width, height) | |
| detected_angles = compute_joint_angles(lm_xy) | |
| percent, per_joint_score, per_joint_diff = compare_angles(detected_angles, reference, tolerance) | |
| aggregate_scores.append(percent) | |
| for joint, (p, j, c) in JOINT_TRIPLETS.items(): | |
| if j in lm_xy and p in lm_xy: | |
| color = (0, 255, 0) | |
| if per_joint_score[joint] is not None: | |
| if per_joint_score[joint] < 33: | |
| color = (0, 0, 255) | |
| elif per_joint_score[joint] < 66: | |
| color = (0, 165, 255) | |
| cv2.line(annotated, (int(lm_xy[p][0]), int(lm_xy[p][1])), | |
| (int(lm_xy[j][0]), int(lm_xy[j][1])), color, 3) | |
| if j in lm_xy and c in lm_xy: | |
| color = (0, 255, 0) | |
| if per_joint_score[joint] is not None: | |
| if per_joint_score[joint] < 33: | |
| color = (0, 0, 255) | |
| elif per_joint_score[joint] < 66: | |
| color = (0, 165, 255) | |
| cv2.line(annotated, (int(lm_xy[j][0]), int(lm_xy[j][1])), | |
| (int(lm_xy[c][0]), int(lm_xy[c][1])), color, 3) | |
| cv2.putText(annotated, f"{pose_name}: {percent:.0f}%", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2) | |
| else: | |
| cv2.putText(annotated, "No pose detected", (10, 30), | |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2) | |
| out.write(annotated) | |
| cap.release() | |
| out.release() | |
| pose.close() | |
| avg_score = float(np.mean(aggregate_scores)) if aggregate_scores else 0.0 | |
| return out_path, { | |
| "pose": pose_name, | |
| "score_percent": avg_score, | |
| "suggestions": [ | |
| f"Try maintaining stability. Overall correctness: {avg_score:.1f}%." | |
| ] | |
| } | |
| # ------------------------------ | |
| # Gradio Interface | |
| # ------------------------------ | |
| ref_poses = load_reference_poses() | |
| pose_list = list(ref_poses.keys()) | |
| with gr.Blocks(title="Yoga Pose Correctness Checker") as demo: | |
| gr.Markdown(""" | |
| # π§ Yoga Pose Correctness Checker | |
| Upload a short video of your yoga pose. | |
| The app will analyze: | |
| - β Pose correctness percentage | |
| - π Joint-by-joint feedback | |
| - π‘ Suggestions for improvement | |
| """) | |
| video_in = gr.Video(label="Upload a video (MP4/MOV)") | |
| pose_dropdown = gr.Dropdown(choices=pose_list, value=pose_list[0], label="Select Pose") | |
| tol_slider = gr.Slider(5, 40, value=DEFAULT_TOLERANCE, step=1, label="Tolerance (degrees)") | |
| run_btn = gr.Button("Analyze Pose") | |
| output_video = gr.Video(label="Annotated Video Output") | |
| output_json = gr.JSON(label="Results and Suggestions") | |
| def analyze(video_path, pose_name, tolerance): | |
| if not video_path: | |
| return None, {"error": "Please upload a video first."} | |
| annotated_path, result = process_video(video_path, pose_name, tolerance) | |
| if annotated_path is None: | |
| return None, {"error": result} | |
| return annotated_path, result | |
| run_btn.click(analyze, inputs=[video_in, pose_dropdown, tol_slider], outputs=[output_video, output_json]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |