""" Training pipeline for emotion recognition models. """ import os import json from pathlib import Path from datetime import datetime from typing import Dict, Optional, Tuple, Callable import numpy as np import tensorflow as tf from tensorflow.keras.models import Model from tensorflow.keras.optimizers import Adam from tensorflow.keras.callbacks import ( EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, TensorBoard, Callback ) import sys sys.path.append(str(Path(__file__).parent.parent.parent)) from src.config import ( EPOCHS, LEARNING_RATE, LEARNING_RATE_FINE_TUNE, EARLY_STOPPING_PATIENCE, REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR, MODELS_DIR, CUSTOM_CNN_PATH, MOBILENET_PATH, VGG_PATH ) class TrainingProgressCallback(Callback): """Custom callback to track and display training progress.""" def __init__(self, total_epochs: int): super().__init__() self.total_epochs = total_epochs def on_epoch_end(self, epoch, logs=None): logs = logs or {} print(f"\nEpoch {epoch + 1}/{self.total_epochs}") print(f" Loss: {logs.get('loss', 0):.4f} - Accuracy: {logs.get('accuracy', 0):.4f}") print(f" Val Loss: {logs.get('val_loss', 0):.4f} - Val Accuracy: {logs.get('val_accuracy', 0):.4f}") class EmotionModelTrainer: """ Trainer class for emotion recognition models. """ def __init__( self, model: Model, model_name: str = "model", save_path: Optional[Path] = None, logs_dir: Optional[Path] = None ): """ Initialize the trainer. Args: model: Keras model to train model_name: Name for the model (used for saving) save_path: Path to save the trained model logs_dir: Directory for TensorBoard logs """ self.model = model self.model_name = model_name self.save_path = save_path or MODELS_DIR / f"{model_name}.h5" self.logs_dir = logs_dir or MODELS_DIR / "logs" / model_name self.history = None self.training_metadata = {} # Create directories Path(self.save_path).parent.mkdir(parents=True, exist_ok=True) Path(self.logs_dir).mkdir(parents=True, exist_ok=True) def compile( self, learning_rate: float = LEARNING_RATE, optimizer: Optional[tf.keras.optimizers.Optimizer] = None, loss: str = 'categorical_crossentropy', metrics: list = ['accuracy'] ) -> None: """ Compile the model. Args: learning_rate: Learning rate for optimizer optimizer: Custom optimizer (uses Adam if None) loss: Loss function metrics: Metrics to track """ if optimizer is None: optimizer = Adam(learning_rate=learning_rate) self.model.compile( optimizer=optimizer, loss=loss, metrics=metrics ) self.training_metadata['learning_rate'] = learning_rate self.training_metadata['loss_function'] = loss self.training_metadata['metrics'] = metrics def get_callbacks( self, use_early_stopping: bool = True, use_reduce_lr: bool = True, use_tensorboard: bool = True, use_checkpoint: bool = True, custom_callbacks: Optional[list] = None ) -> list: """ Get training callbacks. Args: use_early_stopping: Whether to use early stopping use_reduce_lr: Whether to reduce LR on plateau use_tensorboard: Whether to log to TensorBoard use_checkpoint: Whether to save best model custom_callbacks: Additional custom callbacks Returns: List of callbacks """ callbacks = [] if use_early_stopping: callbacks.append(EarlyStopping( monitor='val_loss', patience=EARLY_STOPPING_PATIENCE, restore_best_weights=True, verbose=1 )) if use_reduce_lr: callbacks.append(ReduceLROnPlateau( monitor='val_loss', factor=REDUCE_LR_FACTOR, patience=REDUCE_LR_PATIENCE, min_lr=1e-7, verbose=1 )) if use_tensorboard: callbacks.append(TensorBoard( log_dir=str(self.logs_dir), histogram_freq=1, write_graph=True )) if use_checkpoint: callbacks.append(ModelCheckpoint( filepath=str(self.save_path), monitor='val_accuracy', save_best_only=True, mode='max', verbose=1 )) if custom_callbacks: callbacks.extend(custom_callbacks) return callbacks def train( self, train_generator, val_generator, epochs: int = EPOCHS, class_weights: Optional[Dict] = None, callbacks: Optional[list] = None, verbose: int = 1 ) -> Dict: """ Train the model. Args: train_generator: Training data generator val_generator: Validation data generator epochs: Number of epochs class_weights: Optional class weights for imbalanced data callbacks: Optional custom callbacks (uses defaults if None) verbose: Verbosity mode Returns: Training history dictionary """ if callbacks is None: callbacks = self.get_callbacks() # Add progress callback callbacks.append(TrainingProgressCallback(epochs)) # Record training start start_time = datetime.now() self.training_metadata['training_started'] = start_time.isoformat() self.training_metadata['epochs_requested'] = epochs print(f"\n{'='*60}") print(f"Training {self.model_name}") print(f"{'='*60}") print(f"Epochs: {epochs}") print(f"Training samples: {train_generator.samples}") print(f"Validation samples: {val_generator.samples}") print(f"{'='*60}\n") # Train self.history = self.model.fit( train_generator, epochs=epochs, validation_data=val_generator, class_weight=class_weights, callbacks=callbacks, verbose=verbose ) # Record training end end_time = datetime.now() duration = (end_time - start_time).total_seconds() self.training_metadata['training_ended'] = end_time.isoformat() self.training_metadata['training_duration_seconds'] = duration self.training_metadata['epochs_completed'] = len(self.history.history['loss']) self.training_metadata['final_accuracy'] = float(self.history.history['accuracy'][-1]) self.training_metadata['final_val_accuracy'] = float(self.history.history['val_accuracy'][-1]) self.training_metadata['best_val_accuracy'] = float(max(self.history.history['val_accuracy'])) print(f"\n{'='*60}") print(f"Training Complete!") print(f"Duration: {duration/60:.2f} minutes") print(f"Best Validation Accuracy: {self.training_metadata['best_val_accuracy']:.4f}") print(f"{'='*60}\n") return self.history.history def fine_tune( self, train_generator, val_generator, epochs: int = 20, learning_rate: float = LEARNING_RATE_FINE_TUNE, unfreeze_layers: int = 30 ) -> Dict: """ Fine-tune a transfer learning model. Args: train_generator: Training data generator val_generator: Validation data generator epochs: Number of fine-tuning epochs learning_rate: Learning rate for fine-tuning unfreeze_layers: Number of layers to unfreeze Returns: Fine-tuning history """ # Unfreeze top layers for layer in self.model.layers[-unfreeze_layers:]: layer.trainable = True # Recompile with lower learning rate self.compile(learning_rate=learning_rate) print(f"\nFine-tuning with learning rate: {learning_rate}") print(f"Unfrozen {unfreeze_layers} top layers") # Continue training return self.train(train_generator, val_generator, epochs=epochs) def save_training_history(self) -> None: """Save training history and metadata to files.""" if self.history is None: print("No training history to save.") return # Save history as JSON history_path = self.save_path.with_suffix('.history.json') with open(history_path, 'w') as f: json.dump(self.history.history, f, indent=2) # Save metadata metadata_path = self.save_path.with_suffix('.meta.json') with open(metadata_path, 'w') as f: json.dump(self.training_metadata, f, indent=2) print(f"Training history saved to: {history_path}") print(f"Training metadata saved to: {metadata_path}") def get_training_summary(self) -> Dict: """ Get a summary of the training results. Returns: Dictionary with training summary """ if self.history is None: return {"status": "Not trained"} return { "model_name": self.model_name, "epochs_completed": len(self.history.history['loss']), "final_accuracy": self.history.history['accuracy'][-1], "final_val_accuracy": self.history.history['val_accuracy'][-1], "best_val_accuracy": max(self.history.history['val_accuracy']), "final_loss": self.history.history['loss'][-1], "final_val_loss": self.history.history['val_loss'][-1], "training_duration": self.training_metadata.get('training_duration_seconds', 0) } def train_custom_cnn( train_generator, val_generator, epochs: int = EPOCHS, class_weights: Optional[Dict] = None ) -> Tuple[Model, Dict]: """ Train the custom CNN model. Args: train_generator: Training data generator val_generator: Validation data generator epochs: Number of epochs class_weights: Optional class weights Returns: Tuple of (trained model, training history) """ from src.models.custom_cnn import build_custom_cnn model = build_custom_cnn() trainer = EmotionModelTrainer(model, "custom_cnn", CUSTOM_CNN_PATH) trainer.compile() history = trainer.train(train_generator, val_generator, epochs, class_weights) trainer.save_training_history() return model, history def train_mobilenet( train_generator, val_generator, epochs: int = EPOCHS, fine_tune_epochs: int = 20, class_weights: Optional[Dict] = None ) -> Tuple[Model, Dict]: """ Train the MobileNetV2 model with fine-tuning. Args: train_generator: Training data generator (RGB, 96x96) val_generator: Validation data generator epochs: Initial training epochs fine_tune_epochs: Fine-tuning epochs class_weights: Optional class weights Returns: Tuple of (trained model, training history) """ from src.models.mobilenet_model import build_mobilenet_model model = build_mobilenet_model() trainer = EmotionModelTrainer(model, "mobilenet_v2", MOBILENET_PATH) # Initial training with frozen base trainer.compile() history = trainer.train(train_generator, val_generator, epochs, class_weights) # Fine-tuning if fine_tune_epochs > 0: fine_tune_history = trainer.fine_tune( train_generator, val_generator, fine_tune_epochs ) # Merge histories for key in history: history[key].extend(fine_tune_history[key]) trainer.save_training_history() return model, history def train_vgg( train_generator, val_generator, epochs: int = EPOCHS, fine_tune_epochs: int = 15, class_weights: Optional[Dict] = None ) -> Tuple[Model, Dict]: """ Train the VGG-19 model with fine-tuning. Args: train_generator: Training data generator (RGB, 96x96) val_generator: Validation data generator epochs: Initial training epochs fine_tune_epochs: Fine-tuning epochs class_weights: Optional class weights Returns: Tuple of (trained model, training history) """ from src.models.vgg_model import build_vgg_model model = build_vgg_model() trainer = EmotionModelTrainer(model, "vgg19", VGG_PATH) # Initial training with frozen base trainer.compile() history = trainer.train(train_generator, val_generator, epochs, class_weights) # Fine-tuning if fine_tune_epochs > 0: fine_tune_history = trainer.fine_tune( train_generator, val_generator, fine_tune_epochs ) for key in history: history[key].extend(fine_tune_history[key]) trainer.save_training_history() return model, history