Spaces:
Sleeping
Sleeping
File size: 2,985 Bytes
e2a99cd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import tensorflow as tf
import keras
from keras.applications import MobileNetV3Small, mobilenet_v3
from keras import layers,models
from keras.utils import image_dataset_from_directory
import matplotlib.pyplot as plt
def preprocess(x):
return mobilenet_v3.preprocess_input(x)
train_dir="/media/data/plants_diseases_dataset/train"
val_dir="/media/data/plants_diseases_dataset/valid"
path="/media/data/plants_diseases_dataset/"
img_size=(224,224)
batch_size=32
INPUT_SHAPE = (224, 224, 3)
num_epochs=10
print(INPUT_SHAPE)
def get_datasets(path):
train_dir=path+"/train"
val_dir=path+"/valid"
train_ds=image_dataset_from_directory(train_dir,
image_size=img_size,
batch_size=batch_size,
label_mode="categorical")
val_ds=image_dataset_from_directory(val_dir,
image_size=img_size,
batch_size=batch_size,
label_mode="categorical")
num_classes=len(train_ds.class_names)
train_ds=train_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y))
val_ds=val_ds.map(lambda x, y: (mobilenet_v3.preprocess_input(x), y))
AUTOTUNE=tf.data.AUTOTUNE
train_ds=train_ds.prefetch(buffer_size=AUTOTUNE)
val_ds=val_ds.prefetch(buffer_size=AUTOTUNE)
return train_ds, val_ds, num_classes
def create_MobileNet(INPUT_SHAPE,NUM_CLASSES):
base_model=MobileNetV3Small(input_shape=INPUT_SHAPE,
include_top=False,
weights='imagenet')
model=models.Sequential([
keras.Input(shape=INPUT_SHAPE),
base_model,
layers.GlobalAveragePooling2D(),
layers.Dropout(0.5),
layers.Dense(NUM_CLASSES,activation='softmax')
])
model.compile(
optimizer=keras.optimizers.Adam(0.001),
loss="categorical_crossentropy",
metrics=["accuracy"]
)
return model
def plot_hist(history):
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(history.history["loss"], label="Train Loss")
plt.plot(history.history["val_loss"], label="Val Loss")
plt.title("Training and Validation Loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(history.history["accuracy"], label="Train Accuracy")
plt.plot(history.history["val_accuracy"], label="Val Accuracy")
plt.title("Training and Validation Accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.tight_layout()
plt.show()
def main():
train_ds,val_ds,num_classes=get_datasets(path)
model=create_MobileNet(INPUT_SHAPE,num_classes)
history=model.fit(train_ds,validation_data=val_ds,epochs=num_epochs)
model.save(f"models/mobileNet_{num_epochs}.keras")
plot_hist(history)
if __name__=="__main__":
main() |