Spaces:
Sleeping
Sleeping
File size: 3,433 Bytes
015f0f2 f7c47ba af17b79 8a40a91 f7c47ba 956cfce f7c47ba a55051b af17b79 015f0f2 956cfce 015f0f2 a55051b 015f0f2 af17b79 956cfce 015f0f2 956cfce 015f0f2 956cfce 015f0f2 af17b79 956cfce 015f0f2 956cfce 015f0f2 af17b79 015f0f2 af17b79 015f0f2 af17b79 015f0f2 af17b79 015f0f2 af17b79 015f0f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import gradio as gr
import cv2
import mediapipe as mp
import torch
import numpy as np
import tempfile
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
# Dummy ST-GCN++ model (replace with actual model)
class SimpleSTGCNPlusPlus(torch.nn.Module):
def __init__(self, input_size=99, num_classes=5): # 33 keypoints x 3 coords
super().__init__()
self.fc = torch.nn.Sequential(
torch.nn.Linear(input_size, 64),
torch.nn.ReLU(),
torch.nn.Linear(64, num_classes)
)
def forward(self, x):
return self.fc(x)
# Instantiate the model
model = SimpleSTGCNPlusPlus()
labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"]
def detect_pose_and_activity(video_file):
try:
# Save uploaded video to a temporary file
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_video.write(open(video_file, "rb").read())
temp_video.close()
cap = cv2.VideoCapture(temp_video.name)
fps = cap.get(cv2.CAP_PROP_FPS)
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
duration = total_frames / fps
max_frames = int(min(duration, 10) * fps)
output_frames = []
keypoints_sequence = []
with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
frame_count = 0
while frame_count < max_frames:
ret, frame = cap.read()
if not ret:
break
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
if results.pose_landmarks:
keypoints = []
for lm in results.pose_landmarks.landmark:
keypoints.extend([lm.x, lm.y, lm.z])
keypoints_sequence.append(keypoints)
mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
else:
keypoints_sequence.append([0] * 99)
output_frames.append(frame)
frame_count += 1
cap.release()
if not keypoints_sequence:
return None, "No pose detected."
keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
with torch.no_grad():
preds = model(keypoints_tensor)
action_idx = torch.argmax(preds, dim=1).item()
action_label = labels[action_idx]
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
height, width, _ = output_frames[0].shape
out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
for f in output_frames:
out.write(f)
out.release()
return output_file, f"Predicted Action: {action_label}"
except Exception as e:
return None, f"Error during processing: {str(e)}"
iface = gr.Interface(
fn=detect_pose_and_activity,
inputs=gr.Video(label="Upload a Video (max 10s)"),
outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
title="Human Pose & Activity Recognition",
description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)."
)
iface.launch()
|