Simrandhiman commited on
Commit
2c54b9f
·
verified ·
1 Parent(s): 41a4882

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +317 -0
app.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import os
3
+ import json
4
+ import math
5
+ import tempfile
6
+ from pathlib import Path
7
+ from typing import Dict, List, Tuple
8
+
9
+ import cv2
10
+ import numpy as np
11
+ import mediapipe as mp
12
+ import gradio as gr
13
+
14
+ # --- Config / reference poses (angles in degrees) ---
15
+ REFERENCE_POSES_FILE = "reference_poses.json"
16
+
17
+ # Mediapipe utils
18
+ mp_pose = mp.solutions.pose
19
+ mp_drawing = mp.solutions.drawing_utils
20
+
21
+ # Useful landmark indices from MediaPipe Pose
22
+ LANDMARK = mp_pose.PoseLandmark
23
+ # Example joints we will compute angles for (triplet: parent, joint, child)
24
+ JOINT_TRIPLETS = {
25
+ "left_elbow": (LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_ELBOW, LANDMARK.LEFT_WRIST),
26
+ "right_elbow": (LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_ELBOW, LANDMARK.RIGHT_WRIST),
27
+ "left_shoulder": (LANDMARK.LEFT_HIP, LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_ELBOW),
28
+ "right_shoulder": (LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_ELBOW),
29
+ "left_knee": (LANDMARK.LEFT_HIP, LANDMARK.LEFT_KNEE, LANDMARK.LEFT_ANKLE),
30
+ "right_knee": (LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_KNEE, LANDMARK.RIGHT_ANKLE),
31
+ "left_hip": (LANDMARK.LEFT_SHOULDER, LANDMARK.LEFT_HIP, LANDMARK.LEFT_KNEE),
32
+ "right_hip": (LANDMARK.RIGHT_SHOULDER, LANDMARK.RIGHT_HIP, LANDMARK.RIGHT_KNEE),
33
+ }
34
+
35
+ # thresholds (degrees) for "correct" per joint
36
+ DEFAULT_TOLERANCE = 15.0
37
+
38
+ # --- Helper functions ---
39
+ def load_reference_poses(path: str = REFERENCE_POSES_FILE) -> Dict:
40
+ if not os.path.exists(path):
41
+ # create a default one if missing
42
+ default = {
43
+ "Warrior II": {
44
+ "left_elbow": 170,
45
+ "right_elbow": 170,
46
+ "left_shoulder": 90,
47
+ "right_shoulder": 90,
48
+ "left_knee": 90,
49
+ "right_knee": 170,
50
+ "left_hip": 170,
51
+ "right_hip": 170
52
+ },
53
+ "Tree": {
54
+ "left_elbow": 170,
55
+ "right_elbow": 170,
56
+ "left_shoulder": 120,
57
+ "right_shoulder": 120,
58
+ "left_knee": 170,
59
+ "right_knee": 40,
60
+ "left_hip": 170,
61
+ "right_hip": 40
62
+ },
63
+ "Downward Dog": {
64
+ "left_elbow": 170,
65
+ "right_elbow": 170,
66
+ "left_shoulder": 70,
67
+ "right_shoulder": 70,
68
+ "left_knee": 170,
69
+ "right_knee": 170,
70
+ "left_hip": 160,
71
+ "right_hip": 160
72
+ }
73
+ }
74
+ with open(path, "w") as f:
75
+ json.dump(default, f, indent=2)
76
+ return default
77
+ with open(path, "r") as f:
78
+ return json.load(f)
79
+
80
+ def vector(a: Tuple[float, float], b: Tuple[float, float]) -> np.ndarray:
81
+ return np.array([b[0]-a[0], b[1]-a[1]])
82
+
83
+ def angle_between_points(a, b, c) -> float:
84
+ """
85
+ Returns the angle ABC (in degrees) formed at point b by points a-b-c.
86
+ Points are (x, y).
87
+ """
88
+ v1 = vector(b, a)
89
+ v2 = vector(b, c)
90
+ dot = v1.dot(v2)
91
+ norm = (np.linalg.norm(v1) * np.linalg.norm(v2)) + 1e-8
92
+ cosang = np.clip(dot / norm, -1.0, 1.0)
93
+ ang = math.degrees(math.acos(cosang))
94
+ return ang
95
+
96
+ def landmarks_to_xy(landmark_list, image_width, image_height):
97
+ coords = {}
98
+ for idx, lm in enumerate(landmark_list.landmark):
99
+ coords[idx] = (lm.x * image_width, lm.y * image_height, lm.visibility if hasattr(lm, "visibility") else 1.0)
100
+ return coords
101
+
102
+ def compute_joint_angles(landmarks_xy: Dict[int, Tuple[float, float, float]]) -> Dict[str, float]:
103
+ angles = {}
104
+ for name, (p_idx, j_idx, c_idx) in JOINT_TRIPLETS.items():
105
+ try:
106
+ pa = landmarks_xy[p_idx]
107
+ jb = landmarks_xy[j_idx]
108
+ ca = landmarks_xy[c_idx]
109
+ # ignore if visibility low (z could be used too)
110
+ if pa[2] < 0.3 or jb[2] < 0.3 or ca[2] < 0.3:
111
+ angles[name] = None
112
+ else:
113
+ ang = angle_between_points((pa[0], pa[1]), (jb[0], jb[1]), (ca[0], ca[1]))
114
+ angles[name] = ang
115
+ except KeyError:
116
+ angles[name] = None
117
+ return angles
118
+
119
+ def compare_angles(detected: Dict[str, float], reference: Dict[str, float], tolerance=DEFAULT_TOLERANCE):
120
+ per_joint_score = {}
121
+ per_joint_diff = {}
122
+ for joint, ref_ang in reference.items():
123
+ det_ang = detected.get(joint)
124
+ if det_ang is None:
125
+ per_joint_score[joint] = None
126
+ per_joint_diff[joint] = None
127
+ else:
128
+ diff = abs(det_ang - ref_ang)
129
+ per_joint_diff[joint] = det_ang - ref_ang
130
+ # score: linear falloff: diff 0 -> 100, diff >= 2*tolerance -> 0
131
+ score = max(0.0, 100.0 * (1 - (diff / (2 * tolerance))))
132
+ per_joint_score[joint] = float(np.clip(score, 0.0, 100.0))
133
+ # final percent: average of available joint scores
134
+ valid_scores = [v for v in per_joint_score.values() if v is not None]
135
+ final_percent = float(np.mean(valid_scores)) if valid_scores else 0.0
136
+ return final_percent, per_joint_score, per_joint_diff
137
+
138
+ def suggest_corrections(per_joint_diff: Dict[str, float], tol=DEFAULT_TOLERANCE) -> List[str]:
139
+ suggestions = []
140
+ for joint, diff in per_joint_diff.items():
141
+ if diff is None:
142
+ suggestions.append(f"{joint}: can't detect reliably.")
143
+ continue
144
+ if abs(diff) <= tol:
145
+ suggestions.append(f"{joint}: good (within ±{tol}°).")
146
+ else:
147
+ if diff > 0:
148
+ # detected angle larger than reference -> joint more open than desired
149
+ suggestions.append(f"{joint}: decrease angle by {abs(diff):.0f}° (e.g. bend more).")
150
+ else:
151
+ suggestions.append(f"{joint}: increase angle by {abs(diff):.0f}° (e.g. straighten more).")
152
+ return suggestions
153
+
154
+ # --- Video processing ---
155
+ def process_video(input_path: str, pose_name: str, tolerance: float = DEFAULT_TOLERANCE):
156
+ # load reference poses
157
+ ref_poses = load_reference_poses()
158
+ if pose_name not in ref_poses:
159
+ return None, f"Pose '{pose_name}' not found in reference poses."
160
+
161
+ reference = ref_poses[pose_name]
162
+
163
+ cap = cv2.VideoCapture(input_path)
164
+ if not cap.isOpened():
165
+ return None, "Failed to open uploaded video."
166
+
167
+ fps = cap.get(cv2.CAP_PROP_FPS) or 20.0
168
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 640)
169
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 480)
170
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
171
+ tmp_out = os.path.join(tempfile.gettempdir(), f"annotated_{Path(input_path).stem}.mp4")
172
+ out = cv2.VideoWriter(tmp_out, fourcc, fps, (width, height))
173
+
174
+ pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5)
175
+ frame_idx = 0
176
+ aggregate_scores = []
177
+ joint_scores_over_time = []
178
+
179
+ while True:
180
+ ret, frame = cap.read()
181
+ if not ret:
182
+ break
183
+ frame_idx += 1
184
+ image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
185
+ results = pose.process(image_rgb)
186
+ annotated = frame.copy()
187
+
188
+ if results.pose_landmarks:
189
+ landmark_xy = landmarks_to_xy(results.pose_landmarks, width, height)
190
+ detected_angles = compute_joint_angles(landmark_xy)
191
+ final_percent, per_joint_score, per_joint_diff = compare_angles(detected_angles, reference, tolerance)
192
+ aggregate_scores.append(final_percent)
193
+ joint_scores_over_time.append(per_joint_score)
194
+
195
+ # draw skeleton - color joints green if within tolerance else red
196
+ for joint, (p_idx, j_idx, c_idx) in JOINT_TRIPLETS.items():
197
+ # draw lines parent->joint and joint->child
198
+ if j_idx in landmark_xy and p_idx in landmark_xy:
199
+ x1, y1, v1 = landmark_xy[p_idx]
200
+ x2, y2, v2 = landmark_xy[j_idx]
201
+ score = per_joint_score.get(joint)
202
+ if score is None:
203
+ color = (0, 255, 255) # yellow for unknown
204
+ else:
205
+ color = (0, 255, 0) if score >= 66 else (0, 165, 255) if score >= 33 else (0, 0, 255)
206
+ cv2.line(annotated, (int(x1), int(y1)), (int(x2), int(y2)), color, 3)
207
+
208
+ if j_idx in landmark_xy and c_idx in landmark_xy:
209
+ x2, y2, v2 = landmark_xy[j_idx]
210
+ x3, y3, v3 = landmark_xy[c_idx]
211
+ score = per_joint_score.get(joint)
212
+ if score is None:
213
+ color = (0, 255, 255)
214
+ else:
215
+ color = (0, 255, 0) if score >= 66 else (0, 165, 255) if score >= 33 else (0, 0, 255)
216
+ cv2.line(annotated, (int(x2), int(y2)), (int(x3), int(y3)), color, 3)
217
+
218
+ # draw circles at joints with ang value and highlight bad ones
219
+ for joint, (_, j_idx, _) in JOINT_TRIPLETS.items():
220
+ if j_idx in landmark_xy:
221
+ x, y, v = landmark_xy[j_idx]
222
+ score = per_joint_score.get(joint)
223
+ if score is None:
224
+ cv2.circle(annotated, (int(x), int(y)), 6, (0, 255, 255), -1)
225
+ else:
226
+ color = (0, 255, 0) if score >= 66 else (0, 165, 255) if score >= 33 else (0, 0, 255)
227
+ cv2.circle(annotated, (int(x), int(y)), 8, color, -1)
228
+ # put text of angle difference small
229
+ diff = per_joint_diff.get(joint)
230
+ if diff is not None:
231
+ txt = f"{diff:+.0f}°"
232
+ cv2.putText(annotated, txt, (int(x)+6, int(y)-6), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
233
+
234
+ # frame-level overlay of percent and pose name
235
+ cv2.putText(annotated, f"{pose_name} - {final_percent:.0f}% correct", (10, 30),
236
+ cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
237
+ else:
238
+ # no landmarks, show message
239
+ cv2.putText(annotated, "No person detected", (10,30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0,0,255), 2)
240
+
241
+ out.write(annotated)
242
+
243
+ cap.release()
244
+ out.release()
245
+ pose.close()
246
+
247
+ # aggregate results
248
+ overall_percent = float(np.mean(aggregate_scores)) if aggregate_scores else 0.0
249
+ # use last frame joint scores to produce suggestions (or averaged)
250
+ last_joint_scores = joint_scores_over_time[-1] if joint_scores_over_time else {}
251
+ # compute last diffs using detected angles from last frame - but we saved diffs only inside loop
252
+ # For simplicity, recompute suggestions by re-reading last frame's per_joint_diff from process? We'll use the last computed per_joint_diff stored implicitly above:
253
+ # To keep consistent, re-open video and compute final detected angles on last frame:
254
+ cap2 = cv2.VideoCapture(input_path)
255
+ last_frame = None
256
+ while True:
257
+ ret, f = cap2.read()
258
+ if not ret:
259
+ break
260
+ last_frame = f
261
+ cap2.release()
262
+
263
+ suggestions = ["(no reliable pose detected)"]
264
+ if last_frame is not None:
265
+ h, w = last_frame.shape[:2]
266
+ with mp_pose.Pose(static_image_mode=True, min_detection_confidence=0.5) as pose2:
267
+ res = pose2.process(cv2.cvtColor(last_frame, cv2.COLOR_BGR2RGB))
268
+ if res.pose_landmarks:
269
+ landmark_xy = landmarks_to_xy(res.pose_landmarks, w, h)
270
+ detected_angles = compute_joint_angles(landmark_xy)
271
+ _, _, per_joint_diff = compare_angles(detected_angles, reference, tolerance)
272
+ suggestions = suggest_corrections(per_joint_diff, tol=tolerance)
273
+ else:
274
+ suggestions = ["No person detected in final frame to produce suggestions."]
275
+
276
+ # return annotated video path and a JSON-like result
277
+ result = {
278
+ "pose": pose_name,
279
+ "score_percent": overall_percent,
280
+ "suggestions": suggestions
281
+ }
282
+ return tmp_out, result
283
+
284
+ # --- Gradio UI ---
285
+ ref_poses = load_reference_poses()
286
+
287
+ pose_list = list(ref_poses.keys())
288
+
289
+ with gr.Blocks(title="Yoga Pose Correctness Checker") as demo:
290
+ gr.Markdown(
291
+ """
292
+ # Yoga Pose Correctness Checker
293
+ Upload a short video or use your webcam. The app will analyze each frame, compute joint angles via MediaPipe,
294
+ compare them to a reference pose, and return a percentage correctness plus per-joint corrections.
295
+ """
296
+ )
297
+ with gr.Row():
298
+ video_in = gr.Video(source="webcam", label="Webcam (or upload a video file)", type="filepath")
299
+ with gr.Column():
300
+ pose_dropdown = gr.Dropdown(choices=pose_list, value=pose_list[0], label="Reference Pose")
301
+ tol_slider = gr.Slider(5, 40, value=DEFAULT_TOLERANCE, step=1, label="Tolerance (degrees)")
302
+ run_btn = gr.Button("Analyze")
303
+ output_video = gr.Video(label="Annotated video (downloadable)")
304
+ output_json = gr.JSON(label="Results and suggestions")
305
+
306
+ def analyze(video_path, pose_name, tolerance):
307
+ if not video_path:
308
+ return None, {"error": "No input video provided"}
309
+ annotated_path, result = process_video(video_path, pose_name, tolerance)
310
+ if annotated_path is None:
311
+ return None, {"error": result}
312
+ return annotated_path, result
313
+
314
+ run_btn.click(analyze, inputs=[video_in, pose_dropdown, tol_slider], outputs=[output_video, output_json])
315
+
316
+ if __name__ == "__main__":
317
+ demo.launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))