import gradio as gr from PIL import Image, ImageDraw, ImageFont import cv2 import numpy as np import pandas as pd import tempfile import sys import os from huggingface_hub import hf_hub_download print("="*60) print("Setting up RF-DETR SoccerNet Model...") print("="*60) repo_id = "julianzu9612/RFDETR-Soccernet" try: # Download inference.py print("\nDownloading inference.py...") inference_path = hf_hub_download(repo_id=repo_id, filename="inference.py") # Read the file with open(inference_path, 'r') as f: inference_code = f.read() print("\nšŸ”§ Patching inference.py...") print(" Changing: RFDETRBase() → RFDETRLarge()") # THE FIX: Replace RFDETRBase with RFDETRLarge inference_code = inference_code.replace( 'from rfdetr import RFDETRBase', 'from rfdetr import RFDETRLarge' ) inference_code = inference_code.replace( 'self.model = RFDETRBase()', 'self.model = RFDETRLarge()' ) # Save the patched version with open(inference_path, 'w') as f: f.write(inference_code) print("āœ“ Patched inference.py successfully!") # Download weights print("\nDownloading model weights...") weights_path = hf_hub_download(repo_id=repo_id, filename="weights/checkpoint_best_regular.pth") print(f"āœ“ Downloaded weights") # Setup environment cache_dir = os.path.dirname(inference_path) if cache_dir not in sys.path: sys.path.insert(0, cache_dir) original_dir = os.getcwd() os.chdir(cache_dir) # Create weights directory structure weights_dir = os.path.join(cache_dir, "weights") os.makedirs(weights_dir, exist_ok=True) expected_weights = os.path.join(weights_dir, "checkpoint_best_regular.pth") if not os.path.exists(expected_weights): import shutil shutil.copy(weights_path, expected_weights) print(f"āœ“ Weights copied to: {expected_weights}") print("\n" + "="*60) print("Initializing RF-DETR SoccerNet Model...") print("="*60) # Import and initialize the patched model from inference import RFDETRSoccerNet detector = RFDETRSoccerNet() print("\nāœ… Model loaded successfully!") os.chdir(original_dir) except Exception as e: print(f"\nāŒ Error: {e}") import traceback traceback.print_exc() raise # Helper functions for Gradio def draw_detections_on_image(image, df): """Draw bounding boxes on PIL image""" draw = ImageDraw.Draw(image) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16) except: font = ImageFont.load_default() colors = { 'ball': (255, 0, 0), 'player': (0, 255, 0), 'referee': (255, 255, 0), 'goalkeeper': (0, 0, 255) } for _, row in df.iterrows(): x1, y1, x2, y2 = row['x1'], row['y1'], row['x2'], row['y2'] class_name = row['class_name'] conf = row['confidence'] color = colors.get(class_name, (255, 255, 255)) draw.rectangle([x1, y1, x2, y2], outline=color, width=3) text = f"{class_name}: {conf:.2f}" bbox = draw.textbbox((x1, y1-20), text, font=font) draw.rectangle([bbox[0]-2, bbox[1]-2, bbox[2]+2, bbox[3]+2], fill=color) draw.text((x1, y1-20), text, fill=(0, 0, 0), font=font) return image def process_image_interface(image, confidence_threshold): """Process image with the model""" if image is None: return None, pd.DataFrame() try: # Save temporary image temp_path = tempfile.mktemp(suffix='.jpg') Image.fromarray(image if isinstance(image, np.ndarray) else np.array(image)).save(temp_path) # Process with model df = detector.process_image(temp_path, confidence_threshold=confidence_threshold) # Draw detections img = Image.open(temp_path) annotated_img = draw_detections_on_image(img, df) # Cleanup os.remove(temp_path) return annotated_img, df except Exception as e: print(f"Error processing image: {e}") import traceback traceback.print_exc() return None, pd.DataFrame() def process_video_interface(video, confidence_threshold, frame_skip, max_frames): """Process video with the model""" if video is None: return None, pd.DataFrame() try: max_frames_val = int(max_frames) if max_frames > 0 else None # Process video print(f"Processing video with confidence={confidence_threshold}, frame_skip={frame_skip}, max_frames={max_frames_val}") df = detector.process_video( video, confidence_threshold=confidence_threshold, frame_skip=int(frame_skip), max_frames=max_frames_val ) # Create annotated video cap = cv2.VideoCapture(video) fps = int(cap.get(cv2.CAP_PROP_FPS)) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) output_path = tempfile.mktemp(suffix='.mp4') fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(output_path, fourcc, fps, (width, height)) frame_num = 0 while cap.isOpened(): ret, frame = cap.read() if not ret: break # Get detections for this frame frame_detections = df[df['frame'] == frame_num] if not frame_detections.empty: rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) pil_img = Image.fromarray(rgb_frame) annotated_pil = draw_detections_on_image(pil_img, frame_detections) frame = cv2.cvtColor(np.array(annotated_pil), cv2.COLOR_RGB2BGR) out.write(frame) frame_num += 1 cap.release() out.release() return output_path, df except Exception as e: print(f"Error processing video: {e}") import traceback traceback.print_exc() return None, pd.DataFrame() # Create Gradio interface with gr.Blocks(title="⚽ Soccer Object Detection", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ⚽ Soccer Object Detection with RF-DETR Professional-grade object detection for soccer videos using RF-DETR-Large model. ### Model: [julianzu9612/RFDETR-Soccernet](https://huggingface.co/julianzu9612/RFDETR-Soccernet) - **Architecture**: RF-DETR-Large (128M parameters) - **Performance**: 85.7% mAP@50, 49.8% mAP - **Dataset**: SoccerNet-Tracking 2023 (42,750 images) - **Classes**: Ball, Player, Referee, Goalkeeper """) with gr.Tab("šŸ“ø Image Detection"): gr.Markdown("### Upload a soccer image to detect objects") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Upload Soccer Image", type="numpy") image_confidence = gr.Slider( minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold", info="Lower values detect more objects but may include false positives" ) image_button = gr.Button("šŸ” Detect Objects", variant="primary", size="lg") with gr.Column(): image_output = gr.Image(label="Detected Objects") image_detections = gr.Dataframe( label="Detection Results", wrap=True, interactive=False ) image_button.click( fn=process_image_interface, inputs=[image_input, image_confidence], outputs=[image_output, image_detections] ) gr.Examples( examples=[], inputs=image_input, label="Example Images (Upload your own!)" ) with gr.Tab("šŸŽ„ Video Detection"): gr.Markdown("### Upload a soccer video to track objects frame by frame") with gr.Row(): with gr.Column(): video_input = gr.Video(label="Upload Soccer Video") video_confidence = gr.Slider( minimum=0.1, maximum=1.0, value=0.5, step=0.05, label="Confidence Threshold" ) video_frame_skip = gr.Slider( minimum=1, maximum=10, value=5, step=1, label="Frame Skip", info="Process every Nth frame (higher = faster but less detections)" ) video_max_frames = gr.Number( value=300, label="Max Frames to Process", info="Set to 0 to process entire video (300 frames ā‰ˆ 10 seconds at 30 FPS)" ) gr.Markdown(""" #### ⚔ Performance Tips: - **CPU**: 2-3 FPS (slow) - Use frame_skip=5 and limit frames - **GPU**: 12-30 FPS (fast) - Can process full videos - **Quick test**: Use 300 frames with frame_skip=5 """) video_button = gr.Button("šŸŽ¬ Process Video", variant="primary", size="lg") with gr.Column(): video_output = gr.Video(label="Annotated Video") video_detections = gr.Dataframe( label="Detection Results", wrap=True, interactive=False ) video_button.click( fn=process_video_interface, inputs=[video_input, video_confidence, video_frame_skip, video_max_frames], outputs=[video_output, video_detections] ) with gr.Tab("ā„¹ļø About"): gr.Markdown(""" ## About This Model ### šŸŽÆ Detected Classes | Class | Color | Precision | Description | |-------|-------|-----------|-------------| | šŸ”“ Ball | Red | 78.5% | Soccer ball detection | | 🟢 Player | Green | 91.3% | Field players from both teams | | 🟔 Referee | Yellow | 85.2% | Match officials | | šŸ”µ Goalkeeper | Blue | 88.9% | Specialized goalkeeper detection | ### šŸ“Š Model Performance - **mAP@50**: 85.7% - **mAP**: 49.8% - **mAP@75**: 52.0% - **Parameters**: 128M - **Training Time**: ~14 hours on NVIDIA A100 40GB ### šŸŽ“ Training Details - **Dataset**: SoccerNet-Tracking 2023 - **Images**: 42,750 annotated images - **Source**: Professional soccer broadcasts - **Input Resolution**: 1280x1280 pixels - **Optimizer**: AdamW (lr=1e-4) ### šŸ’” Best Practices 1. **Confidence Threshold**: - Use 0.5 for general detection - Use 0.7+ for high-precision applications 2. **Video Quality**: - Works best on 720p+ broadcast footage - Standard broadcast camera angles preferred 3. **Frame Processing**: - frame_skip=1: Every frame (best accuracy, slow) - frame_skip=5: Every 5th frame (good balance) - frame_skip=10: Every 10th frame (fast, lower accuracy) ### 🚨 Limitations - Optimized for professional broadcast footage - May have reduced accuracy in poor lighting - Small balls may be missed when heavily occluded - Camera angle dependency ### šŸ“š Use Cases - **Sports Analytics**: Player tracking, formation analysis - **Broadcast Enhancement**: Automatic highlighting, statistics overlay - **Research**: Tactical analysis, computer vision benchmarking - **Video Analytics**: Automated video processing pipelines ### šŸ”— Links - [Model on Hugging Face](https://huggingface.co/julianzu9612/RFDETR-Soccernet) - [SoccerNet Dataset](https://www.soccer-net.org/) - [RF-DETR Paper](https://arxiv.org/abs/2304.08069) ### šŸ“„ Citation ```bibtex @misc{rfdetr-soccernet-2025, title={RF-DETR SoccerNet: High-Performance Soccer Object Detection}, author={Computer Vision Research Team}, year={2025}, publisher={Hugging Face}, url={https://huggingface.co/julianzu9612/rf-detr-soccernet} } ``` --- **License**: Apache 2.0 """) print("\n" + "="*60) print("šŸš€ Launching Gradio Interface...") print("="*60) demo.launch()