Spaces:
Sleeping
Sleeping
| import os | |
| import urllib.request as request | |
| from zipfile import ZipFile | |
| import tensorflow as tf | |
| import time | |
| from cnnClassifier.entity.config_entity import TrainingConfig | |
| from pathlib import Path | |
| # --- NEW IMPORTS --- | |
| import pandas as pd | |
| from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping | |
| # -------------------- | |
| class Training: | |
| def __init__(self, config: TrainingConfig): | |
| self.config = config | |
| self.model = None | |
| self.train_generator = None | |
| self.valid_generator = None | |
| def get_base_model(self): | |
| self.model = tf.keras.models.load_model( | |
| self.config.updated_base_model_path | |
| ) | |
| def train_valid_generator(self): | |
| datagenerator_kwargs = dict( | |
| rescale=1./255, | |
| validation_split=0.20 | |
| ) | |
| dataflow_kwargs = dict( | |
| target_size=self.config.params_image_size[:-1], | |
| batch_size=self.config.params_batch_size, | |
| interpolation="bilinear" | |
| ) | |
| valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( | |
| **datagenerator_kwargs | |
| ) | |
| self.valid_generator = valid_datagenerator.flow_from_directory( | |
| directory=self.config.training_data, | |
| subset="validation", | |
| shuffle=False, | |
| **dataflow_kwargs | |
| ) | |
| if self.config.params_is_augmentation: | |
| train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator( | |
| rotation_range=20, # Reduced for stability | |
| horizontal_flip=True, | |
| width_shift_range=0.1, | |
| height_shift_range=0.1, | |
| shear_range=0.1, | |
| zoom_range=0.1, | |
| **datagenerator_kwargs | |
| ) | |
| else: | |
| train_datagenerator = valid_datagenerator | |
| self.train_generator = train_datagenerator.flow_from_directory( | |
| directory=self.config.training_data, | |
| subset="training", | |
| shuffle=True, | |
| **dataflow_kwargs | |
| ) | |
| # --- ADD THIS --- | |
| # Print class indices to be 100% sure of the mapping | |
| print(f"Discovered class indices: {self.train_generator.class_indices}") | |
| # -------------- | |
| def save_model(path: Path, model: tf.keras.Model): | |
| model.save(path) | |
| def train(self): | |
| self.steps_per_epoch = self.train_generator.samples // self.train_generator.batch_size | |
| self.validation_steps = self.valid_generator.samples // self.valid_generator.batch_size | |
| # --- NEW: DEFINE CALLBACKS FOR SMART TRAINING --- | |
| # This will save the BEST model based on validation accuracy | |
| best_model_checkpoint = ModelCheckpoint( | |
| filepath=self.config.trained_model_path, # Saves the best model to your specified path | |
| save_best_only=True, | |
| monitor='val_accuracy', | |
| mode='max', | |
| verbose=1 | |
| ) | |
| # This will stop training if there's no improvement | |
| early_stopping = EarlyStopping( | |
| monitor='val_accuracy', | |
| patience=5, # Number of epochs with no improvement to wait | |
| restore_best_weights=True, | |
| verbose=1 | |
| ) | |
| callbacks_list = [best_model_checkpoint, early_stopping] | |
| # ----------------------------------------------- | |
| # --- MODEL.FIT() IS NOW UPGRADED --- | |
| history = self.model.fit( | |
| self.train_generator, | |
| epochs=self.config.params_epochs, | |
| steps_per_epoch=self.steps_per_epoch, | |
| validation_steps=self.validation_steps, | |
| validation_data=self.valid_generator, | |
| callbacks=callbacks_list # Pass the smart callbacks here | |
| ) | |
| # ------------------------------------- | |
| # --- NEW: SAVE TRAINING HISTORY FOR ANALYSIS --- | |
| history_df = pd.DataFrame(history.history) | |
| history_path = "training_history.csv" # Saved in the root directory | |
| history_df.to_csv(history_path, index=False) | |
| print(f"✅ Training history saved to {history_path}") | |
| # ----------------------------------------------- | |
| # The save_model call is now handled by ModelCheckpoint, | |
| # so this is redundant but harmless. It will save the last epoch's model. | |
| # The BEST model is already saved by the callback. | |
| # self.save_model( | |
| # path=self.config.trained_model_path, | |
| # model=self.model | |
| # ) |