File size: 3,797 Bytes
dfccaa2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import gradio as gr
import cv2
import numpy as np
from ultralytics import YOLO
from huggingface_hub import hf_hub_download
import tempfile
import os

# --- 1. SET UP MODELS ---
# Downloading specialized volleyball models from Davidsv/volley-ref-ai
try:
    court_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_court_keypoints.pt")
    ball_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_volleyball_ball.pt")
    
    court_model = YOLO(court_model_path)
    ball_model = YOLO(ball_model_path)
    pose_model = YOLO("yolo11n-pose.pt")  # General human pose model
except Exception as e:
    print(f"Error loading models: {e}")

def process_volleyball_video(video_path):
    if not video_path:
        return None
    
    cap = cv2.VideoCapture(video_path)
    width  = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps    = int(cap.get(cv2.CAP_PROP_FPS))
    
    # Create a temporary file to save the processed video
    temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
    fourcc = cv2.VideoWriter_fourcc(*'mp4v')
    out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height))

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        # Run Detections
        court_res = court_model(frame, verbose=False)[0]
        pose_res = pose_model(frame, verbose=False)[0]
        ball_res = ball_model(frame, verbose=False)[0]

        annotated_frame = frame.copy()

        # Logic: Find the Net height (using court keypoints)
        # Usually keypoints 6 and 7 in volleyball court models represent the net top
        net_y = height // 2 # Default fallback
        if court_res.keypoints is not None and len(court_res.keypoints.xy[0]) > 7:
            net_y = int(court_res.keypoints.xy[0][6][1]) # Y-coord of net top

        # Process Players
        if pose_res.keypoints is not None:
            for i, person in enumerate(pose_res.keypoints.xy):
                if len(person) < 11: continue
                
                # Get key joints (indices: 5=L_Shoulder, 6=R_Shoulder, 9=L_Wrist, 10=R_Wrist)
                l_shoulder, r_shoulder = person[5], person[6]
                l_wrist, r_wrist = person[9], person[10]

                # ANALYSIS 1: Detection of a "Spike" (Hand above shoulder)
                if (l_wrist[1] < l_shoulder[1] or r_wrist[1] < r_shoulder[1]) and l_wrist[1] > 0:
                    cv2.putText(annotated_frame, "SPIKE ATTACK", (int(l_shoulder[0]), int(l_shoulder[1]-20)),
                                cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2)

                # ANALYSIS 2: Net Touch Mistake
                # If wrist is near the net y-coordinate and moving forward
                if abs(l_wrist[1] - net_y) < 10 or abs(r_wrist[1] - net_y) < 10:
                    cv2.putText(annotated_frame, "WARNING: NET TOUCH", (50, 50 + (i*30)),
                                cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3)

        # Draw detections
        annotated_frame = pose_res.plot(img=annotated_frame)
        annotated_frame = court_res.plot(img=annotated_frame)
        
        out.write(annotated_frame)

    cap.release()
    out.release()
    return temp_output.name

# --- 3. GRADIO INTERFACE ---
interface = gr.Interface(
    fn=process_volleyball_video,
    inputs=gr.Video(label="Upload Volleyball Match"),
    outputs=gr.Video(label="AI Analysis (Detections & Mistakes)"),
    title="🏐 AI Volleyball Performance Lab",
    description="This app uses YOLOv11 and specialized Volleyball-Ref-AI models to detect court lines, ball movement, and player form to identify mistakes.",
    theme="soft"
)

if __name__ == "__main__":
    interface.launch()