Spaces:
Runtime error
Runtime error
| """REINFORCE (Monte Carlo Policy Gradient) algorithm implementation.""" | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from typing import Dict, Any, Optional | |
| import logging | |
| from .algorithm_base import RLAlgorithm | |
| logger = logging.getLogger(__name__) | |
| class REINFORCEAlgorithm(RLAlgorithm): | |
| """ | |
| REINFORCE algorithm (Monte Carlo Policy Gradient). | |
| A simple policy gradient method that uses complete episode returns | |
| to update the policy. | |
| """ | |
| def __init__( | |
| self, | |
| model: nn.Module, | |
| learning_rate: float = 1e-3, | |
| gamma: float = 0.99, | |
| use_baseline: bool = True, | |
| max_grad_norm: float = 0.5, | |
| **kwargs | |
| ): | |
| """ | |
| Initialize REINFORCE algorithm. | |
| Args: | |
| model: The policy network | |
| learning_rate: Learning rate for optimizer | |
| gamma: Discount factor | |
| use_baseline: Whether to use baseline subtraction | |
| max_grad_norm: Maximum gradient norm for clipping | |
| **kwargs: Additional hyperparameters | |
| """ | |
| super().__init__(learning_rate, **kwargs) | |
| self.model = model | |
| self.gamma = gamma | |
| self.use_baseline = use_baseline | |
| self.max_grad_norm = max_grad_norm | |
| self.optimizer = optim.Adam(model.parameters(), lr=learning_rate) | |
| # Running baseline (mean return) | |
| self.baseline = 0.0 | |
| self.baseline_momentum = 0.9 | |
| logger.info(f"Initialized REINFORCE with gamma={gamma}, use_baseline={use_baseline}") | |
| def compute_loss( | |
| self, | |
| states: torch.Tensor, | |
| actions: torch.Tensor, | |
| rewards: torch.Tensor, | |
| next_states: torch.Tensor, | |
| **kwargs | |
| ) -> torch.Tensor: | |
| """ | |
| Compute REINFORCE loss. | |
| Args: | |
| states: Current states | |
| actions: Actions taken | |
| rewards: Rewards received | |
| next_states: Next states (not used in REINFORCE) | |
| **kwargs: Additional inputs | |
| Returns: | |
| Policy gradient loss | |
| """ | |
| # Get policy outputs | |
| outputs = self.model(states) | |
| # Extract log probabilities | |
| if isinstance(outputs, tuple): | |
| log_probs = outputs[0] | |
| else: | |
| # If model outputs logits, compute log probs | |
| log_probs = torch.log_softmax(outputs, dim=-1) | |
| # Gather log probs for taken actions | |
| log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1) | |
| # Compute discounted returns | |
| returns = self._compute_returns(rewards) | |
| # Apply baseline subtraction if enabled | |
| if self.use_baseline: | |
| advantages = returns - self.baseline | |
| # Update baseline with exponential moving average | |
| self.baseline = ( | |
| self.baseline_momentum * self.baseline + | |
| (1 - self.baseline_momentum) * returns.mean().item() | |
| ) | |
| else: | |
| advantages = returns | |
| # Normalize advantages for stability | |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| # Compute policy gradient loss | |
| # Negative because we want to maximize expected return | |
| policy_loss = -(log_probs * advantages).mean() | |
| # Store loss components for logging | |
| self.last_loss_components = { | |
| 'policy_loss': policy_loss.item(), | |
| 'mean_return': returns.mean().item(), | |
| 'baseline': self.baseline, | |
| } | |
| return policy_loss | |
| def _compute_returns(self, rewards: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Compute discounted returns for an episode. | |
| Args: | |
| rewards: Rewards tensor | |
| Returns: | |
| Discounted returns tensor | |
| """ | |
| returns = torch.zeros_like(rewards) | |
| running_return = 0 | |
| # Compute returns backwards through the episode | |
| for t in reversed(range(len(rewards))): | |
| running_return = rewards[t] + self.gamma * running_return | |
| returns[t] = running_return | |
| return returns | |
| def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]: | |
| """ | |
| Update policy using computed loss. | |
| Args: | |
| loss: Computed loss tensor | |
| Returns: | |
| Dictionary with update metrics | |
| """ | |
| # Zero gradients | |
| self.optimizer.zero_grad() | |
| # Backward pass | |
| loss.backward() | |
| # Clip gradients | |
| grad_norm = torch.nn.utils.clip_grad_norm_( | |
| self.model.parameters(), | |
| self.max_grad_norm | |
| ) | |
| # Update parameters | |
| self.optimizer.step() | |
| metrics = { | |
| 'grad_norm': grad_norm.item(), | |
| 'learning_rate': self.learning_rate, | |
| } | |
| # Add loss components if available | |
| if hasattr(self, 'last_loss_components'): | |
| metrics.update(self.last_loss_components) | |
| return metrics | |
| def get_hyperparameters(self) -> Dict[str, Any]: | |
| """Get all hyperparameters.""" | |
| base_params = super().get_hyperparameters() | |
| reinforce_params = { | |
| 'gamma': self.gamma, | |
| 'use_baseline': self.use_baseline, | |
| 'max_grad_norm': self.max_grad_norm, | |
| 'baseline': self.baseline, | |
| } | |
| return {**base_params, **reinforce_params} | |