|
|
|
|
|
""" |
|
|
Water Surface Segmentation Training Script |
|
|
Train YOLOv11n model for water surface segmentation on beach images. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from ultralytics import YOLO |
|
|
|
|
|
|
|
|
def parse_arguments() -> argparse.Namespace: |
|
|
"""Parse command line arguments.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Train YOLOv11n model for water surface segmentation", |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--data", |
|
|
type=str, |
|
|
required=True, |
|
|
help="Path to data.yaml file" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--weights", |
|
|
type=str, |
|
|
default="yolov11n-seg.pt", |
|
|
help="Path to pretrained weights" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--img", |
|
|
type=int, |
|
|
default=640, |
|
|
help="Image size for training" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--batch", |
|
|
type=int, |
|
|
default=16, |
|
|
help="Batch size" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--epochs", |
|
|
type=int, |
|
|
default=50, |
|
|
help="Number of training epochs" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--device", |
|
|
type=str, |
|
|
default="", |
|
|
help="Device to use for training (cpu, cuda, mps)" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--project", |
|
|
type=str, |
|
|
default="runs/segment", |
|
|
help="Project directory" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--name", |
|
|
type=str, |
|
|
default="nwsd_train", |
|
|
help="Experiment name" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--patience", |
|
|
type=int, |
|
|
default=10, |
|
|
help="Early stopping patience" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--save-period", |
|
|
type=int, |
|
|
default=5, |
|
|
help="Save model every n epochs" |
|
|
) |
|
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def validate_inputs(args: argparse.Namespace) -> None: |
|
|
"""Validate input arguments.""" |
|
|
if not os.path.exists(args.data): |
|
|
raise FileNotFoundError(f"Data configuration file not found: {args.data}") |
|
|
|
|
|
if not args.weights.startswith("yolov11") and not os.path.exists(args.weights): |
|
|
raise FileNotFoundError(f"Weights file not found: {args.weights}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main training function.""" |
|
|
args = parse_arguments() |
|
|
|
|
|
try: |
|
|
validate_inputs(args) |
|
|
|
|
|
print(f"Loading model: {args.weights}") |
|
|
model = YOLO(args.weights) |
|
|
|
|
|
train_params = { |
|
|
'data': args.data, |
|
|
'imgsz': args.img, |
|
|
'batch': args.batch, |
|
|
'epochs': args.epochs, |
|
|
'device': args.device, |
|
|
'project': args.project, |
|
|
'name': args.name, |
|
|
'patience': args.patience, |
|
|
'save_period': args.save_period, |
|
|
'save': True, |
|
|
'verbose': True, |
|
|
'plots': True, |
|
|
'val': True, |
|
|
} |
|
|
|
|
|
print("Starting training with parameters:") |
|
|
for key, value in train_params.items(): |
|
|
print(f" {key}: {value}") |
|
|
|
|
|
results = model.train(**train_params) |
|
|
|
|
|
model_save_path = os.path.join(args.project, args.name, "weights", "best.pt") |
|
|
final_model_path = os.path.join("model", "nwsd-v2.pt") |
|
|
|
|
|
os.makedirs("model", exist_ok=True) |
|
|
|
|
|
if os.path.exists(model_save_path): |
|
|
import shutil |
|
|
shutil.copy2(model_save_path, final_model_path) |
|
|
print(f"Best model saved to: {final_model_path}") |
|
|
|
|
|
print("Training completed successfully!") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {str(e)}", file=sys.stderr) |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|