khhamid's picture
Upload folder using huggingface_hub
e2a99cd verified
import tensorflow as tf
import keras
from keras.applications import MobileNetV3Small, mobilenet_v3
from keras import layers,models
from keras.utils import image_dataset_from_directory
import matplotlib.pyplot as plt
def preprocess(x):
return mobilenet_v3.preprocess_input(x)
train_dir="/media/data/plants_diseases_dataset/train"
val_dir="/media/data/plants_diseases_dataset/valid"
path="/media/data/plants_diseases_dataset/"
img_size=(224,224)
batch_size=32
INPUT_SHAPE = (224, 224, 3)
num_epochs=10
print(INPUT_SHAPE)
def get_datasets(path):
train_dir=path+"/train"
val_dir=path+"/valid"
train_ds=image_dataset_from_directory(train_dir,
image_size=img_size,
batch_size=batch_size,
label_mode="categorical")
val_ds=image_dataset_from_directory(val_dir,
image_size=img_size,
batch_size=batch_size,
label_mode="categorical")
num_classes=len(train_ds.class_names)
train_ds=train_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y))
val_ds=val_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y))
AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.prefetch(buffer_size=AUTOTUNE)
return train_ds, val_ds, num_classes
def create_MobileNet(INPUT_SHAPE,NUM_CLASSES):
base_model=MobileNetV3Small(input_shape=INPUT_SHAPE,
include_top=False,
weights='imagenet')
model=models.Sequential([
keras.Input(shape=INPUT_SHAPE),
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.5),
layers.Dense(NUM_CLASSES,activation='softmax')
])
model.compile(
optimizer=keras.optimizers.Adam(0.001),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
return model
def plot_hist(history):
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Val Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
def main():
train_ds,val_ds,num_classes=get_datasets(path)
model=create_MobileNet(INPUT_SHAPE,num_classes)
history=model.fit(train_ds,validation_data=val_ds,epochs=num_epochs)
model.save(f"models/mobileNet_{num_epochs}.keras")
plot_hist(history)
if __name__=="__main__":
main()