""" Real PyTorch Training Engine for the ML Training Optimizer environment. Manages the full training lifecycle: - Model creation with configurable architecture and dropout - Optimizer creation (SGD, Adam, AdamW) with configurable hyperparameters - Learning rate scheduling (constant, step, cosine, warmup_cosine) - Data augmentation control - Training loop execution - Evaluation and metric tracking - Best model checkpointing """ import math import time from dataclasses import dataclass, field from typing import Dict, List, Optional, Tuple import torch import torch.nn as nn from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR, LambdaLR from .models_nn import create_model from .datasets import create_dataloaders @dataclass class TrainingConfig: """Configuration for a training run, set by the agent.""" optimizer: str = "adam" # sgd, adam, adamw learning_rate: float = 0.001 batch_size: int = 64 weight_decay: float = 0.0 dropout: float = 0.0 lr_schedule: str = "constant" # constant, step, cosine, warmup_cosine warmup_epochs: int = 5 augmentation: bool = False augmentation_strength: float = 0.5 @dataclass class TrainingMetrics: """Metrics from a training run.""" train_loss: float = 0.0 val_loss: float = 0.0 train_accuracy: float = 0.0 val_accuracy: float = 0.0 current_lr: float = 0.0 epoch_time_seconds: float = 0.0 @dataclass class TrainingState: """Full internal state of the trainer.""" current_epoch: int = 0 total_epochs_run: int = 0 best_val_accuracy: float = 0.0 best_val_epoch: int = 0 train_loss_history: List[float] = field(default_factory=list) val_loss_history: List[float] = field(default_factory=list) train_acc_history: List[float] = field(default_factory=list) val_acc_history: List[float] = field(default_factory=list) lr_history: List[float] = field(default_factory=list) is_diverged: bool = False config: Optional[TrainingConfig] = None class Trainer: """ Manages real PyTorch model training on CPU. Designed to be used by the environment: the agent configures hyperparameters, then calls train_epochs() to actually train the model. The trainer tracks all metrics and maintains the best model checkpoint. """ def __init__( self, model_type: str, dataset_name: str, max_epochs: int, seed: int = 42, ): self.model_type = model_type self.dataset_name = dataset_name self.max_epochs = max_epochs self.seed = seed # Set deterministic behavior torch.manual_seed(seed) torch.set_num_threads(2) # Match 2 vCPU constraint if hasattr(torch, 'use_deterministic_algorithms'): try: torch.use_deterministic_algorithms(True) except Exception: pass # Some ops don't have deterministic impl self.state = TrainingState() self.model: Optional[nn.Module] = None self.optimizer: Optional[torch.optim.Optimizer] = None self.scheduler = None self.train_loader = None self.val_loader = None self.best_model_state: Optional[dict] = None self.device = torch.device("cpu") self.criterion = nn.CrossEntropyLoss() self._initialized = False def configure(self, config: TrainingConfig) -> str: """ Configure (or reconfigure) the training setup. Creates/rebuilds model, optimizer, scheduler, and data loaders based on the provided configuration. Returns: Status message describing what was configured. """ old_config = self.state.config self.state.config = config # Rebuild model if dropout changed or first time rebuild_model = ( not self._initialized or old_config is None or old_config.dropout != config.dropout ) if rebuild_model: torch.manual_seed(self.seed) self.model = create_model(self.model_type, dropout=config.dropout) self.model.to(self.device) self.state.current_epoch = 0 self.state.total_epochs_run = 0 self.state.best_val_accuracy = 0.0 self.state.best_val_epoch = 0 self.state.train_loss_history = [] self.state.val_loss_history = [] self.state.train_acc_history = [] self.state.val_acc_history = [] self.state.lr_history = [] self.state.is_diverged = False self.best_model_state = None # Rebuild data loaders if batch size or augmentation changed rebuild_data = ( not self._initialized or old_config is None or old_config.batch_size != config.batch_size or old_config.augmentation != config.augmentation or old_config.augmentation_strength != config.augmentation_strength ) if rebuild_data: self.train_loader, self.val_loader, _, _ = create_dataloaders( self.dataset_name, batch_size=config.batch_size, seed=self.seed, augment=config.augmentation, aug_strength=config.augmentation_strength, ) # Always rebuild optimizer (LR, weight decay, or optimizer type may change) self._build_optimizer(config) # Always rebuild scheduler remaining = max(1, self.max_epochs - self.state.current_epoch) self._build_scheduler(config, remaining) self._initialized = True status_parts = [] if rebuild_model: param_count = sum(p.numel() for p in self.model.parameters()) status_parts.append(f"Model created: {self.model_type} ({param_count:,} params, dropout={config.dropout})") if rebuild_data: status_parts.append(f"Data loaded: {self.dataset_name} (batch_size={config.batch_size}, aug={'on' if config.augmentation else 'off'})") status_parts.append(f"Optimizer: {config.optimizer} (lr={config.learning_rate}, wd={config.weight_decay})") status_parts.append(f"Schedule: {config.lr_schedule}") return " | ".join(status_parts) def _build_optimizer(self, config: TrainingConfig): """Create optimizer from config.""" params = self.model.parameters() if config.optimizer == "sgd": self.optimizer = torch.optim.SGD( params, lr=config.learning_rate, weight_decay=config.weight_decay, momentum=0.9, ) elif config.optimizer == "adam": self.optimizer = torch.optim.Adam( params, lr=config.learning_rate, weight_decay=config.weight_decay, ) elif config.optimizer == "adamw": self.optimizer = torch.optim.AdamW( params, lr=config.learning_rate, weight_decay=config.weight_decay, ) else: raise ValueError(f"Unknown optimizer: {config.optimizer}") def _build_scheduler(self, config: TrainingConfig, total_epochs: int): """Create learning rate scheduler from config.""" if config.lr_schedule == "constant": self.scheduler = None elif config.lr_schedule == "step": self.scheduler = StepLR(self.optimizer, step_size=max(1, total_epochs // 3), gamma=0.1) elif config.lr_schedule == "cosine": self.scheduler = CosineAnnealingLR(self.optimizer, T_max=total_epochs) elif config.lr_schedule == "warmup_cosine": warmup = config.warmup_epochs def lr_lambda(epoch): if epoch < warmup: return (epoch + 1) / warmup progress = (epoch - warmup) / max(1, total_epochs - warmup) return 0.5 * (1.0 + math.cos(math.pi * progress)) self.scheduler = LambdaLR(self.optimizer, lr_lambda) else: raise ValueError(f"Unknown LR schedule: {config.lr_schedule}") def adjust_learning_rate(self, new_lr: float) -> str: """Adjust the learning rate on the live optimizer.""" if self.optimizer is None: return "Error: No optimizer configured. Call configure_training first." for param_group in self.optimizer.param_groups: param_group['lr'] = new_lr if self.state.config: self.state.config.learning_rate = new_lr return f"Learning rate adjusted to {new_lr}" def train_epochs(self, num_epochs: int) -> List[TrainingMetrics]: """ Train for num_epochs epochs. Returns metrics for each epoch. This runs REAL PyTorch training — forward pass, loss computation, backward pass, optimizer step — on actual data. """ if not self._initialized or self.model is None: raise RuntimeError("Trainer not configured. Call configure() first.") if self.state.is_diverged: return [TrainingMetrics( train_loss=float('inf'), val_loss=float('inf'), train_accuracy=0.0, val_accuracy=0.0, current_lr=0.0, epoch_time_seconds=0.0, )] # Clamp to budget remaining = self.max_epochs - self.state.current_epoch actual_epochs = min(num_epochs, remaining) if actual_epochs <= 0: return [] all_metrics = [] for _ in range(actual_epochs): epoch_start = time.time() # --- Training phase --- self.model.train() running_loss = 0.0 correct = 0 total = 0 for inputs, targets in self.train_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) self.optimizer.zero_grad() outputs = self.model(inputs) loss = self.criterion(outputs, targets) loss.backward() self.optimizer.step() running_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() train_loss = running_loss / total train_acc = correct / total # Check for divergence if math.isnan(train_loss) or math.isinf(train_loss) or train_loss > 100: self.state.is_diverged = True metrics = TrainingMetrics( train_loss=float('inf'), val_loss=float('inf'), train_accuracy=0.0, val_accuracy=0.0, current_lr=self.optimizer.param_groups[0]['lr'], epoch_time_seconds=time.time() - epoch_start, ) all_metrics.append(metrics) break # --- Validation phase --- val_loss, val_acc = self._evaluate() # Step scheduler if self.scheduler is not None: self.scheduler.step() current_lr = self.optimizer.param_groups[0]['lr'] # Update state self.state.current_epoch += 1 self.state.total_epochs_run += 1 self.state.train_loss_history.append(train_loss) self.state.val_loss_history.append(val_loss) self.state.train_acc_history.append(train_acc) self.state.val_acc_history.append(val_acc) self.state.lr_history.append(current_lr) # Track best model if val_acc > self.state.best_val_accuracy: self.state.best_val_accuracy = val_acc self.state.best_val_epoch = self.state.current_epoch self.best_model_state = { k: v.clone() for k, v in self.model.state_dict().items() } epoch_time = time.time() - epoch_start metrics = TrainingMetrics( train_loss=train_loss, val_loss=val_loss, train_accuracy=train_acc, val_accuracy=val_acc, current_lr=current_lr, epoch_time_seconds=epoch_time, ) all_metrics.append(metrics) return all_metrics def _evaluate(self) -> Tuple[float, float]: """Evaluate model on validation set. Returns (val_loss, val_accuracy).""" self.model.eval() running_loss = 0.0 correct = 0 total = 0 with torch.no_grad(): for inputs, targets in self.val_loader: inputs, targets = inputs.to(self.device), targets.to(self.device) outputs = self.model(inputs) loss = self.criterion(outputs, targets) running_loss += loss.item() * inputs.size(0) _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() return running_loss / total, correct / total def get_metrics_summary(self) -> Dict: """Get current training metrics as a dictionary.""" if not self.state.train_loss_history: return { "current_epoch": 0, "max_epochs": self.max_epochs, "remaining_budget": self.max_epochs, "train_loss": 0.0, "val_loss": 0.0, "train_accuracy": 0.0, "val_accuracy": 0.0, "best_val_accuracy": 0.0, "best_val_epoch": 0, "loss_history_last_10": [], "val_loss_history_last_10": [], "convergence_signal": "not_started", "is_diverged": False, } # Determine convergence signal signal = self._get_convergence_signal() return { "current_epoch": self.state.current_epoch, "max_epochs": self.max_epochs, "remaining_budget": self.max_epochs - self.state.current_epoch, "train_loss": round(self.state.train_loss_history[-1], 6), "val_loss": round(self.state.val_loss_history[-1], 6), "train_accuracy": round(self.state.train_acc_history[-1], 4), "val_accuracy": round(self.state.val_acc_history[-1], 4), "best_val_accuracy": round(self.state.best_val_accuracy, 4), "best_val_epoch": self.state.best_val_epoch, "loss_history_last_10": [round(x, 4) for x in self.state.train_loss_history[-10:]], "val_loss_history_last_10": [round(x, 4) for x in self.state.val_loss_history[-10:]], "convergence_signal": signal, "is_diverged": self.state.is_diverged, } def _get_convergence_signal(self) -> str: """Analyze recent training dynamics to produce a human-readable signal.""" if self.state.is_diverged: return "diverged" if len(self.state.val_loss_history) < 3: return "warming_up" recent_val = self.state.val_loss_history[-5:] recent_train = self.state.train_loss_history[-5:] # Check overfitting: train improving but val getting worse if len(recent_val) >= 3: val_trend = recent_val[-1] - recent_val[0] train_trend = recent_train[-1] - recent_train[0] if val_trend > 0 and train_trend < 0: return "overfitting" # Check plateau: val loss barely changing if len(recent_val) >= 3: val_range = max(recent_val) - min(recent_val) if val_range < 0.005: return "plateaued" # Check if improving if recent_val[-1] < recent_val[0]: return "improving" return "stalling" def get_overfitting_gap(self) -> float: """Returns train_accuracy - val_accuracy (overfitting indicator).""" if not self.state.train_acc_history: return 0.0 return self.state.train_acc_history[-1] - self.state.val_acc_history[-1] def get_wasted_epochs(self) -> int: """Epochs trained after best val accuracy was achieved.""" if self.state.best_val_epoch == 0: return 0 return max(0, self.state.current_epoch - self.state.best_val_epoch) def get_loss_variance(self, window: int = 10) -> float: """Variance of recent validation losses.""" if len(self.state.val_loss_history) < 2: return 0.0 recent = self.state.val_loss_history[-window:] mean = sum(recent) / len(recent) return sum((x - mean) ** 2 for x in recent) / len(recent)