mbellan's picture
Initial deployment
c3efd49
"""REINFORCE (Monte Carlo Policy Gradient) algorithm implementation."""
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Any, Optional
import logging
from .algorithm_base import RLAlgorithm
logger = logging.getLogger(__name__)
class REINFORCEAlgorithm(RLAlgorithm):
"""
REINFORCE algorithm (Monte Carlo Policy Gradient).
A simple policy gradient method that uses complete episode returns
to update the policy.
"""
def __init__(
self,
model: nn.Module,
learning_rate: float = 1e-3,
gamma: float = 0.99,
use_baseline: bool = True,
max_grad_norm: float = 0.5,
**kwargs
):
"""
Initialize REINFORCE algorithm.
Args:
model: The policy network
learning_rate: Learning rate for optimizer
gamma: Discount factor
use_baseline: Whether to use baseline subtraction
max_grad_norm: Maximum gradient norm for clipping
**kwargs: Additional hyperparameters
"""
super().__init__(learning_rate, **kwargs)
self.model = model
self.gamma = gamma
self.use_baseline = use_baseline
self.max_grad_norm = max_grad_norm
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Running baseline (mean return)
self.baseline = 0.0
self.baseline_momentum = 0.9
logger.info(f"Initialized REINFORCE with gamma={gamma}, use_baseline={use_baseline}")
def compute_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""
Compute REINFORCE loss.
Args:
states: Current states
actions: Actions taken
rewards: Rewards received
next_states: Next states (not used in REINFORCE)
**kwargs: Additional inputs
Returns:
Policy gradient loss
"""
# Get policy outputs
outputs = self.model(states)
# Extract log probabilities
if isinstance(outputs, tuple):
log_probs = outputs[0]
else:
# If model outputs logits, compute log probs
log_probs = torch.log_softmax(outputs, dim=-1)
# Gather log probs for taken actions
log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
# Compute discounted returns
returns = self._compute_returns(rewards)
# Apply baseline subtraction if enabled
if self.use_baseline:
advantages = returns - self.baseline
# Update baseline with exponential moving average
self.baseline = (
self.baseline_momentum * self.baseline +
(1 - self.baseline_momentum) * returns.mean().item()
)
else:
advantages = returns
# Normalize advantages for stability
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute policy gradient loss
# Negative because we want to maximize expected return
policy_loss = -(log_probs * advantages).mean()
# Store loss components for logging
self.last_loss_components = {
'policy_loss': policy_loss.item(),
'mean_return': returns.mean().item(),
'baseline': self.baseline,
}
return policy_loss
def _compute_returns(self, rewards: torch.Tensor) -> torch.Tensor:
"""
Compute discounted returns for an episode.
Args:
rewards: Rewards tensor
Returns:
Discounted returns tensor
"""
returns = torch.zeros_like(rewards)
running_return = 0
# Compute returns backwards through the episode
for t in reversed(range(len(rewards))):
running_return = rewards[t] + self.gamma * running_return
returns[t] = running_return
return returns
def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]:
"""
Update policy using computed loss.
Args:
loss: Computed loss tensor
Returns:
Dictionary with update metrics
"""
# Zero gradients
self.optimizer.zero_grad()
# Backward pass
loss.backward()
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
# Update parameters
self.optimizer.step()
metrics = {
'grad_norm': grad_norm.item(),
'learning_rate': self.learning_rate,
}
# Add loss components if available
if hasattr(self, 'last_loss_components'):
metrics.update(self.last_loss_components)
return metrics
def get_hyperparameters(self) -> Dict[str, Any]:
"""Get all hyperparameters."""
base_params = super().get_hyperparameters()
reinforce_params = {
'gamma': self.gamma,
'use_baseline': self.use_baseline,
'max_grad_norm': self.max_grad_norm,
'baseline': self.baseline,
}
return {**base_params, **reinforce_params}