import gradio as gr import cv2 import numpy as np from ultralytics import YOLO from huggingface_hub import hf_hub_download import tempfile import os # --- 1. SET UP MODELS --- # Downloading specialized volleyball models from Davidsv/volley-ref-ai try: court_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_court_keypoints.pt") ball_model_path = hf_hub_download(repo_id="Davidsv/volley-ref-ai", filename="yolo_volleyball_ball.pt") court_model = YOLO(court_model_path) ball_model = YOLO(ball_model_path) pose_model = YOLO("yolo11n-pose.pt") # General human pose model except Exception as e: print(f"Error loading models: {e}") def process_volleyball_video(video_path): if not video_path: return None cap = cv2.VideoCapture(video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) # Create a temporary file to save the processed video temp_output = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_output.name, fourcc, fps, (width, height)) while cap.isOpened(): ret, frame = cap.read() if not ret: break # Run Detections court_res = court_model(frame, verbose=False)[0] pose_res = pose_model(frame, verbose=False)[0] ball_res = ball_model(frame, verbose=False)[0] annotated_frame = frame.copy() # Logic: Find the Net height (using court keypoints) # Usually keypoints 6 and 7 in volleyball court models represent the net top net_y = height // 2 # Default fallback if court_res.keypoints is not None and len(court_res.keypoints.xy[0]) > 7: net_y = int(court_res.keypoints.xy[0][6][1]) # Y-coord of net top # Process Players if pose_res.keypoints is not None: for i, person in enumerate(pose_res.keypoints.xy): if len(person) < 11: continue # Get key joints (indices: 5=L_Shoulder, 6=R_Shoulder, 9=L_Wrist, 10=R_Wrist) l_shoulder, r_shoulder = person[5], person[6] l_wrist, r_wrist = person[9], person[10] # ANALYSIS 1: Detection of a "Spike" (Hand above shoulder) if (l_wrist[1] < l_shoulder[1] or r_wrist[1] < r_shoulder[1]) and l_wrist[1] > 0: cv2.putText(annotated_frame, "SPIKE ATTACK", (int(l_shoulder[0]), int(l_shoulder[1]-20)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 255, 255), 2) # ANALYSIS 2: Net Touch Mistake # If wrist is near the net y-coordinate and moving forward if abs(l_wrist[1] - net_y) < 10 or abs(r_wrist[1] - net_y) < 10: cv2.putText(annotated_frame, "WARNING: NET TOUCH", (50, 50 + (i*30)), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 3) # Draw detections annotated_frame = pose_res.plot(img=annotated_frame) annotated_frame = court_res.plot(img=annotated_frame) out.write(annotated_frame) cap.release() out.release() return temp_output.name # --- 3. GRADIO INTERFACE --- interface = gr.Interface( fn=process_volleyball_video, inputs=gr.Video(label="Upload Volleyball Match"), outputs=gr.Video(label="AI Analysis (Detections & Mistakes)"), title="🏐 AI Volleyball Performance Lab", description="This app uses YOLOv11 and specialized Volleyball-Ref-AI models to detect court lines, ball movement, and player form to identify mistakes.", theme="soft" ) if __name__ == "__main__": interface.launch()