vu0018's picture
Update app.py
f5ffff9 verified
raw
history blame
4.06 kB
import gradio as gr
import cv2
import torch
import numpy as np
import tempfile
from transformers import pipeline
from PIL import Image
import requests
import mediapipe as mp
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
# Load Hugging Face models
action_model = pipeline("image-classification", model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224")
pose_model = torch.hub.load("facebookresearch/ViTPose", "vitpose", pretrained=True)
# Define action labels
action_labels = [
"calling", "clapping", "cycling", "dancing", "drinking", "eating", "fighting", "hugging",
"laughing", "listening_to_music", "running", "sitting", "sleeping", "texting", "using_laptop"
]
def detect_pose_and_activity(video_file):
"""
Process the uploaded video to detect human poses and classify the activity.
Video is trimmed to 10 seconds if longer.
Returns the annotated video and predicted activity label.
"""
try:
# Save uploaded video to a temporary file
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_video.write(open(video_file, "rb").read())
temp_video.close()
cap = cv2.VideoCapture(temp_video.name)
if not cap.isOpened():
return None, "Error: Could not open video file. Please upload a valid mp4 video."
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0:
fps = 30 # fallback if fps is zero
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
max_frames = int(min(total_frames/fps, 10) * fps) # limit to 10 seconds
output_frames = []
keypoints_sequence = []
with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
for _ in range(max_frames):
ret, frame = cap.read()
if not ret:
break
image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
# Extract keypoints
if results.pose_landmarks:
keypoints = []
for lm in results.pose_landmarks.landmark:
keypoints.extend([lm.x, lm.y, lm.z])
if len(keypoints) != 99:
keypoints = [0]*99
keypoints_sequence.append(keypoints)
mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
else:
keypoints_sequence.append([0]*99)
output_frames.append(frame)
cap.release()
if len(keypoints_sequence) == 0 or len(output_frames) == 0:
return None, "Error: No frames or poses detected."
# Convert keypoints sequence to tensor
keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)
# Predict activity
with torch.no_grad():
preds = pose_model(keypoints_tensor)
action_idx = torch.argmax(preds, dim=1).item()
action_label = action_labels[action_idx]
# Save output video
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
height, width, _ = output_frames[0].shape
out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
for f in output_frames:
out.write(f)
out.release()
return output_file, f"Predicted Action: {action_label}"
except Exception as e:
return None, f"Runtime Error: {str(e)}"
# Gradio Interface
iface = gr.Interface(
fn=detect_pose_and_activity,
inputs=gr.Video(label="Upload a Video (max 10s)"),
outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
title="Human Pose & Activity Recognition",
description="Upload a short video (max 10s), and the app will detect human poses and predict the activity (e.g., ballet, cycling, running)."
)
iface.launch()