Spaces:
Runtime error
Runtime error
| """Proximal Policy Optimization (PPO) algorithm implementation.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from typing import Dict, Any, Optional | |
| import logging | |
| from .algorithm_base import RLAlgorithm | |
| logger = logging.getLogger(__name__) | |
| class PPOAlgorithm(RLAlgorithm): | |
| """ | |
| Proximal Policy Optimization (PPO) algorithm. | |
| PPO is a policy gradient method that uses a clipped objective | |
| to prevent large policy updates, improving training stability. | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| learning_rate: float = 3e-4, | |
| clip_epsilon: float = 0.2, | |
| gamma: float = 0.99, | |
| gae_lambda: float = 0.95, | |
| value_loss_coef: float = 0.5, | |
| entropy_coef: float = 0.01, | |
| max_grad_norm: float = 0.5, | |
| **kwargs | |
| ): | |
| """ | |
| Initialize PPO algorithm. | |
| Args: | |
| model: The policy/value network | |
| learning_rate: Learning rate for optimizer | |
| clip_epsilon: PPO clipping parameter | |
| gamma: Discount factor | |
| gae_lambda: GAE lambda parameter for advantage estimation | |
| value_loss_coef: Coefficient for value loss | |
| entropy_coef: Coefficient for entropy bonus | |
| max_grad_norm: Maximum gradient norm for clipping | |
| **kwargs: Additional hyperparameters | |
| """ | |
| super().__init__(learning_rate, **kwargs) | |
| self.model = model | |
| self.clip_epsilon = clip_epsilon | |
| self.gamma = gamma | |
| self.gae_lambda = gae_lambda | |
| self.value_loss_coef = value_loss_coef | |
| self.entropy_coef = entropy_coef | |
| self.max_grad_norm = max_grad_norm | |
| self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| logger.info(f"Initialized PPO with clip_epsilon={clip_epsilon}, gamma={gamma}") | |
| def compute_loss( | |
| self, | |
| states: torch.Tensor, | |
| actions: torch.Tensor, | |
| rewards: torch.Tensor, | |
| next_states: torch.Tensor, | |
| old_log_probs: Optional[torch.Tensor] = None, | |
| values: Optional[torch.Tensor] = None, | |
| dones: Optional[torch.Tensor] = None, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Compute PPO loss. | |
| Args: | |
| states: Current states | |
| actions: Actions taken | |
| rewards: Rewards received | |
| next_states: Next states | |
| old_log_probs: Log probabilities from old policy | |
| values: Value estimates from old policy | |
| dones: Done flags | |
| **kwargs: Additional inputs | |
| Returns: | |
| Total PPO loss | |
| """ | |
| # Get current policy outputs (log_probs, values, entropy from RL model) | |
| outputs = self.model(states) | |
| # Extract log probs and values from model output | |
| if isinstance(outputs, tuple) and len(outputs) >= 2: | |
| # RL-compatible model returns (log_probs, values, ...) | |
| action_logits, new_values, _ = outputs if len(outputs) == 3 else (*outputs, None) | |
| # Compute log probs for taken actions | |
| if action_logits.shape[-1] > 1: # Discrete actions | |
| log_probs_dist = torch.log_softmax(action_logits, dim=-1) | |
| # Handle actions shape | |
| if actions.dim() == 1: | |
| new_log_probs = log_probs_dist.gather(-1, actions.unsqueeze(-1)).squeeze(-1) | |
| else: | |
| # For continuous actions, compute Gaussian log prob | |
| new_log_probs = -0.5 * ((actions - action_logits) ** 2).sum(dim=-1) | |
| else: | |
| new_log_probs = action_logits.squeeze(-1) | |
| else: | |
| # Fallback for non-RL models | |
| new_log_probs = torch.log_softmax(outputs, dim=-1) | |
| if actions.dim() > 0 and new_log_probs.dim() > 1: | |
| new_log_probs = new_log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) | |
| new_values = None | |
| # Compute advantages using GAE if we have values | |
| if values is not None and dones is not None: | |
| advantages = self._compute_gae(rewards, values, next_states, dones) | |
| returns = advantages + values | |
| else: | |
| # Simple advantage estimation | |
| advantages = rewards | |
| returns = rewards | |
| # Normalize advantages | |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| # Compute policy loss (PPO clipped objective) | |
| if old_log_probs is not None: | |
| # Compute probability ratio | |
| ratio = torch.exp(new_log_probs - old_log_probs) | |
| # Clipped surrogate loss | |
| clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon) | |
| surrogate1 = ratio * advantages | |
| surrogate2 = clipped_ratio * advantages | |
| policy_loss = -torch.min(surrogate1, surrogate2).mean() | |
| else: | |
| # Fallback to simple policy gradient if no old log probs | |
| policy_loss = -(new_log_probs * advantages).mean() | |
| # Compute value loss if we have value predictions | |
| value_loss = torch.tensor(0.0, device=states.device) | |
| if new_values is not None: | |
| # Ensure shapes match for value loss computation | |
| # new_values typically has shape [batch, 1] or [batch], returns has shape [batch] | |
| new_values_flat = new_values.squeeze(-1) if new_values.dim() > 1 else new_values | |
| returns_flat = returns.view(-1) if returns.dim() > 1 else returns | |
| value_loss = nn.functional.mse_loss(new_values_flat, returns_flat) | |
| # Compute entropy bonus for exploration | |
| entropy = torch.tensor(0.0, device=states.device) | |
| if isinstance(outputs, tuple) and len(outputs) > 2 and outputs[2] is not None: | |
| entropy = outputs[2] | |
| # Total loss | |
| total_loss = ( | |
| policy_loss + | |
| self.value_loss_coef * value_loss - | |
| self.entropy_coef * entropy | |
| ) | |
| # Store loss components for logging | |
| self.last_loss_components = { | |
| 'policy_loss': policy_loss.item(), | |
| 'value_loss': value_loss.item(), | |
| 'entropy': entropy.item() if isinstance(entropy, torch.Tensor) else entropy, | |
| 'total_loss': total_loss.item() | |
| } | |
| return total_loss | |
| def _compute_gae( | |
| self, | |
| rewards: torch.Tensor, | |
| values: torch.Tensor, | |
| next_states: torch.Tensor, | |
| dones: torch.Tensor | |
| ) -> torch.Tensor: | |
| """ | |
| Compute Generalized Advantage Estimation (GAE). | |
| Args: | |
| rewards: Rewards tensor [batch_size] or [timesteps, batch_size] | |
| values: Value estimates [batch_size] or [timesteps, batch_size] | |
| next_states: Next states | |
| dones: Done flags [batch_size] or [timesteps, batch_size] | |
| Returns: | |
| Advantages tensor | |
| """ | |
| # Get next values | |
| with torch.no_grad(): | |
| next_outputs = self.model(next_states) | |
| if isinstance(next_outputs, tuple): | |
| next_values = next_outputs[1] | |
| else: | |
| next_values = torch.zeros_like(values) | |
| # Ensure next_values has the same shape as values | |
| if next_values.dim() > values.dim(): | |
| next_values = next_values.squeeze() | |
| # Compute TD errors (temporal difference) | |
| deltas = rewards + self.gamma * next_values * (1 - dones) - values | |
| # For batched data (single timestep), GAE simplifies to TD error | |
| # For sequential data, we need to iterate backwards through time | |
| if rewards.dim() == 1: | |
| # Single timestep batch: advantages = TD errors | |
| advantages = deltas | |
| else: | |
| # Multiple timesteps: compute GAE backwards through time | |
| advantages = torch.zeros_like(rewards) | |
| gae = torch.zeros(rewards.shape[1], device=rewards.device) # [batch_size] | |
| for t in reversed(range(rewards.shape[0])): | |
| gae = deltas[t] + self.gamma * self.gae_lambda * (1 - dones[t]) * gae | |
| advantages[t] = gae | |
| return advantages | |
| def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]: | |
| """ | |
| Update policy using computed loss. | |
| Args: | |
| loss: Computed loss tensor | |
| Returns: | |
| Dictionary with update metrics | |
| """ | |
| # Zero gradients | |
| self.optimizer.zero_grad() | |
| # Backward pass | |
| loss.backward() | |
| # Clip gradients | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| self.max_grad_norm | |
| ) | |
| # Update parameters | |
| self.optimizer.step() | |
| metrics = { | |
| 'grad_norm': grad_norm.item(), | |
| 'learning_rate': self.learning_rate, | |
| } | |
| # Add loss components if available | |
| if hasattr(self, 'last_loss_components'): | |
| metrics.update(self.last_loss_components) | |
| return metrics | |
| def get_hyperparameters(self) -> Dict[str, Any]: | |
| """Get all hyperparameters.""" | |
| base_params = super().get_hyperparameters() | |
| ppo_params = { | |
| 'clip_epsilon': self.clip_epsilon, | |
| 'gamma': self.gamma, | |
| 'gae_lambda': self.gae_lambda, | |
| 'value_loss_coef': self.value_loss_coef, | |
| 'entropy_coef': self.entropy_coef, | |
| 'max_grad_norm': self.max_grad_norm, | |
| } | |
| return {**base_params, **ppo_params} | |