Oill_split / model /train.py
Utkarshres32's picture
Initial commit: AI-powered Oil Spill Detection and Monitoring System
7a5bb5d
import os
import argparse
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import glob
from sklearn.model_selection import train_test_split
from unet import build_unet
from loss_metrics import bce_dice_loss, get_metrics
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.data_loader import get_dataset
def get_file_paths(data_dir):
"""Retrieves image and mask paths."""
image_dir = os.path.join(data_dir, 'images')
mask_dir = os.path.join(data_dir, 'masks')
# Supported formats
image_paths = glob.glob(os.path.join(image_dir, '*.*'))
# Sorting ensures that image and mask align if they have identical names
image_paths.sort()
# Map image paths to corresponding mask paths
# Assumes masks have the exact same base name
mask_paths = []
for img_path in image_paths:
base_name = os.path.basename(img_path)
name, _ = os.path.splitext(base_name)
# Search for corresponding mask (usually png)
mask_search = glob.glob(os.path.join(mask_dir, f"{name}.*"))
if mask_search:
mask_paths.append(mask_search[0])
else:
print(f"Warning: No mask found for {img_path}")
# If no mask, we should ideally remove the image too, but for now just skip adding
# Filter image paths to only those with masks
valid_image_paths = [p for p in image_paths if glob.glob(os.path.join(mask_dir, f"{os.path.splitext(os.path.basename(p))[0]}.*"))]
valid_mask_paths = valid_image_paths.copy() # Placeholder - actually they match 1:1 if sorted and filtered
# Let's do it safely
final_img_paths = []
final_mask_paths = []
for img in image_paths:
base = os.path.splitext(os.path.basename(img))[0]
masks = glob.glob(os.path.join(mask_dir, f"{base}.*"))
if len(masks) > 0:
final_img_paths.append(img)
final_mask_paths.append(masks[0])
return final_img_paths, final_mask_paths
def train(args):
print(f"Starting training process. Using data dir: {args.data_dir}")
img_paths, mask_paths = get_file_paths(args.data_dir)
print(f"Found {len(img_paths)} image/mask pairs.")
if len(img_paths) == 0:
print("Error: No data found. Please check data_dir structure.")
return
# Split: 80% Train, 20% Val
train_x, val_x, train_y, val_y = train_test_split(img_paths, mask_paths, test_size=0.2, random_state=42)
train_dataset = get_dataset(train_x, train_y, batch_size=args.batch_size, is_train=True)
val_dataset = get_dataset(val_x, val_y, batch_size=args.batch_size, is_train=False)
model = build_unet(input_shape=(256, 256, 3))
optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=get_metrics())
os.makedirs(args.save_dir, exist_ok=True)
model_path = os.path.join(args.save_dir, "oil_spill_unet_best.keras")
callbacks = [
ModelCheckpoint(model_path, verbose=1, save_best_only=True, monitor='val_loss'),
EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1, min_lr=1e-6)
]
history = model.fit(
train_dataset,
validation_data=val_dataset,
epochs=args.epochs,
callbacks=callbacks
)
print("Training complete. Best model saved at:", model_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Train Oil Spill U-Net")
parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing images/ and masks/ folders.")
parser.add_argument("--save_dir", type=str, default="../model/saved_models", help="Directory to save the trained model.")
parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs.")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size.")
parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
args = parser.parse_args()
train(args)