"""REINFORCE (Monte Carlo Policy Gradient) algorithm implementation.""" import torch import torch.nn as nn import torch.optim as optim from typing import Dict, Any, Optional import logging from .algorithm_base import RLAlgorithm logger = logging.getLogger(__name__) class REINFORCEAlgorithm(RLAlgorithm): """ REINFORCE algorithm (Monte Carlo Policy Gradient). A simple policy gradient method that uses complete episode returns to update the policy. """ def __init__( self, model: nn.Module, learning_rate: float = 1e-3, gamma: float = 0.99, use_baseline: bool = True, max_grad_norm: float = 0.5, **kwargs ): """ Initialize REINFORCE algorithm. Args: model: The policy network learning_rate: Learning rate for optimizer gamma: Discount factor use_baseline: Whether to use baseline subtraction max_grad_norm: Maximum gradient norm for clipping **kwargs: Additional hyperparameters """ super().__init__(learning_rate, **kwargs) self.model = model self.gamma = gamma self.use_baseline = use_baseline self.max_grad_norm = max_grad_norm self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) # Running baseline (mean return) self.baseline = 0.0 self.baseline_momentum = 0.9 logger.info(f"Initialized REINFORCE with gamma={gamma}, use_baseline={use_baseline}") def compute_loss( self, states: torch.Tensor, actions: torch.Tensor, rewards: torch.Tensor, next_states: torch.Tensor, **kwargs ) -> torch.Tensor: """ Compute REINFORCE loss. Args: states: Current states actions: Actions taken rewards: Rewards received next_states: Next states (not used in REINFORCE) **kwargs: Additional inputs Returns: Policy gradient loss """ # Get policy outputs outputs = self.model(states) # Extract log probabilities if isinstance(outputs, tuple): log_probs = outputs[0] else: # If model outputs logits, compute log probs log_probs = torch.log_softmax(outputs, dim=-1) # Gather log probs for taken actions log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) # Compute discounted returns returns = self._compute_returns(rewards) # Apply baseline subtraction if enabled if self.use_baseline: advantages = returns - self.baseline # Update baseline with exponential moving average self.baseline = ( self.baseline_momentum * self.baseline + (1 - self.baseline_momentum) * returns.mean().item() ) else: advantages = returns # Normalize advantages for stability advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) # Compute policy gradient loss # Negative because we want to maximize expected return policy_loss = -(log_probs * advantages).mean() # Store loss components for logging self.last_loss_components = { 'policy_loss': policy_loss.item(), 'mean_return': returns.mean().item(), 'baseline': self.baseline, } return policy_loss def _compute_returns(self, rewards: torch.Tensor) -> torch.Tensor: """ Compute discounted returns for an episode. Args: rewards: Rewards tensor Returns: Discounted returns tensor """ returns = torch.zeros_like(rewards) running_return = 0 # Compute returns backwards through the episode for t in reversed(range(len(rewards))): running_return = rewards[t] + self.gamma * running_return returns[t] = running_return return returns def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]: """ Update policy using computed loss. Args: loss: Computed loss tensor Returns: Dictionary with update metrics """ # Zero gradients self.optimizer.zero_grad() # Backward pass loss.backward() # Clip gradients grad_norm = torch.nn.utils.clip_grad_norm_( self.model.parameters(), self.max_grad_norm ) # Update parameters self.optimizer.step() metrics = { 'grad_norm': grad_norm.item(), 'learning_rate': self.learning_rate, } # Add loss components if available if hasattr(self, 'last_loss_components'): metrics.update(self.last_loss_components) return metrics def get_hyperparameters(self) -> Dict[str, Any]: """Get all hyperparameters.""" base_params = super().get_hyperparameters() reinforce_params = { 'gamma': self.gamma, 'use_baseline': self.use_baseline, 'max_grad_norm': self.max_grad_norm, 'baseline': self.baseline, } return {**base_params, **reinforce_params}