"""Proximal Policy Optimization (PPO) algorithm implementation.""" import torch import torch.nn as nn import torch.optim as optim from typing import Dict, Any, Optional import logging from .algorithm_base import RLAlgorithm logger = logging.getLogger(__name__) class PPOAlgorithm(RLAlgorithm): """ Proximal Policy Optimization (PPO) algorithm. PPO is a policy gradient method that uses a clipped objective to prevent large policy updates, improving training stability. """ def __init__( self, model: nn.Module, learning_rate: float = 3e-4, clip_epsilon: float = 0.2, gamma: float = 0.99, gae_lambda: float = 0.95, value_loss_coef: float = 0.5, entropy_coef: float = 0.01, max_grad_norm: float = 0.5, **kwargs ): """ Initialize PPO algorithm. Args: model: The policy/value network learning_rate: Learning rate for optimizer clip_epsilon: PPO clipping parameter gamma: Discount factor gae_lambda: GAE lambda parameter for advantage estimation value_loss_coef: Coefficient for value loss entropy_coef: Coefficient for entropy bonus max_grad_norm: Maximum gradient norm for clipping **kwargs: Additional hyperparameters """ super().__init__(learning_rate, **kwargs) self.model = model self.clip_epsilon = clip_epsilon self.gamma = gamma self.gae_lambda = gae_lambda self.value_loss_coef = value_loss_coef self.entropy_coef = entropy_coef self.max_grad_norm = max_grad_norm self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) logger.info(f"Initialized PPO with clip_epsilon={clip_epsilon}, gamma={gamma}") def compute_loss( self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, old_log_probs: Optional[torch.Tensor] = None, values: Optional[torch.Tensor] = None, dones: Optional[torch.Tensor] = None, **kwargs ) -> torch.Tensor: """ Compute PPO loss. Args: states: Current states actions: Actions taken rewards: Rewards received next_states: Next states old_log_probs: Log probabilities from old policy values: Value estimates from old policy dones: Done flags **kwargs: Additional inputs Returns: Total PPO loss """ # Get current policy outputs (log_probs, values, entropy from RL model) outputs = self.model(states) # Extract log probs and values from model output if isinstance(outputs, tuple) and len(outputs) >= 2: # RL-compatible model returns (log_probs, values, ...) action_logits, new_values, _ = outputs if len(outputs) == 3 else (*outputs, None) # Compute log probs for taken actions if action_logits.shape[-1] > 1: # Discrete actions log_probs_dist = torch.log_softmax(action_logits, dim=-1) # Handle actions shape if actions.dim() == 1: new_log_probs = log_probs_dist.gather(-1, actions.unsqueeze(-1)).squeeze(-1) else: # For continuous actions, compute Gaussian log prob new_log_probs = -0.5 * ((actions - action_logits) ** 2).sum(dim=-1) else: new_log_probs = action_logits.squeeze(-1) else: # Fallback for non-RL models new_log_probs = torch.log_softmax(outputs, dim=-1) if actions.dim() > 0 and new_log_probs.dim() > 1: new_log_probs = new_log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) new_values = None # Compute advantages using GAE if we have values if values is not None and dones is not None: advantages = self._compute_gae(rewards, values, next_states, dones) returns = advantages + values else: # Simple advantage estimation advantages = rewards returns = rewards # Normalize advantages advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Compute policy loss (PPO clipped objective) if old_log_probs is not None: # Compute probability ratio ratio = torch.exp(new_log_probs - old_log_probs) # Clipped surrogate loss clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) surrogate1 = ratio * advantages surrogate2 = clipped_ratio * advantages policy_loss = -torch.min(surrogate1, surrogate2).mean() else: # Fallback to simple policy gradient if no old log probs policy_loss = -(new_log_probs * advantages).mean() # Compute value loss if we have value predictions value_loss = torch.tensor(0.0, device=states.device) if new_values is not None: # Ensure shapes match for value loss computation # new_values typically has shape [batch, 1] or [batch], returns has shape [batch] new_values_flat = new_values.squeeze(-1) if new_values.dim() > 1 else new_values returns_flat = returns.view(-1) if returns.dim() > 1 else returns value_loss = nn.functional.mse_loss(new_values_flat, returns_flat) # Compute entropy bonus for exploration entropy = torch.tensor(0.0, device=states.device) if isinstance(outputs, tuple) and len(outputs) > 2 and outputs[2] is not None: entropy = outputs[2] # Total loss total_loss = ( policy_loss + self.value_loss_coef * value_loss - self.entropy_coef * entropy ) # Store loss components for logging self.last_loss_components = { 'policy_loss': policy_loss.item(), 'value_loss': value_loss.item(), 'entropy': entropy.item() if isinstance(entropy, torch.Tensor) else entropy, 'total_loss': total_loss.item() } return total_loss def _compute_gae( self, rewards: torch.Tensor, values: torch.Tensor, next_states: torch.Tensor, dones: torch.Tensor ) -> torch.Tensor: """ Compute Generalized Advantage Estimation (GAE). Args: rewards: Rewards tensor [batch_size] or [timesteps, batch_size] values: Value estimates [batch_size] or [timesteps, batch_size] next_states: Next states dones: Done flags [batch_size] or [timesteps, batch_size] Returns: Advantages tensor """ # Get next values with torch.no_grad(): next_outputs = self.model(next_states) if isinstance(next_outputs, tuple): next_values = next_outputs[1] else: next_values = torch.zeros_like(values) # Ensure next_values has the same shape as values if next_values.dim() > values.dim(): next_values = next_values.squeeze() # Compute TD errors (temporal difference) deltas = rewards + self.gamma * next_values * (1 - dones) - values # For batched data (single timestep), GAE simplifies to TD error # For sequential data, we need to iterate backwards through time if rewards.dim() == 1: # Single timestep batch: advantages = TD errors advantages = deltas else: # Multiple timesteps: compute GAE backwards through time advantages = torch.zeros_like(rewards) gae = torch.zeros(rewards.shape[1], device=rewards.device) # [batch_size] for t in reversed(range(rewards.shape[0])): gae = deltas[t] + self.gamma * self.gae_lambda * (1 - dones[t]) * gae advantages[t] = gae return advantages def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]: """ Update policy using computed loss. Args: loss: Computed loss tensor Returns: Dictionary with update metrics """ # Zero gradients self.optimizer.zero_grad() # Backward pass loss.backward() # Clip gradients grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) # Update parameters self.optimizer.step() metrics = { 'grad_norm': grad_norm.item(), 'learning_rate': self.learning_rate, } # Add loss components if available if hasattr(self, 'last_loss_components'): metrics.update(self.last_loss_components) return metrics def get_hyperparameters(self) -> Dict[str, Any]: """Get all hyperparameters.""" base_params = super().get_hyperparameters() ppo_params = { 'clip_epsilon': self.clip_epsilon, 'gamma': self.gamma, 'gae_lambda': self.gae_lambda, 'value_loss_coef': self.value_loss_coef, 'entropy_coef': self.entropy_coef, 'max_grad_norm': self.max_grad_norm, } return {**base_params, **ppo_params}