vu0018's picture
Update app.py
9a13995 verified
raw
history blame
3.81 kB
import gradio as gr
import cv2
import mediapipe as mp
import torch
import numpy as np
import tempfile
from transformers import pipeline
from PIL import Image
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
# Hugging Face pretrained model for action recognition
action_model = pipeline(
"image-classification",
model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224"
)
def detect_pose_and_activity(video_file):
"""
Process the uploaded video to detect human poses and classify activity.
Optimizations:
- Skip frames
- Resize frames
- Batch action prediction
Returns annotated video and predicted action.
"""
try:
# Save uploaded video temporarily
temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
temp_video.write(open(video_file, "rb").read())
temp_video.close()
cap = cv2.VideoCapture(temp_video.name)
if not cap.isOpened():
return None, "Error: Could not open video."
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0:
fps = 30
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
max_frames = int(min(total_frames/fps, 10) * fps) # limit 10s
output_frames = []
pil_frames_for_model = []
frame_skip = 2 # process every 2nd frame
target_size = (224, 224) # Resize for faster inference
with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
frame_index = 0
while frame_index < max_frames:
ret, frame = cap.read()
if not ret:
break
# Resize frame for speed
frame_small = cv2.resize(frame, target_size)
image_rgb = cv2.cvtColor(frame_small, cv2.COLOR_BGR2RGB)
# Pose detection on full frame
results = pose.process(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
if results.pose_landmarks:
mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
output_frames.append(frame)
# Only process every Nth frame for action prediction
if frame_index % frame_skip == 0:
pil_image = Image.fromarray(image_rgb)
pil_frames_for_model.append(pil_image)
frame_index += 1
cap.release()
if len(output_frames) == 0:
return None, "Error: No frames to process."
# Batch prediction
preds = action_model(pil_frames_for_model)
action_labels = [pred['label'] for pred in preds]
# Take the most frequent predicted action
final_action = max(set(action_labels), key=action_labels.count)
# Save annotated video
output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
height, width, _ = output_frames[0].shape
out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
for f in output_frames:
out.write(f)
out.release()
return output_file, f"Predicted Action: {final_action}"
except Exception as e:
return None, f"Runtime Error: {str(e)}"
# Gradio Interface
iface = gr.Interface(
fn=detect_pose_and_activity,
inputs=gr.Video(label="Upload a Video (max 10s)"),
outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
title="Human Pose & Activity Recognition (Optimized)",
description="Upload a short video (max 10s). The app detects human poses and predicts the activity quickly using frame skipping, resizing, and batch predictions."
)
iface.launch()