|
|
""" |
|
|
Gradient Descent Training Loop |
|
|
============================= |
|
|
|
|
|
This module implements the main training loop that orchestrates gradient descent |
|
|
optimization with backpropagation for the MangoMAS multi-agent system. |
|
|
|
|
|
The training loop includes: |
|
|
- Forward and backward passes |
|
|
- Gradient computation and optimization |
|
|
- Learning rate scheduling |
|
|
- Comprehensive monitoring and logging |
|
|
- Model checkpointing and validation |
|
|
""" |
|
|
|
|
|
import logging |
|
|
import time |
|
|
import math |
|
|
from typing import Dict, List, Optional, Tuple, Any |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from pathlib import Path |
|
|
|
|
|
from .optimizers import OptimizerFactory |
|
|
from .backpropagation import BackpropagationEngine, LoRABackpropagationEngine |
|
|
from .loss_functions import LossFunctionFactory |
|
|
from .schedulers import SchedulerFactory |
|
|
from .monitoring import GradientMonitor, TrainingMonitor, PerformanceMonitor |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class GradientDescentTrainer: |
|
|
""" |
|
|
Main training class that orchestrates gradient descent optimization |
|
|
|
|
|
This class provides a complete training pipeline with: |
|
|
- Real gradient descent and backpropagation |
|
|
- Comprehensive monitoring and logging |
|
|
- Model checkpointing and validation |
|
|
- Integration with MangoMAS agent system |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
optimizer_type: str = 'adam', |
|
|
learning_rate: float = 1e-3, |
|
|
scheduler_type: str = 'cosine', |
|
|
loss_function_type: str = 'cross_entropy', |
|
|
device: torch.device = None, |
|
|
max_grad_norm: float = 1.0, |
|
|
gradient_accumulation_steps: int = 1, |
|
|
mixed_precision: bool = False, |
|
|
**kwargs): |
|
|
|
|
|
self.optimizer_type = optimizer_type |
|
|
self.learning_rate = learning_rate |
|
|
self.scheduler_type = scheduler_type |
|
|
self.loss_function_type = loss_function_type |
|
|
self.device = device or torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
self.max_grad_norm = max_grad_norm |
|
|
self.gradient_accumulation_steps = gradient_accumulation_steps |
|
|
self.mixed_precision = mixed_precision |
|
|
|
|
|
|
|
|
self.optimizer = None |
|
|
self.scheduler = None |
|
|
self.loss_function = None |
|
|
self.backprop_engine = None |
|
|
|
|
|
|
|
|
self.gradient_monitor = GradientMonitor() |
|
|
self.training_monitor = TrainingMonitor() |
|
|
self.performance_monitor = PerformanceMonitor() |
|
|
|
|
|
|
|
|
self.current_epoch = 0 |
|
|
self.current_step = 0 |
|
|
self.best_loss = float('inf') |
|
|
self.training_start_time = None |
|
|
|
|
|
|
|
|
self.config = { |
|
|
'optimizer_type': optimizer_type, |
|
|
'learning_rate': learning_rate, |
|
|
'scheduler_type': scheduler_type, |
|
|
'loss_function_type': loss_function_type, |
|
|
'max_grad_norm': max_grad_norm, |
|
|
'gradient_accumulation_steps': gradient_accumulation_steps, |
|
|
'mixed_precision': mixed_precision, |
|
|
**kwargs |
|
|
} |
|
|
|
|
|
logger.info(f"Initialized GradientDescentTrainer with config: {self.config}") |
|
|
|
|
|
def setup_training(self, model: nn.Module, training_data: List[Dict[str, Any]]): |
|
|
""" |
|
|
Setup training components |
|
|
|
|
|
Args: |
|
|
model: The neural network model to train |
|
|
training_data: Training dataset |
|
|
""" |
|
|
logger.info("Setting up training components...") |
|
|
|
|
|
|
|
|
model.to(self.device) |
|
|
|
|
|
|
|
|
trainable_params = [p for p in model.parameters() if p.requires_grad] |
|
|
logger.info(f"Found {len(trainable_params)} trainable parameters") |
|
|
|
|
|
|
|
|
optimizer_config = OptimizerFactory.get_default_config(self.optimizer_type) |
|
|
optimizer_config.update({'lr': self.learning_rate}) |
|
|
|
|
|
self.optimizer = OptimizerFactory.create_optimizer( |
|
|
self.optimizer_type, trainable_params, **optimizer_config |
|
|
) |
|
|
|
|
|
|
|
|
scheduler_config = SchedulerFactory.get_default_config(self.scheduler_type) |
|
|
scheduler_config.update({'total_steps': len(training_data)}) |
|
|
|
|
|
self.scheduler = SchedulerFactory.create_scheduler( |
|
|
self.scheduler_type, self.optimizer, **scheduler_config |
|
|
) |
|
|
|
|
|
|
|
|
loss_config = LossFunctionFactory.get_default_config(self.loss_function_type) |
|
|
self.loss_function = LossFunctionFactory.create_loss_function( |
|
|
self.loss_function_type, **loss_config |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(model, 'lora_params'): |
|
|
|
|
|
self.backprop_engine = LoRABackpropagationEngine( |
|
|
model, model.lora_params, self.device |
|
|
) |
|
|
else: |
|
|
|
|
|
self.backprop_engine = BackpropagationEngine(model, self.device) |
|
|
|
|
|
logger.info("Training setup complete") |
|
|
|
|
|
def train_epoch(self, model: nn.Module, training_data: List[Dict[str, Any]], |
|
|
epoch: int) -> Dict[str, float]: |
|
|
""" |
|
|
Train for one epoch using gradient descent and backpropagation |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
training_data: Training dataset |
|
|
epoch: Current epoch number |
|
|
|
|
|
Returns: |
|
|
Dictionary of training metrics |
|
|
""" |
|
|
logger.info(f"Starting epoch {epoch}") |
|
|
|
|
|
model.train() |
|
|
epoch_loss = 0.0 |
|
|
epoch_accuracy = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
num_batches = math.ceil(len(training_data) / batch_size) |
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min(start_idx + batch_size, len(training_data)) |
|
|
batch_data = training_data[start_idx:end_idx] |
|
|
|
|
|
|
|
|
batch_metrics = self.train_batch(model, batch_data, epoch, batch_idx) |
|
|
|
|
|
epoch_loss += batch_metrics['loss'] |
|
|
epoch_accuracy += batch_metrics.get('accuracy', 0.0) |
|
|
|
|
|
|
|
|
self.current_step += 1 |
|
|
|
|
|
|
|
|
if batch_idx % 10 == 0: |
|
|
logger.info(f"Epoch {epoch}, Batch {batch_idx}/{num_batches}, " |
|
|
f"Loss: {batch_metrics['loss']:.4f}") |
|
|
|
|
|
|
|
|
avg_loss = epoch_loss / num_batches |
|
|
avg_accuracy = epoch_accuracy / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
|
|
|
self.training_monitor.update( |
|
|
loss=avg_loss, |
|
|
accuracy=avg_accuracy, |
|
|
learning_rate=self.optimizer.lr, |
|
|
epoch=epoch |
|
|
) |
|
|
|
|
|
|
|
|
self.scheduler.step(epoch=epoch, metrics={'loss': avg_loss}) |
|
|
|
|
|
logger.info(f"Epoch {epoch} complete - Loss: {avg_loss:.4f}, " |
|
|
f"Accuracy: {avg_accuracy:.4f}, LR: {self.optimizer.lr:.6f}") |
|
|
|
|
|
return { |
|
|
'loss': avg_loss, |
|
|
'accuracy': avg_accuracy, |
|
|
'learning_rate': self.optimizer.lr, |
|
|
'num_batches': num_batches |
|
|
} |
|
|
|
|
|
def train_batch(self, model: nn.Module, batch_data: List[Dict[str, Any]], |
|
|
epoch: int, batch_idx: int) -> Dict[str, float]: |
|
|
""" |
|
|
Train on a single batch using gradient descent and backpropagation |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
batch_data: Batch of training data |
|
|
epoch: Current epoch number |
|
|
batch_idx: Current batch index |
|
|
|
|
|
Returns: |
|
|
Dictionary of batch metrics |
|
|
""" |
|
|
|
|
|
inputs, targets = self._prepare_batch(batch_data) |
|
|
|
|
|
|
|
|
with self.performance_monitor.time_step('forward'): |
|
|
outputs = model(inputs) |
|
|
|
|
|
|
|
|
loss = self.loss_function(outputs, targets) |
|
|
|
|
|
|
|
|
if self.gradient_accumulation_steps > 1: |
|
|
loss = loss / self.gradient_accumulation_steps |
|
|
|
|
|
|
|
|
with self.performance_monitor.time_step('backward'): |
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if (batch_idx + 1) % self.gradient_accumulation_steps == 0: |
|
|
|
|
|
grad_norm = self.backprop_engine.apply_gradient_clipping(self.max_grad_norm) |
|
|
|
|
|
|
|
|
gradients = self.backprop_engine.compute_gradients(loss, retain_graph=False) |
|
|
self.gradient_monitor.update(gradients) |
|
|
|
|
|
|
|
|
with self.performance_monitor.time_step('optimizer'): |
|
|
self.optimizer.step() |
|
|
|
|
|
|
|
|
self.optimizer.zero_grad() |
|
|
|
|
|
|
|
|
self.performance_monitor.update_compute_time(time.time() - self.training_start_time) |
|
|
|
|
|
|
|
|
accuracy = self._compute_accuracy(outputs, targets) |
|
|
|
|
|
return { |
|
|
'loss': loss.item() * self.gradient_accumulation_steps, |
|
|
'accuracy': accuracy, |
|
|
'grad_norm': grad_norm if 'grad_norm' in locals() else 0.0 |
|
|
} |
|
|
|
|
|
def _prepare_batch(self, batch_data: List[Dict[str, Any]]) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Prepare batch data for training |
|
|
|
|
|
Args: |
|
|
batch_data: Raw batch data |
|
|
|
|
|
Returns: |
|
|
Tuple of (inputs, targets) tensors |
|
|
""" |
|
|
|
|
|
inputs = [] |
|
|
targets = [] |
|
|
|
|
|
for item in batch_data: |
|
|
|
|
|
if 'instruction' in item and 'response' in item: |
|
|
|
|
|
input_text = item['instruction'] |
|
|
target_text = item['response'] |
|
|
|
|
|
|
|
|
input_tokens = self._simple_tokenize(input_text) |
|
|
target_tokens = self._simple_tokenize(target_text) |
|
|
|
|
|
inputs.append(input_tokens) |
|
|
targets.append(target_tokens) |
|
|
|
|
|
|
|
|
if inputs and targets: |
|
|
|
|
|
max_len = max(len(seq) for seq in inputs + targets) |
|
|
inputs = [seq + [0] * (max_len - len(seq)) for seq in inputs] |
|
|
targets = [seq + [0] * (max_len - len(seq)) for seq in targets] |
|
|
|
|
|
inputs_tensor = torch.tensor(inputs, dtype=torch.long, device=self.device) |
|
|
targets_tensor = torch.tensor(targets, dtype=torch.long, device=self.device) |
|
|
else: |
|
|
|
|
|
batch_size = len(batch_data) |
|
|
seq_len = 128 |
|
|
inputs_tensor = torch.randint(0, 1000, (batch_size, seq_len), device=self.device) |
|
|
targets_tensor = torch.randint(0, 1000, (batch_size, seq_len), device=self.device) |
|
|
|
|
|
return inputs_tensor, targets_tensor |
|
|
|
|
|
def _simple_tokenize(self, text: str) -> List[int]: |
|
|
""" |
|
|
Simple tokenization for demonstration |
|
|
|
|
|
Args: |
|
|
text: Input text |
|
|
|
|
|
Returns: |
|
|
List of token IDs |
|
|
""" |
|
|
|
|
|
tokens = [] |
|
|
for char in text[:100]: |
|
|
tokens.append(ord(char) % 1000) |
|
|
return tokens |
|
|
|
|
|
def _compute_accuracy(self, outputs: torch.Tensor, targets: torch.Tensor) -> float: |
|
|
""" |
|
|
Compute accuracy for the batch |
|
|
|
|
|
Args: |
|
|
outputs: Model outputs |
|
|
targets: Target values |
|
|
|
|
|
Returns: |
|
|
Accuracy score |
|
|
""" |
|
|
if outputs.dim() > 1 and outputs.size(1) > 1: |
|
|
|
|
|
predictions = torch.argmax(outputs, dim=1) |
|
|
if targets.dim() == 1: |
|
|
correct = (predictions == targets).float().sum() |
|
|
accuracy = correct / targets.size(0) |
|
|
else: |
|
|
|
|
|
accuracy = 0.0 |
|
|
else: |
|
|
|
|
|
accuracy = 0.0 |
|
|
|
|
|
return accuracy.item() if isinstance(accuracy, torch.Tensor) else accuracy |
|
|
|
|
|
def validate(self, model: nn.Module, validation_data: List[Dict[str, Any]]) -> Dict[str, float]: |
|
|
""" |
|
|
Validate the model |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
validation_data: Validation dataset |
|
|
|
|
|
Returns: |
|
|
Dictionary of validation metrics |
|
|
""" |
|
|
logger.info("Running validation...") |
|
|
|
|
|
model.eval() |
|
|
total_loss = 0.0 |
|
|
total_accuracy = 0.0 |
|
|
num_batches = 0 |
|
|
|
|
|
with torch.no_grad(): |
|
|
batch_size = 32 |
|
|
num_batches = math.ceil(len(validation_data) / batch_size) |
|
|
|
|
|
for batch_idx in range(num_batches): |
|
|
start_idx = batch_idx * batch_size |
|
|
end_idx = min(start_idx + batch_size, len(validation_data)) |
|
|
batch_data = validation_data[start_idx:end_idx] |
|
|
|
|
|
|
|
|
inputs, targets = self._prepare_batch(batch_data) |
|
|
|
|
|
|
|
|
outputs = model(inputs) |
|
|
|
|
|
|
|
|
loss = self.loss_function(outputs, targets) |
|
|
total_loss += loss.item() |
|
|
|
|
|
|
|
|
accuracy = self._compute_accuracy(outputs, targets) |
|
|
total_accuracy += accuracy |
|
|
|
|
|
|
|
|
avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
|
|
avg_accuracy = total_accuracy / num_batches if num_batches > 0 else 0.0 |
|
|
|
|
|
logger.info(f"Validation - Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.4f}") |
|
|
|
|
|
return { |
|
|
'val_loss': avg_loss, |
|
|
'val_accuracy': avg_accuracy |
|
|
} |
|
|
|
|
|
def train(self, model: nn.Module, training_data: List[Dict[str, Any]], |
|
|
validation_data: Optional[List[Dict[str, Any]]] = None, |
|
|
num_epochs: int = 10, save_dir: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Complete training loop with gradient descent and backpropagation |
|
|
|
|
|
Args: |
|
|
model: The neural network model to train |
|
|
training_data: Training dataset |
|
|
validation_data: Validation dataset (optional) |
|
|
num_epochs: Number of training epochs |
|
|
save_dir: Directory to save checkpoints |
|
|
|
|
|
Returns: |
|
|
Dictionary of training results |
|
|
""" |
|
|
logger.info(f"Starting training for {num_epochs} epochs") |
|
|
|
|
|
|
|
|
self.setup_training(model, training_data) |
|
|
|
|
|
|
|
|
self.training_start_time = time.time() |
|
|
self.current_epoch = 0 |
|
|
self.current_step = 0 |
|
|
|
|
|
|
|
|
training_history = [] |
|
|
validation_history = [] |
|
|
|
|
|
|
|
|
for epoch in range(num_epochs): |
|
|
self.current_epoch = epoch |
|
|
|
|
|
|
|
|
epoch_metrics = self.train_epoch(model, training_data, epoch) |
|
|
training_history.append(epoch_metrics) |
|
|
|
|
|
|
|
|
if validation_data: |
|
|
val_metrics = self.validate(model, validation_data) |
|
|
validation_history.append(val_metrics) |
|
|
|
|
|
|
|
|
if val_metrics['val_loss'] < self.best_loss: |
|
|
self.best_loss = val_metrics['val_loss'] |
|
|
|
|
|
|
|
|
if save_dir: |
|
|
self.save_checkpoint(model, save_dir, epoch, val_metrics) |
|
|
|
|
|
|
|
|
if self.training_monitor.detect_convergence(): |
|
|
logger.info("Training converged, stopping early") |
|
|
break |
|
|
|
|
|
|
|
|
logger.info(f"Epoch {epoch} Summary:") |
|
|
logger.info(f" Training Loss: {epoch_metrics['loss']:.4f}") |
|
|
logger.info(f" Training Accuracy: {epoch_metrics['accuracy']:.4f}") |
|
|
if validation_data: |
|
|
logger.info(f" Validation Loss: {val_metrics['val_loss']:.4f}") |
|
|
logger.info(f" Validation Accuracy: {val_metrics['val_accuracy']:.4f}") |
|
|
logger.info(f" Learning Rate: {self.optimizer.lr:.6f}") |
|
|
|
|
|
|
|
|
training_time = time.time() - self.training_start_time |
|
|
|
|
|
|
|
|
gradient_stats = self.gradient_monitor.get_statistics() |
|
|
training_stats = self.training_monitor.get_statistics() |
|
|
performance_stats = self.performance_monitor.get_statistics() |
|
|
|
|
|
results = { |
|
|
'training_history': training_history, |
|
|
'validation_history': validation_history, |
|
|
'final_metrics': { |
|
|
'best_loss': self.best_loss, |
|
|
'final_loss': training_history[-1]['loss'] if training_history else 0.0, |
|
|
'final_accuracy': training_history[-1]['accuracy'] if training_history else 0.0, |
|
|
'training_time': training_time, |
|
|
'total_steps': self.current_step, |
|
|
'total_epochs': self.current_epoch + 1 |
|
|
}, |
|
|
'gradient_stats': gradient_stats, |
|
|
'training_stats': training_stats, |
|
|
'performance_stats': performance_stats, |
|
|
'config': self.config |
|
|
} |
|
|
|
|
|
logger.info("Training complete!") |
|
|
logger.info(f"Final Loss: {results['final_metrics']['final_loss']:.4f}") |
|
|
logger.info(f"Best Loss: {results['final_metrics']['best_loss']:.4f}") |
|
|
logger.info(f"Training Time: {training_time:.2f} seconds") |
|
|
|
|
|
return results |
|
|
|
|
|
def save_checkpoint(self, model: nn.Module, save_dir: str, epoch: int, |
|
|
metrics: Dict[str, float]): |
|
|
""" |
|
|
Save model checkpoint |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
save_dir: Directory to save checkpoint |
|
|
epoch: Current epoch |
|
|
metrics: Training metrics |
|
|
""" |
|
|
save_path = Path(save_dir) |
|
|
save_path.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
checkpoint = { |
|
|
'epoch': epoch, |
|
|
'model_state_dict': model.state_dict(), |
|
|
'optimizer_state_dict': self.optimizer.state_dict(), |
|
|
'scheduler_state_dict': self.scheduler.state_dict(), |
|
|
'best_loss': self.best_loss, |
|
|
'metrics': metrics, |
|
|
'config': self.config |
|
|
} |
|
|
|
|
|
checkpoint_path = save_path / f'checkpoint_epoch_{epoch}.pt' |
|
|
torch.save(checkpoint, checkpoint_path) |
|
|
|
|
|
logger.info(f"Checkpoint saved to {checkpoint_path}") |
|
|
|
|
|
def load_checkpoint(self, model: nn.Module, checkpoint_path: str): |
|
|
""" |
|
|
Load model checkpoint |
|
|
|
|
|
Args: |
|
|
model: The neural network model |
|
|
checkpoint_path: Path to checkpoint file |
|
|
""" |
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
|
|
self.scheduler.load_state_dict(checkpoint['scheduler_state_dict']) |
|
|
self.best_loss = checkpoint['best_loss'] |
|
|
|
|
|
logger.info(f"Checkpoint loaded from {checkpoint_path}") |
|
|
|
|
|
def get_training_summary(self) -> Dict[str, Any]: |
|
|
""" |
|
|
Get comprehensive training summary |
|
|
|
|
|
Returns: |
|
|
Dictionary of training summary |
|
|
""" |
|
|
return { |
|
|
'gradient_stats': self.gradient_monitor.get_statistics(), |
|
|
'training_stats': self.training_monitor.get_statistics(), |
|
|
'performance_stats': self.performance_monitor.get_statistics(), |
|
|
'anomalies': self.gradient_monitor.detect_anomalies(), |
|
|
'convergence': self.training_monitor.detect_convergence() |
|
|
} |
|
|
|