import gradio as gr import cv2 import mediapipe as mp import torch import numpy as np import tempfile # Initialize MediaPipe Pose mp_pose = mp.solutions.pose # Dummy ST-GCN++ model (replace with actual model) class SimpleSTGCNPlusPlus(torch.nn.Module): def __init__(self, input_size=99, num_classes=5): # 33 keypoints x 3 coords super().__init__() self.fc = torch.nn.Sequential( torch.nn.Linear(input_size, 64), torch.nn.ReLU(), torch.nn.Linear(64, num_classes) ) def forward(self, x): return self.fc(x) # Instantiate the model model = SimpleSTGCNPlusPlus() labels = ["Ballet Dancing", "Cycling", "Running", "Jumping", "Walking"] def detect_pose_and_activity(video_file): try: # Save uploaded video to a temporary file temp_video = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") temp_video.write(open(video_file, "rb").read()) temp_video.close() cap = cv2.VideoCapture(temp_video.name) fps = cap.get(cv2.CAP_PROP_FPS) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) duration = total_frames / fps max_frames = int(min(duration, 10) * fps) output_frames = [] keypoints_sequence = [] with mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5) as pose: frame_count = 0 while frame_count < max_frames: ret, frame = cap.read() if not ret: break image_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = pose.process(image_rgb) if results.pose_landmarks: keypoints = [] for lm in results.pose_landmarks.landmark: keypoints.extend([lm.x, lm.y, lm.z]) keypoints_sequence.append(keypoints) mp.solutions.drawing_utils.draw_landmarks(frame, results.pose_landmarks, mp_pose.POSE_CONNECTIONS) else: keypoints_sequence.append([0] * 99) output_frames.append(frame) frame_count += 1 cap.release() if not keypoints_sequence: return None, "No pose detected." keypoints_tensor = torch.tensor(keypoints_sequence, dtype=torch.float32).mean(dim=0, keepdim=True) with torch.no_grad(): preds = model(keypoints_tensor) action_idx = torch.argmax(preds, dim=1).item() action_label = labels[action_idx] output_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4").name height, width, _ = output_frames[0].shape out = cv2.VideoWriter(output_file, cv2.VideoWriter_fourcc(*'mp4v'), fps, (width, height)) for f in output_frames: out.write(f) out.release() return output_file, f"Predicted Action: {action_label}" except Exception as e: return None, f"Error during processing: {str(e)}" iface = gr.Interface( fn=detect_pose_and_activity, inputs=gr.Video(label="Upload a Video (max 10s)"), outputs=[gr.Video(label="Pose Detection Output"), gr.Textbox(label="Detected Action")], title="Human Pose & Activity Recognition", description="Upload a short video, and this app will detect human poses and predict the activity (e.g., ballet, cycling)." ) iface.launch()