Dance-AI / app.py
Jiya-08's picture
Update app.py
6de1fe3 verified
import gradio as gr
import cv2
import mediapipe as mp
import numpy as np
import os
# Initialize MediaPipe Pose
mp_pose = mp.solutions.pose
pose = mp_pose.Pose(static_image_mode=False, min_detection_confidence=0.5, min_tracking_confidence=0.5)
# Subset of landmarks to visualize
landmark_indices = [
11, 13, 15, # Left shoulder, elbow, wrist
12, 14, 16, # Right shoulder, elbow, wrist
23, 25, 27, # Left hip, knee, ankle
24, 26, 28 # Right hip, knee, ankle
]
# Function to calculate frame similarity for the selected landmarks
def calculate_frame_similarity(ref_landmarks, inp_landmarks):
ref = np.array([(ref_landmarks[idx].x, ref_landmarks[idx].y) for idx in landmark_indices])
inp = np.array([(inp_landmarks[idx].x, inp_landmarks[idx].y) for idx in landmark_indices])
ref_torso = np.linalg.norm(ref[0][:2] - ref[6][:2])
inp_torso = np.linalg.norm(inp[0][:2] - inp[6][:2])
ref_normalized = ref / ref_torso
inp_normalized = inp / inp_torso
distance = np.linalg.norm(ref_normalized - inp_normalized)
similarity_score = max(0, 1 - distance)
return similarity_score
# Function to draw selected landmarks
def draw_selected_landmarks(frame, landmarks):
connections = [
(11, 13), (13, 15),
(12, 14), (14, 16),
(23, 25), (25, 27),
(24, 26), (26, 28)
]
for idx in landmark_indices:
landmark = landmarks[idx]
if landmark.visibility > 0.5:
x = int(landmark.x * frame.shape[1])
y = int(landmark.y * frame.shape[0])
cv2.circle(frame, (x, y), 5, (0, 255, 0), -1)
for start_idx, end_idx in connections:
start = landmarks[start_idx]
end = landmarks[end_idx]
if start.visibility > 0.5 and end.visibility > 0.5:
start_point = (int(start.x * frame.shape[1]), int(start.y * frame.shape[0]))
end_point = (int(end.x * frame.shape[1]), int(end.y * frame.shape[0]))
cv2.line(frame, start_point, end_point, (255, 0, 0), 2)
# Function to get color based on similarity score
def get_similarity_color(similarity_score):
if similarity_score < 0.25:
return (0, 0, 255)
elif similarity_score < 0.75:
return (0, 255, 255)
else:
return (0, 255, 0)
def generate_feedback(ref_landmarks, inp_landmarks, diff_threshold=0.05):
# Mapping of landmark indices to joint names
landmark_names = {
11: "left shoulder", 13: "left elbow", 15: "left wrist",
12: "right shoulder", 14: "right elbow", 16: "right wrist",
23: "left hip", 25: "left knee", 27: "left ankle",
24: "right hip", 26: "right knee", 28: "right ankle"
}
max_diff = 0
max_idx = None
# Compare each selected landmark
for idx in landmark_indices:
ref_point = np.array([ref_landmarks[idx].x, ref_landmarks[idx].y])
inp_point = np.array([inp_landmarks[idx].x, inp_landmarks[idx].y])
diff = np.linalg.norm(ref_point - inp_point)
if diff > max_diff:
max_diff = diff
max_idx = idx
advice = ""
if max_idx is not None and max_diff > diff_threshold:
joint = landmark_names.get(max_idx, f"landmark {max_idx}")
ref_point = np.array([ref_landmarks[max_idx].x, ref_landmarks[max_idx].y])
inp_point = np.array([inp_landmarks[max_idx].x, inp_landmarks[max_idx].y])
advice = f"Your {joint} seems misaligned. "
# Vertical adjustment (note: in image coordinates, a larger y means lower)
if inp_point[1] > ref_point[1] + diff_threshold:
advice += f"Try raising your {joint}."
elif inp_point[1] < ref_point[1] - diff_threshold:
advice += f"Try lowering your {joint}."
# Horizontal adjustment
if inp_point[0] < ref_point[0] - diff_threshold:
advice += f" Also, move it to the right."
elif inp_point[0] > ref_point[0] + diff_threshold:
advice += f" Also, move it to the left."
return advice
# Function to process videos and generate comparison
def process_videos(reference_video_path, input_video_path):
output_video_path = "output_comparison.mp4"
low_similarity_frames = []
cap_ref = cv2.VideoCapture(reference_video_path)
cap_inp = cv2.VideoCapture(input_video_path)
fps = int(cap_ref.get(cv2.CAP_PROP_FPS))
min_frame_gap = fps # at least one second gap; adjust this as needed
frame_index = 0
last_saved_frame_index = -min_frame_gap
ref_width = int(cap_ref.get(cv2.CAP_PROP_FRAME_WIDTH))
ref_height = int(cap_ref.get(cv2.CAP_PROP_FRAME_HEIGHT))
inp_width = int(cap_inp.get(cv2.CAP_PROP_FRAME_WIDTH))
inp_height = int(cap_inp.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_width = max(ref_width, inp_width)
output_height = max(ref_height, inp_height)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(output_video_path, fourcc, fps, (2 * output_width, output_height))
while cap_ref.isOpened() and cap_inp.isOpened():
ret_ref, frame_ref = cap_ref.read()
ret_inp, frame_inp = cap_inp.read()
if not ret_ref or not ret_inp:
break
frame_ref_resized = cv2.resize(frame_ref, (output_width, output_height))
frame_inp_resized = cv2.resize(frame_inp, (output_width, output_height))
frame_ref_rgb = cv2.cvtColor(frame_ref_resized, cv2.COLOR_BGR2RGB)
frame_inp_rgb = cv2.cvtColor(frame_inp_resized, cv2.COLOR_BGR2RGB)
results_ref = pose.process(frame_ref_rgb)
results_inp = pose.process(frame_inp_rgb)
similarity_score = 0
feedback = "";
if results_ref.pose_landmarks and results_inp.pose_landmarks:
draw_selected_landmarks(frame_ref_resized, results_ref.pose_landmarks.landmark)
draw_selected_landmarks(frame_inp_resized, results_inp.pose_landmarks.landmark)
similarity_score = calculate_frame_similarity(
results_ref.pose_world_landmarks.landmark,
results_inp.pose_world_landmarks.landmark
)
if (similarity_score<0.75) :
feedback = generate_feedback(results_ref.pose_landmarks.landmark, results_inp.pose_landmarks.landmark)
similarity_color = get_similarity_color(similarity_score)
combined_frame = cv2.hconcat([frame_ref_resized, frame_inp_resized])
cv2.putText(
combined_frame,
f"Similarity: {similarity_score*100:.2f}",
(50, 50),
cv2.FONT_HERSHEY_SIMPLEX,
2,
similarity_color,
2
)
cv2.putText(
combined_frame,
"Feedback: " + feedback,
(50,100),
cv2.FONT_HERSHEY_SIMPLEX,
1,
(255,255,255),
2
)
out.write(combined_frame)
if similarity_score < 0.5 and (frame_index - last_saved_frame_index) >= min_frame_gap:
low_similarity_frames.append(cv2.cvtColor(combined_frame.copy(), cv2.COLOR_BGR2RGB))
last_saved_frame_index = frame_index
frame_index += 1
cap_ref.release()
cap_inp.release()
out.release()
return output_video_path, low_similarity_frames
# Gradio interface
def compare_dance_videos(reference_video, input_video):
output_path, low_similarity_frames = process_videos(reference_video, input_video)
return output_path, low_similarity_frames
# Gradio app setup
gr.Interface(
fn=compare_dance_videos,
inputs=[
gr.Video(label="Reference Video"),
gr.Video(label="Input Video")
],
outputs=[
gr.Video(label="Output Comparison Video"),
gr.Gallery(label="Low Similarity Frames")
],
title="Dance Comparison Tool",
description="Upload two videos to compare the dance sequences and generate a similarity score visualization."
).launch()