Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| import tempfile | |
| from transformers import pipeline | |
| from PIL import Image | |
| import requests | |
| import mediapipe as mp | |
| # Initialize MediaPipe Pose | |
| mp_pose = mp.solutions.pose | |
| # Load Hugging Face models | |
| action_model = pipeline("image-classification", model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224") | |
| pose_model = torch.hub.load("facebookresearch/ViTPose", "vitpose", pretrained=True) | |
| # Define action labels | |
| action_labels = [ | |
| "calling", "clapping", "cycling", "dancing", "drinking", "eating", "fighting", "hugging", | |
| "laughing", "listening_to_music", "running", "sitting", "sleeping", "texting", "using_laptop" | |
| ] | |
| def detect_pose_and_activity(video_file): | |
| """ | |
| Process the uploaded video to detect human poses and classify the activity. | |
| Video is trimmed to 10 seconds if longer. | |
| Returns the annotated video and predicted activity label. | |
| """ | |
| try: | |
| # Save uploaded video to a temporary file | |
| temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| temp_video.write(open(video_file, "rb").read()) | |
| temp_video.close() | |
| cap = cv2.VideoCapture(temp_video.name) | |
| if not cap.isOpened(): | |
| return None, "Error: Could not open video file. Please upload a valid mp4 video." | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| if fps == 0: | |
| fps = 30 # fallback if fps is zero | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| max_frames = int(min(total_frames/fps, 10) * fps) # limit to 10 seconds | |
| output_frames = [] | |
| keypoints_sequence = [] | |
| with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose: | |
| for _ in range(max_frames): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| results = pose.process(image_rgb) | |
| # Extract keypoints | |
| if results.pose_landmarks: | |
| keypoints = [] | |
| for lm in results.pose_landmarks.landmark: | |
| keypoints.extend([lm.x, lm.y, lm.z]) | |
| if len(keypoints) != 99: | |
| keypoints = [0]*99 | |
| keypoints_sequence.append(keypoints) | |
| mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) | |
| else: | |
| keypoints_sequence.append([0]*99) | |
| output_frames.append(frame) | |
| cap.release() | |
| if len(keypoints_sequence) == 0 or len(output_frames) == 0: | |
| return None, "Error: No frames or poses detected." | |
| # Convert keypoints sequence to tensor | |
| keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True) | |
| # Predict activity | |
| with torch.no_grad(): | |
| preds = pose_model(keypoints_tensor) | |
| action_idx = torch.argmax(preds, dim=1).item() | |
| action_label = action_labels[action_idx] | |
| # Save output video | |
| output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name | |
| height, width, _ = output_frames[0].shape | |
| out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| for f in output_frames: | |
| out.write(f) | |
| out.release() | |
| return output_file, f"Predicted Action: {action_label}" | |
| except Exception as e: | |
| return None, f"Runtime Error: {str(e)}" | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=detect_pose_and_activity, | |
| inputs=gr.Video(label="Upload a Video (max 10s)"), | |
| outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")], | |
| title="Human Pose & Activity Recognition", | |
| description="Upload a short video (max 10s), and the app will detect human poses and predict the activity (e.g., ballet, cycling, running)." | |
| ) | |
| iface.launch() | |