|
|
import tensorflow as tf |
|
|
import os |
|
|
import matplotlib.pyplot as plt |
|
|
from dotenv import load_dotenv |
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
|
|
|
load_dotenv() |
|
|
|
|
|
|
|
|
BATCH_SIZE = 32 |
|
|
IMG_SIZE = (224, 224) |
|
|
TRAIN_DATASET = os.getenv("TRAIN_DATASET") |
|
|
EPOCHS = 8 |
|
|
OPTIMIZER = 'adam' |
|
|
LOSS_FUNC = 'binary_crossentropy' |
|
|
|
|
|
|
|
|
def load_data(): |
|
|
datagen = ImageDataGenerator( |
|
|
validation_split=0.2, |
|
|
rescale=1./255, |
|
|
horizontal_flip=True, |
|
|
zoom_range=0.2 |
|
|
) |
|
|
|
|
|
train_data = datagen.flow_from_directory( |
|
|
directory=TRAIN_DATASET, |
|
|
target_size=IMG_SIZE, |
|
|
batch_size=BATCH_SIZE, |
|
|
class_mode="binary", |
|
|
subset="training", |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
val_data = datagen.flow_from_directory( |
|
|
directory=TRAIN_DATASET, |
|
|
target_size=IMG_SIZE, |
|
|
batch_size=BATCH_SIZE, |
|
|
class_mode="binary", |
|
|
subset="validation", |
|
|
shuffle=True |
|
|
) |
|
|
|
|
|
return train_data, val_data |
|
|
|
|
|
|
|
|
def build_model(): |
|
|
model = tf.keras.Sequential([ |
|
|
|
|
|
|
|
|
|
|
|
tf.keras.layers.Conv2D(32, (3,3), activation='relu', input_shape=(*IMG_SIZE, 3)), |
|
|
tf.keras.layers.MaxPooling2D(2,2), |
|
|
|
|
|
|
|
|
|
|
|
tf.keras.layers.Conv2D(64, (3,3), activation='relu'), |
|
|
tf.keras.layers.MaxPooling2D(2,2), |
|
|
|
|
|
|
|
|
|
|
|
tf.keras.layers.Conv2D(128, (3,3), activation='relu'), |
|
|
tf.keras.layers.MaxPooling2D(2,2), |
|
|
tf.keras.layers.Flatten(), |
|
|
tf.keras.layers.Dense(512, activation='relu'), |
|
|
tf.keras.layers.Dense(1, activation='sigmoid') |
|
|
]) |
|
|
|
|
|
|
|
|
model.compile(optimizer=OPTIMIZER, |
|
|
loss=LOSS_FUNC, |
|
|
metrics=['accuracy']) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def main(): |
|
|
train_data, val_data = load_data() |
|
|
model = build_model() |
|
|
|
|
|
|
|
|
history = model.fit( |
|
|
train_data, |
|
|
epochs = EPOCHS, |
|
|
validation_data=val_data |
|
|
) |
|
|
|
|
|
|
|
|
model.save("cat_dog_model.h5") |
|
|
|
|
|
|
|
|
acc = history.history['accuracy'] |
|
|
loss = history.history['loss'] |
|
|
val_acc = history.history['val_accuracy'] |
|
|
val_loss = history.history['val_loss'] |
|
|
plt.plot(acc, label='Train Accuracy') |
|
|
plt.plot(val_acc, label='Validation Accuracy') |
|
|
plt.plot(loss, label='Train Loss') |
|
|
plt.plot(val_loss, label='Validation Loss') |
|
|
plt.legend() |
|
|
plt.title('Training Accuracy') |
|
|
plt.show() |
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|