import os import argparse import tensorflow as tf from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau import glob from sklearn.model_selection import train_test_split from unet import build_unet from loss_metrics import bce_dice_loss, get_metrics import sys sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) from utils.data_loader import get_dataset def get_file_paths(data_dir): """Retrieves image and mask paths.""" image_dir = os.path.join(data_dir, 'images') mask_dir = os.path.join(data_dir, 'masks') # Supported formats image_paths = glob.glob(os.path.join(image_dir, '*.*')) # Sorting ensures that image and mask align if they have identical names image_paths.sort() # Map image paths to corresponding mask paths # Assumes masks have the exact same base name mask_paths = [] for img_path in image_paths: base_name = os.path.basename(img_path) name, _ = os.path.splitext(base_name) # Search for corresponding mask (usually png) mask_search = glob.glob(os.path.join(mask_dir, f"{name}.*")) if mask_search: mask_paths.append(mask_search[0]) else: print(f"Warning: No mask found for {img_path}") # If no mask, we should ideally remove the image too, but for now just skip adding # Filter image paths to only those with masks valid_image_paths = [p for p in image_paths if glob.glob(os.path.join(mask_dir, f"{os.path.splitext(os.path.basename(p))[0]}.*"))] valid_mask_paths = valid_image_paths.copy() # Placeholder - actually they match 1:1 if sorted and filtered # Let's do it safely final_img_paths = [] final_mask_paths = [] for img in image_paths: base = os.path.splitext(os.path.basename(img))[0] masks = glob.glob(os.path.join(mask_dir, f"{base}.*")) if len(masks) > 0: final_img_paths.append(img) final_mask_paths.append(masks[0]) return final_img_paths, final_mask_paths def train(args): print(f"Starting training process. Using data dir: {args.data_dir}") img_paths, mask_paths = get_file_paths(args.data_dir) print(f"Found {len(img_paths)} image/mask pairs.") if len(img_paths) == 0: print("Error: No data found. Please check data_dir structure.") return # Split: 80% Train, 20% Val train_x, val_x, train_y, val_y = train_test_split(img_paths, mask_paths, test_size=0.2, random_state=42) train_dataset = get_dataset(train_x, train_y, batch_size=args.batch_size, is_train=True) val_dataset = get_dataset(val_x, val_y, batch_size=args.batch_size, is_train=False) model = build_unet(input_shape=(256, 256, 3)) optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=get_metrics()) os.makedirs(args.save_dir, exist_ok=True) model_path = os.path.join(args.save_dir, "oil_spill_unet_best.keras") callbacks = [ ModelCheckpoint(model_path, verbose=1, save_best_only=True, monitor='val_loss'), EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True), ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1, min_lr=1e-6) ] history = model.fit( train_dataset, validation_data=val_dataset, epochs=args.epochs, callbacks=callbacks ) print("Training complete. Best model saved at:", model_path) if __name__ == "__main__": parser = argparse.ArgumentParser("Train Oil Spill U-Net") parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing images/ and masks/ folders.") parser.add_argument("--save_dir", type=str, default="../model/saved_models", help="Directory to save the trained model.") parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs.") parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") args = parser.parse_args() train(args)