|
|
|
|
|
|
|
|
import tensorflow as tf
|
|
|
from tensorflow.keras.applications import VGG16
|
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator
|
|
|
from tensorflow.keras.models import Sequential
|
|
|
from tensorflow.keras.layers import Flatten, Dense, Dropout
|
|
|
|
|
|
|
|
|
IMAGE_SIZE = (224, 224)
|
|
|
BATCH_SIZE = 32
|
|
|
EPOCHS = 5
|
|
|
|
|
|
|
|
|
DATA_DIR = 'dataset'
|
|
|
|
|
|
|
|
|
|
|
|
train_datagen = ImageDataGenerator(
|
|
|
rescale=1./255,
|
|
|
shear_range=0.2,
|
|
|
zoom_range=0.2,
|
|
|
horizontal_flip=True,
|
|
|
validation_split=0.2
|
|
|
)
|
|
|
|
|
|
train_generator = train_datagen.flow_from_directory(
|
|
|
DATA_DIR,
|
|
|
target_size=IMAGE_SIZE,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
class_mode='binary',
|
|
|
subset='training'
|
|
|
)
|
|
|
|
|
|
validation_generator = train_datagen.flow_from_directory(
|
|
|
DATA_DIR,
|
|
|
target_size=IMAGE_SIZE,
|
|
|
batch_size=BATCH_SIZE,
|
|
|
class_mode='binary',
|
|
|
subset='validation'
|
|
|
)
|
|
|
|
|
|
print("Data Generators Ready!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_model = VGG16(
|
|
|
weights='imagenet',
|
|
|
include_top=False,
|
|
|
input_shape=(224, 224, 3)
|
|
|
)
|
|
|
|
|
|
|
|
|
base_model.trainable = False
|
|
|
|
|
|
|
|
|
model = Sequential([
|
|
|
base_model,
|
|
|
Flatten(),
|
|
|
Dense(512, activation='relu'),
|
|
|
Dropout(0.5),
|
|
|
Dense(1, activation='sigmoid')
|
|
|
])
|
|
|
|
|
|
|
|
|
model.compile(
|
|
|
optimizer='adam',
|
|
|
loss='binary_crossentropy',
|
|
|
metrics=['accuracy']
|
|
|
)
|
|
|
|
|
|
model.summary()
|
|
|
|
|
|
print("Model Training Started...")
|
|
|
history = model.fit(
|
|
|
train_generator,
|
|
|
steps_per_epoch=train_generator.samples // BATCH_SIZE,
|
|
|
epochs=EPOCHS,
|
|
|
validation_data=validation_generator,
|
|
|
validation_steps=validation_generator.samples // BATCH_SIZE
|
|
|
)
|
|
|
|
|
|
print("Model Training Complete!")
|
|
|
|
|
|
|
|
|
MODEL_FILE_NAME = 'model.keras'
|
|
|
model.save(MODEL_FILE_NAME)
|
|
|
print(f"Model saved as: {MODEL_FILE_NAME}")
|
|
|
|
|
|
|
|
|
print(f"Final Validation Accuracy: {history.history['val_accuracy'][-1]:.4f}") |