Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import mediapipe as mp | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| # Initialize MediaPipe Pose | |
| mp_pose = mp.solutions.pose | |
| # Dummy ST-GCN++ model (replace with actual model) | |
| class SimpleSTGCNPlusPlus(torch.nn.Module): | |
| def __init__(self, input_size=99, num_classes=5): # 33 keypoints x 3 coords | |
| super().__init__() | |
| self.fc = torch.nn.Sequential( | |
| torch.nn.Linear(input_size, 64), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(64, num_classes) | |
| ) | |
| def forward(self, x): | |
| return self.fc(x) | |
| # Instantiate the model | |
| model = SimpleSTGCNPlusPlus() | |
| labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"] | |
| def detect_pose_and_activity(video_file): | |
| try: | |
| # Save uploaded video to a temporary file | |
| temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| temp_video.write(open(video_file, "rb").read()) | |
| temp_video.close() | |
| cap = cv2.VideoCapture(temp_video.name) | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| duration = total_frames / fps | |
| max_frames = int(min(duration, 10) * fps) | |
| output_frames = [] | |
| keypoints_sequence = [] | |
| with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose: | |
| frame_count = 0 | |
| while frame_count < max_frames: | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = pose.process(image_rgb) | |
| if results.pose_landmarks: | |
| keypoints = [] | |
| for lm in results.pose_landmarks.landmark: | |
| keypoints.extend([lm.x, lm.y, lm.z]) | |
| keypoints_sequence.append(keypoints) | |
| mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) | |
| else: | |
| keypoints_sequence.append([0] * 99) | |
| output_frames.append(frame) | |
| frame_count += 1 | |
| cap.release() | |
| if not keypoints_sequence: | |
| return None, "No pose detected." | |
| keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True) | |
| with torch.no_grad(): | |
| preds = model(keypoints_tensor) | |
| action_idx = torch.argmax(preds, dim=1).item() | |
| action_label = labels[action_idx] | |
| output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name | |
| height, width, _ = output_frames[0].shape | |
| out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| for f in output_frames: | |
| out.write(f) | |
| out.release() | |
| return output_file, f"Predicted Action: {action_label}" | |
| except Exception as e: | |
| return None, f"Error during processing: {str(e)}" | |
| iface = gr.Interface( | |
| fn=detect_pose_and_activity, | |
| inputs=gr.Video(label="Upload a Video (max 10s)"), | |
| outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")], | |
| title="Human Pose & Activity Recognition", | |
| description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)." | |
| ) | |
| iface.launch() | |