mbellan's picture
Initial deployment
c3efd49
"""Proximal Policy Optimization (PPO) algorithm implementation."""
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Any, Optional
import logging
from .algorithm_base import RLAlgorithm
logger = logging.getLogger(__name__)
class PPOAlgorithm(RLAlgorithm):
"""
Proximal Policy Optimization (PPO) algorithm.
PPO is a policy gradient method that uses a clipped objective
to prevent large policy updates, improving training stability.
"""
def __init__(
self,
model: nn.Module,
learning_rate: float = 3e-4,
clip_epsilon: float = 0.2,
gamma: float = 0.99,
gae_lambda: float = 0.95,
value_loss_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
**kwargs
):
"""
Initialize PPO algorithm.
Args:
model: The policy/value network
learning_rate: Learning rate for optimizer
clip_epsilon: PPO clipping parameter
gamma: Discount factor
gae_lambda: GAE lambda parameter for advantage estimation
value_loss_coef: Coefficient for value loss
entropy_coef: Coefficient for entropy bonus
max_grad_norm: Maximum gradient norm for clipping
**kwargs: Additional hyperparameters
"""
super().__init__(learning_rate, **kwargs)
self.model = model
self.clip_epsilon = clip_epsilon
self.gamma = gamma
self.gae_lambda = gae_lambda
self.value_loss_coef = value_loss_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
logger.info(f"Initialized PPO with clip_epsilon={clip_epsilon}, gamma={gamma}")
def compute_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
old_log_probs: Optional[torch.Tensor] = None,
values: Optional[torch.Tensor] = None,
dones: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
"""
Compute PPO loss.
Args:
states: Current states
actions: Actions taken
rewards: Rewards received
next_states: Next states
old_log_probs: Log probabilities from old policy
values: Value estimates from old policy
dones: Done flags
**kwargs: Additional inputs
Returns:
Total PPO loss
"""
# Get current policy outputs (log_probs, values, entropy from RL model)
outputs = self.model(states)
# Extract log probs and values from model output
if isinstance(outputs, tuple) and len(outputs) >= 2:
# RL-compatible model returns (log_probs, values, ...)
action_logits, new_values, _ = outputs if len(outputs) == 3 else (*outputs, None)
# Compute log probs for taken actions
if action_logits.shape[-1] > 1: # Discrete actions
log_probs_dist = torch.log_softmax(action_logits, dim=-1)
# Handle actions shape
if actions.dim() == 1:
new_log_probs = log_probs_dist.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
else:
# For continuous actions, compute Gaussian log prob
new_log_probs = -0.5 * ((actions - action_logits) ** 2).sum(dim=-1)
else:
new_log_probs = action_logits.squeeze(-1)
else:
# Fallback for non-RL models
new_log_probs = torch.log_softmax(outputs, dim=-1)
if actions.dim() > 0 and new_log_probs.dim() > 1:
new_log_probs = new_log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
new_values = None
# Compute advantages using GAE if we have values
if values is not None and dones is not None:
advantages = self._compute_gae(rewards, values, next_states, dones)
returns = advantages + values
else:
# Simple advantage estimation
advantages = rewards
returns = rewards
# Normalize advantages
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute policy loss (PPO clipped objective)
if old_log_probs is not None:
# Compute probability ratio
ratio = torch.exp(new_log_probs - old_log_probs)
# Clipped surrogate loss
clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
surrogate1 = ratio * advantages
surrogate2 = clipped_ratio * advantages
policy_loss = -torch.min(surrogate1, surrogate2).mean()
else:
# Fallback to simple policy gradient if no old log probs
policy_loss = -(new_log_probs * advantages).mean()
# Compute value loss if we have value predictions
value_loss = torch.tensor(0.0, device=states.device)
if new_values is not None:
# Ensure shapes match for value loss computation
# new_values typically has shape [batch, 1] or [batch], returns has shape [batch]
new_values_flat = new_values.squeeze(-1) if new_values.dim() > 1 else new_values
returns_flat = returns.view(-1) if returns.dim() > 1 else returns
value_loss = nn.functional.mse_loss(new_values_flat, returns_flat)
# Compute entropy bonus for exploration
entropy = torch.tensor(0.0, device=states.device)
if isinstance(outputs, tuple) and len(outputs) > 2 and outputs[2] is not None:
entropy = outputs[2]
# Total loss
total_loss = (
policy_loss +
self.value_loss_coef * value_loss -
self.entropy_coef * entropy
)
# Store loss components for logging
self.last_loss_components = {
'policy_loss': policy_loss.item(),
'value_loss': value_loss.item(),
'entropy': entropy.item() if isinstance(entropy, torch.Tensor) else entropy,
'total_loss': total_loss.item()
}
return total_loss
def _compute_gae(
self,
rewards: torch.Tensor,
values: torch.Tensor,
next_states: torch.Tensor,
dones: torch.Tensor
) -> torch.Tensor:
"""
Compute Generalized Advantage Estimation (GAE).
Args:
rewards: Rewards tensor [batch_size] or [timesteps, batch_size]
values: Value estimates [batch_size] or [timesteps, batch_size]
next_states: Next states
dones: Done flags [batch_size] or [timesteps, batch_size]
Returns:
Advantages tensor
"""
# Get next values
with torch.no_grad():
next_outputs = self.model(next_states)
if isinstance(next_outputs, tuple):
next_values = next_outputs[1]
else:
next_values = torch.zeros_like(values)
# Ensure next_values has the same shape as values
if next_values.dim() > values.dim():
next_values = next_values.squeeze()
# Compute TD errors (temporal difference)
deltas = rewards + self.gamma * next_values * (1 - dones) - values
# For batched data (single timestep), GAE simplifies to TD error
# For sequential data, we need to iterate backwards through time
if rewards.dim() == 1:
# Single timestep batch: advantages = TD errors
advantages = deltas
else:
# Multiple timesteps: compute GAE backwards through time
advantages = torch.zeros_like(rewards)
gae = torch.zeros(rewards.shape[1], device=rewards.device) # [batch_size]
for t in reversed(range(rewards.shape[0])):
gae = deltas[t] + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
advantages[t] = gae
return advantages
def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]:
"""
Update policy using computed loss.
Args:
loss: Computed loss tensor
Returns:
Dictionary with update metrics
"""
# Zero gradients
self.optimizer.zero_grad()
# Backward pass
loss.backward()
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
# Update parameters
self.optimizer.step()
metrics = {
'grad_norm': grad_norm.item(),
'learning_rate': self.learning_rate,
}
# Add loss components if available
if hasattr(self, 'last_loss_components'):
metrics.update(self.last_loss_components)
return metrics
def get_hyperparameters(self) -> Dict[str, Any]:
"""Get all hyperparameters."""
base_params = super().get_hyperparameters()
ppo_params = {
'clip_epsilon': self.clip_epsilon,
'gamma': self.gamma,
'gae_lambda': self.gae_lambda,
'value_loss_coef': self.value_loss_coef,
'entropy_coef': self.entropy_coef,
'max_grad_norm': self.max_grad_norm,
}
return {**base_params, **ppo_params}