File size: 3,433 Bytes
015f0f2
f7c47ba
 
af17b79
 
8a40a91
f7c47ba
956cfce
f7c47ba
a55051b
af17b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
015f0f2
 
 
 
 
956cfce
015f0f2
 
 
 
 
a55051b
015f0f2
af17b79
956cfce
015f0f2
 
 
 
 
 
956cfce
015f0f2
 
956cfce
015f0f2
af17b79
 
 
 
 
 
 
956cfce
015f0f2
 
956cfce
015f0f2
 
af17b79
 
 
 
 
 
 
 
 
015f0f2
 
 
 
 
 
 
 
af17b79
015f0f2
 
af17b79
015f0f2
 
af17b79
015f0f2
af17b79
 
 
015f0f2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import gradio as gr
import cv2
import mediapipe as mp
import torch
import numpy as np
import tempfile

# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose

# Dummy ST-GCN++ model (replace with actual model)
class SimpleSTGCNPlusPlus(torch.nn.Module):
    def __init__(self, input_size=99, num_classes=5):  # 33 keypoints x 3 coords
        super().__init__()
        self.fc = torch.nn.Sequential(
            torch.nn.Linear(input_size, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, num_classes)
        )

    def forward(self, x):
        return self.fc(x)

# Instantiate the model
model = SimpleSTGCNPlusPlus()
labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"]

def detect_pose_and_activity(video_file):
    try:
        # Save uploaded video to a temporary file
        temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
        temp_video.write(open(video_file, "rb").read())
        temp_video.close()

        cap = cv2.VideoCapture(temp_video.name)
        fps = cap.get(cv2.CAP_PROP_FPS)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        duration = total_frames / fps
        max_frames = int(min(duration, 10) * fps)

        output_frames = []
        keypoints_sequence = []

        with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose:
            frame_count = 0
            while frame_count < max_frames:
                ret, frame = cap.read()
                if not ret:
                    break

                image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                results = pose.process(image_rgb)

                if results.pose_landmarks:
                    keypoints = []
                    for lm in results.pose_landmarks.landmark:
                        keypoints.extend([lm.x, lm.y, lm.z])
                    keypoints_sequence.append(keypoints)
                    mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
                else:
                    keypoints_sequence.append([0] * 99)

                output_frames.append(frame)
                frame_count += 1

        cap.release()

        if not keypoints_sequence:
            return None, "No pose detected."

        keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True)

        with torch.no_grad():
            preds = model(keypoints_tensor)
            action_idx = torch.argmax(preds, dim=1).item()
            action_label = labels[action_idx]

        output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name
        height, width, _ = output_frames[0].shape
        out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height))
        for f in output_frames:
            out.write(f)
        out.release()

        return output_file, f"Predicted Action: {action_label}"

    except Exception as e:
        return None, f"Error during processing: {str(e)}"

iface = gr.Interface(
    fn=detect_pose_and_activity,
    inputs=gr.Video(label="Upload a Video (max 10s)"),
    outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")],
    title="Human Pose & Activity Recognition",
    description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)."
)

iface.launch()