File size: 4,531 Bytes
d576da9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
import time
from cnnClassifier.entity.config_entity import TrainingConfig
from pathlib import Path

# --- NEW IMPORTS ---
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# --------------------

class Training:
    def __init__(self, config: TrainingConfig):
        self.config = config
        self.model = None
        self.train_generator = None
        self.valid_generator = None

    def get_base_model(self):
        self.model = tf.keras.models.load_model(
            self.config.updated_base_model_path
        )

    def train_valid_generator(self):
        datagenerator_kwargs = dict(
            rescale=1./255,
            validation_split=0.20
        )

        dataflow_kwargs = dict(
            target_size=self.config.params_image_size[:-1],
            batch_size=self.config.params_batch_size,
            interpolation="bilinear"
        )

        valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
            **datagenerator_kwargs
        )

        self.valid_generator = valid_datagenerator.flow_from_directory(
            directory=self.config.training_data,
            subset="validation",
            shuffle=False,
            **dataflow_kwargs
        )

        if self.config.params_is_augmentation:
            train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
                rotation_range=20, # Reduced for stability
                horizontal_flip=True,
                width_shift_range=0.1,
                height_shift_range=0.1,
                shear_range=0.1,
                zoom_range=0.1,
                **datagenerator_kwargs
            )
        else:
            train_datagenerator = valid_datagenerator

        self.train_generator = train_datagenerator.flow_from_directory(
            directory=self.config.training_data,
            subset="training",
            shuffle=True,
            **dataflow_kwargs
        )
        
        # --- ADD THIS ---
        # Print class indices to be 100% sure of the mapping
        print(f"Discovered class indices: {self.train_generator.class_indices}")
        # --------------

    @staticmethod
    def save_model(path: Path, model: tf.keras.Model):
        model.save(path)

    def train(self):
        self.steps_per_epoch = self.train_generator.samples // self.train_generator.batch_size
        self.validation_steps = self.valid_generator.samples // self.valid_generator.batch_size
        
        # --- NEW: DEFINE CALLBACKS FOR SMART TRAINING ---
        # This will save the BEST model based on validation accuracy
        best_model_checkpoint = ModelCheckpoint(
            filepath=self.config.trained_model_path, # Saves the best model to your specified path
            save_best_only=True,
            monitor='val_accuracy',
            mode='max',
            verbose=1
        )

        # This will stop training if there's no improvement
        early_stopping = EarlyStopping(
            monitor='val_accuracy',
            patience=5, # Number of epochs with no improvement to wait
            restore_best_weights=True,
            verbose=1
        )
        
        callbacks_list = [best_model_checkpoint, early_stopping]
        # -----------------------------------------------

        # --- MODEL.FIT() IS NOW UPGRADED ---
        history = self.model.fit(
            self.train_generator,
            epochs=self.config.params_epochs,
            steps_per_epoch=self.steps_per_epoch,
            validation_steps=self.validation_steps,
            validation_data=self.valid_generator,
            callbacks=callbacks_list # Pass the smart callbacks here
        )
        # -------------------------------------

        # --- NEW: SAVE TRAINING HISTORY FOR ANALYSIS ---
        history_df = pd.DataFrame(history.history)
        history_path = "training_history.csv" # Saved in the root directory
        history_df.to_csv(history_path, index=False)
        print(f"✅ Training history saved to {history_path}")
        # -----------------------------------------------
        
        # The save_model call is now handled by ModelCheckpoint,
        # so this is redundant but harmless. It will save the last epoch's model.
        # The BEST model is already saved by the callback.
        # self.save_model(
        #     path=self.config.trained_model_path,
        #     model=self.model
        # )