Spaces:
Sleeping
Sleeping
| import os | |
| import argparse | |
| import tensorflow as tf | |
| from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau | |
| import glob | |
| from sklearn.model_selection import train_test_split | |
| from unet import build_unet | |
| from loss_metrics import bce_dice_loss, get_metrics | |
| import sys | |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) | |
| from utils.data_loader import get_dataset | |
| def get_file_paths(data_dir): | |
| """Retrieves image and mask paths.""" | |
| image_dir = os.path.join(data_dir, 'images') | |
| mask_dir = os.path.join(data_dir, 'masks') | |
| # Supported formats | |
| image_paths = glob.glob(os.path.join(image_dir, '*.*')) | |
| # Sorting ensures that image and mask align if they have identical names | |
| image_paths.sort() | |
| # Map image paths to corresponding mask paths | |
| # Assumes masks have the exact same base name | |
| mask_paths = [] | |
| for img_path in image_paths: | |
| base_name = os.path.basename(img_path) | |
| name, _ = os.path.splitext(base_name) | |
| # Search for corresponding mask (usually png) | |
| mask_search = glob.glob(os.path.join(mask_dir, f"{name}.*")) | |
| if mask_search: | |
| mask_paths.append(mask_search[0]) | |
| else: | |
| print(f"Warning: No mask found for {img_path}") | |
| # If no mask, we should ideally remove the image too, but for now just skip adding | |
| # Filter image paths to only those with masks | |
| valid_image_paths = [p for p in image_paths if glob.glob(os.path.join(mask_dir, f"{os.path.splitext(os.path.basename(p))[0]}.*"))] | |
| valid_mask_paths = valid_image_paths.copy() # Placeholder - actually they match 1:1 if sorted and filtered | |
| # Let's do it safely | |
| final_img_paths = [] | |
| final_mask_paths = [] | |
| for img in image_paths: | |
| base = os.path.splitext(os.path.basename(img))[0] | |
| masks = glob.glob(os.path.join(mask_dir, f"{base}.*")) | |
| if len(masks) > 0: | |
| final_img_paths.append(img) | |
| final_mask_paths.append(masks[0]) | |
| return final_img_paths, final_mask_paths | |
| def train(args): | |
| print(f"Starting training process. Using data dir: {args.data_dir}") | |
| img_paths, mask_paths = get_file_paths(args.data_dir) | |
| print(f"Found {len(img_paths)} image/mask pairs.") | |
| if len(img_paths) == 0: | |
| print("Error: No data found. Please check data_dir structure.") | |
| return | |
| # Split: 80% Train, 20% Val | |
| train_x, val_x, train_y, val_y = train_test_split(img_paths, mask_paths, test_size=0.2, random_state=42) | |
| train_dataset = get_dataset(train_x, train_y, batch_size=args.batch_size, is_train=True) | |
| val_dataset = get_dataset(val_x, val_y, batch_size=args.batch_size, is_train=False) | |
| model = build_unet(input_shape=(256, 256, 3)) | |
| optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate) | |
| model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=get_metrics()) | |
| os.makedirs(args.save_dir, exist_ok=True) | |
| model_path = os.path.join(args.save_dir, "oil_spill_unet_best.keras") | |
| callbacks = [ | |
| ModelCheckpoint(model_path, verbose=1, save_best_only=True, monitor='val_loss'), | |
| EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True), | |
| ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1, min_lr=1e-6) | |
| ] | |
| history = model.fit( | |
| train_dataset, | |
| validation_data=val_dataset, | |
| epochs=args.epochs, | |
| callbacks=callbacks | |
| ) | |
| print("Training complete. Best model saved at:", model_path) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("Train Oil Spill U-Net") | |
| parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing images/ and masks/ folders.") | |
| parser.add_argument("--save_dir", type=str, default="../model/saved_models", help="Directory to save the trained model.") | |
| parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs.") | |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") | |
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") | |
| args = parser.parse_args() | |
| train(args) | |