Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| import keras | |
| from keras.applications import MobileNetV3Small, mobilenet_v3 | |
| from keras import layers,models | |
| from keras.utils import image_dataset_from_directory | |
| import matplotlib.pyplot as plt | |
| def preprocess(x): | |
| return mobilenet_v3.preprocess_input(x) | |
| train_dir="/media/data/plants_diseases_dataset/train" | |
| val_dir="/media/data/plants_diseases_dataset/valid" | |
| path="/media/data/plants_diseases_dataset/" | |
| img_size=(224,224) | |
| batch_size=32 | |
| INPUT_SHAPE = (224, 224, 3) | |
| num_epochs=10 | |
| print(INPUT_SHAPE) | |
| def get_datasets(path): | |
| train_dir=path+"/train" | |
| val_dir=path+"/valid" | |
| train_ds=image_dataset_from_directory(train_dir, | |
| image_size=img_size, | |
| batch_size=batch_size, | |
| label_mode="categorical") | |
| val_ds=image_dataset_from_directory(val_dir, | |
| image_size=img_size, | |
| batch_size=batch_size, | |
| label_mode="categorical") | |
| num_classes=len(train_ds.class_names) | |
| train_ds=train_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y)) | |
| val_ds=val_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y)) | |
| AUTOTUNE=tf.data.AUTOTUNE | |
| train_ds=train_ds.prefetch(buffer_size=AUTOTUNE) | |
| val_ds=val_ds.prefetch(buffer_size=AUTOTUNE) | |
| return train_ds, val_ds, num_classes | |
| def create_MobileNet(INPUT_SHAPE,NUM_CLASSES): | |
| base_model=MobileNetV3Small(input_shape=INPUT_SHAPE, | |
| include_top=False, | |
| weights='imagenet') | |
| model=models.Sequential([ | |
| keras.Input(shape=INPUT_SHAPE), | |
| base_model, | |
| layers.GlobalAveragePooling2D(), | |
| layers.Dropout(0.5), | |
| layers.Dense(NUM_CLASSES,activation='softmax') | |
| ]) | |
| model.compile( | |
| optimizer=keras.optimizers.Adam(0.001), | |
| loss="categorical_crossentropy", | |
| metrics=["accuracy"] | |
| ) | |
| return model | |
| def plot_hist(history): | |
| plt.figure(figsize=(12, 5)) | |
| plt.subplot(1, 2, 1) | |
| plt.plot(history.history["loss"], label="Train Loss") | |
| plt.plot(history.history["val_loss"], label="Val Loss") | |
| plt.title("Training and Validation Loss") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("Loss") | |
| plt.legend() | |
| plt.subplot(1, 2, 2) | |
| plt.plot(history.history["accuracy"], label="Train Accuracy") | |
| plt.plot(history.history["val_accuracy"], label="Val Accuracy") | |
| plt.title("Training and Validation Accuracy") | |
| plt.xlabel("Epochs") | |
| plt.ylabel("Accuracy") | |
| plt.legend() | |
| plt.tight_layout() | |
| plt.show() | |
| def main(): | |
| train_ds,val_ds,num_classes=get_datasets(path) | |
| model=create_MobileNet(INPUT_SHAPE,num_classes) | |
| history=model.fit(train_ds,validation_data=val_ds,epochs=num_epochs) | |
| model.save(f"models/mobileNet_{num_epochs}.keras") | |
| plot_hist(history) | |
| if __name__=="__main__": | |
| main() |