""" Training Monitoring ================== This module provides comprehensive monitoring capabilities for gradient descent training, including gradient tracking, loss monitoring, and performance metrics. """ import logging import math import time from typing import Dict, Optional, Any import torch import numpy as np import matplotlib.pyplot as plt from collections import defaultdict, deque import json logger = logging.getLogger(__name__) class GradientMonitor: """ Monitor gradient statistics during training Tracks gradient norms, distributions, and anomalies to ensure stable training and detect potential issues. """ def __init__(self, max_history: int = 1000): self.max_history = max_history self.gradient_norms = deque(maxlen=max_history) self.gradient_means = deque(maxlen=max_history) self.gradient_stds = deque(maxlen=max_history) self.gradient_maxs = deque(maxlen=max_history) self.gradient_mins = deque(maxlen=max_history) self.parameter_stats = defaultdict(lambda: { 'norms': deque(maxlen=max_history), 'means': deque(maxlen=max_history), 'stds': deque(maxlen=max_history) }) self.anomaly_count = 0 self.last_anomaly_time = None logger.info("Initialized GradientMonitor") def update(self, gradients: Dict[str, torch.Tensor]): """ Update gradient statistics Args: gradients: Dictionary of parameter gradients """ total_norm = 0.0 total_mean = 0.0 total_std = 0.0 total_max = float('-inf') total_min = float('inf') param_count = 0 for name, grad in gradients.items(): if grad is not None: # Compute statistics for this parameter param_norm = grad.data.norm(2).item() param_mean = grad.data.mean().item() param_std = grad.data.std().item() param_max = grad.data.max().item() param_min = grad.data.min().item() # Update parameter-specific stats self.parameter_stats[name]['norms'].append(param_norm) self.parameter_stats[name]['means'].append(param_mean) self.parameter_stats[name]['stds'].append(param_std) # Accumulate global stats total_norm += param_norm ** 2 total_mean += param_mean total_std += param_std ** 2 total_max = max(total_max, param_max) total_min = min(total_min, param_min) param_count += 1 # Compute global statistics if param_count > 0: total_norm = math.sqrt(total_norm) total_mean /= param_count total_std = math.sqrt(total_std / param_count) # Store global stats self.gradient_norms.append(total_norm) self.gradient_means.append(total_mean) self.gradient_stds.append(total_std) self.gradient_maxs.append(total_max) self.gradient_mins.append(total_min) def detect_anomalies(self) -> Dict[str, Any]: """ Detect gradient anomalies Returns: Dictionary of detected anomalies """ anomalies = { 'exploding_gradients': False, 'vanishing_gradients': False, 'gradient_imbalance': False, 'nan_gradients': False, 'gradient_spikes': False } if len(self.gradient_norms) < 2: return anomalies current_norm = self.gradient_norms[-1] # Check for exploding gradients if current_norm > 10.0: anomalies['exploding_gradients'] = True self.anomaly_count += 1 self.last_anomaly_time = time.time() logger.warning(f"Exploding gradients detected: norm={current_norm:.6f}") # Check for vanishing gradients if current_norm < 1e-6: anomalies['vanishing_gradients'] = True self.anomaly_count += 1 self.last_anomaly_time = time.time() logger.warning(f"Vanishing gradients detected: norm={current_norm:.6f}") # Check for gradient spikes if len(self.gradient_norms) >= 10: recent_norms = list(self.gradient_norms)[-10:] avg_norm = np.mean(recent_norms[:-1]) if current_norm > 3 * avg_norm: anomalies['gradient_spikes'] = True logger.warning(f"Gradient spike detected: {current_norm:.6f} vs avg {avg_norm:.6f}") # Check for NaN gradients if math.isnan(current_norm) or math.isnan(self.gradient_means[-1]): anomalies['nan_gradients'] = True self.anomaly_count += 1 self.last_anomaly_time = time.time() logger.warning("NaN gradients detected") # Check for gradient imbalance between parameters if len(self.parameter_stats) > 1: param_norms = [stats['norms'][-1] for stats in self.parameter_stats.values() if len(stats['norms']) > 0] if param_norms and max(param_norms) / min(param_norms) > 1000: anomalies['gradient_imbalance'] = True logger.warning("Gradient imbalance detected between parameters") return anomalies def get_statistics(self) -> Dict[str, Any]: """ Get comprehensive gradient statistics Returns: Dictionary of gradient statistics """ if not self.gradient_norms: return {} stats = { 'current_norm': self.gradient_norms[-1], 'mean_norm': np.mean(self.gradient_norms), 'std_norm': np.std(self.gradient_norms), 'min_norm': min(self.gradient_norms), 'max_norm': max(self.gradient_norms), 'current_mean': self.gradient_means[-1], 'current_std': self.gradient_stds[-1], 'current_max': self.gradient_maxs[-1], 'current_min': self.gradient_mins[-1], 'anomaly_count': self.anomaly_count, 'parameter_count': len(self.parameter_stats) } # Add parameter-specific statistics param_stats = {} for name, stats_dict in self.parameter_stats.items(): if stats_dict['norms']: param_stats[name] = { 'current_norm': stats_dict['norms'][-1], 'mean_norm': np.mean(stats_dict['norms']), 'std_norm': np.std(stats_dict['norms']) } stats['parameter_stats'] = param_stats return stats def plot_gradients(self, save_path: Optional[str] = None): """ Plot gradient statistics Args: save_path: Path to save the plot """ if not self.gradient_norms: logger.warning("No gradient data to plot") return fig, axes = plt.subplots(2, 2, figsize=(12, 8)) # Plot gradient norms axes[0, 0].plot(self.gradient_norms) axes[0, 0].set_title('Gradient Norms') axes[0, 0].set_xlabel('Step') axes[0, 0].set_ylabel('Norm') axes[0, 0].grid(True) # Plot gradient means axes[0, 1].plot(self.gradient_means) axes[0, 1].set_title('Gradient Means') axes[0, 1].set_xlabel('Step') axes[0, 1].set_ylabel('Mean') axes[0, 1].grid(True) # Plot gradient stds axes[1, 0].plot(self.gradient_stds) axes[1, 0].set_title('Gradient Standard Deviations') axes[1, 0].set_xlabel('Step') axes[1, 0].set_ylabel('Std') axes[1, 0].grid(True) # Plot gradient range axes[1, 1].plot(self.gradient_maxs, label='Max') axes[1, 1].plot(self.gradient_mins, label='Min') axes[1, 1].set_title('Gradient Range') axes[1, 1].set_xlabel('Step') axes[1, 1].set_ylabel('Value') axes[1, 1].legend() axes[1, 1].grid(True) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') logger.info(f"Gradient plots saved to {save_path}") plt.show() class TrainingMonitor: """ Monitor training progress and performance Tracks loss, accuracy, learning rates, and other training metrics to provide comprehensive training insights. """ def __init__(self, max_history: int = 1000): self.max_history = max_history self.losses = deque(maxlen=max_history) self.accuracies = deque(maxlen=max_history) self.learning_rates = deque(maxlen=max_history) self.training_times = deque(maxlen=max_history) self.epoch_metrics = defaultdict(lambda: { 'loss': deque(maxlen=max_history), 'accuracy': deque(maxlen=max_history), 'learning_rate': deque(maxlen=max_history) }) self.best_loss = float('inf') self.best_accuracy = 0.0 self.training_start_time = time.time() logger.info("Initialized TrainingMonitor") def update(self, loss: float, accuracy: Optional[float] = None, learning_rate: Optional[float] = None, epoch: Optional[int] = None): """ Update training metrics Args: loss: Current loss value accuracy: Current accuracy (optional) learning_rate: Current learning rate (optional) epoch: Current epoch (optional) """ current_time = time.time() # Update global metrics self.losses.append(loss) if accuracy is not None: self.accuracies.append(accuracy) if learning_rate is not None: self.learning_rates.append(learning_rate) self.training_times.append(current_time - self.training_start_time) # Update epoch-specific metrics if epoch is not None: self.epoch_metrics[epoch]['loss'].append(loss) if accuracy is not None: self.epoch_metrics[epoch]['accuracy'].append(accuracy) if learning_rate is not None: self.epoch_metrics[epoch]['learning_rate'].append(learning_rate) # Update best metrics if loss < self.best_loss: self.best_loss = loss if accuracy is not None and accuracy > self.best_accuracy: self.best_accuracy = accuracy def get_statistics(self) -> Dict[str, Any]: """ Get comprehensive training statistics Returns: Dictionary of training statistics """ if not self.losses: return {} stats = { 'current_loss': self.losses[-1], 'best_loss': self.best_loss, 'mean_loss': np.mean(self.losses), 'std_loss': np.std(self.losses), 'min_loss': min(self.losses), 'max_loss': max(self.losses), 'best_accuracy': self.best_accuracy, 'total_steps': len(self.losses), 'training_time': self.training_times[-1] if self.training_times else 0 } if self.accuracies: stats.update({ 'current_accuracy': self.accuracies[-1], 'mean_accuracy': np.mean(self.accuracies), 'std_accuracy': np.std(self.accuracies) }) if self.learning_rates: stats.update({ 'current_learning_rate': self.learning_rates[-1], 'mean_learning_rate': np.mean(self.learning_rates), 'min_learning_rate': min(self.learning_rates), 'max_learning_rate': max(self.learning_rates) }) return stats def detect_convergence(self, patience: int = 10, threshold: float = 1e-4) -> bool: """ Detect if training has converged Args: patience: Number of steps to wait for improvement threshold: Minimum improvement threshold Returns: True if training has converged """ if len(self.losses) < patience: return False recent_losses = list(self.losses)[-patience:] best_recent = min(recent_losses) improvement = self.best_loss - best_recent return improvement < threshold def plot_training_curves(self, save_path: Optional[str] = None): """ Plot training curves Args: save_path: Path to save the plot """ if not self.losses: logger.warning("No training data to plot") return fig, axes = plt.subplots(2, 2, figsize=(12, 8)) # Plot loss curve axes[0, 0].plot(self.losses) axes[0, 0].set_title('Training Loss') axes[0, 0].set_xlabel('Step') axes[0, 0].set_ylabel('Loss') axes[0, 0].grid(True) # Plot accuracy curve if self.accuracies: axes[0, 1].plot(self.accuracies) axes[0, 1].set_title('Training Accuracy') axes[0, 1].set_xlabel('Step') axes[0, 1].set_ylabel('Accuracy') axes[0, 1].grid(True) # Plot learning rate curve if self.learning_rates: axes[1, 0].plot(self.learning_rates) axes[1, 0].set_title('Learning Rate') axes[1, 0].set_xlabel('Step') axes[1, 0].set_ylabel('Learning Rate') axes[1, 0].grid(True) # Plot training time if self.training_times: axes[1, 1].plot(self.training_times) axes[1, 1].set_title('Training Time') axes[1, 1].set_xlabel('Step') axes[1, 1].set_ylabel('Time (seconds)') axes[1, 1].grid(True) plt.tight_layout() if save_path: plt.savefig(save_path, dpi=300, bbox_inches='tight') logger.info(f"Training curves saved to {save_path}") plt.show() def save_metrics(self, file_path: str): """ Save training metrics to file Args: file_path: Path to save the metrics """ metrics = { 'losses': list(self.losses), 'accuracies': list(self.accuracies), 'learning_rates': list(self.learning_rates), 'training_times': list(self.training_times), 'best_loss': self.best_loss, 'best_accuracy': self.best_accuracy, 'statistics': self.get_statistics() } with open(file_path, 'w') as f: json.dump(metrics, f, indent=2) logger.info(f"Training metrics saved to {file_path}") def load_metrics(self, file_path: str): """ Load training metrics from file Args: file_path: Path to load the metrics from """ with open(file_path, 'r') as f: metrics = json.load(f) self.losses = deque(metrics['losses'], maxlen=self.max_history) self.accuracies = deque(metrics['accuracies'], maxlen=self.max_history) self.learning_rates = deque(metrics['learning_rates'], maxlen=self.max_history) self.training_times = deque(metrics['training_times'], maxlen=self.max_history) self.best_loss = metrics['best_loss'] self.best_accuracy = metrics['best_accuracy'] logger.info(f"Training metrics loaded from {file_path}") class PerformanceMonitor: """ Monitor system performance during training Tracks memory usage, compute time, and other system metrics to optimize training efficiency. """ def __init__(self): self.memory_usage = deque(maxlen=1000) self.compute_times = deque(maxlen=1000) self.gpu_usage = deque(maxlen=1000) self.step_times = [] self.forward_times = [] self.backward_times = [] self.optimizer_times = [] logger.info("Initialized PerformanceMonitor") def update_memory(self, memory_mb: float): """Update memory usage""" self.memory_usage.append(memory_mb) def update_compute_time(self, time_seconds: float): """Update compute time""" self.compute_times.append(time_seconds) def update_gpu_usage(self, gpu_percent: float): """Update GPU usage""" self.gpu_usage.append(gpu_percent) def time_step(self, step_name: str): """Context manager for timing steps""" return StepTimer(self, step_name) def get_statistics(self) -> Dict[str, Any]: """Get performance statistics""" stats = {} if self.memory_usage: stats['memory'] = { 'current_mb': self.memory_usage[-1], 'mean_mb': np.mean(self.memory_usage), 'max_mb': max(self.memory_usage) } if self.compute_times: stats['compute'] = { 'current_seconds': self.compute_times[-1], 'mean_seconds': np.mean(self.compute_times), 'total_seconds': sum(self.compute_times) } if self.gpu_usage: stats['gpu'] = { 'current_percent': self.gpu_usage[-1], 'mean_percent': np.mean(self.gpu_usage), 'max_percent': max(self.gpu_usage) } return stats class StepTimer: """Context manager for timing training steps""" def __init__(self, monitor: PerformanceMonitor, step_name: str): self.monitor = monitor self.step_name = step_name self.start_time = None def __enter__(self): self.start_time = time.time() return self def __exit__(self, exc_type, exc_val, exc_tb): elapsed = time.time() - self.start_time if self.step_name == 'forward': self.monitor.forward_times.append(elapsed) elif self.step_name == 'backward': self.monitor.backward_times.append(elapsed) elif self.step_name == 'optimizer': self.monitor.optimizer_times.append(elapsed) else: self.monitor.step_times.append(elapsed)