Spaces:
Sleeping
Sleeping
File size: 2,063 Bytes
3e93e14 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | import os
import urllib.request as request
from pathlib import Path
import tensorflow as tf
from cnnClassifier.entity.config_entity import PrepareBaseModelConfig
class PrepareBaseModel:
def __init__(self, config: PrepareBaseModelConfig):
self.config = config
def get_base_model(self):
self.model = tf.keras.applications.VGG16(
input_shape=self.config.params_image_size,
weights=self.config.params_weights,
include_top=self.config.params_include_top
)
self.save_model(path=self.config.base_model_path, model=self.model)
@staticmethod
def _prepare_full_model(model, classes, freeze_all, freeze_till, learning_rate):
if freeze_all:
for layer in model.layers:
layer.trainable = False
elif freeze_till is not None and freeze_till > 0:
for layer in model.layers[:-freeze_till]:
layer.trainable = False
flatten_in = tf.keras.layers.Flatten()(model.output)
prediction = tf.keras.layers.Dense(
units=classes,
activation="softmax"
)(flatten_in)
full_model = tf.keras.models.Model(
inputs=model.input,
outputs=prediction
)
full_model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=learning_rate),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=["accuracy"]
)
full_model.summary()
return full_model
def update_base_model(self):
self.full_model = self._prepare_full_model(
model=self.model,
classes=self.config.params_classes,
freeze_all=True,
freeze_till=None,
learning_rate=self.config.params_learning_rate
)
self.save_model(path=self.config.updated_base_model_path, model=self.full_model)
@staticmethod
def save_model(path: Path, model: tf.keras.Model):
model.save(path)
|