vu0018 commited on
Commit
af17b79
·
verified ·
1 Parent(s): e24d81c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -31
app.py CHANGED
@@ -1,39 +1,46 @@
1
  import gradio as gr
2
  import cv2
3
  import mediapipe as mp
 
 
4
  import tempfile
5
- import os
6
 
7
  # Initialize MediaPipe Pose
8
  mp_pose = mp.solutions.pose
9
 
10
- def detect_pose(video_file):
11
- """
12
- This function takes an uploaded video file, limits it to 10 seconds,
13
- applies human pose estimation using MediaPipe, and returns a new video
14
- with the detected poses drawn on the frames.
15
- """
 
 
 
 
 
 
 
 
 
 
 
 
16
  try:
17
  # Save uploaded video to a temporary file
18
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
19
  temp_video.write(open(video_file, "rb").read())
20
  temp_video.close()
21
 
22
- # Open video using OpenCV
23
  cap = cv2.VideoCapture(temp_video.name)
24
- if not cap.isOpened():
25
- return "Error: Could not open video file."
26
-
27
  fps = cap.get(cv2.CAP_PROP_FPS)
28
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
29
  duration = total_frames / fps
30
-
31
- # Limit processing to max 10 seconds
32
  max_frames = int(min(duration, 10) * fps)
33
 
34
  output_frames = []
 
35
 
36
- # Initialize MediaPipe Pose for pose detection
37
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
38
  frame_count = 0
39
  while frame_count < max_frames:
@@ -41,26 +48,33 @@ def detect_pose(video_file):
41
  if not ret:
42
  break
43
 
44
- # Convert frame to RGB for MediaPipe
45
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
  results = pose.process(image_rgb)
47
 
48
- # Draw pose landmarks if detected
49
  if results.pose_landmarks:
50
- mp.solutions.drawing_utils.draw_landmarks(
51
- frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS
52
- )
 
 
 
 
53
 
54
  output_frames.append(frame)
55
  frame_count += 1
56
 
57
  cap.release()
58
 
59
- # Check if any frames were processed
60
- if len(output_frames) == 0:
61
- return "Error: No frames to process."
 
 
 
 
 
 
62
 
63
- # Save output video
64
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
65
  height, width, _ = output_frames[0].shape
66
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
@@ -68,19 +82,17 @@ def detect_pose(video_file):
68
  out.write(f)
69
  out.release()
70
 
71
- return output_file
72
 
73
  except Exception as e:
74
- # Catch any exceptions and return error message
75
- return f"Error during processing: {str(e)}"
76
 
77
- # Gradio interface
78
  iface = gr.Interface(
79
- fn=detect_pose,
80
  inputs=gr.Video(label="Upload a Video (max 10s)"),
81
- outputs=gr.Video(label="Pose Detection Output"),
82
- title="Human Pose Estimation",
83
- description="Upload a short video, and this app will detect human poses (max 10 seconds)."
84
  )
85
 
86
  iface.launch()
 
1
  import gradio as gr
2
  import cv2
3
  import mediapipe as mp
4
+ import torch
5
+ import numpy as np
6
  import tempfile
 
7
 
8
  # Initialize MediaPipe Pose
9
  mp_pose = mp.solutions.pose
10
 
11
+ # Dummy ST-GCN++ model (replace with actual model)
12
+ class SimpleSTGCNPlusPlus(torch.nn.Module):
13
+ def __init__(self, input_size=99, num_classes=5): # 33 keypoints x 3 coords
14
+ super().__init__()
15
+ self.fc = torch.nn.Sequential(
16
+ torch.nn.Linear(input_size, 64),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Linear(64, num_classes)
19
+ )
20
+
21
+ def forward(self, x):
22
+ return self.fc(x)
23
+
24
+ # Instantiate the model
25
+ model = SimpleSTGCNPlusPlus()
26
+ labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"]
27
+
28
+ def detect_pose_and_activity(video_file):
29
  try:
30
  # Save uploaded video to a temporary file
31
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
32
  temp_video.write(open(video_file, "rb").read())
33
  temp_video.close()
34
 
 
35
  cap = cv2.VideoCapture(temp_video.name)
 
 
 
36
  fps = cap.get(cv2.CAP_PROP_FPS)
37
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
38
  duration = total_frames / fps
 
 
39
  max_frames = int(min(duration, 10) * fps)
40
 
41
  output_frames = []
42
+ keypoints_sequence = []
43
 
 
44
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
45
  frame_count = 0
46
  while frame_count < max_frames:
 
48
  if not ret:
49
  break
50
 
 
51
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52
  results = pose.process(image_rgb)
53
 
 
54
  if results.pose_landmarks:
55
+ keypoints = []
56
+ for lm in results.pose_landmarks.landmark:
57
+ keypoints.extend([lm.x, lm.y, lm.z])
58
+ keypoints_sequence.append(keypoints)
59
+ mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
60
+ else:
61
+ keypoints_sequence.append([0] * 99)
62
 
63
  output_frames.append(frame)
64
  frame_count += 1
65
 
66
  cap.release()
67
 
68
+ if not keypoints_sequence:
69
+ return None, "No pose detected."
70
+
71
+ keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
72
+
73
+ with torch.no_grad():
74
+ preds = model(keypoints_tensor)
75
+ action_idx = torch.argmax(preds, dim=1).item()
76
+ action_label = labels[action_idx]
77
 
 
78
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
79
  height, width, _ = output_frames[0].shape
80
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
82
  out.write(f)
83
  out.release()
84
 
85
+ return output_file, f"Predicted Action: {action_label}"
86
 
87
  except Exception as e:
88
+ return None, f"Error during processing: {str(e)}"
 
89
 
 
90
  iface = gr.Interface(
91
+ fn=detect_pose_and_activity,
92
  inputs=gr.Video(label="Upload a Video (max 10s)"),
93
+ outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
94
+ title="Human Pose & Activity Recognition",
95
+ description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)."
96
  )
97
 
98
  iface.launch()