Spaces:
Sleeping
Sleeping
| """ | |
| Training pipeline for emotion recognition models. | |
| """ | |
| import os | |
| import json | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, Optional, Tuple, Callable | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model | |
| from tensorflow.keras.optimizers import Adam | |
| from tensorflow.keras.callbacks import ( | |
| EarlyStopping, ModelCheckpoint, ReduceLROnPlateau, | |
| TensorBoard, Callback | |
| ) | |
| import sys | |
| sys.path.append(str(Path(__file__).parent.parent.parent)) | |
| from src.config import ( | |
| EPOCHS, LEARNING_RATE, LEARNING_RATE_FINE_TUNE, | |
| EARLY_STOPPING_PATIENCE, REDUCE_LR_PATIENCE, REDUCE_LR_FACTOR, | |
| MODELS_DIR, CUSTOM_CNN_PATH, MOBILENET_PATH, VGG_PATH | |
| ) | |
| class TrainingProgressCallback(Callback): | |
| """Custom callback to track and display training progress.""" | |
| def __init__(self, total_epochs: int): | |
| super().__init__() | |
| self.total_epochs = total_epochs | |
| def on_epoch_end(self, epoch, logs=None): | |
| logs = logs or {} | |
| print(f"\nEpoch {epoch + 1}/{self.total_epochs}") | |
| print(f" Loss: {logs.get('loss', 0):.4f} - Accuracy: {logs.get('accuracy', 0):.4f}") | |
| print(f" Val Loss: {logs.get('val_loss', 0):.4f} - Val Accuracy: {logs.get('val_accuracy', 0):.4f}") | |
| class EmotionModelTrainer: | |
| """ | |
| Trainer class for emotion recognition models. | |
| """ | |
| def __init__( | |
| self, | |
| model: Model, | |
| model_name: str = "model", | |
| save_path: Optional[Path] = None, | |
| logs_dir: Optional[Path] = None | |
| ): | |
| """ | |
| Initialize the trainer. | |
| Args: | |
| model: Keras model to train | |
| model_name: Name for the model (used for saving) | |
| save_path: Path to save the trained model | |
| logs_dir: Directory for TensorBoard logs | |
| """ | |
| self.model = model | |
| self.model_name = model_name | |
| self.save_path = save_path or MODELS_DIR / f"{model_name}.h5" | |
| self.logs_dir = logs_dir or MODELS_DIR / "logs" / model_name | |
| self.history = None | |
| self.training_metadata = {} | |
| # Create directories | |
| Path(self.save_path).parent.mkdir(parents=True, exist_ok=True) | |
| Path(self.logs_dir).mkdir(parents=True, exist_ok=True) | |
| def compile( | |
| self, | |
| learning_rate: float = LEARNING_RATE, | |
| optimizer: Optional[tf.keras.optimizers.Optimizer] = None, | |
| loss: str = 'categorical_crossentropy', | |
| metrics: list = ['accuracy'] | |
| ) -> None: | |
| """ | |
| Compile the model. | |
| Args: | |
| learning_rate: Learning rate for optimizer | |
| optimizer: Custom optimizer (uses Adam if None) | |
| loss: Loss function | |
| metrics: Metrics to track | |
| """ | |
| if optimizer is None: | |
| optimizer = Adam(learning_rate=learning_rate) | |
| self.model.compile( | |
| optimizer=optimizer, | |
| loss=loss, | |
| metrics=metrics | |
| ) | |
| self.training_metadata['learning_rate'] = learning_rate | |
| self.training_metadata['loss_function'] = loss | |
| self.training_metadata['metrics'] = metrics | |
| def get_callbacks( | |
| self, | |
| use_early_stopping: bool = True, | |
| use_reduce_lr: bool = True, | |
| use_tensorboard: bool = True, | |
| use_checkpoint: bool = True, | |
| custom_callbacks: Optional[list] = None | |
| ) -> list: | |
| """ | |
| Get training callbacks. | |
| Args: | |
| use_early_stopping: Whether to use early stopping | |
| use_reduce_lr: Whether to reduce LR on plateau | |
| use_tensorboard: Whether to log to TensorBoard | |
| use_checkpoint: Whether to save best model | |
| custom_callbacks: Additional custom callbacks | |
| Returns: | |
| List of callbacks | |
| """ | |
| callbacks = [] | |
| if use_early_stopping: | |
| callbacks.append(EarlyStopping( | |
| monitor='val_loss', | |
| patience=EARLY_STOPPING_PATIENCE, | |
| restore_best_weights=True, | |
| verbose=1 | |
| )) | |
| if use_reduce_lr: | |
| callbacks.append(ReduceLROnPlateau( | |
| monitor='val_loss', | |
| factor=REDUCE_LR_FACTOR, | |
| patience=REDUCE_LR_PATIENCE, | |
| min_lr=1e-7, | |
| verbose=1 | |
| )) | |
| if use_tensorboard: | |
| callbacks.append(TensorBoard( | |
| log_dir=str(self.logs_dir), | |
| histogram_freq=1, | |
| write_graph=True | |
| )) | |
| if use_checkpoint: | |
| callbacks.append(ModelCheckpoint( | |
| filepath=str(self.save_path), | |
| monitor='val_accuracy', | |
| save_best_only=True, | |
| mode='max', | |
| verbose=1 | |
| )) | |
| if custom_callbacks: | |
| callbacks.extend(custom_callbacks) | |
| return callbacks | |
| def train( | |
| self, | |
| train_generator, | |
| val_generator, | |
| epochs: int = EPOCHS, | |
| class_weights: Optional[Dict] = None, | |
| callbacks: Optional[list] = None, | |
| verbose: int = 1 | |
| ) -> Dict: | |
| """ | |
| Train the model. | |
| Args: | |
| train_generator: Training data generator | |
| val_generator: Validation data generator | |
| epochs: Number of epochs | |
| class_weights: Optional class weights for imbalanced data | |
| callbacks: Optional custom callbacks (uses defaults if None) | |
| verbose: Verbosity mode | |
| Returns: | |
| Training history dictionary | |
| """ | |
| if callbacks is None: | |
| callbacks = self.get_callbacks() | |
| # Add progress callback | |
| callbacks.append(TrainingProgressCallback(epochs)) | |
| # Record training start | |
| start_time = datetime.now() | |
| self.training_metadata['training_started'] = start_time.isoformat() | |
| self.training_metadata['epochs_requested'] = epochs | |
| print(f"\n{'='*60}") | |
| print(f"Training {self.model_name}") | |
| print(f"{'='*60}") | |
| print(f"Epochs: {epochs}") | |
| print(f"Training samples: {train_generator.samples}") | |
| print(f"Validation samples: {val_generator.samples}") | |
| print(f"{'='*60}\n") | |
| # Train | |
| self.history = self.model.fit( | |
| train_generator, | |
| epochs=epochs, | |
| validation_data=val_generator, | |
| class_weight=class_weights, | |
| callbacks=callbacks, | |
| verbose=verbose | |
| ) | |
| # Record training end | |
| end_time = datetime.now() | |
| duration = (end_time - start_time).total_seconds() | |
| self.training_metadata['training_ended'] = end_time.isoformat() | |
| self.training_metadata['training_duration_seconds'] = duration | |
| self.training_metadata['epochs_completed'] = len(self.history.history['loss']) | |
| self.training_metadata['final_accuracy'] = float(self.history.history['accuracy'][-1]) | |
| self.training_metadata['final_val_accuracy'] = float(self.history.history['val_accuracy'][-1]) | |
| self.training_metadata['best_val_accuracy'] = float(max(self.history.history['val_accuracy'])) | |
| print(f"\n{'='*60}") | |
| print(f"Training Complete!") | |
| print(f"Duration: {duration/60:.2f} minutes") | |
| print(f"Best Validation Accuracy: {self.training_metadata['best_val_accuracy']:.4f}") | |
| print(f"{'='*60}\n") | |
| return self.history.history | |
| def fine_tune( | |
| self, | |
| train_generator, | |
| val_generator, | |
| epochs: int = 20, | |
| learning_rate: float = LEARNING_RATE_FINE_TUNE, | |
| unfreeze_layers: int = 30 | |
| ) -> Dict: | |
| """ | |
| Fine-tune a transfer learning model. | |
| Args: | |
| train_generator: Training data generator | |
| val_generator: Validation data generator | |
| epochs: Number of fine-tuning epochs | |
| learning_rate: Learning rate for fine-tuning | |
| unfreeze_layers: Number of layers to unfreeze | |
| Returns: | |
| Fine-tuning history | |
| """ | |
| # Unfreeze top layers | |
| for layer in self.model.layers[-unfreeze_layers:]: | |
| layer.trainable = True | |
| # Recompile with lower learning rate | |
| self.compile(learning_rate=learning_rate) | |
| print(f"\nFine-tuning with learning rate: {learning_rate}") | |
| print(f"Unfrozen {unfreeze_layers} top layers") | |
| # Continue training | |
| return self.train(train_generator, val_generator, epochs=epochs) | |
| def save_training_history(self) -> None: | |
| """Save training history and metadata to files.""" | |
| if self.history is None: | |
| print("No training history to save.") | |
| return | |
| # Save history as JSON | |
| history_path = self.save_path.with_suffix('.history.json') | |
| with open(history_path, 'w') as f: | |
| json.dump(self.history.history, f, indent=2) | |
| # Save metadata | |
| metadata_path = self.save_path.with_suffix('.meta.json') | |
| with open(metadata_path, 'w') as f: | |
| json.dump(self.training_metadata, f, indent=2) | |
| print(f"Training history saved to: {history_path}") | |
| print(f"Training metadata saved to: {metadata_path}") | |
| def get_training_summary(self) -> Dict: | |
| """ | |
| Get a summary of the training results. | |
| Returns: | |
| Dictionary with training summary | |
| """ | |
| if self.history is None: | |
| return {"status": "Not trained"} | |
| return { | |
| "model_name": self.model_name, | |
| "epochs_completed": len(self.history.history['loss']), | |
| "final_accuracy": self.history.history['accuracy'][-1], | |
| "final_val_accuracy": self.history.history['val_accuracy'][-1], | |
| "best_val_accuracy": max(self.history.history['val_accuracy']), | |
| "final_loss": self.history.history['loss'][-1], | |
| "final_val_loss": self.history.history['val_loss'][-1], | |
| "training_duration": self.training_metadata.get('training_duration_seconds', 0) | |
| } | |
| def train_custom_cnn( | |
| train_generator, | |
| val_generator, | |
| epochs: int = EPOCHS, | |
| class_weights: Optional[Dict] = None | |
| ) -> Tuple[Model, Dict]: | |
| """ | |
| Train the custom CNN model. | |
| Args: | |
| train_generator: Training data generator | |
| val_generator: Validation data generator | |
| epochs: Number of epochs | |
| class_weights: Optional class weights | |
| Returns: | |
| Tuple of (trained model, training history) | |
| """ | |
| from src.models.custom_cnn import build_custom_cnn | |
| model = build_custom_cnn() | |
| trainer = EmotionModelTrainer(model, "custom_cnn", CUSTOM_CNN_PATH) | |
| trainer.compile() | |
| history = trainer.train(train_generator, val_generator, epochs, class_weights) | |
| trainer.save_training_history() | |
| return model, history | |
| def train_mobilenet( | |
| train_generator, | |
| val_generator, | |
| epochs: int = EPOCHS, | |
| fine_tune_epochs: int = 20, | |
| class_weights: Optional[Dict] = None | |
| ) -> Tuple[Model, Dict]: | |
| """ | |
| Train the MobileNetV2 model with fine-tuning. | |
| Args: | |
| train_generator: Training data generator (RGB, 96x96) | |
| val_generator: Validation data generator | |
| epochs: Initial training epochs | |
| fine_tune_epochs: Fine-tuning epochs | |
| class_weights: Optional class weights | |
| Returns: | |
| Tuple of (trained model, training history) | |
| """ | |
| from src.models.mobilenet_model import build_mobilenet_model | |
| model = build_mobilenet_model() | |
| trainer = EmotionModelTrainer(model, "mobilenet_v2", MOBILENET_PATH) | |
| # Initial training with frozen base | |
| trainer.compile() | |
| history = trainer.train(train_generator, val_generator, epochs, class_weights) | |
| # Fine-tuning | |
| if fine_tune_epochs > 0: | |
| fine_tune_history = trainer.fine_tune( | |
| train_generator, val_generator, fine_tune_epochs | |
| ) | |
| # Merge histories | |
| for key in history: | |
| history[key].extend(fine_tune_history[key]) | |
| trainer.save_training_history() | |
| return model, history | |
| def train_vgg( | |
| train_generator, | |
| val_generator, | |
| epochs: int = EPOCHS, | |
| fine_tune_epochs: int = 15, | |
| class_weights: Optional[Dict] = None | |
| ) -> Tuple[Model, Dict]: | |
| """ | |
| Train the VGG-19 model with fine-tuning. | |
| Args: | |
| train_generator: Training data generator (RGB, 96x96) | |
| val_generator: Validation data generator | |
| epochs: Initial training epochs | |
| fine_tune_epochs: Fine-tuning epochs | |
| class_weights: Optional class weights | |
| Returns: | |
| Tuple of (trained model, training history) | |
| """ | |
| from src.models.vgg_model import build_vgg_model | |
| model = build_vgg_model() | |
| trainer = EmotionModelTrainer(model, "vgg19", VGG_PATH) | |
| # Initial training with frozen base | |
| trainer.compile() | |
| history = trainer.train(train_generator, val_generator, epochs, class_weights) | |
| # Fine-tuning | |
| if fine_tune_epochs > 0: | |
| fine_tune_history = trainer.fine_tune( | |
| train_generator, val_generator, fine_tune_epochs | |
| ) | |
| for key in history: | |
| history[key].extend(fine_tune_history[key]) | |
| trainer.save_training_history() | |
| return model, history | |