vAIbe_2048 / trainer.py
forthezero's picture
Upload 13 files
0642513 verified
"""
PPO训练器
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
from collections import deque
import random
@dataclass
class Transition:
"""状态转移数据"""
state: np.ndarray # 棋盘状态 (4, 4)
scores: np.ndarray # 分数特征 (2,)
action: int # 采取的动作
reward: float # 奖励
next_state: np.ndarray # 下一状态
next_scores: np.ndarray # 下一分数
done: bool # 是否结束
log_prob: float # 动作的log概率
value: float # 状态价值
valid_actions: np.ndarray # 有效动作mask
class RolloutBuffer:
"""存储轨迹数据的缓冲区"""
def __init__(self, capacity: int = 10000):
self.capacity = capacity
self.buffer: List[Transition] = []
self.position = 0
def push(self, transition: Transition) -> None:
"""添加一个转移"""
if len(self.buffer) < self.capacity:
self.buffer.append(transition)
else:
self.buffer[self.position] = transition
self.position = (self.position + 1) % self.capacity
def push_batch(self, transitions: List[Transition]) -> None:
"""批量添加转移"""
for t in transitions:
self.push(t)
def get_all(self) -> List[Transition]:
"""获取所有数据"""
return self.buffer.copy()
def clear(self) -> None:
"""清空缓冲区"""
self.buffer = []
self.position = 0
def __len__(self) -> int:
return len(self.buffer)
class PPOTrainer:
"""PPO训练器"""
def __init__(
self,
model,
lr: float = 1e-4,
gamma: float = 0.99,
gae_lambda: float = 0.95,
clip_ratio: float = 0.2,
value_coef: float = 0.5,
entropy_coef: float = 0.01,
max_grad_norm: float = 0.5,
update_epochs: int = 4,
batch_size: int = 64,
device: str = "cpu"
):
self.model = model.to(device)
self.device = device
self.gamma = gamma
self.gae_lambda = gae_lambda
self.clip_ratio = clip_ratio
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm
self.update_epochs = update_epochs
self.batch_size = batch_size
self.optimizer = optim.Adam(model.parameters(), lr=lr)
# 训练统计
self.stats = {
'policy_loss': deque(maxlen=100),
'value_loss': deque(maxlen=100),
'entropy': deque(maxlen=100),
'total_loss': deque(maxlen=100)
}
def compute_gae(
self,
rewards: np.ndarray,
values: np.ndarray,
dones: np.ndarray,
next_value: float = 0.0
) -> Tuple[np.ndarray, np.ndarray]:
"""
计算Generalized Advantage Estimation (GAE)
Args:
rewards: 奖励序列 (T,)
values: 价值序列 (T,)
dones: 结束标志序列 (T,)
next_value: 最后状态的下一个价值
Returns:
returns: 回报 (T,)
advantages: 优势 (T,)
"""
T = len(rewards)
advantages = np.zeros(T, dtype=np.float32)
returns = np.zeros(T, dtype=np.float32)
last_gae = 0
last_return = next_value
for t in reversed(range(T)):
if dones[t]:
next_value_t = 0
last_gae = 0
else:
next_value_t = values[t + 1] if t + 1 < T else next_value
delta = rewards[t] + self.gamma * next_value_t - values[t]
last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
advantages[t] = last_gae
last_return = rewards[t] + self.gamma * (1 - dones[t]) * last_return
returns[t] = last_return
return returns, advantages
def update(self, buffer: RolloutBuffer) -> dict:
"""
使用PPO更新模型
Args:
buffer: 存储轨迹数据的缓冲区
Returns:
训练统计信息
"""
if len(buffer) < self.batch_size:
return {}
# 获取所有数据
transitions = buffer.get_all()
# 转换为数组
states = np.array([t.state for t in transitions])
scores = np.array([t.scores for t in transitions])
actions = np.array([t.action for t in transitions])
rewards = np.array([t.reward for t in transitions])
dones = np.array([t.done for t in transitions])
old_log_probs = np.array([t.log_prob for t in transitions])
old_values = np.array([t.value for t in transitions])
valid_actions = np.array([t.valid_actions for t in transitions])
# 计算优势和回报
returns, advantages = self.compute_gae(rewards, old_values, dones)
# 标准化优势
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
# 转换为张量
states_t = torch.FloatTensor(states).to(self.device)
scores_t = torch.FloatTensor(scores).to(self.device)
actions_t = torch.LongTensor(actions).to(self.device)
old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
returns_t = torch.FloatTensor(returns).to(self.device)
advantages_t = torch.FloatTensor(advantages).to(self.device)
valid_actions_t = torch.BoolTensor(valid_actions).to(self.device)
# PPO更新
total_policy_loss = 0
total_value_loss = 0
total_entropy = 0
num_updates = 0
dataset_size = len(transitions)
indices = np.arange(dataset_size)
for _ in range(self.update_epochs):
np.random.shuffle(indices)
for start in range(0, dataset_size, self.batch_size):
end = start + self.batch_size
batch_indices = indices[start:end]
# 获取批次数据
batch_states = states_t[batch_indices]
batch_scores = scores_t[batch_indices]
batch_actions = actions_t[batch_indices]
batch_old_log_probs = old_log_probs_t[batch_indices]
batch_returns = returns_t[batch_indices]
batch_advantages = advantages_t[batch_indices]
batch_valid = valid_actions_t[batch_indices]
# 前向传播
log_probs, values, entropy = self.model.evaluate_actions(
batch_states, batch_actions, batch_scores, batch_valid
)
# 策略损失 (PPO Clip)
ratio = torch.exp(log_probs - batch_old_log_probs)
surr1 = ratio * batch_advantages
surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
policy_loss = -torch.min(surr1, surr2).mean()
# 价值损失
value_loss = F.mse_loss(values.squeeze(), batch_returns)
# 总损失
loss = (
policy_loss +
self.value_coef * value_loss -
self.entropy_coef * entropy.mean()
)
# 反向传播
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()
total_policy_loss += policy_loss.item()
total_value_loss += value_loss.item()
total_entropy += entropy.mean().item()
num_updates += 1
# 记录统计
stats = {
'policy_loss': total_policy_loss / num_updates,
'value_loss': total_value_loss / num_updates,
'entropy': total_entropy / num_updates
}
for key, value in stats.items():
self.stats[key].append(value)
return stats
def get_recent_stats(self) -> dict:
"""获取最近的训练统计"""
return {key: np.mean(values) for key, values in self.stats.items() if values}
class TrainingStats:
"""训练统计记录器"""
def __init__(self):
self.games_played = 0
self.total_steps = 0
self.scores = [] # 每局累积分数
self.situational_scores = [] # 每局平均局面分数
self.max_tiles = [] # 每局最大砖块
self.game_lengths = [] # 每局步数
# 历史记录用于绘图
self.score_history = []
self.situational_history = []
self.max_tile_history = []
self.steps_history = []
# 最佳记录
self.best_score = 0
self.best_max_tile = 0
def record_game(
self,
score: int,
situational_score: float,
max_tile: int,
steps: int
) -> None:
"""记录一局游戏"""
self.games_played += 1
self.total_steps += steps
self.scores.append(score)
self.situational_scores.append(situational_score)
self.max_tiles.append(max_tile)
self.game_lengths.append(steps)
self.score_history.append(score)
self.situational_history.append(situational_score)
self.max_tile_history.append(max_tile)
self.steps_history.append(steps)
if score > self.best_score:
self.best_score = score
if max_tile > self.best_max_tile:
self.best_max_tile = max_tile
def get_avg_stats(self, window: int = 100) -> dict:
"""获取平均统计"""
def avg(lst):
if not lst:
return 0
recent = lst[-window:]
return sum(recent) / len(recent)
return {
'games_played': self.games_played,
'total_steps': self.total_steps,
'avg_score': avg(self.scores),
'avg_situational': avg(self.situational_scores),
'avg_max_tile': avg(self.max_tiles),
'avg_game_length': avg(self.game_lengths),
'best_score': self.best_score,
'best_max_tile': self.best_max_tile,
'recent_scores': self.scores[-10:] if self.scores else [],
'recent_max_tiles': self.max_tiles[-10:] if self.max_tiles else []
}
if __name__ == "__main__":
from model import Game2048Transformer
# 测试PPO训练器
device = torch.device("cpu")
model = Game2048Transformer().to(device)
trainer = PPOTrainer(model, device=device)
# 创建测试数据
buffer = RolloutBuffer(capacity=1000)
for _ in range(100):
t = Transition(
state=np.random.randn(4, 4).astype(np.float32),
scores=np.random.rand(2).astype(np.float32),
action=np.random.randint(0, 4),
reward=np.random.randn(),
next_state=np.random.randn(4, 4).astype(np.float32),
next_scores=np.random.rand(2).astype(np.float32),
done=np.random.rand() < 0.1,
log_prob=np.random.randn(),
value=np.random.randn(),
valid_actions=np.ones(4, dtype=bool)
)
buffer.push(t)
# 测试更新
stats = trainer.update(buffer)
print(f"Training stats: {stats}")
# 测试统计
training_stats = TrainingStats()
for i in range(10):
training_stats.record_game(
score=1000 * (i + 1),
situational_score=50.0 + i * 5,
max_tile=2 ** (i + 5),
steps=100 + i * 10
)
print(f"Average stats: {training_stats.get_avg_stats()}")