vu0018 commited on
Commit
c5fba2e
·
verified ·
1 Parent(s): 8fcf823

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -41
app.py CHANGED
@@ -1,90 +1,75 @@
1
  import gradio as gr
2
  import cv2
 
3
  import torch
4
  import numpy as np
5
  import tempfile
6
  from transformers import pipeline
7
  from PIL import Image
8
- import requests
9
- import mediapipe as mp
10
 
11
  # Initialize MediaPipe Pose
12
  mp_pose = mp.solutions.pose
13
 
14
- # Load Hugging Face models
15
- action_model = pipeline("image-classification", model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224")
16
- pose_model = torch.hub.load("facebookresearch/ViTPose", "vitpose", pretrained=True)
17
-
18
- # Define action labels
19
- action_labels = [
20
- "calling", "clapping", "cycling", "dancing", "drinking", "eating", "fighting", "hugging",
21
- "laughing", "listening_to_music", "running", "sitting", "sleeping", "texting", "using_laptop"
22
- ]
23
 
24
  def detect_pose_and_activity(video_file):
25
  """
26
- Process the uploaded video to detect human poses and classify the activity.
27
- Video is trimmed to 10 seconds if longer.
28
- Returns the annotated video and predicted activity label.
29
  """
30
  try:
31
- # Save uploaded video to a temporary file
32
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
33
  temp_video.write(open(video_file, "rb").read())
34
  temp_video.close()
35
 
36
  cap = cv2.VideoCapture(temp_video.name)
37
  if not cap.isOpened():
38
- return None, "Error: Could not open video file. Please upload a valid mp4 video."
39
 
40
  fps = cap.get(cv2.CAP_PROP_FPS)
41
  if fps == 0:
42
- fps = 30 # fallback if fps is zero
43
 
44
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
45
- max_frames = int(min(total_frames/fps, 10) * fps) # limit to 10 seconds
46
 
47
  output_frames = []
48
- keypoints_sequence = []
49
 
 
50
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
51
  for _ in range(max_frames):
52
  ret, frame = cap.read()
53
  if not ret:
54
  break
55
 
 
56
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
57
  results = pose.process(image_rgb)
58
-
59
- # Extract keypoints
60
  if results.pose_landmarks:
61
- keypoints = []
62
- for lm in results.pose_landmarks.landmark:
63
- keypoints.extend([lm.x, lm.y, lm.z])
64
- if len(keypoints) != 99:
65
- keypoints = [0]*99
66
- keypoints_sequence.append(keypoints)
67
  mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
68
- else:
69
- keypoints_sequence.append([0]*99)
70
 
71
  output_frames.append(frame)
72
 
73
- cap.release()
 
 
 
74
 
75
- if len(keypoints_sequence) == 0 or len(output_frames) == 0:
76
- return None, "Error: No frames or poses detected."
77
 
78
- # Convert keypoints sequence to tensor
79
- keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
80
 
81
- # Predict activity
82
- with torch.no_grad():
83
- preds = pose_model(keypoints_tensor)
84
- action_idx = torch.argmax(preds, dim=1).item()
85
- action_label = action_labels[action_idx]
86
 
87
- # Save output video
88
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
89
  height, width, _ = output_frames[0].shape
90
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
@@ -103,7 +88,7 @@ iface = gr.Interface(
103
  inputs=gr.Video(label="Upload a Video (max 10s)"),
104
  outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
105
  title="Human Pose & Activity Recognition",
106
- description="Upload a short video (max 10s), and the app will detect human poses and predict the activity (e.g., ballet, cycling, running)."
107
  )
108
 
109
  iface.launch()
 
1
  import gradio as gr
2
  import cv2
3
+ import mediapipe as mp
4
  import torch
5
  import numpy as np
6
  import tempfile
7
  from transformers import pipeline
8
  from PIL import Image
 
 
9
 
10
  # Initialize MediaPipe Pose
11
  mp_pose = mp.solutions.pose
12
 
13
+ # Hugging Face pretrained model for action recognition
14
+ action_model = pipeline(
15
+ "image-classification",
16
+ model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224"
17
+ )
 
 
 
 
18
 
19
  def detect_pose_and_activity(video_file):
20
  """
21
+ Process the uploaded video to detect human poses and classify activity.
22
+ Video is limited to 10 seconds. Returns annotated video and predicted action.
 
23
  """
24
  try:
25
+ # Save uploaded video temporarily
26
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
27
  temp_video.write(open(video_file, "rb").read())
28
  temp_video.close()
29
 
30
  cap = cv2.VideoCapture(temp_video.name)
31
  if not cap.isOpened():
32
+ return None, "Error: Could not open video."
33
 
34
  fps = cap.get(cv2.CAP_PROP_FPS)
35
  if fps == 0:
36
+ fps = 30 # fallback
37
 
38
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
39
+ max_frames = int(min(total_frames/fps, 10) * fps) # limit 10s
40
 
41
  output_frames = []
42
+ action_predictions = []
43
 
44
+ # Process frames
45
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
46
  for _ in range(max_frames):
47
  ret, frame = cap.read()
48
  if not ret:
49
  break
50
 
51
+ # Pose detection
52
  image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
53
  results = pose.process(image_rgb)
 
 
54
  if results.pose_landmarks:
 
 
 
 
 
 
55
  mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
 
 
56
 
57
  output_frames.append(frame)
58
 
59
+ # Convert frame to PIL image for Hugging Face model
60
+ pil_image = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
61
+ pred = action_model(pil_image)
62
+ action_predictions.append(pred[0]['label'])
63
 
64
+ cap.release()
 
65
 
66
+ if len(output_frames) == 0:
67
+ return None, "Error: No frames to process."
68
 
69
+ # Take the most frequent predicted action
70
+ action_label = max(set(action_predictions), key=action_predictions.count)
 
 
 
71
 
72
+ # Save annotated video
73
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
74
  height, width, _ = output_frames[0].shape
75
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
88
  inputs=gr.Video(label="Upload a Video (max 10s)"),
89
  outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
90
  title="Human Pose & Activity Recognition",
91
+ description="Upload a short video (max 10s). The app detects human poses and predicts the activity (e.g., dancing, cycling, running)."
92
  )
93
 
94
  iface.launch()