File size: 2,177 Bytes
05f7b3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
from cnnClassfier.entity.config_entity import PrepareBaseModelConfig
from pathlib import Path
class PrepareBaseModel:
def __init__(self, config: PrepareBaseModelConfig):
self.config = config
def get_base_model(self):
self.model = tf.keras.applications.vgg16.VGG16(
input_shape = self.config.params_image_size,
weights = self.config.params_weights,
include_top = self.config.params_include_top
)
self.save_model(path = self.config.base_model_path, model = self.model)
@staticmethod
def prepare_full_model(model, classes, freeze_all, freeze_till, learinig_rate):
if freeze_all:
for layer in model.layers:
model.trainable = False
elif (freeze_till is not None) and (freeze_till > 0):
for layer in model.layers[:-freeze_till]:
model.trainable = False
flatten_in = tf.keras.layers.Flatten()(model.output)
prediction = tf.keras.layers.Dense(
units = classes,
activation = 'softmax'
)(flatten_in)
full_model = tf.keras.models.Model(
inputs = model.input,
outputs = prediction
)
full_model.compile(
optimizer = tf.keras.optimizers.SGD(lr = learinig_rate),
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = ['accuracy']
)
full_model.summary()
return full_model
def update_base_model(self):
self.full_model = self.prepare_full_model(
model = self.model,
classes = self.config.params_classes,
freeze_all=True,
freeze_till=None,
learinig_rate=self.config.params_learning_rate
)
self.save_model(path = self.config.updated_base_model_path, model = self.full_model)
@staticmethod
def save_model(path: Path, model: tf.keras.Model):
model.save(path)
|