| """ |
| PPO训练器 |
| """ |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.optim as optim |
| import numpy as np |
| from typing import List, Tuple, Optional |
| from dataclasses import dataclass |
| from collections import deque |
| import random |
|
|
|
|
| @dataclass |
| class Transition: |
| """状态转移数据""" |
| state: np.ndarray |
| scores: np.ndarray |
| action: int |
| reward: float |
| next_state: np.ndarray |
| next_scores: np.ndarray |
| done: bool |
| log_prob: float |
| value: float |
| valid_actions: np.ndarray |
|
|
|
|
| class RolloutBuffer: |
| """存储轨迹数据的缓冲区""" |
| |
| def __init__(self, capacity: int = 10000): |
| self.capacity = capacity |
| self.buffer: List[Transition] = [] |
| self.position = 0 |
| |
| def push(self, transition: Transition) -> None: |
| """添加一个转移""" |
| if len(self.buffer) < self.capacity: |
| self.buffer.append(transition) |
| else: |
| self.buffer[self.position] = transition |
| self.position = (self.position + 1) % self.capacity |
| |
| def push_batch(self, transitions: List[Transition]) -> None: |
| """批量添加转移""" |
| for t in transitions: |
| self.push(t) |
| |
| def get_all(self) -> List[Transition]: |
| """获取所有数据""" |
| return self.buffer.copy() |
| |
| def clear(self) -> None: |
| """清空缓冲区""" |
| self.buffer = [] |
| self.position = 0 |
| |
| def __len__(self) -> int: |
| return len(self.buffer) |
|
|
|
|
| class PPOTrainer: |
| """PPO训练器""" |
| |
| def __init__( |
| self, |
| model, |
| lr: float = 1e-4, |
| gamma: float = 0.99, |
| gae_lambda: float = 0.95, |
| clip_ratio: float = 0.2, |
| value_coef: float = 0.5, |
| entropy_coef: float = 0.01, |
| max_grad_norm: float = 0.5, |
| update_epochs: int = 4, |
| batch_size: int = 64, |
| device: str = "cpu" |
| ): |
| self.model = model.to(device) |
| self.device = device |
| |
| self.gamma = gamma |
| self.gae_lambda = gae_lambda |
| self.clip_ratio = clip_ratio |
| self.value_coef = value_coef |
| self.entropy_coef = entropy_coef |
| self.max_grad_norm = max_grad_norm |
| self.update_epochs = update_epochs |
| self.batch_size = batch_size |
| |
| self.optimizer = optim.Adam(model.parameters(), lr=lr) |
| |
| |
| self.stats = { |
| 'policy_loss': deque(maxlen=100), |
| 'value_loss': deque(maxlen=100), |
| 'entropy': deque(maxlen=100), |
| 'total_loss': deque(maxlen=100) |
| } |
| |
| def compute_gae( |
| self, |
| rewards: np.ndarray, |
| values: np.ndarray, |
| dones: np.ndarray, |
| next_value: float = 0.0 |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| 计算Generalized Advantage Estimation (GAE) |
| |
| Args: |
| rewards: 奖励序列 (T,) |
| values: 价值序列 (T,) |
| dones: 结束标志序列 (T,) |
| next_value: 最后状态的下一个价值 |
| |
| Returns: |
| returns: 回报 (T,) |
| advantages: 优势 (T,) |
| """ |
| T = len(rewards) |
| advantages = np.zeros(T, dtype=np.float32) |
| returns = np.zeros(T, dtype=np.float32) |
| |
| last_gae = 0 |
| last_return = next_value |
| |
| for t in reversed(range(T)): |
| if dones[t]: |
| next_value_t = 0 |
| last_gae = 0 |
| else: |
| next_value_t = values[t + 1] if t + 1 < T else next_value |
| |
| delta = rewards[t] + self.gamma * next_value_t - values[t] |
| last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae |
| advantages[t] = last_gae |
| |
| last_return = rewards[t] + self.gamma * (1 - dones[t]) * last_return |
| returns[t] = last_return |
| |
| return returns, advantages |
| |
| def update(self, buffer: RolloutBuffer) -> dict: |
| """ |
| 使用PPO更新模型 |
| |
| Args: |
| buffer: 存储轨迹数据的缓冲区 |
| |
| Returns: |
| 训练统计信息 |
| """ |
| if len(buffer) < self.batch_size: |
| return {} |
| |
| |
| transitions = buffer.get_all() |
| |
| |
| states = np.array([t.state for t in transitions]) |
| scores = np.array([t.scores for t in transitions]) |
| actions = np.array([t.action for t in transitions]) |
| rewards = np.array([t.reward for t in transitions]) |
| dones = np.array([t.done for t in transitions]) |
| old_log_probs = np.array([t.log_prob for t in transitions]) |
| old_values = np.array([t.value for t in transitions]) |
| valid_actions = np.array([t.valid_actions for t in transitions]) |
| |
| |
| returns, advantages = self.compute_gae(rewards, old_values, dones) |
| |
| |
| advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8) |
| |
| |
| states_t = torch.FloatTensor(states).to(self.device) |
| scores_t = torch.FloatTensor(scores).to(self.device) |
| actions_t = torch.LongTensor(actions).to(self.device) |
| old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device) |
| returns_t = torch.FloatTensor(returns).to(self.device) |
| advantages_t = torch.FloatTensor(advantages).to(self.device) |
| valid_actions_t = torch.BoolTensor(valid_actions).to(self.device) |
| |
| |
| total_policy_loss = 0 |
| total_value_loss = 0 |
| total_entropy = 0 |
| num_updates = 0 |
| |
| dataset_size = len(transitions) |
| indices = np.arange(dataset_size) |
| |
| for _ in range(self.update_epochs): |
| np.random.shuffle(indices) |
| |
| for start in range(0, dataset_size, self.batch_size): |
| end = start + self.batch_size |
| batch_indices = indices[start:end] |
| |
| |
| batch_states = states_t[batch_indices] |
| batch_scores = scores_t[batch_indices] |
| batch_actions = actions_t[batch_indices] |
| batch_old_log_probs = old_log_probs_t[batch_indices] |
| batch_returns = returns_t[batch_indices] |
| batch_advantages = advantages_t[batch_indices] |
| batch_valid = valid_actions_t[batch_indices] |
| |
| |
| log_probs, values, entropy = self.model.evaluate_actions( |
| batch_states, batch_actions, batch_scores, batch_valid |
| ) |
| |
| |
| ratio = torch.exp(log_probs - batch_old_log_probs) |
| surr1 = ratio * batch_advantages |
| surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages |
| policy_loss = -torch.min(surr1, surr2).mean() |
| |
| |
| value_loss = F.mse_loss(values.squeeze(), batch_returns) |
| |
| |
| loss = ( |
| policy_loss + |
| self.value_coef * value_loss - |
| self.entropy_coef * entropy.mean() |
| ) |
| |
| |
| self.optimizer.zero_grad() |
| loss.backward() |
| nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm) |
| self.optimizer.step() |
| |
| total_policy_loss += policy_loss.item() |
| total_value_loss += value_loss.item() |
| total_entropy += entropy.mean().item() |
| num_updates += 1 |
| |
| |
| stats = { |
| 'policy_loss': total_policy_loss / num_updates, |
| 'value_loss': total_value_loss / num_updates, |
| 'entropy': total_entropy / num_updates |
| } |
| |
| for key, value in stats.items(): |
| self.stats[key].append(value) |
| |
| return stats |
| |
| def get_recent_stats(self) -> dict: |
| """获取最近的训练统计""" |
| return {key: np.mean(values) for key, values in self.stats.items() if values} |
|
|
|
|
| class TrainingStats: |
| """训练统计记录器""" |
| |
| def __init__(self): |
| self.games_played = 0 |
| self.total_steps = 0 |
| self.scores = [] |
| self.situational_scores = [] |
| self.max_tiles = [] |
| self.game_lengths = [] |
| |
| |
| self.score_history = [] |
| self.situational_history = [] |
| self.max_tile_history = [] |
| self.steps_history = [] |
| |
| |
| self.best_score = 0 |
| self.best_max_tile = 0 |
| |
| def record_game( |
| self, |
| score: int, |
| situational_score: float, |
| max_tile: int, |
| steps: int |
| ) -> None: |
| """记录一局游戏""" |
| self.games_played += 1 |
| self.total_steps += steps |
| |
| self.scores.append(score) |
| self.situational_scores.append(situational_score) |
| self.max_tiles.append(max_tile) |
| self.game_lengths.append(steps) |
| |
| self.score_history.append(score) |
| self.situational_history.append(situational_score) |
| self.max_tile_history.append(max_tile) |
| self.steps_history.append(steps) |
| |
| if score > self.best_score: |
| self.best_score = score |
| if max_tile > self.best_max_tile: |
| self.best_max_tile = max_tile |
| |
| def get_avg_stats(self, window: int = 100) -> dict: |
| """获取平均统计""" |
| def avg(lst): |
| if not lst: |
| return 0 |
| recent = lst[-window:] |
| return sum(recent) / len(recent) |
| |
| return { |
| 'games_played': self.games_played, |
| 'total_steps': self.total_steps, |
| 'avg_score': avg(self.scores), |
| 'avg_situational': avg(self.situational_scores), |
| 'avg_max_tile': avg(self.max_tiles), |
| 'avg_game_length': avg(self.game_lengths), |
| 'best_score': self.best_score, |
| 'best_max_tile': self.best_max_tile, |
| 'recent_scores': self.scores[-10:] if self.scores else [], |
| 'recent_max_tiles': self.max_tiles[-10:] if self.max_tiles else [] |
| } |
|
|
|
|
| if __name__ == "__main__": |
| from model import Game2048Transformer |
| |
| |
| device = torch.device("cpu") |
| model = Game2048Transformer().to(device) |
| trainer = PPOTrainer(model, device=device) |
| |
| |
| buffer = RolloutBuffer(capacity=1000) |
| |
| for _ in range(100): |
| t = Transition( |
| state=np.random.randn(4, 4).astype(np.float32), |
| scores=np.random.rand(2).astype(np.float32), |
| action=np.random.randint(0, 4), |
| reward=np.random.randn(), |
| next_state=np.random.randn(4, 4).astype(np.float32), |
| next_scores=np.random.rand(2).astype(np.float32), |
| done=np.random.rand() < 0.1, |
| log_prob=np.random.randn(), |
| value=np.random.randn(), |
| valid_actions=np.ones(4, dtype=bool) |
| ) |
| buffer.push(t) |
| |
| |
| stats = trainer.update(buffer) |
| print(f"Training stats: {stats}") |
| |
| |
| training_stats = TrainingStats() |
| for i in range(10): |
| training_stats.record_game( |
| score=1000 * (i + 1), |
| situational_score=50.0 + i * 5, |
| max_tile=2 ** (i + 5), |
| steps=100 + i * 10 |
| ) |
| |
| print(f"Average stats: {training_stats.get_avg_stats()}") |
|
|