VB / app.py
Jofax's picture
Create app.py
dfccaa2 verified
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
import tempfile
import os
# --- 1. SET UP MODELS ---
# Downloading specialized volleyball models from Davidsv/volley-ref-ai
try:
court_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_court_keypoints.pt")
ball_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_volleyball_ball.pt")
court_model = YOLO(court_model_path)
ball_model = YOLO(ball_model_path)
pose_model = YOLO("yolo11n-pose.pt") # General human pose model
except Exception as e:
print(f"Error loading models: {e}")
def process_volleyball_video(video_path):
if not video_path:
return None
cap = cv2.VideoCapture(video_path)
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = int(cap.get(cv2.CAP_PROP_FPS))
# Create a temporary file to save the processed video
temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
# Run Detections
court_res = court_model(frame, verbose=False)[0]
pose_res = pose_model(frame, verbose=False)[0]
ball_res = ball_model(frame, verbose=False)[0]
annotated_frame = frame.copy()
# Logic: Find the Net height (using court keypoints)
# Usually keypoints 6 and 7 in volleyball court models represent the net top
net_y = height // 2 # Default fallback
if court_res.keypoints is not None and len(court_res.keypoints.xy[0]) > 7:
net_y = int(court_res.keypoints.xy[0][6][1]) # Y-coord of net top
# Process Players
if pose_res.keypoints is not None:
for i, person in enumerate(pose_res.keypoints.xy):
if len(person) < 11: continue
# Get key joints (indices: 5=L_Shoulder, 6=R_Shoulder, 9=L_Wrist, 10=R_Wrist)
l_shoulder, r_shoulder = person[5], person[6]
l_wrist, r_wrist = person[9], person[10]
# ANALYSIS 1: Detection of a "Spike" (Hand above shoulder)
if (l_wrist[1] < l_shoulder[1] or r_wrist[1] < r_shoulder[1]) and l_wrist[1] > 0:
cv2.putText(annotated_frame, "SPIKE ATTACK", (int(l_shoulder[0]), int(l_shoulder[1]-20)),
cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)
# ANALYSIS 2: Net Touch Mistake
# If wrist is near the net y-coordinate and moving forward
if abs(l_wrist[1] - net_y) < 10 or abs(r_wrist[1] - net_y) < 10:
cv2.putText(annotated_frame, "WARNING: NET TOUCH", (50, 50 + (i*30)),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)
# Draw detections
annotated_frame = pose_res.plot(img=annotated_frame)
annotated_frame = court_res.plot(img=annotated_frame)
out.write(annotated_frame)
cap.release()
out.release()
return temp_output.name
# --- 3. GRADIO INTERFACE ---
interface = gr.Interface(
fn=process_volleyball_video,
inputs=gr.Video(label="Upload Volleyball Match"),
outputs=gr.Video(label="AI Analysis (Detections & Mistakes)"),
title="๐Ÿ AI Volleyball Performance Lab",
description="This app uses YOLOv11 and specialized Volleyball-Ref-AI models to detect court lines, ball movement, and player form to identify mistakes.",
theme="soft"
)
if __name__ == "__main__":
interface.launch()