| | """
|
| | RLHF 2.0: Reinforcement Learning from Everything
|
| | Implements multi-objective optimization with various feedback sources
|
| | """
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.distributions import Categorical
|
| | from typing import Dict, List, Tuple, Optional, Any, Union
|
| | from dataclasses import dataclass, field
|
| | from enum import Enum
|
| | import numpy as np
|
| | from collections import deque
|
| | import logging
|
| | from tqdm import tqdm
|
| |
|
| | logger = logging.getLogger(__name__)
|
| |
|
| |
|
| | class FeedbackSource(Enum):
|
| | """Sources of feedback for RLHF"""
|
| | HUMAN = "human"
|
| | AI_FEEDBACK = "ai_feedback"
|
| | TOOL_EXECUTION = "tool_execution"
|
| | CONSTITUTIONAL = "constitutional"
|
| | SELF_CONSISTENCY = "self_consistency"
|
| | PROCESS_SUPERVISION = "process_supervision"
|
| |
|
| |
|
| | class RewardObjective(Enum):
|
| | """Different reward objectives to optimize"""
|
| | HELPFULNESS = "helpfulness"
|
| | HARMLESSNESS = "harmlessness"
|
| | HONESTY = "honesty"
|
| | ACCURACY = "accuracy"
|
| | COHERENCE = "coherence"
|
| | CREATIVITY = "creativity"
|
| | EFFICIENCY = "efficiency"
|
| |
|
| |
|
| | @dataclass
|
| | class RLHFConfig:
|
| | """Configuration for RLHF 2.0"""
|
| |
|
| | ppo_epochs: int = 4
|
| | ppo_batch_size: int = 32
|
| | ppo_mini_batch_size: int = 8
|
| | gradient_accumulation_steps: int = 4
|
| |
|
| |
|
| | clip_range: float = 0.2
|
| | value_clip_range: float = 0.2
|
| | max_grad_norm: float = 1.0
|
| | target_kl: float = 0.01
|
| | gae_lambda: float = 0.95
|
| | discount_factor: float = 0.99
|
| |
|
| |
|
| | reward_weights: Dict[RewardObjective, float] = field(default_factory=lambda: {
|
| | RewardObjective.HELPFULNESS: 1.0,
|
| | RewardObjective.HARMLESSNESS: 1.5,
|
| | RewardObjective.HONESTY: 1.2,
|
| | RewardObjective.ACCURACY: 1.3,
|
| | RewardObjective.COHERENCE: 0.8,
|
| | RewardObjective.CREATIVITY: 0.5,
|
| | RewardObjective.EFFICIENCY: 0.3
|
| | })
|
| |
|
| |
|
| | use_dpo: bool = True
|
| | dpo_beta: float = 0.1
|
| |
|
| |
|
| | enable_process_supervision: bool = True
|
| | step_reward_weight: float = 0.5
|
| |
|
| |
|
| | use_pareto_optimization: bool = True
|
| | num_preference_samples: int = 100
|
| |
|
| |
|
| | buffer_size: int = 10000
|
| | min_buffer_size: int = 1000
|
| |
|
| |
|
| | @dataclass
|
| | class Experience:
|
| | """Single experience in RLHF training"""
|
| | states: torch.Tensor
|
| | actions: torch.Tensor
|
| | rewards: Dict[RewardObjective, float]
|
| | next_states: torch.Tensor
|
| | dones: torch.Tensor
|
| | log_probs: torch.Tensor
|
| | values: torch.Tensor
|
| | advantages: Optional[torch.Tensor] = None
|
| | returns: Optional[torch.Tensor] = None
|
| | feedback_source: FeedbackSource = FeedbackSource.HUMAN
|
| | metadata: Dict[str, Any] = field(default_factory=dict)
|
| |
|
| |
|
| | class MultiObjectiveRewardModel(nn.Module):
|
| | """Multi-objective reward model for different aspects"""
|
| |
|
| | def __init__(self, hidden_dim: int = 768, num_objectives: int = 7):
|
| | super().__init__()
|
| |
|
| | self.objectives = list(RewardObjective)
|
| |
|
| |
|
| | self.shared_encoder = nn.Sequential(
|
| | nn.Linear(hidden_dim, hidden_dim),
|
| | nn.ReLU(),
|
| | nn.Dropout(0.1),
|
| | nn.Linear(hidden_dim, hidden_dim // 2),
|
| | nn.ReLU()
|
| | )
|
| |
|
| |
|
| | self.reward_heads = nn.ModuleDict({
|
| | obj.value: nn.Sequential(
|
| | nn.Linear(hidden_dim // 2, 256),
|
| | nn.ReLU(),
|
| | nn.Dropout(0.1),
|
| | nn.Linear(256, 64),
|
| | nn.ReLU(),
|
| | nn.Linear(64, 1)
|
| | )
|
| | for obj in RewardObjective
|
| | })
|
| |
|
| |
|
| | self.preference_net = nn.Sequential(
|
| | nn.Linear(num_objectives * 2, 128),
|
| | nn.ReLU(),
|
| | nn.Linear(128, num_objectives),
|
| | nn.Softmax(dim=-1)
|
| | )
|
| |
|
| | def forward(
|
| | self,
|
| | states: torch.Tensor,
|
| | return_all_objectives: bool = True
|
| | ) -> Union[torch.Tensor, Dict[RewardObjective, torch.Tensor]]:
|
| | """Compute rewards for all objectives"""
|
| |
|
| |
|
| | if len(states.shape) == 3:
|
| | states = states.mean(dim=1)
|
| |
|
| |
|
| | encoded = self.shared_encoder(states)
|
| |
|
| |
|
| | rewards = {}
|
| | for objective in RewardObjective:
|
| | reward = self.reward_heads[objective.value](encoded)
|
| | rewards[objective] = reward.squeeze(-1)
|
| |
|
| | if return_all_objectives:
|
| | return rewards
|
| | else:
|
| |
|
| | return self.combine_rewards(rewards)
|
| |
|
| | def combine_rewards(
|
| | self,
|
| | rewards: Dict[RewardObjective, torch.Tensor],
|
| | weights: Optional[Dict[RewardObjective, float]] = None
|
| | ) -> torch.Tensor:
|
| | """Combine multi-objective rewards"""
|
| | if weights is None:
|
| |
|
| | weights = {obj: 1.0 / len(RewardObjective) for obj in RewardObjective}
|
| |
|
| | combined = torch.zeros_like(next(iter(rewards.values())))
|
| | for obj, reward in rewards.items():
|
| | combined += weights.get(obj, 0.0) * reward
|
| |
|
| | return combined
|
| |
|
| | def predict_preferences(
|
| | self,
|
| | rewards1: Dict[RewardObjective, torch.Tensor],
|
| | rewards2: Dict[RewardObjective, torch.Tensor]
|
| | ) -> torch.Tensor:
|
| | """Predict human preference between two sets of rewards"""
|
| |
|
| | r1 = torch.stack([rewards1[obj] for obj in RewardObjective], dim=-1)
|
| | r2 = torch.stack([rewards2[obj] for obj in RewardObjective], dim=-1)
|
| |
|
| |
|
| | combined = torch.cat([r1, r2], dim=-1)
|
| |
|
| |
|
| | preferences = self.preference_net(combined)
|
| |
|
| | return preferences
|
| |
|
| |
|
| | class ProcessSupervisor(nn.Module):
|
| | """Process supervision for step-by-step reasoning evaluation"""
|
| |
|
| | def __init__(self, hidden_dim: int = 768):
|
| | super().__init__()
|
| |
|
| |
|
| | self.step_evaluator = nn.Sequential(
|
| | nn.Linear(hidden_dim * 2, hidden_dim),
|
| | nn.ReLU(),
|
| | nn.Dropout(0.1),
|
| | nn.Linear(hidden_dim, 256),
|
| | nn.ReLU(),
|
| | nn.Linear(256, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| |
|
| | self.coherence_checker = nn.Sequential(
|
| | nn.Linear(hidden_dim * 3, hidden_dim),
|
| | nn.ReLU(),
|
| | nn.Linear(hidden_dim, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| |
|
| | self.error_detector = nn.Sequential(
|
| | nn.Linear(hidden_dim, 256),
|
| | nn.ReLU(),
|
| | nn.Linear(256, 1),
|
| | nn.Sigmoid()
|
| | )
|
| |
|
| | def evaluate_step(
|
| | self,
|
| | current_step: torch.Tensor,
|
| | previous_step: Optional[torch.Tensor] = None
|
| | ) -> torch.Tensor:
|
| | """Evaluate quality of reasoning step"""
|
| | if previous_step is None:
|
| |
|
| | previous_step = torch.zeros_like(current_step)
|
| |
|
| |
|
| | combined = torch.cat([current_step, previous_step], dim=-1)
|
| |
|
| |
|
| | step_quality = self.step_evaluator(combined)
|
| |
|
| | return step_quality
|
| |
|
| | def check_coherence(
|
| | self,
|
| | steps: List[torch.Tensor]
|
| | ) -> torch.Tensor:
|
| | """Check coherence across multiple steps"""
|
| | if len(steps) < 2:
|
| | return torch.ones(1)
|
| |
|
| |
|
| | coherence_scores = []
|
| | for i in range(len(steps) - 1):
|
| | if i == 0:
|
| |
|
| | context = torch.zeros_like(steps[0])
|
| | else:
|
| | context = steps[i - 1]
|
| |
|
| | combined = torch.cat([context, steps[i], steps[i + 1]], dim=-1)
|
| | score = self.coherence_checker(combined)
|
| | coherence_scores.append(score)
|
| |
|
| |
|
| | return torch.stack(coherence_scores).mean()
|
| |
|
| | def detect_errors(self, step: torch.Tensor) -> torch.Tensor:
|
| | """Detect errors in reasoning step"""
|
| | error_probability = self.error_detector(step)
|
| | return error_probability
|
| |
|
| |
|
| | class DirectPreferenceOptimization(nn.Module):
|
| | """DPO for direct preference learning without explicit reward model"""
|
| |
|
| | def __init__(self, beta: float = 0.1):
|
| | super().__init__()
|
| | self.beta = beta
|
| |
|
| | def compute_dpo_loss(
|
| | self,
|
| | policy_chosen_logps: torch.Tensor,
|
| | policy_rejected_logps: torch.Tensor,
|
| | reference_chosen_logps: torch.Tensor,
|
| | reference_rejected_logps: torch.Tensor
|
| | ) -> torch.Tensor:
|
| | """Compute DPO loss"""
|
| |
|
| | chosen_ratio = policy_chosen_logps - reference_chosen_logps
|
| | rejected_ratio = policy_rejected_logps - reference_rejected_logps
|
| |
|
| |
|
| | loss = -F.logsigmoid(self.beta * (chosen_ratio - rejected_ratio)).mean()
|
| |
|
| | return loss
|
| |
|
| |
|
| | class ExperienceBuffer:
|
| | """Experience replay buffer for RLHF"""
|
| |
|
| | def __init__(self, capacity: int = 10000):
|
| | self.capacity = capacity
|
| | self.buffer = deque(maxlen=capacity)
|
| | self.priorities = deque(maxlen=capacity)
|
| |
|
| | def push(self, experience: Experience, priority: float = 1.0):
|
| | """Add experience to buffer"""
|
| | self.buffer.append(experience)
|
| | self.priorities.append(priority)
|
| |
|
| | def sample(self, batch_size: int, prioritized: bool = True) -> List[Experience]:
|
| | """Sample batch from buffer"""
|
| | if len(self.buffer) < batch_size:
|
| | return list(self.buffer)
|
| |
|
| | if prioritized and self.priorities:
|
| |
|
| | priorities = np.array(self.priorities)
|
| | probabilities = priorities / priorities.sum()
|
| | indices = np.random.choice(len(self.buffer), batch_size, p=probabilities)
|
| | else:
|
| |
|
| | indices = np.random.choice(len(self.buffer), batch_size)
|
| |
|
| | return [self.buffer[i] for i in indices]
|
| |
|
| | def __len__(self):
|
| | return len(self.buffer)
|
| |
|
| |
|
| | class PPOTrainer:
|
| | """PPO trainer for RLHF optimization"""
|
| |
|
| | def __init__(
|
| | self,
|
| | policy_model: nn.Module,
|
| | value_model: nn.Module,
|
| | reward_model: MultiObjectiveRewardModel,
|
| | config: RLHFConfig
|
| | ):
|
| | self.policy = policy_model
|
| | self.value = value_model
|
| | self.reward_model = reward_model
|
| | self.config = config
|
| |
|
| |
|
| | self.policy_optimizer = torch.optim.Adam(
|
| | policy_model.parameters(), lr=3e-5, eps=1e-8
|
| | )
|
| | self.value_optimizer = torch.optim.Adam(
|
| | value_model.parameters(), lr=1e-4, eps=1e-8
|
| | )
|
| |
|
| |
|
| | self.dpo = DirectPreferenceOptimization(config.dpo_beta) if config.use_dpo else None
|
| |
|
| |
|
| | self.process_supervisor = ProcessSupervisor() if config.enable_process_supervision else None
|
| |
|
| |
|
| | self.buffer = ExperienceBuffer(config.buffer_size)
|
| |
|
| |
|
| | self.stats = {
|
| | 'policy_loss': [],
|
| | 'value_loss': [],
|
| | 'rewards': [],
|
| | 'kl_divergence': []
|
| | }
|
| |
|
| | def compute_advantages(
|
| | self,
|
| | rewards: torch.Tensor,
|
| | values: torch.Tensor,
|
| | next_values: torch.Tensor,
|
| | dones: torch.Tensor
|
| | ) -> Tuple[torch.Tensor, torch.Tensor]:
|
| | """Compute GAE advantages and returns"""
|
| | advantages = torch.zeros_like(rewards)
|
| | last_advantage = 0
|
| |
|
| | for t in reversed(range(rewards.shape[0])):
|
| | if t == rewards.shape[0] - 1:
|
| | next_value = next_values[t]
|
| | else:
|
| | next_value = values[t + 1]
|
| |
|
| | delta = rewards[t] + self.config.discount_factor * next_value * (1 - dones[t]) - values[t]
|
| | advantages[t] = delta + self.config.discount_factor * self.config.gae_lambda * (1 - dones[t]) * last_advantage
|
| | last_advantage = advantages[t]
|
| |
|
| | returns = advantages + values
|
| |
|
| |
|
| | advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| |
|
| | return advantages, returns
|
| |
|
| | def train_step(self, experiences: List[Experience]) -> Dict[str, float]:
|
| | """Single PPO training step"""
|
| |
|
| | states = torch.stack([e.states for e in experiences])
|
| | actions = torch.stack([e.actions for e in experiences])
|
| | old_log_probs = torch.stack([e.log_probs for e in experiences])
|
| | advantages = torch.stack([e.advantages for e in experiences])
|
| | returns = torch.stack([e.returns for e in experiences])
|
| |
|
| |
|
| | total_policy_loss = 0
|
| | total_value_loss = 0
|
| | total_kl = 0
|
| |
|
| | for epoch in range(self.config.ppo_epochs):
|
| |
|
| | policy_output = self.policy(states)
|
| | new_log_probs = self.compute_log_probs(policy_output, actions)
|
| | new_values = self.value(states).squeeze(-1)
|
| |
|
| |
|
| | ratios = torch.exp(new_log_probs - old_log_probs)
|
| |
|
| |
|
| | surr1 = ratios * advantages
|
| | surr2 = torch.clamp(ratios, 1 - self.config.clip_range, 1 + self.config.clip_range) * advantages
|
| | policy_loss = -torch.min(surr1, surr2).mean()
|
| |
|
| |
|
| | value_pred_clipped = old_values + torch.clamp(
|
| | new_values - old_values,
|
| | -self.config.value_clip_range,
|
| | self.config.value_clip_range
|
| | )
|
| | value_loss1 = F.mse_loss(new_values, returns)
|
| | value_loss2 = F.mse_loss(value_pred_clipped, returns)
|
| | value_loss = torch.max(value_loss1, value_loss2)
|
| |
|
| |
|
| | kl = (old_log_probs - new_log_probs).mean()
|
| |
|
| | if kl > self.config.target_kl:
|
| | logger.info(f"Early stopping at epoch {epoch} due to KL divergence")
|
| | break
|
| |
|
| |
|
| | self.policy_optimizer.zero_grad()
|
| | policy_loss.backward()
|
| | torch.nn.utils.clip_grad_norm_(self.policy.parameters(), self.config.max_grad_norm)
|
| | self.policy_optimizer.step()
|
| |
|
| | self.value_optimizer.zero_grad()
|
| | value_loss.backward()
|
| | torch.nn.utils.clip_grad_norm_(self.value.parameters(), self.config.max_grad_norm)
|
| | self.value_optimizer.step()
|
| |
|
| | total_policy_loss += policy_loss.item()
|
| | total_value_loss += value_loss.item()
|
| | total_kl += kl.item()
|
| |
|
| | return {
|
| | 'policy_loss': total_policy_loss / (epoch + 1),
|
| | 'value_loss': total_value_loss / (epoch + 1),
|
| | 'kl_divergence': total_kl / (epoch + 1)
|
| | }
|
| |
|
| | def compute_log_probs(self, logits: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
|
| | """Compute log probabilities of actions"""
|
| | probs = F.softmax(logits, dim=-1)
|
| | dist = Categorical(probs)
|
| | return dist.log_prob(actions)
|
| |
|
| | def collect_experience(
|
| | self,
|
| | env,
|
| | num_steps: int = 1000
|
| | ) -> List[Experience]:
|
| | """Collect experience from environment"""
|
| | experiences = []
|
| |
|
| | state = env.reset()
|
| |
|
| | for step in range(num_steps):
|
| |
|
| | with torch.no_grad():
|
| | policy_output = self.policy(state.unsqueeze(0))
|
| | probs = F.softmax(policy_output, dim=-1)
|
| | dist = Categorical(probs)
|
| | action = dist.sample()
|
| | log_prob = dist.log_prob(action)
|
| |
|
| |
|
| | value = self.value(state.unsqueeze(0)).squeeze()
|
| |
|
| |
|
| | next_state, reward_dict, done, info = env.step(action.item())
|
| |
|
| |
|
| | exp = Experience(
|
| | states=state,
|
| | actions=action,
|
| | rewards=reward_dict,
|
| | next_states=next_state,
|
| | dones=torch.tensor(done, dtype=torch.float),
|
| | log_probs=log_prob,
|
| | values=value,
|
| | feedback_source=info.get('feedback_source', FeedbackSource.HUMAN),
|
| | metadata=info
|
| | )
|
| |
|
| | experiences.append(exp)
|
| |
|
| | if done:
|
| | state = env.reset()
|
| | else:
|
| | state = next_state
|
| |
|
| |
|
| | self._compute_experience_advantages(experiences)
|
| |
|
| | return experiences
|
| |
|
| | def _compute_experience_advantages(self, experiences: List[Experience]):
|
| | """Compute advantages for collected experiences"""
|
| |
|
| | rewards = []
|
| | for exp in experiences:
|
| | combined_reward = self.reward_model.combine_rewards(
|
| | exp.rewards,
|
| | self.config.reward_weights
|
| | )
|
| | rewards.append(combined_reward)
|
| |
|
| | rewards = torch.stack(rewards)
|
| | values = torch.stack([e.values for e in experiences])
|
| |
|
| |
|
| | next_values = values.clone()
|
| | next_values[:-1] = values[1:]
|
| | next_values[-1] = 0
|
| |
|
| | dones = torch.stack([e.dones for e in experiences])
|
| |
|
| |
|
| | advantages, returns = self.compute_advantages(rewards, values, next_values, dones)
|
| |
|
| |
|
| | for i, exp in enumerate(experiences):
|
| | exp.advantages = advantages[i]
|
| | exp.returns = returns[i]
|
| |
|
| |
|
| | class RLHF2System:
|
| | """Complete RLHF 2.0 system with all components"""
|
| |
|
| | def __init__(
|
| | self,
|
| | base_model: nn.Module,
|
| | config: RLHFConfig
|
| | ):
|
| | self.base_model = base_model
|
| | self.config = config
|
| |
|
| |
|
| | self.reward_model = MultiObjectiveRewardModel()
|
| |
|
| |
|
| | self.value_model = nn.Sequential(
|
| | nn.Linear(base_model.config.n_embd, 512),
|
| | nn.ReLU(),
|
| | nn.Linear(512, 1)
|
| | )
|
| |
|
| |
|
| | self.ppo_trainer = PPOTrainer(
|
| | base_model,
|
| | self.value_model,
|
| | self.reward_model,
|
| | config
|
| | )
|
| |
|
| |
|
| | if config.enable_process_supervision:
|
| | self.process_supervisor = ProcessSupervisor()
|
| |
|
| |
|
| | self.ai_feedback_models = {}
|
| |
|
| | def add_ai_feedback_model(self, name: str, model: nn.Module):
|
| | """Add AI model for feedback generation"""
|
| | self.ai_feedback_models[name] = model
|
| |
|
| | def train_reward_model(
|
| | self,
|
| | preference_data,
|
| | num_epochs: int = 3
|
| | ):
|
| | """Train reward model on preference data"""
|
| | optimizer = torch.optim.Adam(self.reward_model.parameters(), lr=1e-4)
|
| |
|
| | for epoch in range(num_epochs):
|
| | total_loss = 0
|
| | num_batches = 0
|
| |
|
| | for batch in preference_data:
|
| | chosen = batch['chosen']
|
| | rejected = batch['rejected']
|
| |
|
| |
|
| | chosen_rewards = self.reward_model(chosen)
|
| | rejected_rewards = self.reward_model(rejected)
|
| |
|
| |
|
| | combined_chosen = self.reward_model.combine_rewards(
|
| | chosen_rewards, self.config.reward_weights
|
| | )
|
| | combined_rejected = self.reward_model.combine_rewards(
|
| | rejected_rewards, self.config.reward_weights
|
| | )
|
| |
|
| |
|
| | loss = -F.logsigmoid(combined_chosen - combined_rejected).mean()
|
| |
|
| |
|
| | optimizer.zero_grad()
|
| | loss.backward()
|
| | optimizer.step()
|
| |
|
| | total_loss += loss.item()
|
| | num_batches += 1
|
| |
|
| | avg_loss = total_loss / num_batches
|
| | logger.info(f"Reward model epoch {epoch}: loss = {avg_loss:.4f}")
|
| |
|
| | def train_with_dpo(
|
| | self,
|
| | preference_data,
|
| | reference_model: nn.Module,
|
| | num_epochs: int = 3
|
| | ):
|
| | """Train using Direct Preference Optimization"""
|
| | if not self.config.use_dpo:
|
| | logger.warning("DPO not enabled in config")
|
| | return
|
| |
|
| | optimizer = torch.optim.Adam(self.base_model.parameters(), lr=3e-5)
|
| |
|
| | for epoch in range(num_epochs):
|
| | total_loss = 0
|
| | num_batches = 0
|
| |
|
| | for batch in preference_data:
|
| | chosen = batch['chosen']
|
| | rejected = batch['rejected']
|
| |
|
| |
|
| | with torch.no_grad():
|
| | ref_chosen_logps = reference_model(chosen)['logits'].log_softmax(-1)
|
| | ref_rejected_logps = reference_model(rejected)['logits'].log_softmax(-1)
|
| |
|
| | policy_chosen_logps = self.base_model(chosen)['logits'].log_softmax(-1)
|
| | policy_rejected_logps = self.base_model(rejected)['logits'].log_softmax(-1)
|
| |
|
| |
|
| | loss = self.ppo_trainer.dpo.compute_dpo_loss(
|
| | policy_chosen_logps.mean(),
|
| | policy_rejected_logps.mean(),
|
| | ref_chosen_logps.mean(),
|
| | ref_rejected_logps.mean()
|
| | )
|
| |
|
| |
|
| | optimizer.zero_grad()
|
| | loss.backward()
|
| | optimizer.step()
|
| |
|
| | total_loss += loss.item()
|
| | num_batches += 1
|
| |
|
| | avg_loss = total_loss / num_batches
|
| | logger.info(f"DPO epoch {epoch}: loss = {avg_loss:.4f}")
|
| |
|
| | def run_rlhf_loop(
|
| | self,
|
| | env,
|
| | num_iterations: int = 100,
|
| | steps_per_iteration: int = 1000
|
| | ):
|
| | """Main RLHF training loop"""
|
| |
|
| | for iteration in range(num_iterations):
|
| | logger.info(f"RLHF iteration {iteration}")
|
| |
|
| |
|
| | experiences = self.ppo_trainer.collect_experience(env, steps_per_iteration)
|
| |
|
| |
|
| | for exp in experiences:
|
| | priority = abs(exp.advantages.item()) if exp.advantages is not None else 1.0
|
| | self.ppo_trainer.buffer.push(exp, priority)
|
| |
|
| |
|
| | if len(self.ppo_trainer.buffer) >= self.config.min_buffer_size:
|
| |
|
| | batch = self.ppo_trainer.buffer.sample(self.config.ppo_batch_size)
|
| |
|
| |
|
| | metrics = self.ppo_trainer.train_step(batch)
|
| |
|
| | logger.info(f"Iteration {iteration} metrics: {metrics}")
|
| |
|
| |
|
| | for key, value in metrics.items():
|
| | self.ppo_trainer.stats[key].append(value)
|
| |
|
| |
|
| | if iteration % 10 == 0:
|
| | self.evaluate(env)
|
| |
|
| | def evaluate(self, env, num_episodes: int = 10):
|
| | """Evaluate current policy"""
|
| | total_rewards = {obj: 0.0 for obj in RewardObjective}
|
| |
|
| | for episode in range(num_episodes):
|
| | state = env.reset()
|
| | done = False
|
| | episode_rewards = {obj: 0.0 for obj in RewardObjective}
|
| |
|
| | while not done:
|
| | with torch.no_grad():
|
| | action = self.base_model(state.unsqueeze(0)).argmax(dim=-1)
|
| |
|
| | next_state, reward_dict, done, _ = env.step(action.item())
|
| |
|
| | for obj, reward in reward_dict.items():
|
| | episode_rewards[obj] += reward
|
| |
|
| | state = next_state
|
| |
|
| | for obj in RewardObjective:
|
| | total_rewards[obj] += episode_rewards[obj]
|
| |
|
| |
|
| | avg_rewards = {obj: total / num_episodes for obj, total in total_rewards.items()}
|
| |
|
| | logger.info(f"Evaluation results: {avg_rewards}")
|
| |
|
| | return avg_rewards
|
| |
|