Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import cv2 | |
| import mediapipe as mp | |
| import torch | |
| import tempfile | |
| # Load YOLOv5 model from torch hub | |
| yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, trust_repo=True) | |
| yolo_model.conf = 0.4 # confidence threshold | |
| yolo_model.classes = [0] # only detect persons | |
| # Initialize MediaPipe Pose | |
| mp_pose = mp.solutions.pose | |
| mp_drawing = mp.solutions.drawing_utils | |
| def detect_pose(video_file): | |
| try: | |
| temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
| temp_video.write(open(video_file, "rb").read()) | |
| temp_video.close() | |
| cap = cv2.VideoCapture(temp_video.name) | |
| if not cap.isOpened(): | |
| return None, "Error: Could not open video." | |
| fps = cap.get(cv2.CAP_PROP_FPS) or 30 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| max_frames = int(min(total_frames / fps, 15) * fps) # limit to 15s | |
| output_frames = [] | |
| with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose: | |
| for _ in range(max_frames): | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| results = yolo_model(frame) | |
| detections = results.xyxy[0].cpu().numpy() | |
| for det in detections: | |
| x1, y1, x2, y2 = map(int, det[:4]) | |
| person_crop = frame[y1:y2, x1:x2] | |
| person_rgb = cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB) | |
| pose_result = pose.process(person_rgb) | |
| if pose_result.pose_landmarks: | |
| mp_drawing.draw_landmarks( | |
| person_crop, pose_result.pose_landmarks, mp_pose.POSE_CONNECTIONS | |
| ) | |
| # Draw bounding box | |
| cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2) | |
| output_frames.append(frame) | |
| cap.release() | |
| if not output_frames: | |
| return None, "Error: No frames processed." | |
| output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name | |
| height, width, _ = output_frames[0].shape | |
| out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) | |
| for f in output_frames: | |
| out.write(f) | |
| out.release() | |
| return output_file, "Pose detection completed." | |
| except Exception as e: | |
| return None, f"Runtime Error: {str(e)}" | |
| # Gradio Interface | |
| iface = gr.Interface( | |
| fn=detect_pose, | |
| inputs=gr.Video(label="Upload a Video (max 10s)"), | |
| outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Status")], | |
| title="Multi-Person Pose Detection", | |
| description="Upload a short video (max 15s). The app detects multiple people and estimates their poses." | |
| ) | |
| iface.launch() | |