File size: 4,260 Bytes
7a5bb5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import argparse
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
import glob
from sklearn.model_selection import train_test_split

from unet import build_unet
from loss_metrics import bce_dice_loss, get_metrics
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
from utils.data_loader import get_dataset

def get_file_paths(data_dir):
    """Retrieves image and mask paths."""
    image_dir = os.path.join(data_dir, 'images')
    mask_dir = os.path.join(data_dir, 'masks')
    
    # Supported formats
    image_paths = glob.glob(os.path.join(image_dir, '*.*'))
    # Sorting ensures that image and mask align if they have identical names
    image_paths.sort()
    
    # Map image paths to corresponding mask paths 
    # Assumes masks have the exact same base name
    mask_paths = []
    for img_path in image_paths:
        base_name = os.path.basename(img_path)
        name, _ = os.path.splitext(base_name)
        # Search for corresponding mask (usually png)
        mask_search = glob.glob(os.path.join(mask_dir, f"{name}.*"))
        if mask_search:
            mask_paths.append(mask_search[0])
        else:
            print(f"Warning: No mask found for {img_path}")
            # If no mask, we should ideally remove the image too, but for now just skip adding
    
    # Filter image paths to only those with masks
    valid_image_paths = [p for p in image_paths if glob.glob(os.path.join(mask_dir, f"{os.path.splitext(os.path.basename(p))[0]}.*"))]
    valid_mask_paths = valid_image_paths.copy() # Placeholder - actually they match 1:1 if sorted and filtered
    
    # Let's do it safely
    final_img_paths = []
    final_mask_paths = []
    for img in image_paths:
        base = os.path.splitext(os.path.basename(img))[0]
        masks = glob.glob(os.path.join(mask_dir, f"{base}.*"))
        if len(masks) > 0:
            final_img_paths.append(img)
            final_mask_paths.append(masks[0])

    return final_img_paths, final_mask_paths

def train(args):
    print(f"Starting training process. Using data dir: {args.data_dir}")
    
    img_paths, mask_paths = get_file_paths(args.data_dir)
    print(f"Found {len(img_paths)} image/mask pairs.")
    
    if len(img_paths) == 0:
        print("Error: No data found. Please check data_dir structure.")
        return

    # Split: 80% Train, 20% Val
    train_x, val_x, train_y, val_y = train_test_split(img_paths, mask_paths, test_size=0.2, random_state=42)
    
    train_dataset = get_dataset(train_x, train_y, batch_size=args.batch_size, is_train=True)
    val_dataset = get_dataset(val_x, val_y, batch_size=args.batch_size, is_train=False)
    
    model = build_unet(input_shape=(256, 256, 3))
    
    optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
    
    model.compile(optimizer=optimizer, loss=bce_dice_loss, metrics=get_metrics())
    
    os.makedirs(args.save_dir, exist_ok=True)
    model_path = os.path.join(args.save_dir, "oil_spill_unet_best.keras")
    
    callbacks = [
        ModelCheckpoint(model_path, verbose=1, save_best_only=True, monitor='val_loss'),
        EarlyStopping(monitor='val_loss', patience=15, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=5, verbose=1, min_lr=1e-6)
    ]
    
    history = model.fit(
        train_dataset,
        validation_data=val_dataset,
        epochs=args.epochs,
        callbacks=callbacks
    )
    
    print("Training complete. Best model saved at:", model_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser("Train Oil Spill U-Net")
    parser.add_argument("--data_dir", type=str, default="../data", help="Directory containing images/ and masks/ folders.")
    parser.add_argument("--save_dir", type=str, default="../model/saved_models", help="Directory to save the trained model.")
    parser.add_argument("--epochs", type=int, default=50, help="Number of training epochs.")
    parser.add_argument("--batch_size", type=int, default=16, help="Batch size.")
    parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.")
    
    args = parser.parse_args()
    train(args)