strawberryPicker / scripts /setup_training.py
Gareth
Initial clean commit for Hugging Face
efb1801
#!/usr/bin/env python3
"""
Setup script for Strawberry Picker ML Training Environment
This script installs dependencies and validates the training setup.
"""
import os
import sys
import subprocess
import argparse
from pathlib import Path
def run_command(cmd, description=""):
"""Run a shell command and handle errors"""
print(f"\n{'='*60}")
print(f"Running: {description or cmd}")
print(f"{'='*60}\n")
try:
result = subprocess.run(cmd, shell=True, check=True, capture_output=True, text=True)
if result.stdout:
print(result.stdout)
return True
except subprocess.CalledProcessError as e:
print(f"ERROR: Command failed with return code {e.returncode}")
if e.stdout:
print(f"STDOUT: {e.stdout}")
if e.stderr:
print(f"STDERR: {e.stderr}")
return False
def check_python_version():
"""Check if Python version is compatible"""
print("Checking Python version...")
version = sys.version_info
if version.major < 3 or (version.major == 3 and version.minor < 8):
print(f"ERROR: Python 3.8+ required. Found {version.major}.{version.minor}")
return False
print(f"✓ Python {version.major}.{version.minor}.{version.micro}")
return True
def check_pip():
"""Check if pip is available"""
print("Checking pip availability...")
return run_command("pip --version", "Check pip version")
def install_requirements():
"""Install Python dependencies"""
print("Installing Python dependencies...")
requirements_file = Path(__file__).parent / "requirements.txt"
if not requirements_file.exists():
print(f"ERROR: requirements.txt not found at {requirements_file}")
return False
# Upgrade pip first
if not run_command("pip install --upgrade pip", "Upgrade pip"):
return False
# Install requirements
return run_command(f"pip install -r {requirements_file}", "Install requirements")
def check_ultralytics():
"""Check if ultralytics is installed correctly"""
print("Checking ultralytics installation...")
try:
from ultralytics import YOLO
print("✓ ultralytics installed successfully")
return True
except ImportError as e:
print(f"ERROR: Failed to import ultralytics: {e}")
return False
def check_torch():
"""Check PyTorch installation and GPU availability"""
print("Checking PyTorch installation...")
try:
import torch
print(f"✓ PyTorch version: {torch.__version__}")
if torch.cuda.is_available():
print(f"✓ GPU available: {torch.cuda.get_device_name(0)}")
print(f"✓ CUDA version: {torch.version.cuda}")
else:
print("⚠ GPU not available, will use CPU for training")
return True
except ImportError as e:
print(f"ERROR: Failed to import torch: {e}")
return False
def validate_dataset():
"""Validate dataset structure"""
print("Validating dataset structure...")
dataset_path = Path(__file__).parent / "model" / "dataset" / "straw-detect.v1-straw-detect.yolov8"
data_yaml = dataset_path / "data.yaml"
if not data_yaml.exists():
print(f"ERROR: data.yaml not found at {data_yaml}")
print("Please ensure your dataset is in the correct location")
return False
try:
import yaml
with open(data_yaml, 'r') as f:
data = yaml.safe_load(f)
print(f"✓ Dataset configuration loaded")
print(f" Classes: {data['nc']}")
print(f" Names: {data['names']}")
# Check training images
train_path = dataset_path / data['train']
if train_path.exists():
train_images = list(train_path.glob('*.jpg')) + list(train_path.glob('*.png'))
print(f" Training images: {len(train_images)}")
else:
print(f"⚠ Training path not found: {train_path}")
# Check validation images
val_path = dataset_path / data['val']
if val_path.exists():
val_images = list(val_path.glob('*.jpg')) + list(val_path.glob('*.png'))
print(f" Validation images: {len(val_images)}")
else:
print(f"⚠ Validation path not found: {val_path}")
return True
except Exception as e:
print(f"ERROR: Failed to validate dataset: {e}")
return False
def create_directories():
"""Create necessary directories"""
print("Creating project directories...")
base_path = Path(__file__).parent
dirs = [
base_path / "model" / "weights",
base_path / "model" / "results",
base_path / "model" / "exports"
]
for dir_path in dirs:
dir_path.mkdir(parents=True, exist_ok=True)
print(f"✓ Created: {dir_path}")
return True
def main():
parser = argparse.ArgumentParser(description='Setup training environment for strawberry detection')
parser.add_argument('--skip-install', action='store_true', help='Skip package installation')
parser.add_argument('--validate-only', action='store_true', help='Only validate setup without installing')
args = parser.parse_args()
print("="*60)
print("Strawberry Picker ML Training Environment Setup")
print("="*60)
# Step 1: Check Python version
if not check_python_version():
sys.exit(1)
# Step 2: Check pip
if not check_pip():
sys.exit(1)
# Step 3: Install requirements (unless skipped)
if not args.skip_install and not args.validate_only:
if not install_requirements():
print("\n⚠ Installation failed. Please check the errors above.")
response = input("Continue with validation anyway? (y/n): ")
if response.lower() != 'y':
sys.exit(1)
# Step 4: Check ultralytics
if not check_ultralytics():
sys.exit(1)
# Step 5: Check PyTorch
if not check_torch():
sys.exit(1)
# Step 6: Validate dataset
if not validate_dataset():
print("\n⚠ Dataset validation failed. Please fix the issues above.")
if not args.validate_only:
response = input("Continue with directory creation anyway? (y/n): ")
if response.lower() != 'y':
sys.exit(1)
# Step 7: Create directories
if not args.validate_only:
if not create_directories():
sys.exit(1)
print("\n" + "="*60)
if args.validate_only:
print("Setup validation completed!")
else:
print("Setup completed successfully!")
print("\nNext steps:")
print("1. Run training: python train_yolov8.py")
print("2. Or open train_yolov8_colab.ipynb in Google Colab")
print("3. Check README.md for detailed instructions")
print("="*60)
if __name__ == '__main__':
main()