File size: 3,818 Bytes
f8f5549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import tensorflow as tf
from EmotionRecognition import logger
from EmotionRecognition.entity.config_entity import ModelTrainerConfig
from EmotionRecognition.utils.common import create_mobilenetv2_model
from pathlib import Path

class ModelTrainer:
    def __init__(self, config: ModelTrainerConfig, params: dict):
        self.config = config
        self.params = params
        self.model = None

    def get_datasets(self):
        data_params = self.params.DATA_PARAMS
        logger.info("Loading prepared train and test datasets...")

        # Create a training dataset from the combined, imbalanced data
        train_ds = tf.keras.utils.image_dataset_from_directory(
            self.config.train_data_dir,
            labels='inferred',
            label_mode='categorical',
            class_names=data_params.CLASSES,
            image_size=data_params.IMAGE_SIZE,
            interpolation='nearest',
            batch_size=data_params.BATCH_SIZE,
            shuffle=True,
            color_mode='grayscale' # <--- ADD THIS LINE
        )

        # Create a validation/test dataset
        val_ds = tf.keras.utils.image_dataset_from_directory(
            self.config.test_data_dir,
            labels='inferred',
            label_mode='categorical',
            class_names=data_params.CLASSES,
            image_size=data_params.IMAGE_SIZE,
            interpolation='nearest',
            batch_size=data_params.BATCH_SIZE,
            shuffle=False,
            color_mode='grayscale' # <--- AND ADD THIS LINE
        )
        
        def preprocess(image, label):
            # This dataset is already in PNG format, so we decode PNG
            # It's also already grayscale (1 channel)
            image = tf.image.grayscale_to_rgb(image) # Models expect 3 channels
            image = tf.cast(image, tf.float32) / 255.0
            return image, label
        
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.1),
            tf.keras.layers.RandomZoom(0.1)
        ])

        train_ds = train_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        val_ds = val_ds.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y), num_parallel_calls=tf.data.AUTOTUNE)

        return train_ds.prefetch(tf.data.AUTOTUNE), val_ds.prefetch(tf.data.AUTOTUNE)

    def build_and_train_model(self):
        data_params = self.params.DATA_PARAMS
        training_params = self.params.TRAINING_PARAMS
        
        logger.info("Building model with a frozen MobileNetV2 base...")
        input_shape = data_params.IMAGE_SIZE + [data_params.CHANNELS]

        self.model = create_mobilenetv2_model(
            input_shape=input_shape,
            num_classes=data_params.NUM_CLASSES,
            dropout_rate=training_params.DROPOUT_RATE
        )
        
        base_model = self.model.layers[1]
        base_model.trainable = False
        
        self.model.compile(
            optimizer=tf.keras.optimizers.Adam(learning_rate=training_params.LEARNING_RATE),
            loss=training_params.LOSS_FUNCTION, 
            metrics=training_params.METRICS
        )
        self.model.summary(print_fn=logger.info)

        train_ds, val_ds = self.get_datasets()

        logger.info(f"--- Starting training for {training_params.EPOCHS} epochs ---")
        self.model.fit(
            train_ds, 
            epochs=training_params.EPOCHS, 
            validation_data=val_ds
        )
        
        self.save_model()

    def save_model(self):
        model_path = str(self.config.trained_model_path)
        self.model.save(model_path)
        logger.info(f"Full model saved successfully to: {model_path}")