GitHub Action
Automated deployment from GitHub Actions
f8f5549
import tensorflow as tf
from EmotionRecognition import logger
from EmotionRecognition.entity.config_entity import ModelTrainerConfig
from EmotionRecognition.utils.common import create_mobilenetv2_model
from pathlib import Path
class ModelTrainer:
def __init__(self, config: ModelTrainerConfig, params: dict):
self.config = config
self.params = params
self.model = None
def get_datasets(self):
data_params = self.params.DATA_PARAMS
logger.info("Loading prepared train and test datasets...")
# Create a training dataset from the combined, imbalanced data
train_ds = tf.keras.utils.image_dataset_from_directory(
self.config.train_data_dir,
labels='inferred',
label_mode='categorical',
class_names=data_params.CLASSES,
image_size=data_params.IMAGE_SIZE,
interpolation='nearest',
batch_size=data_params.BATCH_SIZE,
shuffle=True,
color_mode='grayscale' # <--- ADD THIS LINE
)
# Create a validation/test dataset
val_ds = tf.keras.utils.image_dataset_from_directory(
self.config.test_data_dir,
labels='inferred',
label_mode='categorical',
class_names=data_params.CLASSES,
image_size=data_params.IMAGE_SIZE,
interpolation='nearest',
batch_size=data_params.BATCH_SIZE,
shuffle=False,
color_mode='grayscale' # <--- AND ADD THIS LINE
)
def preprocess(image, label):
# This dataset is already in PNG format, so we decode PNG
# It's also already grayscale (1 channel)
image = tf.image.grayscale_to_rgb(image) # Models expect 3 channels
image = tf.cast(image, tf.float32) / 255.0
return image, label
data_augmentation = tf.keras.Sequential([
tf.keras.layers.RandomFlip("horizontal"),
tf.keras.layers.RandomRotation(0.1),
tf.keras.layers.RandomZoom(0.1)
])
train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)
return train_ds.prefetch(tf.data.AUTOTUNE), val_ds.prefetch(tf.data.AUTOTUNE)
def build_and_train_model(self):
data_params = self.params.DATA_PARAMS
training_params = self.params.TRAINING_PARAMS
logger.info("Building model with a frozen MobileNetV2 base...")
input_shape = data_params.IMAGE_SIZE + [data_params.CHANNELS]
self.model = create_mobilenetv2_model(
input_shape=input_shape,
num_classes=data_params.NUM_CLASSES,
dropout_rate=training_params.DROPOUT_RATE
)
base_model = self.model.layers[1]
base_model.trainable = False
self.model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=training_params.LEARNING_RATE),
loss=training_params.LOSS_FUNCTION,
metrics=training_params.METRICS
)
self.model.summary(print_fn=logger.info)
train_ds, val_ds = self.get_datasets()
logger.info(f"--- Starting training for {training_params.EPOCHS} epochs ---")
self.model.fit(
train_ds,
epochs=training_params.EPOCHS,
validation_data=val_ds
)
self.save_model()
def save_model(self):
model_path = str(self.config.trained_model_path)
self.model.save(model_path)
logger.info(f"Full model saved successfully to: {model_path}")