#!/usr/bin/env python3 """ Water Surface Segmentation Training Script Train YOLOv11n model for water surface segmentation on beach images. """ import argparse import os import sys from pathlib import Path from ultralytics import YOLO def parse_arguments() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Train YOLOv11n model for water surface segmentation", formatter_class=argparse.ArgumentDefaultsHelpFormatter ) parser.add_argument( "--data", type=str, required=True, help="Path to data.yaml file" ) parser.add_argument( "--weights", type=str, default="yolov11n-seg.pt", help="Path to pretrained weights" ) parser.add_argument( "--img", type=int, default=640, help="Image size for training" ) parser.add_argument( "--batch", type=int, default=16, help="Batch size" ) parser.add_argument( "--epochs", type=int, default=50, help="Number of training epochs" ) parser.add_argument( "--device", type=str, default="", help="Device to use for training (cpu, cuda, mps)" ) parser.add_argument( "--project", type=str, default="runs/segment", help="Project directory" ) parser.add_argument( "--name", type=str, default="nwsd_train", help="Experiment name" ) parser.add_argument( "--patience", type=int, default=10, help="Early stopping patience" ) parser.add_argument( "--save-period", type=int, default=5, help="Save model every n epochs" ) return parser.parse_args() def validate_inputs(args: argparse.Namespace) -> None: """Validate input arguments.""" if not os.path.exists(args.data): raise FileNotFoundError(f"Data configuration file not found: {args.data}") if not args.weights.startswith("yolov11") and not os.path.exists(args.weights): raise FileNotFoundError(f"Weights file not found: {args.weights}") def main(): """Main training function.""" args = parse_arguments() try: validate_inputs(args) print(f"Loading model: {args.weights}") model = YOLO(args.weights) train_params = { 'data': args.data, 'imgsz': args.img, 'batch': args.batch, 'epochs': args.epochs, 'device': args.device, 'project': args.project, 'name': args.name, 'patience': args.patience, 'save_period': args.save_period, 'save': True, 'verbose': True, 'plots': True, 'val': True, } print("Starting training with parameters:") for key, value in train_params.items(): print(f" {key}: {value}") results = model.train(**train_params) model_save_path = os.path.join(args.project, args.name, "weights", "best.pt") final_model_path = os.path.join("model", "nwsd-v2.pt") os.makedirs("model", exist_ok=True) if os.path.exists(model_save_path): import shutil shutil.copy2(model_save_path, final_model_path) print(f"Best model saved to: {final_model_path}") print("Training completed successfully!") except Exception as e: print(f"Error: {str(e)}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main()