| import gradio as gr |
| import cv2 |
| import numpy as np |
| from ultralytics import YOLO |
| from huggingface_hub import hf_hub_download |
| import tempfile |
| import os |
|
|
| |
| |
| try: |
| court_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_court_keypoints.pt") |
| ball_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_volleyball_ball.pt") |
| |
| court_model = YOLO(court_model_path) |
| ball_model = YOLO(ball_model_path) |
| pose_model = YOLO("yolo11n-pose.pt") |
| except Exception as e: |
| print(f"Error loading models: {e}") |
|
|
| def process_volleyball_video(video_path): |
| if not video_path: |
| return None |
| |
| cap = cv2.VideoCapture(video_path) |
| width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
| height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
| fps = int(cap.get(cv2.CAP_PROP_FPS)) |
| |
| |
| temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') |
| fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height)) |
|
|
| while cap.isOpened(): |
| ret, frame = cap.read() |
| if not ret: |
| break |
|
|
| |
| court_res = court_model(frame, verbose=False)[0] |
| pose_res = pose_model(frame, verbose=False)[0] |
| ball_res = ball_model(frame, verbose=False)[0] |
|
|
| annotated_frame = frame.copy() |
|
|
| |
| |
| net_y = height // 2 |
| if court_res.keypoints is not None and len(court_res.keypoints.xy[0]) > 7: |
| net_y = int(court_res.keypoints.xy[0][6][1]) |
|
|
| |
| if pose_res.keypoints is not None: |
| for i, person in enumerate(pose_res.keypoints.xy): |
| if len(person) < 11: continue |
| |
| |
| l_shoulder, r_shoulder = person[5], person[6] |
| l_wrist, r_wrist = person[9], person[10] |
|
|
| |
| if (l_wrist[1] < l_shoulder[1] or r_wrist[1] < r_shoulder[1]) and l_wrist[1] > 0: |
| cv2.putText(annotated_frame, "SPIKE ATTACK", (int(l_shoulder[0]), int(l_shoulder[1]-20)), |
| cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) |
|
|
| |
| |
| if abs(l_wrist[1] - net_y) < 10 or abs(r_wrist[1] - net_y) < 10: |
| cv2.putText(annotated_frame, "WARNING: NET TOUCH", (50, 50 + (i*30)), |
| cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3) |
|
|
| |
| annotated_frame = pose_res.plot(img=annotated_frame) |
| annotated_frame = court_res.plot(img=annotated_frame) |
| |
| out.write(annotated_frame) |
|
|
| cap.release() |
| out.release() |
| return temp_output.name |
|
|
| |
| interface = gr.Interface( |
| fn=process_volleyball_video, |
| inputs=gr.Video(label="Upload Volleyball Match"), |
| outputs=gr.Video(label="AI Analysis (Detections & Mistakes)"), |
| title="๐ AI Volleyball Performance Lab", |
| description="This app uses YOLOv11 and specialized Volleyball-Ref-AI models to detect court lines, ball movement, and player form to identify mistakes.", |
| theme="soft" |
| ) |
|
|
| if __name__ == "__main__": |
| interface.launch() |