|
|
import os |
|
|
import urllib.request as request |
|
|
from zipfile import ZipFile |
|
|
import tensorflow as tf |
|
|
import time |
|
|
from cnnClassfier.entity.config_entity import TrainingConfig |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
class Training: |
|
|
def __init__(self, config: TrainingConfig): |
|
|
self.config = config |
|
|
|
|
|
def get_base_model(self): |
|
|
self.model = tf.keras.models.load_model( |
|
|
self.config.updated_base_model_path |
|
|
) |
|
|
|
|
|
def train_valid_generator(self): |
|
|
|
|
|
datagenerator_kwargs = dict( |
|
|
rescale = 1./255, |
|
|
validation_split=0.20 |
|
|
) |
|
|
|
|
|
dataflow_kwargs = dict( |
|
|
target_size=self.config.params_image_size[:-1], |
|
|
batch_size=self.config.params_batch_size, |
|
|
interpolation="bilinear" |
|
|
) |
|
|
|
|
|
valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( |
|
|
**datagenerator_kwargs |
|
|
) |
|
|
|
|
|
self.valid_generator = valid_datagenerator.flow_from_directory( |
|
|
directory=self.config.training_data, |
|
|
subset="validation", |
|
|
shuffle=False, |
|
|
**dataflow_kwargs |
|
|
) |
|
|
|
|
|
if self.config.params_is_augmentation: |
|
|
train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( |
|
|
rotation_range=40, |
|
|
horizontal_flip=True, |
|
|
width_shift_range=0.2, |
|
|
height_shift_range=0.2, |
|
|
shear_range=0.2, |
|
|
zoom_range=0.2, |
|
|
**datagenerator_kwargs |
|
|
) |
|
|
else: |
|
|
train_datagenerator = valid_datagenerator |
|
|
|
|
|
self.train_generator = train_datagenerator.flow_from_directory( |
|
|
directory=self.config.training_data, |
|
|
subset="training", |
|
|
shuffle=True, |
|
|
**dataflow_kwargs |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def save_model(path: Path, model: tf.keras.Model): |
|
|
model.save(path) |
|
|
|
|
|
|
|
|
def train(self, callback_list: list): |
|
|
self.steps_per_epoch = self.train_generator.samples // self.train_generator.batch_size |
|
|
self.validation_steps = self.valid_generator.samples // self.valid_generator.batch_size |
|
|
|
|
|
self.model.fit( |
|
|
self.train_generator, |
|
|
epochs=self.config.params_epochs, |
|
|
steps_per_epoch=self.steps_per_epoch, |
|
|
validation_steps=self.validation_steps, |
|
|
validation_data=self.valid_generator, |
|
|
callbacks=callback_list |
|
|
) |
|
|
|
|
|
self.save_model( |
|
|
path=self.config.trained_model_path, |
|
|
model=self.model |
|
|
) |
|
|
|
|
|
|
|
|
|