Spaces:
Sleeping
Sleeping
| import argparse | |
| import os | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| from tensorflow.keras import layers | |
| import pathlib | |
| def load_dataset(folder_path, img_size=(180, 180), batch_size=32): | |
| """Loads image dataset from folder using tf.keras.utils.image_dataset_from_directory""" | |
| folder_path = pathlib.Path(folder_path) | |
| if not folder_path.exists(): | |
| raise ValueError(f"Folder does not exist: {folder_path}") | |
| dataset = tf.keras.utils.image_dataset_from_directory( | |
| folder_path, | |
| labels='inferred', | |
| label_mode='categorical', | |
| shuffle=True, | |
| batch_size=batch_size, | |
| image_size=img_size | |
| ) | |
| return dataset | |
| def build_model(input_shape=(180, 180, 3), num_classes=4, learning_rate=0.001): | |
| """Builds a simple CNN for Alzheimer's classification""" | |
| model = keras.Sequential([ | |
| layers.Rescaling(1./255, input_shape=input_shape), | |
| layers.Conv2D(32, 3, activation="relu"), | |
| layers.MaxPooling2D(), | |
| layers.Conv2D(64, 3, activation="relu"), | |
| layers.MaxPooling2D(), | |
| layers.Conv2D(128, 3, activation="relu"), | |
| layers.MaxPooling2D(), | |
| layers.Flatten(), | |
| layers.Dense(128, activation="relu"), | |
| layers.Dropout(0.3), | |
| layers.Dense(num_classes, activation="softmax") | |
| ]) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(learning_rate=learning_rate), | |
| loss="categorical_crossentropy", | |
| metrics=["accuracy"] | |
| ) | |
| return model | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train CNN on Alzheimer's MRI Dataset") | |
| parser.add_argument("--training_folder", type=str, required=True) | |
| parser.add_argument("--testing_folder", type=str, required=True) | |
| parser.add_argument("--output_folder", type=str, required=True) | |
| parser.add_argument("--epochs", type=int, default=5) | |
| parser.add_argument("--learning_rate", type=float, default=0.01) | |
| parser.add_argument("--batch_size", type=int, default=32) | |
| parser.add_argument("--patience", type=int, default=10) | |
| parser.add_argument("--model_name", type=str, default="alzheimers-cnn") | |
| args = parser.parse_args() | |
| print("\n=== LOADING DATASETS ===") | |
| train_ds = load_dataset(args.training_folder, batch_size=args.batch_size) | |
| val_ds = load_dataset(args.testing_folder, batch_size=args.batch_size) | |
| class_names = train_ds.class_names | |
| num_classes = len(class_names) | |
| print(f"Detected classes: {class_names}") | |
| print(f"Number of classes: {num_classes}") | |
| print("\n=== BUILDING MODEL ===") | |
| model = build_model(num_classes=num_classes, learning_rate=args.learning_rate) | |
| model.summary() | |
| print("\n=== TRAINING STARTED ===") | |
| early_stop = keras.callbacks.EarlyStopping( | |
| patience=args.patience, | |
| restore_best_weights=True | |
| ) | |
| history = model.fit( | |
| train_ds, | |
| validation_data=val_ds, | |
| epochs=args.epochs, | |
| callbacks=[early_stop] | |
| ) | |
| print("\n=== TRAINING COMPLETED ===") | |
| os.makedirs(args.output_folder, exist_ok=True) | |
| output_path = os.path.join(args.output_folder, f"{args.model_name}.h5") | |
| print(f"Saving model to: {output_path}") | |
| model.save(output_path) | |
| print("\n✔ Model saved successfully!") | |
| print("✔ Training finished!") | |
| if __name__ == "__main__": | |
| main() | |