Alessia2004's picture
Upload folder using huggingface_hub
ae51a24 verified
import argparse
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import pathlib
def load_dataset(folder_path, img_size=(180, 180), batch_size=32):
"""Loads image dataset from folder using tf.keras.utils.image_dataset_from_directory"""
folder_path = pathlib.Path(folder_path)
if not folder_path.exists():
raise ValueError(f"Folder does not exist: {folder_path}")
dataset = tf.keras.utils.image_dataset_from_directory(
folder_path,
labels='inferred',
label_mode='categorical',
shuffle=True,
batch_size=batch_size,
image_size=img_size
)
return dataset
def build_model(input_shape=(180, 180, 3), num_classes=4, learning_rate=0.001):
"""Builds a simple CNN for Alzheimer's classification"""
model = keras.Sequential([
layers.Rescaling(1./255, input_shape=input_shape),
layers.Conv2D(32, 3, activation="relu"),
layers.MaxPooling2D(),
layers.Conv2D(64, 3, activation="relu"),
layers.MaxPooling2D(),
layers.Conv2D(128, 3, activation="relu"),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation="relu"),
layers.Dropout(0.3),
layers.Dense(num_classes, activation="softmax")
])
model.compile(
optimizer=keras.optimizers.Adam(learning_rate=learning_rate),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
return model
def main():
parser = argparse.ArgumentParser(description="Train CNN on Alzheimer's MRI Dataset")
parser.add_argument("--training_folder", type=str, required=True)
parser.add_argument("--testing_folder", type=str, required=True)
parser.add_argument("--output_folder", type=str, required=True)
parser.add_argument("--epochs", type=int, default=5)
parser.add_argument("--learning_rate", type=float, default=0.01)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--patience", type=int, default=10)
parser.add_argument("--model_name", type=str, default="alzheimers-cnn")
args = parser.parse_args()
print("\n=== LOADING DATASETS ===")
train_ds = load_dataset(args.training_folder, batch_size=args.batch_size)
val_ds = load_dataset(args.testing_folder, batch_size=args.batch_size)
class_names = train_ds.class_names
num_classes = len(class_names)
print(f"Detected classes: {class_names}")
print(f"Number of classes: {num_classes}")
print("\n=== BUILDING MODEL ===")
model = build_model(num_classes=num_classes, learning_rate=args.learning_rate)
model.summary()
print("\n=== TRAINING STARTED ===")
early_stop = keras.callbacks.EarlyStopping(
patience=args.patience,
restore_best_weights=True
)
history = model.fit(
train_ds,
validation_data=val_ds,
epochs=args.epochs,
callbacks=[early_stop]
)
print("\n=== TRAINING COMPLETED ===")
os.makedirs(args.output_folder, exist_ok=True)
output_path = os.path.join(args.output_folder, f"{args.model_name}.h5")
print(f"Saving model to: {output_path}")
model.save(output_path)
print("\n✔ Model saved successfully!")
print("✔ Training finished!")
if __name__ == "__main__":
main()