gibil commited on
Commit
3259094
·
verified ·
1 Parent(s): aa938e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -46
app.py CHANGED
@@ -1,6 +1,9 @@
1
  import os
2
- import tempfile
3
  import shutil
 
 
 
4
  import cv2
5
  import numpy as np
6
  import mediapipe as mp
@@ -10,9 +13,13 @@ import gradio as gr
10
  # -----------------------
11
  # Core pipeline function
12
  # -----------------------
13
- def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated_out_path: str | None = None):
 
 
 
 
14
  """
15
- Returns:
16
  {
17
  "ok": bool,
18
  "error": str | None,
@@ -22,6 +29,7 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
22
  "annotated_video_path": str | None
23
  }
24
  """
 
25
  if not os.path.exists(video_path):
26
  return {
27
  "ok": False,
@@ -32,6 +40,7 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
32
  "annotated_video_path": None,
33
  }
34
 
 
35
  def clamp(x, lo=0.0, hi=1.0):
36
  return max(lo, min(hi, x))
37
 
@@ -51,12 +60,12 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
51
  return 1.0
52
  if val < good_lo:
53
  return clamp((val - ok_lo) / (good_lo - ok_lo))
54
- else:
55
- return clamp((ok_hi - val) / (ok_hi - good_hi))
56
 
57
  def ema(prev, x, a=0.25):
58
  return x if prev is None else (a * x + (1 - a) * prev)
59
 
 
60
  mp_pose = mp.solutions.pose
61
  pose = mp_pose.Pose(
62
  static_image_mode=False,
@@ -67,12 +76,13 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
67
  min_tracking_confidence=0.5,
68
  )
69
 
 
70
  cap = cv2.VideoCapture(video_path)
71
  if not cap.isOpened():
72
  pose.close()
