|
|
""" |
|
|
Training Monitoring |
|
|
================== |
|
|
|
|
|
This module provides comprehensive monitoring capabilities for gradient descent |
|
|
training, including gradient tracking, loss monitoring, and performance metrics. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import math |
|
|
import time |
|
|
from typing import Dict, Optional, Any |
|
|
import torch |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from collections import defaultdict, deque |
|
|
import json |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GradientMonitor: |
|
|
""" |
|
|
Monitor gradient statistics during training |
|
|
|
|
|
Tracks gradient norms, distributions, and anomalies to ensure |
|
|
stable training and detect potential issues. |
|
|
""" |
|
|
|
|
|
def __init__(self, max_history: int = 1000): |
|
|
self.max_history = max_history |
|
|
self.gradient_norms = deque(maxlen=max_history) |
|
|
self.gradient_means = deque(maxlen=max_history) |
|
|
self.gradient_stds = deque(maxlen=max_history) |
|
|
self.gradient_maxs = deque(maxlen=max_history) |
|
|
self.gradient_mins = deque(maxlen=max_history) |
|
|
|
|
|
self.parameter_stats = defaultdict(lambda: { |
|
|
'norms': deque(maxlen=max_history), |
|
|
'means': deque(maxlen=max_history), |
|
|
'stds': deque(maxlen=max_history) |
|
|
}) |
|
|
|
|
|
self.anomaly_count = 0 |
|
|
self.last_anomaly_time = None |
|
|
|
|
|
logger.info("Initialized GradientMonitor") |
|
|
|
|
|
def update(self, gradients: Dict[str, torch.Tensor]): |
|
|
""" |
|
|
Update gradient statistics |
|
|
|
|
|
Args: |
|
|
gradients: Dictionary of parameter gradients |
|
|
""" |
|
|
total_norm = 0.0 |
|
|
total_mean = 0.0 |
|
|
total_std = 0.0 |
|
|
total_max = float('-inf') |
|
|
total_min = float('inf') |
|
|
param_count = 0 |
|
|
|
|
|
for name, grad in gradients.items(): |
|
|
if grad is not None: |
|
|
|
|
|
param_norm = grad.data.norm(2).item() |
|
|
param_mean = grad.data.mean().item() |
|
|
param_std = grad.data.std().item() |
|
|
param_max = grad.data.max().item() |
|
|
param_min = grad.data.min().item() |
|
|
|
|
|
|
|
|
self.parameter_stats[name]['norms'].append(param_norm) |
|
|
self.parameter_stats[name]['means'].append(param_mean) |
|
|
self.parameter_stats[name]['stds'].append(param_std) |
|
|
|
|
|
|
|
|
total_norm += param_norm ** 2 |
|
|
total_mean += param_mean |
|
|
total_std += param_std ** 2 |
|
|
total_max = max(total_max, param_max) |
|
|
total_min = min(total_min, param_min) |
|
|
param_count += 1 |
|
|
|
|
|
|
|
|
if param_count > 0: |
|
|
total_norm = math.sqrt(total_norm) |
|
|
total_mean /= param_count |
|
|
total_std = math.sqrt(total_std / param_count) |
|
|
|
|
|
|
|
|
self.gradient_norms.append(total_norm) |
|
|
self.gradient_means.append(total_mean) |
|
|
self.gradient_stds.append(total_std) |
|
|
self.gradient_maxs.append(total_max) |
|
|
self.gradient_mins.append(total_min) |
|
|
|
|
|
def detect_anomalies(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Detect gradient anomalies |
|
|
|
|
|
Returns: |
|
|
Dictionary of detected anomalies |
|
|
""" |
|
|
anomalies = { |
|
|
'exploding_gradients': False, |
|
|
'vanishing_gradients': False, |
|
|
'gradient_imbalance': False, |
|
|
'nan_gradients': False, |
|
|
'gradient_spikes': False |
|
|
} |
|
|
|
|
|
if len(self.gradient_norms) < 2: |
|
|
return anomalies |
|
|
|
|
|
current_norm = self.gradient_norms[-1] |
|
|
|
|
|
|
|
|
if current_norm > 10.0: |
|
|
anomalies['exploding_gradients'] = True |
|
|
self.anomaly_count += 1 |
|
|
self.last_anomaly_time = time.time() |
|
|
logger.warning(f"Exploding gradients detected: norm={current_norm:.6f}") |
|
|
|
|
|
|
|
|
if current_norm < 1e-6: |
|
|
anomalies['vanishing_gradients'] = True |
|
|
self.anomaly_count += 1 |
|
|
self.last_anomaly_time = time.time() |
|
|
logger.warning(f"Vanishing gradients detected: norm={current_norm:.6f}") |
|
|
|
|
|
|
|
|
if len(self.gradient_norms) >= 10: |
|
|
recent_norms = list(self.gradient_norms)[-10:] |
|
|
avg_norm = np.mean(recent_norms[:-1]) |
|
|
if current_norm > 3 * avg_norm: |
|
|
anomalies['gradient_spikes'] = True |
|
|
logger.warning(f"Gradient spike detected: {current_norm:.6f} vs avg {avg_norm:.6f}") |
|
|
|
|
|
|
|
|
if math.isnan(current_norm) or math.isnan(self.gradient_means[-1]): |
|
|
anomalies['nan_gradients'] = True |
|
|
self.anomaly_count += 1 |
|
|
self.last_anomaly_time = time.time() |
|
|
logger.warning("NaN gradients detected") |
|
|
|
|
|
|
|
|
if len(self.parameter_stats) > 1: |
|
|
param_norms = [stats['norms'][-1] for stats in self.parameter_stats.values() |
|
|
if len(stats['norms']) > 0] |
|
|
if param_norms and max(param_norms) / min(param_norms) > 1000: |
|
|
anomalies['gradient_imbalance'] = True |
|
|
logger.warning("Gradient imbalance detected between parameters") |
|
|
|
|
|
return anomalies |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get comprehensive gradient statistics |
|
|
|
|
|
Returns: |
|
|
Dictionary of gradient statistics |
|
|
""" |
|
|
if not self.gradient_norms: |
|
|
return {} |
|
|
|
|
|
stats = { |
|
|
'current_norm': self.gradient_norms[-1], |
|
|
'mean_norm': np.mean(self.gradient_norms), |
|
|
'std_norm': np.std(self.gradient_norms), |
|
|
'min_norm': min(self.gradient_norms), |
|
|
'max_norm': max(self.gradient_norms), |
|
|
'current_mean': self.gradient_means[-1], |
|
|
'current_std': self.gradient_stds[-1], |
|
|
'current_max': self.gradient_maxs[-1], |
|
|
'current_min': self.gradient_mins[-1], |
|
|
'anomaly_count': self.anomaly_count, |
|
|
'parameter_count': len(self.parameter_stats) |
|
|
} |
|
|
|
|
|
|
|
|
param_stats = {} |
|
|
for name, stats_dict in self.parameter_stats.items(): |
|
|
if stats_dict['norms']: |
|
|
param_stats[name] = { |
|
|
'current_norm': stats_dict['norms'][-1], |
|
|
'mean_norm': np.mean(stats_dict['norms']), |
|
|
'std_norm': np.std(stats_dict['norms']) |
|
|
} |
|
|
|
|
|
stats['parameter_stats'] = param_stats |
|
|
|
|
|
return stats |
|
|
|
|
|
def plot_gradients(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot gradient statistics |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.gradient_norms: |
|
|
logger.warning("No gradient data to plot") |
|
|
return |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8)) |
|
|
|
|
|
|
|
|
axes[0, 0].plot(self.gradient_norms) |
|
|
axes[0, 0].set_title('Gradient Norms') |
|
|
axes[0, 0].set_xlabel('Step') |
|
|
axes[0, 0].set_ylabel('Norm') |
|
|
axes[0, 0].grid(True) |
|
|
|
|
|
|
|
|
axes[0, 1].plot(self.gradient_means) |
|
|
axes[0, 1].set_title('Gradient Means') |
|
|
axes[0, 1].set_xlabel('Step') |
|
|
axes[0, 1].set_ylabel('Mean') |
|
|
axes[0, 1].grid(True) |
|
|
|
|
|
|
|
|
axes[1, 0].plot(self.gradient_stds) |
|
|
axes[1, 0].set_title('Gradient Standard Deviations') |
|
|
axes[1, 0].set_xlabel('Step') |
|
|
axes[1, 0].set_ylabel('Std') |
|
|
axes[1, 0].grid(True) |
|
|
|
|
|
|
|
|
axes[1, 1].plot(self.gradient_maxs, label='Max') |
|
|
axes[1, 1].plot(self.gradient_mins, label='Min') |
|
|
axes[1, 1].set_title('Gradient Range') |
|
|
axes[1, 1].set_xlabel('Step') |
|
|
axes[1, 1].set_ylabel('Value') |
|
|
axes[1, 1].legend() |
|
|
axes[1, 1].grid(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
logger.info(f"Gradient plots saved to {save_path}") |
|
|
|
|
|
plt.show() |
|
|
|
|
|
|
|
|
class TrainingMonitor: |
|
|
""" |
|
|
Monitor training progress and performance |
|
|
|
|
|
Tracks loss, accuracy, learning rates, and other training metrics |
|
|
to provide comprehensive training insights. |
|
|
""" |
|
|
|
|
|
def __init__(self, max_history: int = 1000): |
|
|
self.max_history = max_history |
|
|
self.losses = deque(maxlen=max_history) |
|
|
self.accuracies = deque(maxlen=max_history) |
|
|
self.learning_rates = deque(maxlen=max_history) |
|
|
self.training_times = deque(maxlen=max_history) |
|
|
|
|
|
self.epoch_metrics = defaultdict(lambda: { |
|
|
'loss': deque(maxlen=max_history), |
|
|
'accuracy': deque(maxlen=max_history), |
|
|
'learning_rate': deque(maxlen=max_history) |
|
|
}) |
|
|
|
|
|
self.best_loss = float('inf') |
|
|
self.best_accuracy = 0.0 |
|
|
self.training_start_time = time.time() |
|
|
|
|
|
logger.info("Initialized TrainingMonitor") |
|
|
|
|
|
def update(self, loss: float, accuracy: Optional[float] = None, |
|
|
learning_rate: Optional[float] = None, epoch: Optional[int] = None): |
|
|
""" |
|
|
Update training metrics |
|
|
|
|
|
Args: |
|
|
loss: Current loss value |
|
|
accuracy: Current accuracy (optional) |
|
|
learning_rate: Current learning rate (optional) |
|
|
epoch: Current epoch (optional) |
|
|
""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
self.losses.append(loss) |
|
|
if accuracy is not None: |
|
|
self.accuracies.append(accuracy) |
|
|
if learning_rate is not None: |
|
|
self.learning_rates.append(learning_rate) |
|
|
self.training_times.append(current_time - self.training_start_time) |
|
|
|
|
|
|
|
|
if epoch is not None: |
|
|
self.epoch_metrics[epoch]['loss'].append(loss) |
|
|
if accuracy is not None: |
|
|
self.epoch_metrics[epoch]['accuracy'].append(accuracy) |
|
|
if learning_rate is not None: |
|
|
self.epoch_metrics[epoch]['learning_rate'].append(learning_rate) |
|
|
|
|
|
|
|
|
if loss < self.best_loss: |
|
|
self.best_loss = loss |
|
|
if accuracy is not None and accuracy > self.best_accuracy: |
|
|
self.best_accuracy = accuracy |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get comprehensive training statistics |
|
|
|
|
|
Returns: |
|
|
Dictionary of training statistics |
|
|
""" |
|
|
if not self.losses: |
|
|
return {} |
|
|
|
|
|
stats = { |
|
|
'current_loss': self.losses[-1], |
|
|
'best_loss': self.best_loss, |
|
|
'mean_loss': np.mean(self.losses), |
|
|
'std_loss': np.std(self.losses), |
|
|
'min_loss': min(self.losses), |
|
|
'max_loss': max(self.losses), |
|
|
'best_accuracy': self.best_accuracy, |
|
|
'total_steps': len(self.losses), |
|
|
'training_time': self.training_times[-1] if self.training_times else 0 |
|
|
} |
|
|
|
|
|
if self.accuracies: |
|
|
stats.update({ |
|
|
'current_accuracy': self.accuracies[-1], |
|
|
'mean_accuracy': np.mean(self.accuracies), |
|
|
'std_accuracy': np.std(self.accuracies) |
|
|
}) |
|
|
|
|
|
if self.learning_rates: |
|
|
stats.update({ |
|
|
'current_learning_rate': self.learning_rates[-1], |
|
|
'mean_learning_rate': np.mean(self.learning_rates), |
|
|
'min_learning_rate': min(self.learning_rates), |
|
|
'max_learning_rate': max(self.learning_rates) |
|
|
}) |
|
|
|
|
|
return stats |
|
|
|
|
|
def detect_convergence(self, patience: int = 10, threshold: float = 1e-4) -> bool: |
|
|
""" |
|
|
Detect if training has converged |
|
|
|
|
|
Args: |
|
|
patience: Number of steps to wait for improvement |
|
|
threshold: Minimum improvement threshold |
|
|
|
|
|
Returns: |
|
|
True if training has converged |
|
|
""" |
|
|
if len(self.losses) < patience: |
|
|
return False |
|
|
|
|
|
recent_losses = list(self.losses)[-patience:] |
|
|
best_recent = min(recent_losses) |
|
|
improvement = self.best_loss - best_recent |
|
|
|
|
|
return improvement < threshold |
|
|
|
|
|
def plot_training_curves(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot training curves |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.losses: |
|
|
logger.warning("No training data to plot") |
|
|
return |
|
|
|
|
|
fig, axes = plt.subplots(2, 2, figsize=(12, 8)) |
|
|
|
|
|
|
|
|
axes[0, 0].plot(self.losses) |
|
|
axes[0, 0].set_title('Training Loss') |
|
|
axes[0, 0].set_xlabel('Step') |
|
|
axes[0, 0].set_ylabel('Loss') |
|
|
axes[0, 0].grid(True) |
|
|
|
|
|
|
|
|
if self.accuracies: |
|
|
axes[0, 1].plot(self.accuracies) |
|
|
axes[0, 1].set_title('Training Accuracy') |
|
|
axes[0, 1].set_xlabel('Step') |
|
|
axes[0, 1].set_ylabel('Accuracy') |
|
|
axes[0, 1].grid(True) |
|
|
|
|
|
|
|
|
if self.learning_rates: |
|
|
axes[1, 0].plot(self.learning_rates) |
|
|
axes[1, 0].set_title('Learning Rate') |
|
|
axes[1, 0].set_xlabel('Step') |
|
|
axes[1, 0].set_ylabel('Learning Rate') |
|
|
axes[1, 0].grid(True) |
|
|
|
|
|
|
|
|
if self.training_times: |
|
|
axes[1, 1].plot(self.training_times) |
|
|
axes[1, 1].set_title('Training Time') |
|
|
axes[1, 1].set_xlabel('Step') |
|
|
axes[1, 1].set_ylabel('Time (seconds)') |
|
|
axes[1, 1].grid(True) |
|
|
|
|
|
plt.tight_layout() |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
logger.info(f"Training curves saved to {save_path}") |
|
|
|
|
|
plt.show() |
|
|
|
|
|
def save_metrics(self, file_path: str): |
|
|
""" |
|
|
Save training metrics to file |
|
|
|
|
|
Args: |
|
|
file_path: Path to save the metrics |
|
|
""" |
|
|
metrics = { |
|
|
'losses': list(self.losses), |
|
|
'accuracies': list(self.accuracies), |
|
|
'learning_rates': list(self.learning_rates), |
|
|
'training_times': list(self.training_times), |
|
|
'best_loss': self.best_loss, |
|
|
'best_accuracy': self.best_accuracy, |
|
|
'statistics': self.get_statistics() |
|
|
} |
|
|
|
|
|
with open(file_path, 'w') as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
|
|
|
logger.info(f"Training metrics saved to {file_path}") |
|
|
|
|
|
def load_metrics(self, file_path: str): |
|
|
""" |
|
|
Load training metrics from file |
|
|
|
|
|
Args: |
|
|
file_path: Path to load the metrics from |
|
|
""" |
|
|
with open(file_path, 'r') as f: |
|
|
metrics = json.load(f) |
|
|
|
|
|
self.losses = deque(metrics['losses'], maxlen=self.max_history) |
|
|
self.accuracies = deque(metrics['accuracies'], maxlen=self.max_history) |
|
|
self.learning_rates = deque(metrics['learning_rates'], maxlen=self.max_history) |
|
|
self.training_times = deque(metrics['training_times'], maxlen=self.max_history) |
|
|
self.best_loss = metrics['best_loss'] |
|
|
self.best_accuracy = metrics['best_accuracy'] |
|
|
|
|
|
logger.info(f"Training metrics loaded from {file_path}") |
|
|
|
|
|
|
|
|
class PerformanceMonitor: |
|
|
""" |
|
|
Monitor system performance during training |
|
|
|
|
|
Tracks memory usage, compute time, and other system metrics |
|
|
to optimize training efficiency. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.memory_usage = deque(maxlen=1000) |
|
|
self.compute_times = deque(maxlen=1000) |
|
|
self.gpu_usage = deque(maxlen=1000) |
|
|
|
|
|
self.step_times = [] |
|
|
self.forward_times = [] |
|
|
self.backward_times = [] |
|
|
self.optimizer_times = [] |
|
|
|
|
|
logger.info("Initialized PerformanceMonitor") |
|
|
|
|
|
def update_memory(self, memory_mb: float): |
|
|
"""Update memory usage""" |
|
|
self.memory_usage.append(memory_mb) |
|
|
|
|
|
def update_compute_time(self, time_seconds: float): |
|
|
"""Update compute time""" |
|
|
self.compute_times.append(time_seconds) |
|
|
|
|
|
def update_gpu_usage(self, gpu_percent: float): |
|
|
"""Update GPU usage""" |
|
|
self.gpu_usage.append(gpu_percent) |
|
|
|
|
|
def time_step(self, step_name: str): |
|
|
"""Context manager for timing steps""" |
|
|
return StepTimer(self, step_name) |
|
|
|
|
|
def get_statistics(self) -> Dict[str, Any]: |
|
|
"""Get performance statistics""" |
|
|
stats = {} |
|
|
|
|
|
if self.memory_usage: |
|
|
stats['memory'] = { |
|
|
'current_mb': self.memory_usage[-1], |
|
|
'mean_mb': np.mean(self.memory_usage), |
|
|
'max_mb': max(self.memory_usage) |
|
|
} |
|
|
|
|
|
if self.compute_times: |
|
|
stats['compute'] = { |
|
|
'current_seconds': self.compute_times[-1], |
|
|
'mean_seconds': np.mean(self.compute_times), |
|
|
'total_seconds': sum(self.compute_times) |
|
|
} |
|
|
|
|
|
if self.gpu_usage: |
|
|
stats['gpu'] = { |
|
|
'current_percent': self.gpu_usage[-1], |
|
|
'mean_percent': np.mean(self.gpu_usage), |
|
|
'max_percent': max(self.gpu_usage) |
|
|
} |
|
|
|
|
|
return stats |
|
|
|
|
|
|
|
|
class StepTimer: |
|
|
"""Context manager for timing training steps""" |
|
|
|
|
|
def __init__(self, monitor: PerformanceMonitor, step_name: str): |
|
|
self.monitor = monitor |
|
|
self.step_name = step_name |
|
|
self.start_time = None |
|
|
|
|
|
def __enter__(self): |
|
|
self.start_time = time.time() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
elapsed = time.time() - self.start_time |
|
|
|
|
|
if self.step_name == 'forward': |
|
|
self.monitor.forward_times.append(elapsed) |
|
|
elif self.step_name == 'backward': |
|
|
self.monitor.backward_times.append(elapsed) |
|
|
elif self.step_name == 'optimizer': |
|
|
self.monitor.optimizer_times.append(elapsed) |
|
|
else: |
|
|
self.monitor.step_times.append(elapsed) |
|
|
|