Spaces:
Runtime error
Runtime error
File size: 5,620 Bytes
c3efd49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
"""REINFORCE (Monte Carlo Policy Gradient) algorithm implementation."""
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Any, Optional
import logging
from .algorithm_base import RLAlgorithm
logger = logging.getLogger(__name__)
class REINFORCEAlgorithm(RLAlgorithm):
"""
REINFORCE algorithm (Monte Carlo Policy Gradient).
A simple policy gradient method that uses complete episode returns
to update the policy.
"""
def __init__(
self,
model: nn.Module,
learning_rate: float = 1e-3,
gamma: float = 0.99,
use_baseline: bool = True,
max_grad_norm: float = 0.5,
**kwargs
):
"""
Initialize REINFORCE algorithm.
Args:
model: The policy network
learning_rate: Learning rate for optimizer
gamma: Discount factor
use_baseline: Whether to use baseline subtraction
max_grad_norm: Maximum gradient norm for clipping
**kwargs: Additional hyperparameters
"""
super().__init__(learning_rate, **kwargs)
self.model = model
self.gamma = gamma
self.use_baseline = use_baseline
self.max_grad_norm = max_grad_norm
self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Running baseline (mean return)
self.baseline = 0.0
self.baseline_momentum = 0.9
logger.info(f"Initialized REINFORCE with gamma={gamma}, use_baseline={use_baseline}")
def compute_loss(
self,
states: torch.Tensor,
actions: torch.Tensor,
rewards: torch.Tensor,
next_states: torch.Tensor,
**kwargs
) -> torch.Tensor:
"""
Compute REINFORCE loss.
Args:
states: Current states
actions: Actions taken
rewards: Rewards received
next_states: Next states (not used in REINFORCE)
**kwargs: Additional inputs
Returns:
Policy gradient loss
"""
# Get policy outputs
outputs = self.model(states)
# Extract log probabilities
if isinstance(outputs, tuple):
log_probs = outputs[0]
else:
# If model outputs logits, compute log probs
log_probs = torch.log_softmax(outputs, dim=-1)
# Gather log probs for taken actions
log_probs = log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
# Compute discounted returns
returns = self._compute_returns(rewards)
# Apply baseline subtraction if enabled
if self.use_baseline:
advantages = returns - self.baseline
# Update baseline with exponential moving average
self.baseline = (
self.baseline_momentum * self.baseline +
(1 - self.baseline_momentum) * returns.mean().item()
)
else:
advantages = returns
# Normalize advantages for stability
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# Compute policy gradient loss
# Negative because we want to maximize expected return
policy_loss = -(log_probs * advantages).mean()
# Store loss components for logging
self.last_loss_components = {
'policy_loss': policy_loss.item(),
'mean_return': returns.mean().item(),
'baseline': self.baseline,
}
return policy_loss
def _compute_returns(self, rewards: torch.Tensor) -> torch.Tensor:
"""
Compute discounted returns for an episode.
Args:
rewards: Rewards tensor
Returns:
Discounted returns tensor
"""
returns = torch.zeros_like(rewards)
running_return = 0
# Compute returns backwards through the episode
for t in reversed(range(len(rewards))):
running_return = rewards[t] + self.gamma * running_return
returns[t] = running_return
return returns
def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]:
"""
Update policy using computed loss.
Args:
loss: Computed loss tensor
Returns:
Dictionary with update metrics
"""
# Zero gradients
self.optimizer.zero_grad()
# Backward pass
loss.backward()
# Clip gradients
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
self.max_grad_norm
)
# Update parameters
self.optimizer.step()
metrics = {
'grad_norm': grad_norm.item(),
'learning_rate': self.learning_rate,
}
# Add loss components if available
if hasattr(self, 'last_loss_components'):
metrics.update(self.last_loss_components)
return metrics
def get_hyperparameters(self) -> Dict[str, Any]:
"""Get all hyperparameters."""
base_params = super().get_hyperparameters()
reinforce_params = {
'gamma': self.gamma,
'use_baseline': self.use_baseline,
'max_grad_norm': self.max_grad_norm,
'baseline': self.baseline,
}
return {**base_params, **reinforce_params}
|