Spaces:
Running
Running
| import tensorflow as tf | |
| from EmotionRecognition import logger | |
| from EmotionRecognition.entity.config_entity import ModelTrainerConfig | |
| from EmotionRecognition.utils.common import create_mobilenetv2_model | |
| from pathlib import Path | |
| class ModelTrainer: | |
| def __init__(self, config: ModelTrainerConfig, params: dict): | |
| self.config = config | |
| self.params = params | |
| self.model = None | |
| def get_datasets(self): | |
| data_params = self.params.DATA_PARAMS | |
| logger.info("Loading prepared train and test datasets...") | |
| # Create a training dataset from the combined, imbalanced data | |
| train_ds = tf.keras.utils.image_dataset_from_directory( | |
| self.config.train_data_dir, | |
| labels='inferred', | |
| label_mode='categorical', | |
| class_names=data_params.CLASSES, | |
| image_size=data_params.IMAGE_SIZE, | |
| interpolation='nearest', | |
| batch_size=data_params.BATCH_SIZE, | |
| shuffle=True, | |
| color_mode='grayscale' # <--- ADD THIS LINE | |
| ) | |
| # Create a validation/test dataset | |
| val_ds = tf.keras.utils.image_dataset_from_directory( | |
| self.config.test_data_dir, | |
| labels='inferred', | |
| label_mode='categorical', | |
| class_names=data_params.CLASSES, | |
| image_size=data_params.IMAGE_SIZE, | |
| interpolation='nearest', | |
| batch_size=data_params.BATCH_SIZE, | |
| shuffle=False, | |
| color_mode='grayscale' # <--- AND ADD THIS LINE | |
| ) | |
| def preprocess(image, label): | |
| # This dataset is already in PNG format, so we decode PNG | |
| # It's also already grayscale (1 channel) | |
| image = tf.image.grayscale_to_rgb(image) # Models expect 3 channels | |
| image = tf.cast(image, tf.float32) / 255.0 | |
| return image, label | |
| data_augmentation = tf.keras.Sequential([ | |
| tf.keras.layers.RandomFlip("horizontal"), | |
| tf.keras.layers.RandomRotation(0.1), | |
| tf.keras.layers.RandomZoom(0.1) | |
| ]) | |
| train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) | |
| val_ds = val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE) | |
| train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE) | |
| return train_ds.prefetch(tf.data.AUTOTUNE), val_ds.prefetch(tf.data.AUTOTUNE) | |
| def build_and_train_model(self): | |
| data_params = self.params.DATA_PARAMS | |
| training_params = self.params.TRAINING_PARAMS | |
| logger.info("Building model with a frozen MobileNetV2 base...") | |
| input_shape = data_params.IMAGE_SIZE + [data_params.CHANNELS] | |
| self.model = create_mobilenetv2_model( | |
| input_shape=input_shape, | |
| num_classes=data_params.NUM_CLASSES, | |
| dropout_rate=training_params.DROPOUT_RATE | |
| ) | |
| base_model = self.model.layers[1] | |
| base_model.trainable = False | |
| self.model.compile( | |
| optimizer=tf.keras.optimizers.Adam(learning_rate=training_params.LEARNING_RATE), | |
| loss=training_params.LOSS_FUNCTION, | |
| metrics=training_params.METRICS | |
| ) | |
| self.model.summary(print_fn=logger.info) | |
| train_ds, val_ds = self.get_datasets() | |
| logger.info(f"--- Starting training for {training_params.EPOCHS} epochs ---") | |
| self.model.fit( | |
| train_ds, | |
| epochs=training_params.EPOCHS, | |
| validation_data=val_ds | |
| ) | |
| self.save_model() | |
| def save_model(self): | |
| model_path = str(self.config.trained_model_path) | |
| self.model.save(model_path) | |
| logger.info(f"Full model saved successfully to: {model_path}") |