dpatel9923's picture
Update model.py
f8fa4fd verified
raw
history blame contribute delete
977 Bytes
from tensorflow.keras.applications import VGG16
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D, Dropout
from tensorflow.keras.optimizers import Adam
def build_model(input_shape, num_classes):
# Load VGG16 model without the top layers
base_model = VGG16(weights='imagenet', include_top=False, input_shape=input_shape)
# Adding additional layers
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
x = Dropout(0.5)(x)
predictions = Dense(num_classes, activation='softmax')(x)
# Creating the final model
model = Model(inputs=base_model.input, outputs=predictions)
# Freezing the layers except the last layers
for layer in base_model.layers:
layer.trainable = False
# Compile the model
model.compile(optimizer=Adam(learning_rate=0.0001), loss='categorical_crossentropy', metrics=['accuracy'])
return model