vu0018 commited on
Commit
f5ffff9
·
verified ·
1 Parent(s): af17b79

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -28
app.py CHANGED
@@ -1,31 +1,32 @@
1
  import gradio as gr
2
  import cv2
3
- import mediapipe as mp
4
  import torch
5
  import numpy as np
6
  import tempfile
 
 
 
 
7
 
8
  # Initialize MediaPipe Pose
9
  mp_pose = mp.solutions.pose
10
 
11
- # Dummy ST-GCN++ model (replace with actual model)
12
- class SimpleSTGCNPlusPlus(torch.nn.Module):
13
- def __init__(self, input_size=99, num_classes=5): # 33 keypoints x 3 coords
14
- super().__init__()
15
- self.fc = torch.nn.Sequential(
16
- torch.nn.Linear(input_size, 64),
17
- torch.nn.ReLU(),
18
- torch.nn.Linear(64, num_classes)
19
- )
20
-
21
- def forward(self, x):
22
- return self.fc(x)
23
 
24
- # Instantiate the model
25
- model = SimpleSTGCNPlusPlus()
26
- labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"]
 
 
27
 
28
  def detect_pose_and_activity(video_file):
 
 
 
 
 
29
  try:
30
  # Save uploaded video to a temporary file
31
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
@@ -33,17 +34,21 @@ def detect_pose_and_activity(video_file):
33
  temp_video.close()
34
 
35
  cap = cv2.VideoCapture(temp_video.name)
 
 
 
36
  fps = cap.get(cv2.CAP_PROP_FPS)
 
 
 
37
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
38
- duration = total_frames / fps
39
- max_frames = int(min(duration, 10) * fps)
40
 
41
  output_frames = []
42
  keypoints_sequence = []
43
 
44
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
45
- frame_count = 0
46
- while frame_count < max_frames:
47
  ret, frame = cap.read()
48
  if not ret:
49
  break
@@ -51,30 +56,35 @@ def detect_pose_and_activity(video_file):
51
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
52
  results = pose.process(image_rgb)
53
 
 
54
  if results.pose_landmarks:
55
  keypoints = []
56
  for lm in results.pose_landmarks.landmark:
57
  keypoints.extend([lm.x, lm.y, lm.z])
 
 
58
  keypoints_sequence.append(keypoints)
59
  mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
60
  else:
61
- keypoints_sequence.append([0] * 99)
62
 
63
  output_frames.append(frame)
64
- frame_count += 1
65
 
66
  cap.release()
67
 
68
- if not keypoints_sequence:
69
- return None, "No pose detected."
70
 
 
71
  keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
72
 
 
73
  with torch.no_grad():
74
- preds = model(keypoints_tensor)
75
  action_idx = torch.argmax(preds, dim=1).item()
76
- action_label = labels[action_idx]
77
 
 
78
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
79
  height, width, _ = output_frames[0].shape
80
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
@@ -85,14 +95,15 @@ def detect_pose_and_activity(video_file):
85
  return output_file, f"Predicted Action: {action_label}"
86
 
87
  except Exception as e:
88
- return None, f"Error during processing: {str(e)}"
89
 
 
90
  iface = gr.Interface(
91
  fn=detect_pose_and_activity,
92
  inputs=gr.Video(label="Upload a Video (max 10s)"),
93
  outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
94
  title="Human Pose & Activity Recognition",
95
- description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)."
96
  )
97
 
98
  iface.launch()
 
1
  import gradio as gr
2
  import cv2
 
3
  import torch
4
  import numpy as np
5
  import tempfile
6
+ from transformers import pipeline
7
+ from PIL import Image
8
+ import requests
9
+ import mediapipe as mp
10
 
11
  # Initialize MediaPipe Pose
12
  mp_pose = mp.solutions.pose
13
 
14
+ # Load Hugging Face models
15
+ action_model = pipeline("image-classification", model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224")
16
+ pose_model = torch.hub.load("facebookresearch/ViTPose", "vitpose", pretrained=True)
 
 
 
 
 
 
 
 
 
17
 
18
+ # Define action labels
19
+ action_labels = [
20
+ "calling", "clapping", "cycling", "dancing", "drinking", "eating", "fighting", "hugging",
21
+ "laughing", "listening_to_music", "running", "sitting", "sleeping", "texting", "using_laptop"
22
+ ]
23
 
24
  def detect_pose_and_activity(video_file):
25
+ """
26
+ Process the uploaded video to detect human poses and classify the activity.
27
+ Video is trimmed to 10 seconds if longer.
28
+ Returns the annotated video and predicted activity label.
29
+ """
30
  try:
31
  # Save uploaded video to a temporary file
32
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
 
34
  temp_video.close()
35
 
36
  cap = cv2.VideoCapture(temp_video.name)
37
+ if not cap.isOpened():
38
+ return None, "Error: Could not open video file. Please upload a valid mp4 video."
39
+
40
  fps = cap.get(cv2.CAP_PROP_FPS)
41
+ if fps == 0:
42
+ fps = 30 # fallback if fps is zero
43
+
44
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
+ max_frames = int(min(total_frames/fps, 10) * fps) # limit to 10 seconds
 
46
 
47
  output_frames = []
48
  keypoints_sequence = []
49
 
50
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
51
+ for _ in range(max_frames):
 
52
  ret, frame = cap.read()
53
  if not ret:
54
  break
 
56
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
57
  results = pose.process(image_rgb)
58
 
59
+ # Extract keypoints
60
  if results.pose_landmarks:
61
  keypoints = []
62
  for lm in results.pose_landmarks.landmark:
63
  keypoints.extend([lm.x, lm.y, lm.z])
64
+ if len(keypoints) != 99:
65
+ keypoints = [0]*99
66
  keypoints_sequence.append(keypoints)
67
  mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
68
  else:
69
+ keypoints_sequence.append([0]*99)
70
 
71
  output_frames.append(frame)
 
72
 
73
  cap.release()
74
 
75
+ if len(keypoints_sequence) == 0 or len(output_frames) == 0:
76
+ return None, "Error: No frames or poses detected."
77
 
78
+ # Convert keypoints sequence to tensor
79
  keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
80
 
81
+ # Predict activity
82
  with torch.no_grad():
83
+ preds = pose_model(keypoints_tensor)
84
  action_idx = torch.argmax(preds, dim=1).item()
85
+ action_label = action_labels[action_idx]
86
 
87
+ # Save output video
88
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
89
  height, width, _ = output_frames[0].shape
90
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
95
  return output_file, f"Predicted Action: {action_label}"
96
 
97
  except Exception as e:
98
+ return None, f"Runtime Error: {str(e)}"
99
 
100
+ # Gradio Interface
101
  iface = gr.Interface(
102
  fn=detect_pose_and_activity,
103
  inputs=gr.Video(label="Upload a Video (max 10s)"),
104
  outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
105
  title="Human Pose & Activity Recognition",
106
+ description="Upload a short video (max 10s), and the app will detect human poses and predict the activity (e.g., ballet, cycling, running)."
107
  )
108
 
109
  iface.launch()