#!/usr/bin/env python3 """ All-in-one script for training YOLOv5 player detection model. Downloads dataset from Roboflow and trains on GPU. """ import os import sys import subprocess from pathlib import Path # Configuration VENV_YOLO = "../.venv/bin/yolo" ROBOFLOW_API_KEY = "ZzD21wz5oTPdE0fhb04C" WORKSPACE = "tomatoes-iicln" PROJECT = "nbl" DATASET_VERSION = 13 DATASET_FORMAT = "yolov5" # Training parameters MODEL = "yolov5l6u.pt" # Large model EPOCHS = 100 IMG_SIZE = 640 BATCH_SIZE = 32 # Players are large; 32 is safe for 4080 PLOTS = True def setup_venv(): """Verify venv and YOLO are available.""" if not os.path.exists(VENV_YOLO): print(f"āœ— YOLO not found at {VENV_YOLO}") sys.exit(1) print(f"āœ“ YOLO found: {VENV_YOLO}") def download_dataset(): """Download dataset from Roboflow.""" print("\n" + "="*60) print("šŸ“„ Downloading Player Detection Dataset from Roboflow") print("="*60) try: from roboflow import Roboflow print(f" Workspace: {WORKSPACE}") print(f" Project: {PROJECT}") print(f" Version: {DATASET_VERSION}") print(f" Format: {DATASET_FORMAT}") rf = Roboflow(api_key=ROBOFLOW_API_KEY) project = rf.workspace(WORKSPACE).project(PROJECT) version = project.version(DATASET_VERSION) dataset = version.download(DATASET_FORMAT) print(f"\nāœ“ Dataset downloaded to: {dataset.location}") return dataset.location except ImportError: print("āœ— Roboflow not installed. Run: pip install roboflow") sys.exit(1) except Exception as e: print(f"āœ— Failed to download dataset: {e}") sys.exit(1) def verify_dataset(dataset_dir): """Verify dataset structure.""" print("\n" + "="*60) print("āœ“ Verifying Dataset Structure") print("="*60) if not os.path.exists(dataset_dir): print(f"āœ— Dataset directory not found: {dataset_dir}") sys.exit(1) train_dir = os.path.join(dataset_dir, "train", "images") valid_dir = os.path.join(dataset_dir, "valid", "images") test_dir = os.path.join(dataset_dir, "test", "images") data_yaml = os.path.join(dataset_dir, "data.yaml") if not os.path.exists(data_yaml): print(f"āœ— data.yaml not found: {data_yaml}") sys.exit(1) train_count = len(os.listdir(train_dir)) if os.path.exists(train_dir) else 0 valid_count = len(os.listdir(valid_dir)) if os.path.exists(valid_dir) else 0 test_count = len(os.listdir(test_dir)) if os.path.exists(test_dir) else 0 print(f" Dataset: {dataset_dir}") print(f" ā”œā”€ā”€ train/images: {train_count} files") print(f" ā”œā”€ā”€ valid/images: {valid_count} files") print(f" └── test/images: {test_count} files") if train_count == 0 or valid_count == 0: print("āœ— Dataset is missing training or validation images") sys.exit(1) print("\nāœ“ Dataset structure verified") return data_yaml def train_model(dataset_dir): """Start YOLO training.""" print("\n" + "="*60) print("šŸš€ Starting YOLOv5 Player Detection Training") print("="*60) print(f" Model: {MODEL}") print(f" Dataset: {dataset_dir}") print(f" Epochs: {EPOCHS}") print(f" Image Size: {IMG_SIZE}") print(f" Batch Size: {BATCH_SIZE}") print(f" Plots: {PLOTS}") print("="*60 + "\n") cmd = [ VENV_YOLO, "task=detect", "mode=train", f"model={MODEL}", f"data={dataset_dir}/data.yaml", f"epochs={EPOCHS}", f"imgsz={IMG_SIZE}", f"batch={BATCH_SIZE}", f"plots={PLOTS}", "device=0", "workers=8", "amp=True" ] try: result = subprocess.run(cmd, check=False) if result.returncode == 0: print("\nāœ“ Training completed successfully!") print(" Results saved to: runs/detect/trainX (check for highest X)") else: print("\nāœ— Training failed with return code:", result.returncode) sys.exit(1) except Exception as e: print(f"āœ— Failed to start training: {e}") sys.exit(1) def main(): """Main training pipeline.""" print("\n" + "ā–ˆ"*60) print(" YOLOv5 PLAYER DETECTION TRAINING") print("ā–ˆ"*60 + "\n") # Step 1: Setup print("Step 1: Verifying Virtual Environment") setup_venv() # Step 2: Download dataset print("\nStep 2: Downloading Dataset") dataset_location = download_dataset() # Step 3: Find the actual dataset directory (extracted folder) # Roboflow extracts to a subdirectory, typically NBL-VERSION nbl_dir = os.path.join(os.getcwd(), "NBL-6") if os.path.exists(nbl_dir): dataset_location = nbl_dir print(f" Found dataset at: {dataset_location}") # Step 4: Verify dataset print("\nStep 3: Verifying Dataset") verify_dataset(dataset_location) # Step 5: Train print("\nStep 4: Starting Training") train_model(dataset_location) print("\n" + "ā–ˆ"*60) print(" TRAINING PIPELINE COMPLETE") print("ā–ˆ"*60 + "\n") if __name__ == "__main__": main()