File size: 12,300 Bytes
0642513
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
"""
PPO训练器
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
from typing import List, Tuple, Optional
from dataclasses import dataclass
from collections import deque
import random


@dataclass
class Transition:
    """状态转移数据"""
    state: np.ndarray      # 棋盘状态 (4, 4)
    scores: np.ndarray     # 分数特征 (2,)
    action: int            # 采取的动作
    reward: float          # 奖励
    next_state: np.ndarray # 下一状态
    next_scores: np.ndarray # 下一分数
    done: bool             # 是否结束
    log_prob: float        # 动作的log概率
    value: float           # 状态价值
    valid_actions: np.ndarray  # 有效动作mask


class RolloutBuffer:
    """存储轨迹数据的缓冲区"""
    
    def __init__(self, capacity: int = 10000):
        self.capacity = capacity
        self.buffer: List[Transition] = []
        self.position = 0
    
    def push(self, transition: Transition) -> None:
        """添加一个转移"""
        if len(self.buffer) < self.capacity:
            self.buffer.append(transition)
        else:
            self.buffer[self.position] = transition
        self.position = (self.position + 1) % self.capacity
    
    def push_batch(self, transitions: List[Transition]) -> None:
        """批量添加转移"""
        for t in transitions:
            self.push(t)
    
    def get_all(self) -> List[Transition]:
        """获取所有数据"""
        return self.buffer.copy()
    
    def clear(self) -> None:
        """清空缓冲区"""
        self.buffer = []
        self.position = 0
    
    def __len__(self) -> int:
        return len(self.buffer)


class PPOTrainer:
    """PPO训练器"""
    
    def __init__(
        self,
        model,
        lr: float = 1e-4,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        clip_ratio: float = 0.2,
        value_coef: float = 0.5,
        entropy_coef: float = 0.01,
        max_grad_norm: float = 0.5,
        update_epochs: int = 4,
        batch_size: int = 64,
        device: str = "cpu"
    ):
        self.model = model.to(device)
        self.device = device
        
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        self.clip_ratio = clip_ratio
        self.value_coef = value_coef
        self.entropy_coef = entropy_coef
        self.max_grad_norm = max_grad_norm
        self.update_epochs = update_epochs
        self.batch_size = batch_size
        
        self.optimizer = optim.Adam(model.parameters(), lr=lr)
        
        # 训练统计
        self.stats = {
            'policy_loss': deque(maxlen=100),
            'value_loss': deque(maxlen=100),
            'entropy': deque(maxlen=100),
            'total_loss': deque(maxlen=100)
        }
    
    def compute_gae(
        self,
        rewards: np.ndarray,
        values: np.ndarray,
        dones: np.ndarray,
        next_value: float = 0.0
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        计算Generalized Advantage Estimation (GAE)
        
        Args:
            rewards: 奖励序列 (T,)
            values: 价值序列 (T,)
            dones: 结束标志序列 (T,)
            next_value: 最后状态的下一个价值
            
        Returns:
            returns: 回报 (T,)
            advantages: 优势 (T,)
        """
        T = len(rewards)
        advantages = np.zeros(T, dtype=np.float32)
        returns = np.zeros(T, dtype=np.float32)
        
        last_gae = 0
        last_return = next_value
        
        for t in reversed(range(T)):
            if dones[t]:
                next_value_t = 0
                last_gae = 0
            else:
                next_value_t = values[t + 1] if t + 1 < T else next_value
            
            delta = rewards[t] + self.gamma * next_value_t - values[t]
            last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
            advantages[t] = last_gae
            
            last_return = rewards[t] + self.gamma * (1 - dones[t]) * last_return
            returns[t] = last_return
        
        return returns, advantages
    
    def update(self, buffer: RolloutBuffer) -> dict:
        """
        使用PPO更新模型
        
        Args:
            buffer: 存储轨迹数据的缓冲区
            
        Returns:
            训练统计信息
        """
        if len(buffer) < self.batch_size:
            return {}
        
        # 获取所有数据
        transitions = buffer.get_all()
        
        # 转换为数组
        states = np.array([t.state for t in transitions])
        scores = np.array([t.scores for t in transitions])
        actions = np.array([t.action for t in transitions])
        rewards = np.array([t.reward for t in transitions])
        dones = np.array([t.done for t in transitions])
        old_log_probs = np.array([t.log_prob for t in transitions])
        old_values = np.array([t.value for t in transitions])
        valid_actions = np.array([t.valid_actions for t in transitions])
        
        # 计算优势和回报
        returns, advantages = self.compute_gae(rewards, old_values, dones)
        
        # 标准化优势
        advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
        
        # 转换为张量
        states_t = torch.FloatTensor(states).to(self.device)
        scores_t = torch.FloatTensor(scores).to(self.device)
        actions_t = torch.LongTensor(actions).to(self.device)
        old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
        returns_t = torch.FloatTensor(returns).to(self.device)
        advantages_t = torch.FloatTensor(advantages).to(self.device)
        valid_actions_t = torch.BoolTensor(valid_actions).to(self.device)
        
        # PPO更新
        total_policy_loss = 0
        total_value_loss = 0
        total_entropy = 0
        num_updates = 0
        
        dataset_size = len(transitions)
        indices = np.arange(dataset_size)
        
        for _ in range(self.update_epochs):
            np.random.shuffle(indices)
            
            for start in range(0, dataset_size, self.batch_size):
                end = start + self.batch_size
                batch_indices = indices[start:end]
                
                # 获取批次数据
                batch_states = states_t[batch_indices]
                batch_scores = scores_t[batch_indices]
                batch_actions = actions_t[batch_indices]
                batch_old_log_probs = old_log_probs_t[batch_indices]
                batch_returns = returns_t[batch_indices]
                batch_advantages = advantages_t[batch_indices]
                batch_valid = valid_actions_t[batch_indices]
                
                # 前向传播
                log_probs, values, entropy = self.model.evaluate_actions(
                    batch_states, batch_actions, batch_scores, batch_valid
                )
                
                # 策略损失 (PPO Clip)
                ratio = torch.exp(log_probs - batch_old_log_probs)
                surr1 = ratio * batch_advantages
                surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
                policy_loss = -torch.min(surr1, surr2).mean()
                
                # 价值损失
                value_loss = F.mse_loss(values.squeeze(), batch_returns)
                
                # 总损失
                loss = (
                    policy_loss +
                    self.value_coef * value_loss -
                    self.entropy_coef * entropy.mean()
                )
                
                # 反向传播
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                total_policy_loss += policy_loss.item()
                total_value_loss += value_loss.item()
                total_entropy += entropy.mean().item()
                num_updates += 1
        
        # 记录统计
        stats = {
            'policy_loss': total_policy_loss / num_updates,
            'value_loss': total_value_loss / num_updates,
            'entropy': total_entropy / num_updates
        }
        
        for key, value in stats.items():
            self.stats[key].append(value)
        
        return stats
    
    def get_recent_stats(self) -> dict:
        """获取最近的训练统计"""
        return {key: np.mean(values) for key, values in self.stats.items() if values}


