import argparse import os import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import pathlib def load_dataset(folder_path, img_size=(180, 180), batch_size=32): """Loads image dataset from folder using tf.keras.utils.image_dataset_from_directory""" folder_path = pathlib.Path(folder_path) if not folder_path.exists(): raise ValueError(f"Folder does not exist: {folder_path}") dataset = tf.keras.utils.image_dataset_from_directory( folder_path, labels='inferred', label_mode='categorical', shuffle=True, batch_size=batch_size, image_size=img_size ) return dataset def build_model(input_shape=(180, 180, 3), num_classes=4, learning_rate=0.001): """Builds a simple CNN for Alzheimer's classification""" model = keras.Sequential([ layers.Rescaling(1./255, input_shape=input_shape), layers.Conv2D(32, 3, activation="relu"), layers.MaxPooling2D(), layers.Conv2D(64, 3, activation="relu"), layers.MaxPooling2D(), layers.Conv2D(128, 3, activation="relu"), layers.MaxPooling2D(), layers.Flatten(), layers.Dense(128, activation="relu"), layers.Dropout(0.3), layers.Dense(num_classes, activation="softmax") ]) model.compile( optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss="categorical_crossentropy", metrics=["accuracy"] ) return model def main(): parser = argparse.ArgumentParser(description="Train CNN on Alzheimer's MRI Dataset") parser.add_argument("--training_folder", type=str, required=True) parser.add_argument("--testing_folder", type=str, required=True) parser.add_argument("--output_folder", type=str, required=True) parser.add_argument("--epochs", type=int, default=5) parser.add_argument("--learning_rate", type=float, default=0.01) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--patience", type=int, default=10) parser.add_argument("--model_name", type=str, default="alzheimers-cnn") args = parser.parse_args() print("\n=== LOADING DATASETS ===") train_ds = load_dataset(args.training_folder, batch_size=args.batch_size) val_ds = load_dataset(args.testing_folder, batch_size=args.batch_size) class_names = train_ds.class_names num_classes = len(class_names) print(f"Detected classes: {class_names}") print(f"Number of classes: {num_classes}") print("\n=== BUILDING MODEL ===") model = build_model(num_classes=num_classes, learning_rate=args.learning_rate) model.summary() print("\n=== TRAINING STARTED ===") early_stop = keras.callbacks.EarlyStopping( patience=args.patience, restore_best_weights=True ) history = model.fit( train_ds, validation_data=val_ds, epochs=args.epochs, callbacks=[early_stop] ) print("\n=== TRAINING COMPLETED ===") os.makedirs(args.output_folder, exist_ok=True) output_path = os.path.join(args.output_folder, f"{args.model_name}.h5") print(f"Saving model to: {output_path}") model.save(output_path) print("\nāœ” Model saved successfully!") print("āœ” Training finished!") if __name__ == "__main__": main()