File size: 3,354 Bytes
ae51a24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pathlib

def load_dataset(folder_path, img_size=(180, 180), batch_size=32):
    """Loads image dataset from folder using tf.keras.utils.image_dataset_from_directory"""

    folder_path = pathlib.Path(folder_path)

    if not folder_path.exists():
        raise ValueError(f"Folder does not exist: {folder_path}")

    dataset = tf.keras.utils.image_dataset_from_directory(
        folder_path,
        labels='inferred',
        label_mode='categorical',
        shuffle=True,
        batch_size=batch_size,
        image_size=img_size
    )
    return dataset


def build_model(input_shape=(180, 180, 3), num_classes=4, learning_rate=0.001):
    """Builds a simple CNN for Alzheimer's classification"""

    model = keras.Sequential([
        layers.Rescaling(1./255, input_shape=input_shape),

        layers.Conv2D(32, 3, activation="relu"),
        layers.MaxPooling2D(),

        layers.Conv2D(64, 3, activation="relu"),
        layers.MaxPooling2D(),

        layers.Conv2D(128, 3, activation="relu"),
        layers.MaxPooling2D(),

        layers.Flatten(),
        layers.Dense(128, activation="relu"),
        layers.Dropout(0.3),

        layers.Dense(num_classes, activation="softmax")
    ])

    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
        loss="categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


def main():
    parser = argparse.ArgumentParser(description="Train CNN on Alzheimer's MRI Dataset")

    parser.add_argument("--training_folder", type=str, required=True)
    parser.add_argument("--testing_folder", type=str, required=True)
    parser.add_argument("--output_folder", type=str, required=True)

    parser.add_argument("--epochs", type=int, default=5)
    parser.add_argument("--learning_rate", type=float, default=0.01)
    parser.add_argument("--batch_size", type=int, default=32)
    parser.add_argument("--patience", type=int, default=10)
    parser.add_argument("--model_name", type=str, default="alzheimers-cnn")

    args = parser.parse_args()

    print("\n=== LOADING DATASETS ===")
    train_ds = load_dataset(args.training_folder, batch_size=args.batch_size)
    val_ds = load_dataset(args.testing_folder, batch_size=args.batch_size)

    class_names = train_ds.class_names
    num_classes = len(class_names)

    print(f"Detected classes: {class_names}")
    print(f"Number of classes: {num_classes}")

    print("\n=== BUILDING MODEL ===")
    model = build_model(num_classes=num_classes, learning_rate=args.learning_rate)
    model.summary()

    print("\n=== TRAINING STARTED ===")

    early_stop = keras.callbacks.EarlyStopping(
        patience=args.patience,
        restore_best_weights=True
    )

    history = model.fit(
        train_ds,
        validation_data=val_ds,
        epochs=args.epochs,
        callbacks=[early_stop]
    )

    print("\n=== TRAINING COMPLETED ===")

    os.makedirs(args.output_folder, exist_ok=True)
    output_path = os.path.join(args.output_folder, f"{args.model_name}.h5")

    print(f"Saving model to: {output_path}")
    model.save(output_path)

    print("\n✔ Model saved successfully!")
    print("✔ Training finished!")


if __name__ == "__main__":
    main()