File size: 9,732 Bytes
c3efd49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
"""Proximal Policy Optimization (PPO) algorithm implementation."""
import torch
import torch.nn as nn
import torch.optim as optim
from typing import Dict, Any, Optional
import logging

from .algorithm_base import RLAlgorithm

logger = logging.getLogger(__name__)


class PPOAlgorithm(RLAlgorithm):
    """
    Proximal Policy Optimization (PPO) algorithm.
    
    PPO is a policy gradient method that uses a clipped objective
    to prevent large policy updates, improving training stability.
    """
    
    def __init__(
        self,
        model: nn.Module,
        learning_rate: float = 3e-4,
        clip_epsilon: float = 0.2,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        value_loss_coef: float = 0.5,
        entropy_coef: float = 0.01,
        max_grad_norm: float = 0.5,
        **kwargs
    ):
        """
        Initialize PPO algorithm.
        
        Args:
            model: The policy/value network
            learning_rate: Learning rate for optimizer
            clip_epsilon: PPO clipping parameter
            gamma: Discount factor
            gae_lambda: GAE lambda parameter for advantage estimation
            value_loss_coef: Coefficient for value loss
            entropy_coef: Coefficient for entropy bonus
            max_grad_norm: Maximum gradient norm for clipping
            **kwargs: Additional hyperparameters
        """
        super().__init__(learning_rate, **kwargs)
        
        self.model = model
        self.clip_epsilon = clip_epsilon
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        
        self.optimizer = optim.Adam(model.parameters(), lr=learning_rate)
        
        logger.info(f"Initialized PPO with clip_epsilon={clip_epsilon}, gamma={gamma}")
    
    def compute_loss(
        self,
        states: torch.Tensor,
        actions: torch.Tensor,
        rewards: torch.Tensor,
        next_states: torch.Tensor,
        old_log_probs: Optional[torch.Tensor] = None,
        values: Optional[torch.Tensor] = None,
        dones: Optional[torch.Tensor] = None,
        **kwargs
    ) -> torch.Tensor:
        """
        Compute PPO loss.

        Args:
            states: Current states
            actions: Actions taken
            rewards: Rewards received
            next_states: Next states
            old_log_probs: Log probabilities from old policy
            values: Value estimates from old policy
            dones: Done flags
            **kwargs: Additional inputs

        Returns:
            Total PPO loss
        """
        # Get current policy outputs (log_probs, values, entropy from RL model)
        outputs = self.model(states)

        # Extract log probs and values from model output
        if isinstance(outputs, tuple) and len(outputs) >= 2:
            # RL-compatible model returns (log_probs, values, ...)
            action_logits, new_values, _ = outputs if len(outputs) == 3 else (*outputs, None)

            # Compute log probs for taken actions
            if action_logits.shape[-1] > 1:  # Discrete actions
                log_probs_dist = torch.log_softmax(action_logits, dim=-1)
                # Handle actions shape
                if actions.dim() == 1:
                    new_log_probs = log_probs_dist.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
                else:
                    # For continuous actions, compute Gaussian log prob
                    new_log_probs = -0.5 * ((actions - action_logits) ** 2).sum(dim=-1)
            else:
                new_log_probs = action_logits.squeeze(-1)
        else:
            # Fallback for non-RL models
            new_log_probs = torch.log_softmax(outputs, dim=-1)
            if actions.dim() > 0 and new_log_probs.dim() > 1:
                new_log_probs = new_log_probs.gather(-1, actions.unsqueeze(-1)).squeeze(-1)
            new_values = None
        
        # Compute advantages using GAE if we have values
        if values is not None and dones is not None:
            advantages = self._compute_gae(rewards, values, next_states, dones)
            returns = advantages + values
        else:
            # Simple advantage estimation
            advantages = rewards
            returns = rewards
        
        # Normalize advantages
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # Compute policy loss (PPO clipped objective)
        if old_log_probs is not None:
            # Compute probability ratio
            ratio = torch.exp(new_log_probs - old_log_probs)

            # Clipped surrogate loss
            clipped_ratio = torch.clamp(ratio, 1 - self.clip_epsilon, 1 + self.clip_epsilon)
            surrogate1 = ratio * advantages
            surrogate2 = clipped_ratio * advantages
            policy_loss = -torch.min(surrogate1, surrogate2).mean()
        else:
            # Fallback to simple policy gradient if no old log probs
            policy_loss = -(new_log_probs * advantages).mean()
        
        # Compute value loss if we have value predictions
        value_loss = torch.tensor(0.0, device=states.device)
        if new_values is not None:
            # Ensure shapes match for value loss computation
            # new_values typically has shape [batch, 1] or [batch], returns has shape [batch]
            new_values_flat = new_values.squeeze(-1) if new_values.dim() > 1 else new_values
            returns_flat = returns.view(-1) if returns.dim() > 1 else returns
            value_loss = nn.functional.mse_loss(new_values_flat, returns_flat)

        # Compute entropy bonus for exploration
        entropy = torch.tensor(0.0, device=states.device)
        if isinstance(outputs, tuple) and len(outputs) > 2 and outputs[2] is not None:
            entropy = outputs[2]

        # Total loss
        total_loss = (
            policy_loss +
            self.value_loss_coef * value_loss -
            self.entropy_coef * entropy
        )
        
        # Store loss components for logging
        self.last_loss_components = {
            'policy_loss': policy_loss.item(),
            'value_loss': value_loss.item(),
            'entropy': entropy.item() if isinstance(entropy, torch.Tensor) else entropy,
            'total_loss': total_loss.item()
        }
        
        return total_loss
    
    def _compute_gae(
        self,
        rewards: torch.Tensor,
        values: torch.Tensor,
        next_states: torch.Tensor,
        dones: torch.Tensor
    ) -> torch.Tensor:
        """
        Compute Generalized Advantage Estimation (GAE).

        Args:
            rewards: Rewards tensor [batch_size] or [timesteps, batch_size]
            values: Value estimates [batch_size] or [timesteps, batch_size]
            next_states: Next states
            dones: Done flags [batch_size] or [timesteps, batch_size]

        Returns:
            Advantages tensor
        """
        # Get next values
        with torch.no_grad():
            next_outputs = self.model(next_states)
            if isinstance(next_outputs, tuple):
                next_values = next_outputs[1]
            else:
                next_values = torch.zeros_like(values)

        # Ensure next_values has the same shape as values
        if next_values.dim() > values.dim():
            next_values = next_values.squeeze()

        # Compute TD errors (temporal difference)
        deltas = rewards + self.gamma * next_values * (1 - dones) - values

        # For batched data (single timestep), GAE simplifies to TD error
        # For sequential data, we need to iterate backwards through time
        if rewards.dim() == 1:
            # Single timestep batch: advantages = TD errors
            advantages = deltas
        else:
            # Multiple timesteps: compute GAE backwards through time
            advantages = torch.zeros_like(rewards)
            gae = torch.zeros(rewards.shape[1], device=rewards.device)  # [batch_size]

            for t in reversed(range(rewards.shape[0])):
                gae = deltas[t] + self.gamma * self.gae_lambda * (1 - dones[t]) * gae
                advantages[t] = gae

        return advantages
    
    def update_policy(self, loss: torch.Tensor) -> Dict[str, Any]:
        """
        Update policy using computed loss.
        
        Args:
            loss: Computed loss tensor
        
        Returns:
            Dictionary with update metrics
        """
        # Zero gradients
        self.optimizer.zero_grad()
        
        # Backward pass
        loss.backward()
        
        # Clip gradients
        grad_norm = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(),
            self.max_grad_norm
        )
        
        # Update parameters
        self.optimizer.step()
        
        metrics = {
            'grad_norm': grad_norm.item(),
            'learning_rate': self.learning_rate,
        }
        
        # Add loss components if available
        if hasattr(self, 'last_loss_components'):
            metrics.update(self.last_loss_components)
        
        return metrics
    
    def get_hyperparameters(self) -> Dict[str, Any]:
        """Get all hyperparameters."""
        base_params = super().get_hyperparameters()
        ppo_params = {
            'clip_epsilon': self.clip_epsilon,
            'gamma': self.gamma,
            'gae_lambda': self.gae_lambda,
            'value_loss_coef': self.value_loss_coef,
            'entropy_coef': self.entropy_coef,
            'max_grad_norm': self.max_grad_norm,
        }
        return {**base_params, **ppo_params}