Spaces:
Sleeping
Sleeping
File size: 4,531 Bytes
d576da9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import os
import urllib.request as request
from zipfile import ZipFile
import tensorflow as tf
import time
from cnnClassifier.entity.config_entity import TrainingConfig
from pathlib import Path
# --- NEW IMPORTS ---
import pandas as pd
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
# --------------------
class Training:
def __init__(self, config: TrainingConfig):
self.config = config
self.model = None
self.train_generator = None
self.valid_generator = None
def get_base_model(self):
self.model = tf.keras.models.load_model(
self.config.updated_base_model_path
)
def train_valid_generator(self):
datagenerator_kwargs = dict(
rescale=1./255,
validation_split=0.20
)
dataflow_kwargs = dict(
target_size=self.config.params_image_size[:-1],
batch_size=self.config.params_batch_size,
interpolation="bilinear"
)
valid_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
**datagenerator_kwargs
)
self.valid_generator = valid_datagenerator.flow_from_directory(
directory=self.config.training_data,
subset="validation",
shuffle=False,
**dataflow_kwargs
)
if self.config.params_is_augmentation:
train_datagenerator = tf.keras.preprocessing.image.ImageDataGenerator(
rotation_range=20, # Reduced for stability
horizontal_flip=True,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=0.1,
zoom_range=0.1,
**datagenerator_kwargs
)
else:
train_datagenerator = valid_datagenerator
self.train_generator = train_datagenerator.flow_from_directory(
directory=self.config.training_data,
subset="training",
shuffle=True,
**dataflow_kwargs
)
# --- ADD THIS ---
# Print class indices to be 100% sure of the mapping
print(f"Discovered class indices: {self.train_generator.class_indices}")
# --------------
@staticmethod
def save_model(path: Path, model: tf.keras.Model):
model.save(path)
def train(self):
self.steps_per_epoch = self.train_generator.samples // self.train_generator.batch_size
self.validation_steps = self.valid_generator.samples // self.valid_generator.batch_size
# --- NEW: DEFINE CALLBACKS FOR SMART TRAINING ---
# This will save the BEST model based on validation accuracy
best_model_checkpoint = ModelCheckpoint(
filepath=self.config.trained_model_path, # Saves the best model to your specified path
save_best_only=True,
monitor='val_accuracy',
mode='max',
verbose=1
)
# This will stop training if there's no improvement
early_stopping = EarlyStopping(
monitor='val_accuracy',
patience=5, # Number of epochs with no improvement to wait
restore_best_weights=True,
verbose=1
)
callbacks_list = [best_model_checkpoint, early_stopping]
# -----------------------------------------------
# --- MODEL.FIT() IS NOW UPGRADED ---
history = self.model.fit(
self.train_generator,
epochs=self.config.params_epochs,
steps_per_epoch=self.steps_per_epoch,
validation_steps=self.validation_steps,
validation_data=self.valid_generator,
callbacks=callbacks_list # Pass the smart callbacks here
)
# -------------------------------------
# --- NEW: SAVE TRAINING HISTORY FOR ANALYSIS ---
history_df = pd.DataFrame(history.history)
history_path = "training_history.csv" # Saved in the root directory
history_df.to_csv(history_path, index=False)
print(f"✅ Training history saved to {history_path}")
# -----------------------------------------------
# The save_model call is now handled by ModelCheckpoint,
# so this is redundant but harmless. It will save the last epoch's model.
# The BEST model is already saved by the callback.
# self.save_model(
# path=self.config.trained_model_path,
# model=self.model
# ) |