strawberryPicker / scripts /train_yolov8.py
Gareth
Initial clean commit for Hugging Face
efb1801
#!/usr/bin/env python3
"""
YOLOv8 Training Script for Strawberry Detection
Compatible with: Local Python, WSL, Google Colab (VS Code extension)
"""
import os
import sys
import argparse
from pathlib import Path
import torch
import yaml
def check_environment():
"""Detect running environment and configure paths accordingly"""
env_info = {
'is_colab': 'COLAB_GPU' in os.environ or '/content' in os.getcwd(),
'is_wsl': 'WSL_DISTRO_NAME' in os.environ,
'has_gpu': torch.cuda.is_available(),
'gpu_name': torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'
}
return env_info
def setup_paths(dataset_path=None):
"""Configure dataset and output paths based on environment"""
env = check_environment()
if env['is_colab']:
# Google Colab paths
base_path = Path('/content/strawberry-picker')
dataset_path = dataset_path or '/content/dataset'
weights_dir = base_path / 'weights'
results_dir = base_path / 'results'
else:
# Local/WSL paths
base_path = Path(__file__).parent
dataset_path = dataset_path or base_path / 'model' / 'dataset' / 'straw-detect.v1-straw-detect.yolov8'
weights_dir = base_path / 'model' / 'weights'
results_dir = base_path / 'model' / 'results'
# Create directories
weights_dir.mkdir(parents=True, exist_ok=True)
results_dir.mkdir(parents=True, exist_ok=True)
return {
'dataset_path': Path(dataset_path),
'weights_dir': weights_dir,
'results_dir': results_dir,
'base_path': base_path
}
def validate_dataset(dataset_path):
"""Validate YOLO dataset structure"""
dataset_path = Path(dataset_path)
data_yaml = dataset_path / 'data.yaml'
if not data_yaml.exists():
raise FileNotFoundError(f"data.yaml not found at {data_yaml}")
# Load and validate YAML
with open(data_yaml, 'r') as f:
data = yaml.safe_load(f)
required_keys = ['train', 'val', 'nc', 'names']
for key in required_keys:
if key not in data:
raise ValueError(f"Missing required key '{key}' in data.yaml")
# Check if paths are relative and resolve them
train_path = dataset_path / data['train']
val_path = dataset_path / data['val']
if not train_path.exists():
raise FileNotFoundError(f"Training images not found at {train_path}")
if not val_path.exists():
raise FileNotFoundError(f"Validation images not found at {val_path}")
print(f"✓ Dataset validated: {data['nc']} classes - {data['names']}")
print(f"✓ Training images: {train_path}")
print(f"✓ Validation images: {val_path}")
return data_yaml
def train_model(data_yaml, weights_dir, results_dir, epochs=100, img_size=640, batch_size=16, weights=None, resume=False):
"""Train YOLOv8 model (supports resuming from checkpoints)"""
try:
from ultralytics import YOLO
except ImportError:
print("ERROR: ultralytics not installed. Run: pip install ultralytics")
sys.exit(1)
env = check_environment()
print(f"\n{'='*60}")
print(f"Environment: {'Google Colab' if env['is_colab'] else 'Local/WSL'}")
print(f"GPU Available: {env['has_gpu']} ({env['gpu_name']})")
print(f"{'='*60}\n")
# Use GPU if available
device = '0' if env['has_gpu'] else 'cpu'
# Load model (custom weights or default YOLOv8n)
model_source = Path(weights) if weights else 'yolov8n.pt'
print(f"Loading model from {model_source}...")
model = YOLO(str(model_source))
# Training arguments
train_args = {
'data': str(data_yaml),
'epochs': epochs,
'imgsz': img_size,
'batch': batch_size,
'device': device,
'project': str(results_dir),
'name': 'strawberry_detection',
'exist_ok': True,
'patience': 20, # Early stopping patience
'save': True,
'save_period': 10, # Save checkpoint every 10 epochs
'cache': True, # Cache images for faster training
}
if resume:
train_args['resume'] = True
# Adjust batch size for Colab's limited RAM
if env['is_colab'] and batch_size > 16:
train_args['batch'] = 16
print(f"Adjusted batch size to 16 for Colab environment")
print(f"\nStarting training with arguments:")
for key, value in train_args.items():
print(f" {key}: {value}")
# Train the model
print(f"\n{'='*60}")
print("TRAINING STARTED")
print(f"{'='*60}\n")
results = model.train(**train_args)
# Save final model
final_model_path = weights_dir / 'strawberry_yolov8n.pt'
model.save(str(final_model_path))
print(f"\n{'='*60}")
print(f"Training completed!")
print(f"Final model saved to: {final_model_path}")
print(f"Results saved to: {results_dir / 'strawberry_detection'}")
print(f"{'='*60}\n")
return results, final_model_path
def export_model(model_path, weights_dir):
"""Export model to ONNX format"""
try:
from ultralytics import YOLO
except ImportError:
print("ERROR: ultralytics not installed")
return None
print(f"\nExporting model to ONNX...")
model = YOLO(str(model_path))
# Export to ONNX
onnx_path = weights_dir / 'strawberry_yolov8n.onnx'
model.export(format='onnx', imgsz=640, dynamic=True)
print(f"ONNX model exported to: {onnx_path}")
return onnx_path
def main():
parser = argparse.ArgumentParser(description='Train YOLOv8 for strawberry detection')
parser.add_argument('--dataset', type=str, help='Path to dataset directory')
parser.add_argument('--epochs', type=int, default=100, help='Number of training epochs')
parser.add_argument('--img-size', type=int, default=640, help='Image size for training')
parser.add_argument('--batch-size', type=int, default=16, help='Batch size for training')
parser.add_argument('--weights', type=str, help='Path to pretrained weights or checkpoint')
parser.add_argument('--resume', action='store_true', help='Resume training from the latest checkpoint')
parser.add_argument('--export-onnx', action='store_true', help='Export to ONNX after training')
parser.add_argument('--validate-only', action='store_true', help='Only validate dataset without training')
args = parser.parse_args()
try:
# Setup paths
paths = setup_paths(args.dataset)
print(f"Base path: {paths['base_path']}")
print(f"Dataset path: {paths['dataset_path']}")
print(f"Weights directory: {paths['weights_dir']}")
print(f"Results directory: {paths['results_dir']}")
# Validate dataset
print(f"\nValidating dataset...")
data_yaml = validate_dataset(paths['dataset_path'])
if args.validate_only:
print("Dataset validation completed. Exiting without training.")
return
# Train model
results, model_path = train_model(
data_yaml=data_yaml,
weights_dir=paths['weights_dir'],
results_dir=paths['results_dir'],
epochs=args.epochs,
img_size=args.img_size,
batch_size=args.batch_size,
weights=args.weights,
resume=args.resume
)
# Export to ONNX if requested
if args.export_onnx:
export_model(model_path, paths['weights_dir'])
print("\n✓ Training pipeline completed successfully!")
except Exception as e:
print(f"\n✗ Error: {str(e)}")
sys.exit(1)
if __name__ == '__main__':
main()