File size: 4,056 Bytes
015f0f2
f7c47ba
af17b79
 
8a40a91
f5ffff9
 
 
 
f7c47ba
956cfce
f7c47ba
a55051b
f5ffff9
 
 
af17b79
f5ffff9
 
 
 
 
af17b79
 
f5ffff9
 
 
 
 
015f0f2
 
 
 
 
956cfce
015f0f2
f5ffff9
 
 
015f0f2
f5ffff9
 
 
015f0f2
f5ffff9
a55051b
015f0f2
af17b79
956cfce
015f0f2
f5ffff9
015f0f2
 
 
956cfce
015f0f2
 
956cfce
f5ffff9
015f0f2
af17b79
 
 
f5ffff9
 
af17b79
 
 
f5ffff9
956cfce
015f0f2
956cfce
015f0f2
 
f5ffff9
 
af17b79
f5ffff9
af17b79
 
f5ffff9
af17b79
f5ffff9
af17b79
f5ffff9
015f0f2
f5ffff9
015f0f2
 
 
 
 
 
 
af17b79
015f0f2
 
f5ffff9
015f0f2
f5ffff9
015f0f2
af17b79
015f0f2
af17b79
 
f5ffff9
015f0f2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import gradio as gr
import cv2
import torch
import numpy as np
import tempfile
from transformers import pipeline
from PIL import Image
import requests
import mediapipe as mp

# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose

# Load Hugging Face models
action_model = pipeline("image-classification", model="rvv-karma/Human-Action-Recognition-VIT-Base-patch16-224")
pose_model = torch.hub.load("facebookresearch/ViTPose", "vitpose", pretrained=True)

# Define action labels
action_labels = [
    "calling", "clapping", "cycling", "dancing", "drinking", "eating", "fighting", "hugging",
    "laughing", "listening_to_music", "running", "sitting", "sleeping", "texting", "using_laptop"
]

def detect_pose_and_activity(video_file):
    """
    Process the uploaded video to detect human poses and classify the activity.
    Video is trimmed to 10 seconds if longer.
    Returns the annotated video and predicted activity label.
    """
    try:
        # Save uploaded video to a temporary file
        temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        temp_video.write(open(video_file, "rb").read())
        temp_video.close()

        cap = cv2.VideoCapture(temp_video.name)
        if not cap.isOpened():
            return None, "Error: Could not open video file. Please upload a valid mp4 video."

        fps = cap.get(cv2.CAP_PROP_FPS)
        if fps == 0:
            fps = 30  # fallback if fps is zero

        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        max_frames = int(min(total_frames/fps, 10) * fps)  # limit to 10 seconds

        output_frames = []
        keypoints_sequence = []

        with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
            for _ in range(max_frames):
                ret, frame = cap.read()
                if not ret:
                    break

                image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = pose.process(image_rgb)

                # Extract keypoints
                if results.pose_landmarks:
                    keypoints = []
                    for lm in results.pose_landmarks.landmark:
                        keypoints.extend([lm.x, lm.y, lm.z])
                    if len(keypoints) != 99:
                        keypoints = [0]*99
                    keypoints_sequence.append(keypoints)
                    mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
                else:
                    keypoints_sequence.append([0]*99)

                output_frames.append(frame)

        cap.release()

        if len(keypoints_sequence) == 0 or len(output_frames) == 0:
            return None, "Error: No frames or poses detected."

        # Convert keypoints sequence to tensor
        keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)

        # Predict activity
        with torch.no_grad():
            preds = pose_model(keypoints_tensor)
            action_idx = torch.argmax(preds, dim=1).item()
            action_label = action_labels[action_idx]

        # Save output video
        output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
        height, width, _ = output_frames[0].shape
        out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        for f in output_frames:
            out.write(f)
        out.release()

        return output_file, f"Predicted Action: {action_label}"

    except Exception as e:
        return None, f"Runtime Error: {str(e)}"

# Gradio Interface
iface = gr.Interface(
    fn=detect_pose_and_activity,
    inputs=gr.Video(label="Upload a Video (max 10s)"),
    outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
    title="Human Pose & Activity Recognition",
    description="Upload a short video (max 10s), and the app will detect human poses and predict the activity (e.g., ballet, cycling, running)."
)

iface.launch()