class TrainingStats:
    """训练统计记录器"""
    
    def __init__(self):
        self.games_played = 0
        self.total_steps = 0
        self.scores = []           # 每局累积分数
        self.situational_scores = []  # 每局平均局面分数
        self.max_tiles = []        # 每局最大砖块
        self.game_lengths = []     # 每局步数
        
        # 历史记录用于绘图
        self.score_history = []
        self.situational_history = []
        self.max_tile_history = []
        self.steps_history = []
        
        # 最佳记录
        self.best_score = 0
        self.best_max_tile = 0
    
    def record_game(
        self, 
        score: int, 
        situational_score: float,
        max_tile: int, 
        steps: int
    ) -> None:
        """记录一局游戏"""
        self.games_played += 1
        self.total_steps += steps
        
        self.scores.append(score)
        self.situational_scores.append(situational_score)
        self.max_tiles.append(max_tile)
        self.game_lengths.append(steps)
        
        self.score_history.append(score)
        self.situational_history.append(situational_score)
        self.max_tile_history.append(max_tile)
        self.steps_history.append(steps)
        
        if score > self.best_score:
            self.best_score = score
        if max_tile > self.best_max_tile:
            self.best_max_tile = max_tile
    
    def get_avg_stats(self, window: int = 100) -> dict:
        """获取平均统计"""
        def avg(lst):
            if not lst:
                return 0
            recent = lst[-window:]
            return sum(recent) / len(recent)
        
        return {
            'games_played': self.games_played,
            'total_steps': self.total_steps,
            'avg_score': avg(self.scores),
            'avg_situational': avg(self.situational_scores),
            'avg_max_tile': avg(self.max_tiles),
            'avg_game_length': avg(self.game_lengths),
            'best_score': self.best_score,
            'best_max_tile': self.best_max_tile,
            'recent_scores': self.scores[-10:] if self.scores else [],
            'recent_max_tiles': self.max_tiles[-10:] if self.max_tiles else []
        }


if __name__ == "__main__":
    from model import Game2048Transformer
    
    # 测试PPO训练器
    device = torch.device("cpu")
    model = Game2048Transformer().to(device)
    trainer = PPOTrainer(model, device=device)
    
    # 创建测试数据
    buffer = RolloutBuffer(capacity=1000)
    
    for _ in range(100):
        t = Transition(
            state=np.random.randn(4, 4).astype(np.float32),
            scores=np.random.rand(2).astype(np.float32),
            action=np.random.randint(0, 4),
            reward=np.random.randn(),
            next_state=np.random.randn(4, 4).astype(np.float32),
            next_scores=np.random.rand(2).astype(np.float32),
            done=np.random.rand() < 0.1,
            log_prob=np.random.randn(),
            value=np.random.randn(),
            valid_actions=np.ones(4, dtype=bool)
        )
        buffer.push(t)
    
    # 测试更新
    stats = trainer.update(buffer)
    print(f"Training stats: {stats}")
    
    # 测试统计
    training_stats = TrainingStats()
    for i in range(10):
        training_stats.record_game(
            score=1000 * (i + 1),
            situational_score=50.0 + i * 5,
            max_tile=2 ** (i + 5),
            steps=100 + i * 10
        )
    
    print(f"Average stats: {training_stats.get_avg_stats()}")