vu0018 commited on
Commit
8cac824
·
verified ·
1 Parent(s): 21fd24d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -43
app.py CHANGED
@@ -2,28 +2,19 @@ import gradio as gr
2
  import cv2
3
  import mediapipe as mp
4
  import torch
5
- import numpy as np
6
  import tempfile
7
- from transformers import pipeline
8
- from PIL import Image
9
 
10
  # Load YOLOv5 model from torch hub
11
  yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, trust_repo=True)
12
  yolo_model.conf = 0.4 # confidence threshold
13
- yolo_model.classes = [0] # only detect persons (class 0)
14
 
15
  # Initialize MediaPipe Pose
16
  mp_pose = mp.solutions.pose
 
17
 
18
- # Hugging Face pretrained model for action recognition
19
- action_model = pipeline(
20
- "image-classification",
21
- model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224"
22
- )
23
-
24
- def detect_pose_and_activity(video_file):
25
  try:
26
- # Save uploaded video temporarily
27
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
28
  temp_video.write(open(video_file, "rb").read())
29
  temp_video.close()
@@ -32,15 +23,11 @@ def detect_pose_and_activity(video_file):
32
  if not cap.isOpened():
33
  return None, "Error: Could not open video."
34
 
35
- fps = cap.get(cv2.CAP_PROP_FPS)
36
- if fps == 0:
37
- fps = 30 # fallback
38
-
39
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
40
- max_frames = int(min(total_frames / fps, 10) * fps) # limit 10s
41
 
42
  output_frames = []
43
- action_predictions = []
44
 
45
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
46
  for _ in range(max_frames):
@@ -48,47 +35,31 @@ def detect_pose_and_activity(video_file):
48
  if not ret:
49
  break
50
 
51
- # Detect people using YOLOv5
52
  results = yolo_model(frame)
53
  detections = results.xyxy[0].cpu().numpy()
54
 
55
- frame_actions = []
56
-
57
  for det in detections:
58
- x1, y1, x2, y2, conf, cls = map(int, det[:6])
59
  person_crop = frame[y1:y2, x1:x2]
60
 
61
- # Pose estimation on cropped person
62
  person_rgb = cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB)
63
  pose_result = pose.process(person_rgb)
64
 
65
  if pose_result.pose_landmarks:
66
- mp.solutions.drawing_utils.draw_landmarks(
67
  person_crop, pose_result.pose_landmarks, mp_pose.POSE_CONNECTIONS
68
  )
69
 
70
- # Action recognition
71
- pil_image = Image.fromarray(cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB))
72
- pred = action_model(pil_image)
73
- frame_actions.append(pred[0]['label'])
74
-
75
  # Draw bounding box
76
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
77
 
78
  output_frames.append(frame)
79
 
80
- if frame_actions:
81
- action_predictions.append(max(set(frame_actions), key=frame_actions.count))
82
-
83
  cap.release()
84
 
85
- if len(output_frames) == 0:
86
- return None, "Error: No frames to process."
87
-
88
- # Take the most frequent predicted action
89
- action_label = max(set(action_predictions), key=action_predictions.count) if action_predictions else "Unknown"
90
 
91
- # Save annotated video
92
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
93
  height, width, _ = output_frames[0].shape
94
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
@@ -96,18 +67,18 @@ def detect_pose_and_activity(video_file):
96
  out.write(f)
97
  out.release()
98
 
99
- return output_file, f"Predicted Action: {action_label}"
100
 
101
  except Exception as e:
102
  return None, f"Runtime Error: {str(e)}"
103
 
104
  # Gradio Interface
105
  iface = gr.Interface(
106
- fn=detect_pose_and_activity,
107
  inputs=gr.Video(label="Upload a Video (max 10s)"),
108
- outputs=[gr.Video(label="Pose Multiple Detection Output"), gr.Textbox(label="Detected Pose")],
109
- title="Multi-Person Pose & Activity Recognition",
110
- description="Upload a short video (max 10s). The app detects multiple people, estimates their poses, and predicts their actions."
111
  )
112
 
113
  iface.launch()
 
2
  import cv2
3
  import mediapipe as mp
4
  import torch
 
5
  import tempfile
 
 
6
 
7
  # Load YOLOv5 model from torch hub
8
  yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, trust_repo=True)
9
  yolo_model.conf = 0.4 # confidence threshold
10
+ yolo_model.classes = [0] # only detect persons
11
 
12
  # Initialize MediaPipe Pose
13
  mp_pose = mp.solutions.pose
14
+ mp_drawing = mp.solutions.drawing_utils
15
 
16
+ def detect_pose(video_file):
 
 
 
 
 
 
17
  try:
 
18
  temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
19
  temp_video.write(open(video_file, "rb").read())
20
  temp_video.close()
 
23
  if not cap.isOpened():
24
  return None, "Error: Could not open video."
25
 
26
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30
 
 
 
27
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
28
+ max_frames = int(min(total_frames / fps, 10) * fps) # limit to 10s
29
 
30
  output_frames = []
 
31
 
32
  with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
33
  for _ in range(max_frames):
 
35
  if not ret:
36
  break
37
 
 
38
  results = yolo_model(frame)
39
  detections = results.xyxy[0].cpu().numpy()
40
 
 
 
41
  for det in detections:
42
+ x1, y1, x2, y2 = map(int, det[:4])
43
  person_crop = frame[y1:y2, x1:x2]
44
 
 
45
  person_rgb = cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB)
46
  pose_result = pose.process(person_rgb)
47
 
48
  if pose_result.pose_landmarks:
49
+ mp_drawing.draw_landmarks(
50
  person_crop, pose_result.pose_landmarks, mp_pose.POSE_CONNECTIONS
51
  )
52
 
 
 
 
 
 
53
  # Draw bounding box
54
  cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
55
 
56
  output_frames.append(frame)
57
 
 
 
 
58
  cap.release()
59
 
60
+ if not output_frames:
61
+ return None, "Error: No frames processed."
 
 
 
62
 
 
63
  output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
64
  height, width, _ = output_frames[0].shape
65
  out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
 
67
  out.write(f)
68
  out.release()
69
 
70
+ return output_file, "Pose detection completed."
71
 
72
  except Exception as e:
73
  return None, f"Runtime Error: {str(e)}"
74
 
75
  # Gradio Interface
76
  iface = gr.Interface(
77
+ fn=detect_pose,
78
  inputs=gr.Video(label="Upload a Video (max 10s)"),
79
+ outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Status")],
80
+ title="Multi-Person Pose Detection",
81
+ description="Upload a short video (max 10s). The app detects multiple people and estimates their poses."
82
  )
83
 
84
  iface.launch()