vu0018's picture
Update app.py
6734399 verified
import gradio as gr
import cv2
import mediapipe as mp
import torch
import tempfile
# Load YOLOv5 model from torch hub
yolo_model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True, trust_repo=True)
yolo_model.conf = 0.4 # confidence threshold
yolo_model.classes = [0] # only detect persons
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
def detect_pose(video_file):
try:
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_video.write(open(video_file, "rb").read())
temp_video.close()
cap = cv2.VideoCapture(temp_video.name)
if not cap.isOpened():
return None, "Error: Could not open video."
fps = cap.get(cv2.CAP_PROP_FPS) or 30
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
max_frames = int(min(total_frames / fps, 15) * fps) # limit to 15s
output_frames = []
with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
for _ in range(max_frames):
ret, frame = cap.read()
if not ret:
break
results = yolo_model(frame)
detections = results.xyxy[0].cpu().numpy()
for det in detections:
x1, y1, x2, y2 = map(int, det[:4])
person_crop = frame[y1:y2, x1:x2]
person_rgb = cv2.cvtColor(person_crop, cv2.COLOR_BGR2RGB)
pose_result = pose.process(person_rgb)
if pose_result.pose_landmarks:
mp_drawing.draw_landmarks(
person_crop, pose_result.pose_landmarks, mp_pose.POSE_CONNECTIONS
)
# Draw bounding box
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
output_frames.append(frame)
cap.release()
if not output_frames:
return None, "Error: No frames processed."
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
height, width, _ = output_frames[0].shape
out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
for f in output_frames:
out.write(f)
out.release()
return output_file, "Pose detection completed."
except Exception as e:
return None, f"Runtime Error: {str(e)}"
# Gradio Interface
iface = gr.Interface(
fn=detect_pose,
inputs=gr.Video(label="Upload a Video (max 10s)"),
outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Status")],
title="Multi-Person Pose Detection",
description="Upload a short video (max 15s). The app detects multiple people and estimates their poses."
)
iface.launch()