73
  return {
74
  "ok": False,
75
- "error": "OpenCV could not open the video. Try a different MP4 encoding.",
76
  "rep_count": 0,
77
  "avg_rep_prob": None,
78
  "rep_events": [],
@@ -83,8 +93,8 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
83
  W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 0
84
  H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 0
85
 
86
- annotated_path = None
87
  writer = None
 
88
  if save_annotated:
89
  if annotated_out_path is None:
90
  annotated_out_path = os.path.join(tempfile.mkdtemp(), "annotated.mp4")
@@ -92,9 +102,10 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
92
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
93
  writer = cv2.VideoWriter(annotated_path, fourcc, fps, (W, H))
94
 
95
- state = "UNKNOWN"
96
- rep_events = []
97
- current_rep = None
 
98
  rep_count = 0
99
 
100
  ema_elbow = None
@@ -188,7 +199,7 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
188
  rep_prob = float(np.mean(probs))
189
 
190
  rep_events.append({
191
- "rep": int(rep_count),
192
  "start_t": float(current_rep["start_f"] / fps),
193
  "end_t": float(end_f / fps),
194
  "prob": float(rep_prob),
@@ -205,6 +216,7 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
205
 
206
  debug_txt = f"{'L' if left_side else 'R'} vis={ema_vis:.2f} elbow={ema_elbow:.0f} straight={ema_straight:.0f} p={frame_prob:.2f} state={state}"
207
 
 
208
  cv2.putText(frame, f"Reps: {rep_count}", (20, 40),
209
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
210
  cv2.putText(frame, debug_txt[:90], (20, 75),
@@ -247,55 +259,57 @@ def analyze_pushup_video(video_path: str, save_annotated: bool = True, annotated
247
 
248
 
249
  # -----------------------
250
- # Gradio API wrapper
251
  # -----------------------
252
- def api_run(uploaded_file):
253
  """
254
- uploaded_file can be:
255
- - a Gradio UploadedFile object with .name
256
- - a dict with 'path'
257
- - a string path
258
- We normalize it to a real filepath, copy it, run pipeline, return JSON + annotated video.
259
  """
260
- # get path
261
- path = None
262
- if uploaded_file is None:
263
- return {"ok": False, "error": "No file uploaded."}, None
264
 
265
- if isinstance(uploaded_file, str):
266
- path = uploaded_file
267
- elif isinstance(uploaded_file, dict) and "path" in uploaded_file:
268
- path = uploaded_file["path"]
269
- elif hasattr(uploaded_file, "name"):
270
- path = uploaded_file.name
 
271
 
272
- if not path or not os.path.exists(path):
273
- return {"ok": False, "error": f"Upload path missing or not found: {path}"}, None
274
 
275
  workdir = tempfile.mkdtemp()
276
  in_path = os.path.join(workdir, "input.mp4")
277
- shutil.copy(path, in_path)
278
 
279
  out_path = os.path.join(workdir, "annotated.mp4")
280
  result = analyze_pushup_video(in_path, save_annotated=True, annotated_out_path=out_path)
281
 
282
- # Return JSON + video path (Gradio serves it)
283
- return result, result["annotated_video_path"]
 
 
 
 
 
 
 
 
284
 
 
 
 
 
 
285
 
286
- with gr.Blocks(title="Pushup Prototype API") as demo:
287
- gr.Markdown("# Pushup Prototype (API)\nUpload a video -> get JSON result + annotated output.")
288
- video_file = gr.File(label="Upload MP4", file_types=["video"])
289
- run_btn = gr.Button("Analyze")
290
- out_json = gr.JSON(label="Result JSON")
291
- out_vid = gr.Video(label="Annotated Output")
292
 
293
- run_btn.click(
294
- fn=api_run,
295
- inputs=[video_file],
296
- outputs=[out_json, out_vid],
297
- api_name="analyze" # IMPORTANT for GitHub JS calls
298
- )
299
 
300
- demo.queue()
301
  demo.launch()
 
1
  import os
2
+ import math
3
  import shutil
4
+ import tempfile
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
  import cv2
8
  import numpy as np
9
  import mediapipe as mp
 
13
  # -----------------------
14
  # Core pipeline function
15
  # -----------------------
16
+ def analyze_pushup_video(
17
+ video_path: str,
18
+ save_annotated: bool = True,
19
+ annotated_out_path: Optional[str] = None,
20
+ ) -> Dict[str, Any]:
21
  """
22
+ Runs MediaPipe Pose on a video, counts pushup reps, and returns:
23
  {
24
  "ok": bool,
25
  "error": str | None,
 
29
  "annotated_video_path": str | None
30
  }
31
  """
32
+
33
  if not os.path.exists(video_path):
34
  return {
35
  "ok": False,
 
40
  "annotated_video_path": None,
41
  }
42
 
43
+ # ---------- Helpers ----------
44
  def clamp(x, lo=0.0, hi=1.0):
45
  return max(lo, min(hi, x))
46
 
 
60
  return 1.0
61
  if val < good_lo:
62
  return clamp((val - ok_lo) / (good_lo - ok_lo))
63
+ return clamp((ok_hi - val) / (ok_hi - good_hi))
 
64
 
65
  def ema(prev, x, a=0.25):
66
  return x if prev is None else (a * x + (1 - a) * prev)
67
 
68
+ # ---------- Pose ----------
69
  mp_pose = mp.solutions.pose
70
  pose = mp_pose.Pose(
71
  static_image_mode=False,
 
76
  min_tracking_confidence=0.5,
77
  )
78
 
79
+ # ---------- Video ----------
80
  cap = cv2.VideoCapture(video_path)
81
  if not cap.isOpened():
82
  pose.close()
83
  return {
84
  "ok": False,
85
+ "error": "OpenCV could not open the video. Try re-exporting to a standard H.264 mp4.",
86
  "rep_count": 0,
87
  "avg_rep_prob": None,
88
  "rep_events": [],
 
93
  W = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) or 0
94
  H = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) or 0
95
 
 
96
  writer = None
97
+ annotated_path = None
98
  if save_annotated:
99
  if annotated_out_path is None:
100
  annotated_out_path = os.path.join(tempfile.mkdtemp(), "annotated.mp4")
 
102
  fourcc = cv2.VideoWriter_fourcc(*"mp4v")
103
  writer = cv2.VideoWriter(annotated_path, fourcc, fps, (W, H))
104
 
105
+ # ---------- Detection state ----------
106
+ state = "UNKNOWN" # UP/DOWN
107
+ rep_events: List[Dict[str, Any]] = []
108
+ current_rep: Optional[Dict[str, Any]] = None
109
  rep_count = 0
110
 
111
  ema_elbow = None
 
199
  rep_prob = float(np.mean(probs))
200
 
201
  rep_events.append({
202
+ "rep": rep_count,
203
  "start_t": float(current_rep["start_f"] / fps),
204
  "end_t": float(end_f / fps),
205
  "prob": float(rep_prob),
 
216
 
217
  debug_txt = f"{'L' if left_side else 'R'} vis={ema_vis:.2f} elbow={ema_elbow:.0f} straight={ema_straight:.0f} p={frame_prob:.2f} state={state}"
218
 
219
+ # annotate frame
220
  cv2.putText(frame, f"Reps: {rep_count}", (20, 40),
221
  cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
222
  cv2.putText(frame, debug_txt[:90], (20, 75),
 
259
 
260
 
261
  # -----------------------
262
+ # Gradio API function
263
  # -----------------------
264
+ def api_analyze(file_obj: Union[str, Dict[str, Any], None]):
265
  """
266
+ Accepts Gradio File input. Depending on Gradio version, it may come in as:
267
+ - a string path
268
+ - a dict with {"name": "..."}
 
 
269
  """
270
+ if file_obj is None:
271
+ return {"ok": False, "error": "No file received."}
 
 
272
 
273
+ # Extract path robustly
274
+ if isinstance(file_obj, str):
275
+ src_path = file_obj
276
+ elif isinstance(file_obj, dict) and "name" in file_obj:
277
+ src_path = file_obj["name"]
278
+ else:
279
+ return {"ok": False, "error": f"Unsupported file object: {type(file_obj)}"}
280
 
281
+ if not os.path.exists(src_path):
282
+ return {"ok": False, "error": f"Upload path not found on server: {src_path}"}
283
 
284
  workdir = tempfile.mkdtemp()
285
  in_path = os.path.join(workdir, "input.mp4")
286
+ shutil.copy(src_path, in_path)
287
 
288
  out_path = os.path.join(workdir, "annotated.mp4")
289
  result = analyze_pushup_video(in_path, save_annotated=True, annotated_out_path=out_path)
290
 
291
+ # Return only what frontend needs + a file path for annotated video
292
+ return {
293
+ "ok": result["ok"],
294
+ "error": result["error"],
295
+ "rep_count": result["rep_count"],
296
+ "avg_rep_prob": result["avg_rep_prob"],
297
+ "rep_events": result["rep_events"],
298
+ "annotated_video_path": result["annotated_video_path"],
299
+ }
300
+
301
 
302
+ # -----------------------
303
+ # Gradio app
304
+ # -----------------------
305
+ with gr.Blocks(title="Pushup Analyzer API") as demo:
306
+ gr.Markdown("# Pushup Analyzer (Backend)\nThis Space provides an API + optional Gradio UI.")
307
 
308
+ file_in = gr.File(label="Upload video", file_types=["video"])
309
+ btn = gr.Button("Analyze")
310
+ json_out = gr.JSON(label="Result JSON")
 
 
 
311
 
312
+ # api_name is IMPORTANT: frontend will call this endpoint
313
+ btn.click(fn=api_analyze, inputs=file_in, outputs=json_out, api_name="analyze")
 
 
 
 
314
 
 
315
  demo.launch()