myappname / app.py
kimappl's picture
Update app.py
9aa31ae verified
import cv2
import mediapipe as mp
import numpy as np
import gradio as gr
import tempfile
import os
# ---------------- POSE DETECTION CLASS -----------------
class PoseDetector:
def __init__(self):
self.mp_pose = mp.solutions.pose
self.pose = self.mp_pose.Pose(
min_detection_confidence=0.5,
min_tracking_confidence=0.5
)
def detect_pose(self, frame):
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.pose.process(rgb_frame)
return results.pose_landmarks if results.pose_landmarks else None
# ---------------- DANCE GENERATOR CLASS -----------------
class DanceGenerator:
def generate_dance_sequence(self, all_poses, total_frames, frame_size):
"""Generate a simple mirrored AI dance partner sequence"""
height, width = frame_size
sequence = []
for i in range(total_frames):
frame = np.zeros((height, width, 3), dtype=np.uint8)
if all_poses[i] is not None:
mirrored = self._mirror_movements(all_poses[i])
frame = self._create_dance_frame(mirrored, frame_size)
sequence.append(frame)
return sequence
def _mirror_movements(self, landmarks):
"""Mirror the input movements"""
mirrored = landmarks.copy()
mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x-coordinates
return mirrored
def _create_dance_frame(self, pose_array, frame_size):
"""Create visualization frame from pose array"""
height, width = frame_size
frame = np.zeros((height, width, 3), dtype=np.uint8)
points = (pose_array[:, :2] * [width, height]).astype(int)
for point in points:
cv2.circle(frame, tuple(point), 4, (0, 255, 0), -1)
return frame
# ---------------- AI DANCE PARTNER CLASS -----------------
class AIDancePartner:
def __init__(self):
self.pose_detector = PoseDetector()
self.dance_generator = DanceGenerator()
def process_video(self, video_path):
temp_dir = tempfile.mkdtemp()
output_path = os.path.join(temp_dir, 'output_dance.mp4')
cap = cv2.VideoCapture(video_path)
fps = int(cap.get(cv2.CAP_PROP_FPS))
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width, frame_height))
all_poses = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
pose_landmarks = self.pose_detector.detect_pose(frame)
all_poses.append(pose_landmarks)
ai_sequence = self.dance_generator.generate_dance_sequence(all_poses, total_frames, (frame_height, frame_width))
cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
for i in range(total_frames):
ret, frame = cap.read()
if not ret:
break
ai_frame = ai_sequence[i]
out.write(ai_frame)
cap.release()
out.release()
return output_path
# ---------------- GRADIO WEB INTERFACE -----------------
def generate_dance(video_file):
ai_dance = AIDancePartner()
output_video = ai_dance.process_video(video_file)
return output_video
iface = gr.Interface(
fn=generate_dance,
inputs=gr.Video(label="Upload Your Dance Video"),
outputs=gr.Video(label="AI Dance Partner Output"),
title="AI Dance Partner",
description="Upload a dance video and get an AI-generated dance partner synchronized to your moves!"
)
if __name__ == "__main__":
iface.launch()