import cv2 import mediapipe as mp import numpy as np import gradio as gr import tempfile import os # ---------------- POSE DETECTION CLASS ----------------- class PoseDetector: def __init__(self): self.mp_pose = mp.solutions.pose self.pose = self.mp_pose.Pose( min_detection_confidence=0.5, min_tracking_confidence=0.5 ) def detect_pose(self, frame): rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) results = self.pose.process(rgb_frame) return results.pose_landmarks if results.pose_landmarks else None # ---------------- DANCE GENERATOR CLASS ----------------- class DanceGenerator: def generate_dance_sequence(self, all_poses, total_frames, frame_size): """Generate a simple mirrored AI dance partner sequence""" height, width = frame_size sequence = [] for i in range(total_frames): frame = np.zeros((height, width, 3), dtype=np.uint8) if all_poses[i] is not None: mirrored = self._mirror_movements(all_poses[i]) frame = self._create_dance_frame(mirrored, frame_size) sequence.append(frame) return sequence def _mirror_movements(self, landmarks): """Mirror the input movements""" mirrored = landmarks.copy() mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x-coordinates return mirrored def _create_dance_frame(self, pose_array, frame_size): """Create visualization frame from pose array""" height, width = frame_size frame = np.zeros((height, width, 3), dtype=np.uint8) points = (pose_array[:, :2] * [width, height]).astype(int) for point in points: cv2.circle(frame, tuple(point), 4, (0, 255, 0), -1) return frame # ---------------- AI DANCE PARTNER CLASS ----------------- class AIDancePartner: def __init__(self): self.pose_detector = PoseDetector() self.dance_generator = DanceGenerator() def process_video(self, video_path): temp_dir = tempfile.mkdtemp() output_path = os.path.join(temp_dir, 'output_dance.mp4') cap = cv2.VideoCapture(video_path) fps = int(cap.get(cv2.CAP_PROP_FPS)) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height)) all_poses = [] while cap.isOpened(): ret, frame = cap.read() if not ret: break pose_landmarks = self.pose_detector.detect_pose(frame) all_poses.append(pose_landmarks) ai_sequence = self.dance_generator.generate_dance_sequence(all_poses, total_frames, (frame_height, frame_width)) cap.set(cv2.CAP_PROP_POS_FRAMES, 0) for i in range(total_frames): ret, frame = cap.read() if not ret: break ai_frame = ai_sequence[i] out.write(ai_frame) cap.release() out.release() return output_path # ---------------- GRADIO WEB INTERFACE ----------------- def generate_dance(video_file): ai_dance = AIDancePartner() output_video = ai_dance.process_video(video_file) return output_video iface = gr.Interface( fn=generate_dance, inputs=gr.Video(label="Upload Your Dance Video"), outputs=gr.Video(label="AI Dance Partner Output"), title="AI Dance Partner", description="Upload a dance video and get an AI-generated dance partner synchronized to your moves!" ) if __name__ == "__main__": iface.launch()