import os import urllib.request as request from zipfile import ZipFile import tensorflow as tf import time from cnnClassfier.entity.config_entity import TrainingConfig from pathlib import Path class Training: def __init__(self, config: TrainingConfig): self.config = config def get_base_model(self): self.model = tf.keras.models.load_model( self.config.updated_base_model_path ) def train_valid_generator(self): datagenerator_kwargs = dict( rescale = 1./255, validation_split=0.20 ) dataflow_kwargs = dict( target_size=self.config.params_image_size[:-1], batch_size=self.config.params_batch_size, interpolation="bilinear" ) valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( **datagenerator_kwargs ) self.valid_generator = valid_datagenerator.flow_from_directory( directory=self.config.training_data, subset="validation", shuffle=False, **dataflow_kwargs ) if self.config.params_is_augmentation: train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( rotation_range=40, horizontal_flip=True, width_shift_range=0.2, height_shift_range=0.2, shear_range=0.2, zoom_range=0.2, **datagenerator_kwargs ) else: train_datagenerator = valid_datagenerator self.train_generator = train_datagenerator.flow_from_directory( directory=self.config.training_data, subset="training", shuffle=True, **dataflow_kwargs ) @staticmethod def save_model(path: Path, model: tf.keras.Model): model.save(path) def train(self, callback_list: list): self.steps_per_epoch = self.train_generator.samples // self.train_generator.batch_size self.validation_steps = self.valid_generator.samples // self.valid_generator.batch_size self.model.fit( self.train_generator, epochs=self.config.params_epochs, steps_per_epoch=self.steps_per_epoch, validation_steps=self.validation_steps, validation_data=self.valid_generator, callbacks=callback_list ) self.save_model( path=self.config.trained_model_path, model=self.model )