forthezero commited on
Commit
0642513
·
verified ·
1 Parent(s): c473f51

Upload 13 files

Browse files
Files changed (13) hide show
  1. PLAN.md +230 -0
  2. README.md +318 -0
  3. TASK.md +24 -0
  4. checkpoints/model_1500.pt +3 -0
  5. game.py +388 -0
  6. gui.py +835 -0
  7. main.py +302 -0
  8. model.py +341 -0
  9. parallel.py +461 -0
  10. requirements.txt +13 -0
  11. train_simple.py +149 -0
  12. trainer.py +370 -0
  13. utils.py +295 -0
PLAN.md ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 项目计划:基于Transformer的2048游戏AI
2
+
3
+ ## 技术方案概述
4
+
5
+ ### 硬件约束
6
+ - AMD Ryzen 5 PRO 4650U (6核12线程) CPU
7
+ - 8GB RAM
8
+ - 无NVIDIA GPU,纯CPU训练
9
+ - 需要小型高效的模型架构
10
+
11
+ ### 整体架构
12
+
13
+ ```
14
+ ┌─────────────────────────────────────────────────────────┐
15
+ │ GUI主窗口 (PyQt5) │
16
+ ├─────────────────────────────────────────────────────────┤
17
+ │ ┌──────────────┐ ┌──────────────────────────────────┐ │
18
+ │ │ 2048游戏面板 │ │ 训练状态面板 │ │
19
+ │ │ (4x4网格) │ │ - 当前分数/局面分数 │ │
20
+ │ │ │ │ - 训练速度 (games/sec) │ │
21
+ │ │ │ │ - 累积分数变化曲线 │ │
22
+ │ │ │ │ - 局面分数变化曲线 │ │
23
+ │ └──────────────┘ └──────────────────────────────────┘ │
24
+ ├─────────────────────────────────────────────────────────┤
25
+ │ 控制面板: [训练模式] [演示模式] [开始/停止] [AI托管] │
26
+ └─────────────────────────────────────────────────────────┘
27
+ ```
28
+
29
+ ## 模块设计
30
+
31
+ ### 1. 游戏核心模块 (`game.py`)
32
+
33
+ ```python
34
+ class Game2048:
35
+ """2048游戏核心逻辑"""
36
+
37
+ def __init__(self):
38
+ self.board: np.ndarray # 4x4棋盘
39
+ self.accumulated_score: int # 累积分数
40
+ self.situational_score: float # 局面分数
41
+
42
+ def reset(self) -> None: ...
43
+ def move(self, direction: int) -> tuple[bool, bool]: ... # (moved, game_over)
44
+ def get_state(self) -> np.ndarray: ... # 返回当前局面
45
+ def calculate_situational_score(self) -> float: ...
46
+ ```
47
+
48
+ **局面分数计算公式:**
49
+ ```
50
+ situation_score = (
51
+ empty_cells * 10 + # 空格越多越好
52
+ max_consecutive_adjacent * 15 + # 连续相邻数字越多越好
53
+ log2(max_tile) * 5 - # 最高数字的对数
54
+ monotonicity_penalty # 单调性惩罚(避免混乱)
55
+ )
56
+ ```
57
+
58
+ ### 2. Transformer模型 (`model.py`)
59
+
60
+ 考虑到CPU训练的限制,采用小型Transformer:
61
+
62
+ ```python
63
+ class Game2048Transformer(nn.Module):
64
+ """小型Transformer用于2048决策"""
65
+
66
+ def __init__(self):
67
+ # 输入: 4x4棋盘 + 2个分数特征
68
+ # 将棋盘展平为16个token,每个token代表一个格子的状态
69
+
70
+ self.embedding = nn.Embedding(16, 64) # 0-15 表示 log2(value),16表示空
71
+ self.score_embedding = nn.Linear(2, 64) # 两种分数的embedding
72
+
73
+ encoder_layer = nn.TransformerEncoderLayer(
74
+ d_model=64,
75
+ nhead=4,
76
+ dim_feedforward=128,
77
+ dropout=0.1,
78
+ batch_first=True
79
+ )
80
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
81
+
82
+ self.policy_head = nn.Linear(64, 4) # 输出4个动作的概率
83
+ self.value_head = nn.Linear(64, 1) # 输出状态价值
84
+ ```
85
+
86
+ **模型大小估算:**
87
+ - Embedding: 17 * 64 = 1,088 参数
88
+ - Transformer (2层): ~50,000 参数
89
+ - 输出头: ~300 参数
90
+ - **总计: ~52,000 参数** - 非常小,适合CPU训练
91
+
92
+ ### 3. 训练模块 (`trainer.py`)
93
+
94
+ 采用 **Actor-Critic + PPO** 策略:
95
+
96
+ ```python
97
+ class PPOTrainer:
98
+ """PPO训练器"""
99
+
100
+ def __init__(self, model, lr=1e-4):
101
+ self.model = model
102
+ self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
103
+
104
+ def compute_advantage(self, rewards, values, dones):
105
+ # 计算GAE (Generalized Advantage Estimation)
106
+ ...
107
+
108
+ def update(self, trajectories):
109
+ # PPO更新逻辑
110
+ ...
111
+ ```
112
+
113
+ **奖励设计:**
114
+ ```python
115
+ reward = (
116
+ accumulated_score_delta * 0.3 + # 累积分数增量(权重低)
117
+ situational_score * 0.7 + # 局面分数(权重高)
118
+ game_over_penalty * (-100) # 游戏结束惩罚
119
+ )
120
+ ```
121
+
122
+ ### 4. 多进程训练 (`parallel.py`)
123
+
124
+ 利用6核CPU,同时运行多个游戏实例:
125
+
126
+ ```python
127
+ class ParallelGameEnv:
128
+ """并行游戏环境"""
129
+
130
+ def __init__(self, num_envs=4):
131
+ self.num_envs = num_envs
132
+ self.envs = [Game2048() for _ in range(num_envs)]
133
+
134
+ def step(self, actions: list[int]) -> list[Transition]:
135
+ # 并行执行动作,返回状态转移
136
+ ...
137
+ ```
138
+
139
+ ### 5. GUI模块 (`gui.py`)
140
+
141
+ 使用 PyQt5 构建界面:
142
+
143
+ ```python
144
+ class MainWindow(QMainWindow):
145
+ """主窗口"""
146
+
147
+ def __init__(self):
148
+ self.game_widget = GameBoardWidget()
149
+ self.stats_widget = StatsWidget()
150
+ self.control_widget = ControlWidget()
151
+
152
+ # 训练线程
153
+ self.training_thread = TrainingThread()
154
+
155
+ def switch_mode(self, mode: str): ...
156
+ def update_display(self): ...
157
+ ```
158
+
159
+ ## 文件结构
160
+
161
+ ```
162
+ game2048/
163
+ ├── TASK.md # 任务描述
164
+ ├── PLAN.md # 本文件
165
+ ├── main.py # 入口文件
166
+ ├── game.py # 游戏核心逻辑
167
+ ├── model.py # Transformer模型定义
168
+ ├── trainer.py # PPO训练器
169
+ ├── parallel.py # 多进程训练
170
+ ├── gui.py # GUI界面
171
+ ├── utils.py # 工具函数
172
+ ├── requirements.txt # 依赖
173
+ └── checkpoints/ # 模型保存目录
174
+ ```
175
+
176
+ ## 实现步骤
177
+
178
+ ### 阶段1: 核心游戏逻辑
179
+ 1. 实现 `game.py` - 2048游戏规则
180
+ 2. 实现局面分数计算
181
+ 3. 编写游戏逻辑单元测试
182
+
183
+ ### 阶段2: 模型与训练
184
+ 4. 实现 `model.py` - Transformer模型
185
+ 5. 实现 `trainer.py` - PPO训练器
186
+ 6. 实现 `parallel.py` - 多进程环境
187
+ 7. 验证训练流程可以运行
188
+
189
+ ### 阶段3: GUI界面
190
+ 8. 实现 `gui.py` - 主窗口和游戏面板
191
+ 9. 实现训练状态可视化(分数曲线)
192
+ 10. 实现模式切换(训练/演示)
193
+
194
+ ### 阶段4: 整合与优化
195
+ 11. 整合所有模块
196
+ 12. 性能优化
197
+ 13. 模型保存/加载功能
198
+
199
+ ## 依赖
200
+
201
+ ```
202
+ torch>=2.0.0
203
+ numpy>=1.24.0
204
+ PyQt5>=5.15.0
205
+ matplotlib>=3.7.0
206
+ ```
207
+
208
+ ## 训练策略细节
209
+
210
+ ### 状态表示
211
+ - 棋盘状态:将每个格子的值转换为 log2(value),空格为0
212
+ - 分数归一化:累积分数和局面分数归一化到 [0, 1]
213
+
214
+ ### 动作空间
215
+ - 0: 上
216
+ - 1: 下
217
+ - 2: 左
218
+ - 3: 右
219
+
220
+ ### 训练超参数
221
+ - Learning rate: 1e-4
222
+ - Batch size: 64
223
+ - PPO clip ratio: 0.2
224
+ - GAE lambda: 0.95
225
+ - Discount factor (gamma): 0.99
226
+ - 并行环境数: 4 (根据CPU核心数调整)
227
+
228
+ ### 停止条件
229
+ - 连续100局游戏平均分数无提升
230
+ - 用户手动停止
README.md ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 2048 AI Trainer
2
+
3
+ 基于 Transformer 的 2048 游戏人工智能训练器,使用 PPO(Proximal Policy Optimization)强化学习算法,让 AI 学会玩 2048 游戏。
4
+
5
+ ## 项目简介
6
+
7
+ 本项目实现了一个完整的 2048 游戏 AI 训练系统,包括:
8
+
9
+ - **游戏引擎**: 完整的 2048 游戏逻辑实现
10
+ - **深度学习模型**: 基于 Transformer 架构的策略网络
11
+ - **强化学习训练**: PPO 算法实现
12
+ - **可视化界面**: PyQt5 图形界面,支持训练监控和演示
13
+ - **命令行工具**: 支持无 GUI 的训练和演示模式
14
+
15
+ ## 功能特点
16
+
17
+ ### 1. Transformer 模型
18
+
19
+ 采用小型 Transformer 架构,专为 CPU 训练优化:
20
+
21
+ - **参数量**: 约 77,000 个参数(~300KB)
22
+ - **输入处理**:
23
+ - 棋盘状态编码为 16 个 token(每个格子对应一个 token)
24
+ - 分数特征(累积分数、局面分数)作为额外输入
25
+ - 位置编码:行/列位置嵌入
26
+ - **网络结构**:
27
+ - 2 层 Transformer Encoder
28
+ - 4 个注意力头
29
+ - 隐藏维度 64
30
+ - 前馈网络维度 128
31
+ - **输出**:
32
+ - 策略头:4 个动作(上/下/左/右)的概率分布
33
+ - 价值头:当前状态的价值评估
34
+
35
+ ### 2. 双评分机制
36
+
37
+ #### 累积分数(Accumulated Score)
38
+ 传统 2048 计分方式,每次合成砖块获得合成后砖块的数值作为分数。
39
+
40
+ #### 局面分数(Situational Score)
41
+ 综合评估当前局面的质量,鼓励 AI 保持良好局面:
42
+
43
+ ```
44
+ 局面分数 = 空格数 × 10 + 最大连续相邻数 × 15 + log₂(最大砖块) × 5 + 单调性奖励
45
+ ```
46
+
47
+ - **空格数**: 空格越多,操作空间越大
48
+ - **连续相邻数**: 如 512-1024-2048 连续排列,便于后续合并
49
+ - **单调性**: 鼓励数字按方向有序排列
50
+
51
+ ### 3. PPO 训练算法
52
+
53
+ 使用 Proximal Policy Optimization 算法进行训练:
54
+
55
+ - **优势估计**: GAE(Generalized Advantage Estimation)
56
+ - **策略裁剪**: 防止策略更新过大
57
+ - **价值函数**: 辅助训练,提供状态价值估计
58
+ - **熵正则化**: 鼓励探索
59
+
60
+ ### 4. GUI 界面
61
+
62
+ 基于 PyQt5 的图形界面:
63
+
64
+ - **训练模式**:
65
+ - 设置训练局数
66
+ - 实时显示训练进度
67
+ - 分数曲线可视化
68
+ - 训练完成后自动保存模型
69
+
70
+ - **演示模式**:
71
+ - 键盘手动操作
72
+ - AI 托管模式
73
+ - 单步执行
74
+ - 自动连续执行
75
+ - 实时局面分数曲线
76
+
77
+ ## 安装
78
+
79
+ ### 环境要求
80
+
81
+ - Python 3.8+
82
+ - Windows / Linux / macOS
83
+
84
+ ### 安装依赖
85
+
86
+ ```bash
87
+ cd game2048
88
+ pip install -r requirements.txt
89
+ ```
90
+
91
+ ### 依赖列表
92
+
93
+ ```
94
+ torch>=2.0.0 # 深度学习框架
95
+ numpy<2 # 数值计算
96
+ PyQt5>=5.15.0 # GUI 框架
97
+ matplotlib>=3.7.0 # 绘图库
98
+ ```
99
+
100
+ ## 使用方法
101
+
102
+ ### 1. GUI 模式
103
+
104
+ ```bash
105
+ python main.py
106
+ ```
107
+
108
+ 启动图形界面后:
109
+
110
+ **训练模式**:
111
+ 1. 选择 "Training Mode"
112
+ 2. 设置训练局数(默认 500)
113
+ 3. 点击 "Start Training" 开始训练
114
+ 4. 训练完成后自动保存到 `checkpoints/model.pt`
115
+ 5. 可随时点击 "Stop Training" 停止
116
+
117
+ **演示模式**:
118
+ 1. 选择 "Demo Mode"
119
+ 2. 点击 "Load Model" 加载已训练模型
120
+ 3. 使用方式:
121
+ - 键盘方向键:手动操作
122
+ - "AI Mode":切换 AI 托管
123
+ - "Step":AI 单步执行
124
+ - "Auto":AI 自动连续执行
125
+ - "Reset":重新开始游戏
126
+
127
+ ### 2. 命令行训练
128
+
129
+ ```bash
130
+ # 训练 1000 局
131
+ python main.py --train --games 1000
132
+
133
+ # 使用 4 个并行环境
134
+ python main.py --train --games 1000 --envs 4
135
+
136
+ # 设置随机种子
137
+ python main.py --train --games 1000 --seed 42
138
+ ```
139
+
140
+ ### 3. 演示模式
141
+
142
+ ```bash
143
+ # 加载模型并演示 5 局
144
+ python main.py --demo --model checkpoints/model.pt --games 5
145
+
146
+ # 不加载模型(随机权重)
147
+ python main.py --demo --games 3
148
+ ```
149
+
150
+ ### 4. 简单训练脚本
151
+
152
+ ```bash
153
+ python train_simple.py
154
+ ```
155
+
156
+ 修改脚本末尾可调整训练参数:
157
+
158
+ ```python
159
+ train_simple(num_games=500, save_path="checkpoints/model.pt")
160
+ ```
161
+
162
+ ## 项目结构
163
+
164
+ ```
165
+ game2048/
166
+ ├── TASK.md # 任务需求文档
167
+ ├── PLAN.md # 项目计划文档
168
+ ├── README.md # 本文件
169
+ ├── main.py # 程序入口
170
+ ├── game.py # 2048 游戏核心逻辑
171
+ │ ├── Game2048 # 游戏类
172
+ │ ├── move() # 移动操作
173
+ │ ├── get_state() # 获取状态
174
+ │ └── calculate_situational_score() # 计算局面分数
175
+
176
+ ├── model.py # Transformer 模型
177
+ │ ├── Game2048Transformer # Transformer 模型
178
+ │ ├── Game2048CNN # CNN 备选模型
179
+ │ └── get_action() # 动作选择
180
+
181
+ ├── trainer.py # PPO 训练器
182
+ │ ├── PPOTrainer # PPO 训练类
183
+ │ ├── RolloutBuffer # 经验缓冲区
184
+ │ ├── Transition # 状态转移数据结构
185
+ │ └── TrainingStats # 训练统计
186
+
187
+ ├── parallel.py # 并行训练环境
188
+ │ ├── ParallelGameEnv # 并行游戏环境
189
+ │ ├── TrainingWorker # 训练工作器
190
+ │ └── TrainingLoop # 训练循环
191
+
192
+ ├── gui.py # GUI 界面
193
+ │ ├── MainWindow # 主窗口
194
+ │ ├── GameBoardWidget # 游戏面板
195
+ │ ├── ScoreWidget # 分数显示
196
+ │ ├── PlotCanvas # 曲线绑图
197
+ │ └── SimpleTrainingThread # 训练线程
198
+
199
+ ├── train_simple.py # 简化训练脚本
200
+ ├── utils.py # 工具函数
201
+ ├── requirements.txt # 依赖列表
202
+ └── checkpoints/ # 模型保存目录
203
+ └── model.pt # 训练好的模型
204
+ ```
205
+
206
+ ## 模型架构详解
207
+
208
+ ### 输入表示
209
+
210
+ ```python
211
+ # 棋盘状态 (4, 4)
212
+ # 每个格子值转换为 log₂(value),空格为 0
213
+ state = [[0, 1, 2, 0], # 对应 [空, 2, 4, 空]
214
+ [1, 2, 3, 1], # 对应 [2, 4, 8, 2]
215
+ ...]
216
+
217
+ # 分数特征 (2,)
218
+ # [归一化累积分数, 归一化局面分数]
219
+ scores = [0.05, 0.85]
220
+ ```
221
+
222
+ ### 网络结构
223
+
224
+ ```
225
+ Input: (batch, 4, 4) board + (batch, 2) scores
226
+
227
+ Position Embedding: (batch, 16, 64)
228
+ + Spatial Embedding: (batch, 16, 64)
229
+ + Score Embedding: (batch, 1, 64)
230
+
231
+ Transformer Encoder (2 layers)
232
+ - Multi-Head Attention (4 heads)
233
+ - Feed-Forward Network (dim=128)
234
+
235
+ Global Mean Pooling: (batch, 64)
236
+
237
+ ├── Policy Head → (batch, 4) # 动作概率
238
+ └── Value Head → (batch, 1) # 状态价值
239
+ ```
240
+
241
+ ## 训练策略
242
+
243
+ ### 奖励设计
244
+
245
+ ```python
246
+ reward = 局面分数变化 × 0.7 + 累积分数增量 × 0.003
247
+
248
+ # 游戏结束惩罚
249
+ if game_over:
250
+ reward -= 10.0
251
+ ```
252
+
253
+ ### 超参数
254
+
255
+ | 参数 | 值 |
256
+ |------|-----|
257
+ | Learning Rate | 3e-4 |
258
+ | Batch Size | 64 |
259
+ | PPO Clip Ratio | 0.2 |
260
+ | GAE Lambda | 0.95 |
261
+ | Discount Factor (γ) | 0.99 |
262
+ | Entropy Coefficient | 0.01 |
263
+
264
+ ## 训练结果
265
+
266
+ ### 500 局训练后
267
+
268
+ | 指标 | 数值 |
269
+ |------|------|
270
+ | 平均分数 | ~2500 |
271
+ | 最高分数 | 6812 |
272
+ | 最大砖块 | 512 |
273
+ | 训练时间 | ~9 分钟 |
274
+
275
+ ### 分数分布
276
+
277
+ ```
278
+ 随机权重: 平均 ~800, 最高 ~2000
279
+ 训练 500 局: 平均 ~2500, 最高 ~6800
280
+ ```
281
+
282
+ ## 开发说明
283
+
284
+ ### 添加新功能
285
+
286
+ 1. **修改局面评分**: 编辑 `game.py` 中的 `calculate_situational_score()`
287
+ 2. **调整模型**: 修改 `model.py` 中的网络结构
288
+ 3. **优化训练**: 调整 `trainer.py` 中的超参数
289
+
290
+ ### 调试模式
291
+
292
+ ```python
293
+ # 在 game.py 中测试游戏逻辑
294
+ if __name__ == "__main__":
295
+ game = Game2048()
296
+ print(game)
297
+ game.move(Game2048.LEFT)
298
+ print(game)
299
+ ```
300
+
301
+ ## 已知问题
302
+
303
+ - Windows 下 PyTorch 可能需要特定版本以避免 DLL 加载问题
304
+ - NumPy 2.x 与 PyTorch 存在兼容性问题,建议使用 NumPy < 2
305
+
306
+ ## 参考资料
307
+
308
+ - [PPO 论文](https://arxiv.org/abs/1707.06347)
309
+ - [Transformer 论文](https://arxiv.org/abs/1706.03762)
310
+ - [2048 游戏](https://play2048.co/)
311
+
312
+ ## 许可证
313
+
314
+ MIT License
315
+
316
+ ---
317
+
318
+ *本项目由 GLM-5 开发实现*
TASK.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 需求
2
+
3
+ 用Python制作一个基于Transformer的2048游戏AI
4
+
5
+ 可以创建一个简单的窗口让我看到实时的训练状态,包括训练时AI的当前分数、训练速度、历史已经结束的游戏的*累积分数*变化、当前游戏的*局面分数*变化,窗口可切换演示模式和训练模式
6
+
7
+ 2048的一部分规定补充:
8
+
9
+ - 开局会在随机位置生成一个2,后续在空格生成的砖块在随机位置且值为2或4
10
+ - 同时存在两个分数机制:累积分数和局面分数
11
+ - 累积分数机制:每合成一个增加合成后的砖块的数值
12
+ - 局面分数机制:根据*空格数量*、*相邻连续数字数量(相邻砖块拥有相邻数字的数量,取最大值,不计算重复,如512 1024 2048为3,256 512 512 2为2)*、*最高数字的对数*设计以恶合理的函数,局面分数越高越好
13
+
14
+ ## 第一步:训练模型
15
+
16
+ 为模型提供当前的局面状态和两种分数(累积分数的权重低于其它,目的是局面分数保持更高),输出只有上下左右
17
+
18
+ 游戏结束后再开始新的游戏继续训练,直到训练没有长进或我手动点击停止
19
+
20
+ 性能允许的话可以多开2048游戏进行训练
21
+
22
+ ## 第二步:演示模式
23
+
24
+ 在窗口中可以手动操作或AI托管,AI托管可以单步执行也可以持续下一步,可以看到两种分数及局面分数变化
checkpoints/model_1500.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11ff8939d4676c1ad76ca9c5b0ec2616bfc48ed10788c36ff2b0e50cd7fbcb35
3
+ size 415458
game.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2048游戏核心逻辑
3
+ """
4
+ import numpy as np
5
+ from typing import Tuple, Optional
6
+ import random
7
+
8
+
9
+ class Game2048:
10
+ """2048游戏核心类"""
11
+
12
+ # 动作定义
13
+ UP = 0
14
+ DOWN = 1
15
+ LEFT = 2
16
+ RIGHT = 3
17
+
18
+ def __init__(self):
19
+ self.board: np.ndarray = np.zeros((4, 4), dtype=np.int64)
20
+ self.accumulated_score: int = 0
21
+ self.situational_score: float = 0.0
22
+ self.game_over: bool = False
23
+ self.moves_count: int = 0
24
+ self.reset()
25
+
26
+ def reset(self) -> np.ndarray:
27
+ """重置游戏,返回初始状态"""
28
+ self.board = np.zeros((4, 4), dtype=np.int64)
29
+ self.accumulated_score = 0
30
+ self.situational_score = 0.0
31
+ self.game_over = False
32
+ self.moves_count = 0
33
+
34
+ # 开局生成一个2
35
+ self._spawn_tile(value=2)
36
+ self._update_situational_score()
37
+ return self.get_state()
38
+
39
+ def _spawn_tile(self, value: Optional[int] = None) -> bool:
40
+ """
41
+ 在空格生成新砖块
42
+ 开局时value=2,后续随机2或4
43
+ 返回是否成功生成
44
+ """
45
+ empty_cells = list(zip(*np.where(self.board == 0)))
46
+ if not empty_cells:
47
+ return False
48
+
49
+ row, col = random.choice(empty_cells)
50
+ if value is None:
51
+ # 90%概率生成2,10%概率生成4
52
+ value = 2 if random.random() < 0.9 else 4
53
+ self.board[row, col] = value
54
+ return True
55
+
56
+ def _compress(self, line: np.ndarray) -> Tuple[np.ndarray, int]:
57
+ """
58
+ 压缩一行/列,将非零元素移到一端
59
+ 返回压缩后的行和合并得分
60
+ """
61
+ # 移除零,填充到末尾
62
+ non_zero = line[line != 0]
63
+ new_line = np.zeros_like(line)
64
+ score = 0
65
+
66
+ pos = 0
67
+ i = 0
68
+ while i < len(non_zero):
69
+ if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:
70
+ # 合并
71
+ new_line[pos] = non_zero[i] * 2
72
+ score += new_line[pos]
73
+ i += 2
74
+ else:
75
+ new_line[pos] = non_zero[i]
76
+ i += 1
77
+ pos += 1
78
+
79
+ return new_line, score
80
+
81
+ def _move_left(self) -> Tuple[bool, int]:
82
+ """向左移动,返回(是否移动, 得分)"""
83
+ moved = False
84
+ total_score = 0
85
+
86
+ for i in range(4):
87
+ original = self.board[i].copy()
88
+ new_line, score = self._compress(self.board[i])
89
+ self.board[i] = new_line
90
+ total_score += score
91
+ if not np.array_equal(original, new_line):
92
+ moved = True
93
+
94
+ return moved, total_score
95
+
96
+ def _move_right(self) -> Tuple[bool, int]:
97
+ """向右移动"""
98
+ moved = False
99
+ total_score = 0
100
+
101
+ for i in range(4):
102
+ original = self.board[i].copy()
103
+ new_line, score = self._compress(self.board[i][::-1])
104
+ self.board[i] = new_line[::-1]
105
+ total_score += score
106
+ if not np.array_equal(original, self.board[i]):
107
+ moved = True
108
+
109
+ return moved, total_score
110
+
111
+ def _move_up(self) -> Tuple[bool, int]:
112
+ """向上移动"""
113
+ moved = False
114
+ total_score = 0
115
+
116
+ for j in range(4):
117
+ original = self.board[:, j].copy()
118
+ new_line, score = self._compress(self.board[:, j])
119
+ self.board[:, j] = new_line
120
+ total_score += score
121
+ if not np.array_equal(original, new_line):
122
+ moved = True
123
+
124
+ return moved, total_score
125
+
126
+ def _move_down(self) -> Tuple[bool, int]:
127
+ """向下移动"""
128
+ moved = False
129
+ total_score = 0
130
+
131
+ for j in range(4):
132
+ original = self.board[:, j].copy()
133
+ new_line, score = self._compress(self.board[:, j][::-1])
134
+ self.board[:, j] = new_line[::-1]
135
+ total_score += score
136
+ if not np.array_equal(original, self.board[:, j]):
137
+ moved = True
138
+
139
+ return moved, total_score
140
+
141
+ def move(self, direction: int) -> Tuple[np.ndarray, float, bool, bool]:
142
+ """
143
+ 执行移动
144
+
145
+ Args:
146
+ direction: 0=上, 1=下, 2=左, 3=右
147
+
148
+ Returns:
149
+ state: 新状态
150
+ reward: 奖励(累积分数增量 + 局面分数变化)
151
+ moved: 是否成功移动
152
+ done: 游戏是否结束
153
+ """
154
+ if self.game_over:
155
+ return self.get_state(), 0.0, False, True
156
+
157
+ old_accumulated = self.accumulated_score
158
+ old_situational = self.situational_score
159
+
160
+ # 执行移动
161
+ if direction == self.UP:
162
+ moved, score = self._move_up()
163
+ elif direction == self.DOWN:
164
+ moved, score = self._move_down()
165
+ elif direction == self.LEFT:
166
+ moved, score = self._move_left()
167
+ elif direction == self.RIGHT:
168
+ moved, score = self._move_right()
169
+ else:
170
+ raise ValueError(f"Invalid direction: {direction}")
171
+
172
+ if moved:
173
+ self.accumulated_score += score
174
+ self.moves_count += 1
175
+ self._spawn_tile()
176
+ self._update_situational_score()
177
+
178
+ # 检查游戏是否结束
179
+ self.game_over = self._check_game_over()
180
+
181
+ # 计算奖励
182
+ accumulated_delta = self.accumulated_score - old_accumulated
183
+ situational_delta = self.situational_score - old_situational
184
+
185
+ # 奖励 = 局面分数变化 * 0.7 + 累积分数增量 * 0.3 / 100 (归一化)
186
+ reward = situational_delta * 0.7 + accumulated_delta * 0.003
187
+
188
+ # 游戏结束惩罚
189
+ if self.game_over:
190
+ reward -= 10.0
191
+
192
+ return self.get_state(), reward, moved, self.game_over
193
+
194
+ def _check_game_over(self) -> bool:
195
+ """检查游戏是否结束"""
196
+ # 还有空格
197
+ if np.any(self.board == 0):
198
+ return False
199
+
200
+ # 检查是否还能合并
201
+ for i in range(4):
202
+ for j in range(4):
203
+ if i < 3 and self.board[i, j] == self.board[i + 1, j]:
204
+ return False
205
+ if j < 3 and self.board[i, j] == self.board[i, j + 1]:
206
+ return False
207
+
208
+ return True
209
+
210
+ def _update_situational_score(self) -> None:
211
+ """
212
+ 更新局面分数
213
+ 局面分数 = 空格数 * 10 + 最大连续相邻数 * 15 + log2(最大数字) * 5 + 单调性奖励
214
+ """
215
+ empty_cells = np.sum(self.board == 0)
216
+
217
+ # 计算最大连续相邻数字
218
+ max_consecutive = self._calculate_max_consecutive()
219
+
220
+ # 最高数字的对数
221
+ max_tile = np.max(self.board)
222
+ max_tile_log = np.log2(max_tile) if max_tile > 0 else 0
223
+
224
+ # 单调性评估(鼓励数字按方向排列)
225
+ monotonicity = self._calculate_monotonicity()
226
+
227
+ # 局面分数
228
+ self.situational_score = (
229
+ empty_cells * 10 +
230
+ max_consecutive * 15 +
231
+ max_tile_log * 5 +
232
+ monotonicity * 5
233
+ )
234
+
235
+ def _calculate_max_consecutive(self) -> int:
236
+ """
237
+ 计算最大连续相邻数字数量
238
+ 相邻砖块拥有相邻数字,如512 1024 2048为3
239
+ """
240
+ max_count = 0
241
+
242
+ # 检查所有行
243
+ for i in range(4):
244
+ count = self._count_consecutive_in_line(self.board[i])
245
+ max_count = max(max_count, count)
246
+
247
+ # 检查所有列
248
+ for j in range(4):
249
+ count = self._count_consecutive_in_line(self.board[:, j])
250
+ max_count = max(max_count, count)
251
+
252
+ return max_count
253
+
254
+ def _count_consecutive_in_line(self, line: np.ndarray) -> int:
255
+ """计算一行/列中的最大连续相邻数字"""
256
+ non_zero = line[line != 0]
257
+ if len(non_zero) < 2:
258
+ return 0
259
+
260
+ max_count = 1
261
+ current_count = 1
262
+
263
+ for i in range(1, len(non_zero)):
264
+ # 相邻数字:2的幂次相邻
265
+ if abs(np.log2(non_zero[i]) - np.log2(non_zero[i-1])) == 1:
266
+ current_count += 1
267
+ max_count = max(max_count, current_count)
268
+ else:
269
+ current_count = 1
270
+
271
+ return max_count
272
+
273
+ def _calculate_monotonicity(self) -> float:
274
+ """
275
+ 计算单调性
276
+ 鼓励数字在行/列上递增或递减
277
+ """
278
+ score = 0.0
279
+
280
+ # 行单调性
281
+ for i in range(4):
282
+ row = self.board[i]
283
+ row = row[row != 0]
284
+ if len(row) >= 2:
285
+ # 检查递增
286
+ if all(row[i] <= row[i+1] for i in range(len(row)-1)):
287
+ score += len(row) - 1
288
+ # 检查递减
289
+ elif all(row[i] >= row[i+1] for i in range(len(row)-1)):
290
+ score += len(row) - 1
291
+
292
+ # 列单调性
293
+ for j in range(4):
294
+ col = self.board[:, j]
295
+ col = col[col != 0]
296
+ if len(col) >= 2:
297
+ if all(col[i] <= col[i+1] for i in range(len(col)-1)):
298
+ score += len(col) - 1
299
+ elif all(col[i] >= col[i+1] for i in range(len(col)-1)):
300
+ score += len(col) - 1
301
+
302
+ return score
303
+
304
+ def get_state(self) -> np.ndarray:
305
+ """
306
+ 获取当前状态表示
307
+ 返回: (4, 4) 棋盘,值为log2(value),空格为0
308
+ """
309
+ state = np.zeros((4, 4), dtype=np.float32)
310
+ non_zero_mask = self.board > 0
311
+ state[non_zero_mask] = np.log2(self.board[non_zero_mask])
312
+ return state
313
+
314
+ def get_state_with_scores(self) -> np.ndarray:
315
+ """
316
+ 获取带分数的状态表示
317
+ 返回: (18,) 包含16个棋盘位置 + 2个分数
318
+ """
319
+ board_state = self.get_state().flatten()
320
+
321
+ # 归一化分数
322
+ max_accumulated = 50000 # 假设最大累积分数
323
+ max_situational = 200 # 假设最大局面分数
324
+
325
+ normalized_accumulated = min(self.accumulated_score / max_accumulated, 1.0)
326
+ normalized_situational = min(self.situational_score / max_situational, 1.0)
327
+
328
+ return np.concatenate([
329
+ board_state / 15.0, # 归一化到 [0, 1],最大2048 = log2(2048) = 11
330
+ [normalized_accumulated, normalized_situational]
331
+ ]).astype(np.float32)
332
+
333
+ def get_valid_actions(self) -> np.ndarray:
334
+ """获取当前可执行的动作"""
335
+ valid = np.zeros(4, dtype=bool)
336
+
337
+ # 临时保存状态
338
+ old_board = self.board.copy()
339
+ old_accumulated = self.accumulated_score
340
+
341
+ for direction in range(4):
342
+ if direction == self.UP:
343
+ moved, _ = self._move_up()
344
+ elif direction == self.DOWN:
345
+ moved, _ = self._move_down()
346
+ elif direction == self.LEFT:
347
+ moved, _ = self._move_left()
348
+ else:
349
+ moved, _ = self._move_right()
350
+
351
+ valid[direction] = moved
352
+ self.board = old_board.copy()
353
+
354
+ self.accumulated_score = old_accumulated
355
+ return valid
356
+
357
+ def get_max_tile(self) -> int:
358
+ """获取最大砖块值"""
359
+ return int(np.max(self.board))
360
+
361
+ def get_empty_cells_count(self) -> int:
362
+ """获取空格数量"""
363
+ return int(np.sum(self.board == 0))
364
+
365
+ def __str__(self) -> str:
366
+ """字符串表示"""
367
+ result = []
368
+ for row in self.board:
369
+ result.append(" | ".join(f"{int(x):4d}" if x > 0 else " ." for x in row))
370
+ return "\n".join(result)
371
+
372
+
373
+ if __name__ == "__main__":
374
+ # 测试游戏
375
+ game = Game2048()
376
+ print("Initial state:")
377
+ print(game)
378
+ print(f"Accumulated score: {game.accumulated_score}")
379
+ print(f"Situational score: {game.situational_score}")
380
+
381
+ # 测试一些移动
382
+ moves = ['UP', 'LEFT', 'DOWN', 'RIGHT']
383
+ for i in range(10):
384
+ direction = i % 4
385
+ state, reward, moved, done = game.move(direction)
386
+ print(f"\nMove {moves[direction]}: moved={moved}, done={done}")
387
+ print(game)
388
+ print(f"Reward: {reward:.2f}")
gui.py ADDED
@@ -0,0 +1,835 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GUI界面 - PyQt5实现
3
+ """
4
+ import sys
5
+ import os
6
+ import numpy as np
7
+ from PyQt5.QtWidgets import (
8
+ QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
9
+ QGridLayout, QLabel, QPushButton, QComboBox, QGroupBox,
10
+ QFrame, QSplitter, QStatusBar, QProgressBar, QCheckBox,
11
+ QFileDialog, QMessageBox, QSpinBox
12
+ )
13
+ from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal
14
+ from PyQt5.QtGui import QFont, QColor, QPalette
15
+ import matplotlib
16
+ matplotlib.use('Qt5Agg')
17
+ from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
18
+ from matplotlib.figure import Figure
19
+ from collections import deque
20
+ import time
21
+ import torch
22
+
23
+ from game import Game2048
24
+ from model import Game2048Transformer
25
+ from trainer import PPOTrainer, RolloutBuffer, Transition
26
+
27
+
28
+ # 砖块颜色配置
29
+ TILE_COLORS = {
30
+ 0: '#cdc1b4',
31
+ 2: '#eee4da',
32
+ 4: '#ede0c8',
33
+ 8: '#f2b179',
34
+ 16: '#f59563',
35
+ 32: '#f67c5f',
36
+ 64: '#f65e3b',
37
+ 128: '#edcf72',
38
+ 256: '#edcc61',
39
+ 512: '#edc850',
40
+ 1024: '#edc53f',
41
+ 2048: '#edc22e',
42
+ 4096: '#3c3a32',
43
+ 8192: '#3c3a32',
44
+ }
45
+
46
+ TILE_TEXT_COLORS = {
47
+ 2: '#776e65',
48
+ 4: '#776e65',
49
+ }
50
+
51
+
52
+ class GameBoardWidget(QWidget):
53
+ """2048游戏面板组件"""
54
+
55
+ def __init__(self, parent=None):
56
+ super().__init__(parent)
57
+ self.board = np.zeros((4, 4), dtype=np.int64)
58
+ self.cell_size = 80
59
+ self.padding = 5
60
+ self.init_ui()
61
+
62
+ def init_ui(self):
63
+ """初始化UI"""
64
+ self.setFixedSize(
65
+ self.cell_size * 4 + self.padding * 5,
66
+ self.cell_size * 4 + self.padding * 5
67
+ )
68
+
69
+ def set_board(self, board: np.ndarray):
70
+ """设置棋盘状态"""
71
+ self.board = board.copy()
72
+ self.update()
73
+
74
+ def paintEvent(self, event):
75
+ """绘制棋盘"""
76
+ from PyQt5.QtGui import QPainter, QColor, QPen, QBrush, QFont
77
+
78
+ painter = QPainter(self)
79
+ painter.setRenderHint(QPainter.Antialiasing)
80
+
81
+ # 背景
82
+ painter.fillRect(self.rect(), QColor('#bbada0'))
83
+
84
+ # 绘制每个格子
85
+ for i in range(4):
86
+ for j in range(4):
87
+ x = self.padding + j * (self.cell_size + self.padding)
88
+ y = self.padding + i * (self.cell_size + self.padding)
89
+ value = self.board[i, j]
90
+
91
+ # 格子颜色
92
+ color = TILE_COLORS.get(value, '#3c3a32')
93
+ painter.fillRect(x, y, self.cell_size, self.cell_size, QColor(color))
94
+
95
+ # 数字
96
+ if value > 0:
97
+ # 文字颜色
98
+ text_color = TILE_TEXT_COLORS.get(value, '#f9f6f2')
99
+ painter.setPen(QColor(text_color))
100
+
101
+ # 字体大小根据数字位数调整
102
+ if value < 100:
103
+ font_size = 32
104
+ elif value < 1000:
105
+ font_size = 28
106
+ else:
107
+ font_size = 22
108
+
109
+ font = QFont('Arial', font_size, QFont.Bold)
110
+ painter.setFont(font)
111
+
112
+ # 居中绘制
113
+ text = str(int(value))
114
+ painter.drawText(x, y, self.cell_size, self.cell_size,
115
+ Qt.AlignCenter, text)
116
+
117
+
118
+ class ScoreWidget(QWidget):
119
+ """分数显示组件"""
120
+
121
+ def __init__(self, title: str, parent=None):
122
+ super().__init__(parent)
123
+ self.title = title
124
+ self.value = 0
125
+ self.init_ui()
126
+
127
+ def init_ui(self):
128
+ """初始化UI"""
129
+ layout = QVBoxLayout(self)
130
+ layout.setContentsMargins(5, 5, 5, 5)
131
+
132
+ # 标题
133
+ self.title_label = QLabel(self.title)
134
+ self.title_label.setAlignment(Qt.AlignCenter)
135
+ self.title_label.setStyleSheet("""
136
+ QLabel {
137
+ background-color: #bbada0;
138
+ color: #eee4da;
139
+ font-size: 12px;
140
+ font-weight: bold;
141
+ border-radius: 3px;
142
+ padding: 5px;
143
+ }
144
+ """)
145
+ layout.addWidget(self.title_label)
146
+
147
+ # 数值
148
+ self.value_label = QLabel('0')
149
+ self.value_label.setAlignment(Qt.AlignCenter)
150
+ self.value_label.setStyleSheet("""
151
+ QLabel {
152
+ background-color: #8f7a66;
153
+ color: white;
154
+ font-size: 20px;
155
+ font-weight: bold;
156
+ border-radius: 3px;
157
+ padding: 10px;
158
+ min-width: 80px;
159
+ }
160
+ """)
161
+ layout.addWidget(self.value_label)
162
+
163
+ def set_value(self, value):
164
+ """设置分数值"""
165
+ self.value = value
166
+ if isinstance(value, float):
167
+ self.value_label.setText(f'{value:.1f}')
168
+ else:
169
+ self.value_label.setText(str(int(value)))
170
+
171
+
172
+ class PlotCanvas(FigureCanvas):
173
+ """matplotlib绑定的画布"""
174
+
175
+ def __init__(self, parent=None, width=5, height=4, dpi=100):
176
+ self.fig = Figure(figsize=(width, height), dpi=dpi)
177
+ self.axes = self.fig.add_subplot(111)
178
+ super().__init__(self.fig)
179
+ self.setParent(parent)
180
+
181
+ self.fig.patch.set_facecolor('#faf8ef')
182
+ self.axes.set_facecolor('#faf8ef')
183
+
184
+ # 设置中文字体
185
+ self.axes.set_prop_cycle(color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
186
+
187
+ self.fig.tight_layout()
188
+
189
+ def plot_training_scores(self, scores, title="Training Scores"):
190
+ """绘制训练曲线 - 只显示累积分数"""
191
+ self.axes.clear()
192
+
193
+ if scores:
194
+ x = range(1, len(scores) + 1)
195
+ self.axes.plot(x, scores, '#1f77b4', linewidth=1.5, alpha=0.8)
196
+
197
+ # 添加移动平均线
198
+ if len(scores) >= 10:
199
+ window = min(50, len(scores) // 5)
200
+ if window >= 5:
201
+ ma = np.convolve(scores, np.ones(window)/window, mode='valid')
202
+ ma_x = range(window, len(scores) + 1)
203
+ self.axes.plot(ma_x, ma, '#d62728', linewidth=2, alpha=0.8, label=f'MA({window})')
204
+
205
+ self.axes.set_xlabel('Games', fontsize=10)
206
+ self.axes.set_ylabel('Score', fontsize=10)
207
+ self.axes.set_title(title, fontsize=12)
208
+ self.axes.grid(True, alpha=0.3)
209
+
210
+ self.fig.tight_layout()
211
+ self.draw()
212
+
213
+ def plot_demo_scores(self, scores, title="Situational Score"):
214
+ """绘制演示曲线 - 只显示局面分数"""
215
+ self.axes.clear()
216
+
217
+ if scores:
218
+ x = range(1, len(scores) + 1)
219
+ self.axes.plot(x, scores, '#2ca02c', linewidth=1.5, alpha=0.8)
220
+
221
+ self.axes.set_xlabel('Steps', fontsize=10)
222
+ self.axes.set_ylabel('Situational Score', fontsize=10)
223
+ self.axes.set_title(title, fontsize=12)
224
+ self.axes.grid(True, alpha=0.3)
225
+
226
+ self.fig.tight_layout()
227
+ self.draw()
228
+
229
+
230
+ class SimpleTrainingThread(QThread):
231
+ """简化训练线程"""
232
+
233
+ game_end_signal = pyqtSignal(dict)
234
+ progress_signal = pyqtSignal(dict)
235
+ finished_signal = pyqtSignal()
236
+
237
+ def __init__(self, model, trainer, num_games=1000):
238
+ super().__init__()
239
+ self.model = model
240
+ self.trainer = trainer
241
+ self.num_games = num_games
242
+ self.running = True
243
+ self.device = "cpu"
244
+
245
+ def run(self):
246
+ """运行训练"""
247
+ scores = []
248
+ max_tiles = []
249
+ best_score = 0
250
+ start_time = time.time()
251
+
252
+ for game_idx in range(self.num_games):
253
+ if not self.running:
254
+ break
255
+
256
+ game = Game2048()
257
+ game.reset()
258
+ buffer = RolloutBuffer(capacity=10000)
259
+
260
+ while not game.game_over and self.running:
261
+ state = game.get_state()
262
+ scores_feat = np.array([
263
+ min(game.accumulated_score / 50000, 1.0),
264
+ min(game.situational_score / 200, 1.0)
265
+ ], dtype=np.float32)
266
+ valid = game.get_valid_actions()
267
+
268
+ state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
269
+ scores_t = torch.FloatTensor(scores_feat).unsqueeze(0).to(self.device)
270
+ valid_t = torch.BoolTensor(valid).unsqueeze(0).to(self.device)
271
+
272
+ with torch.no_grad():
273
+ action, log_prob, value = self.model.get_action(
274
+ state_t, scores_t, valid_t, deterministic=False
275
+ )
276
+
277
+ old_state = state.copy()
278
+ old_scores = scores_feat.copy()
279
+ next_state, reward, moved, done = game.move(action)
280
+
281
+ transition = Transition(
282
+ state=old_state,
283
+ scores=old_scores,
284
+ action=action,
285
+ reward=reward,
286
+ next_state=next_state.copy(),
287
+ next_scores=np.array([
288
+ min(game.accumulated_score / 50000, 1.0),
289
+ min(game.situational_score / 200, 1.0)
290
+ ], dtype=np.float32),
291
+ done=done,
292
+ log_prob=log_prob,
293
+ value=value,
294
+ valid_actions=valid
295
+ )
296
+ buffer.push(transition)
297
+
298
+ if len(buffer) >= 64:
299
+ self.trainer.update(buffer)
300
+ buffer.clear()
301
+
302
+ scores.append(game.accumulated_score)
303
+ max_tiles.append(game.get_max_tile())
304
+
305
+ if game.accumulated_score > best_score:
306
+ best_score = game.accumulated_score
307
+
308
+ # 发送游戏结束信号
309
+ elapsed = time.time() - start_time
310
+ self.game_end_signal.emit({
311
+ 'score': game.accumulated_score,
312
+ 'max_tile': game.get_max_tile(),
313
+ 'game_idx': game_idx + 1,
314
+ 'best_score': best_score,
315
+ 'elapsed': elapsed,
316
+ 'avg_score': np.mean(scores[-100:]) if scores else 0
317
+ })
318
+
319
+ self.finished_signal.emit()
320
+
321
+ def stop(self):
322
+ """停止训练"""
323
+ self.running = False
324
+
325
+
326
+ class MainWindow(QMainWindow):
327
+ """主窗口"""
328
+
329
+ def __init__(self):
330
+ super().__init__()
331
+ self.setWindowTitle('2048 AI Trainer')
332
+ self.setMinimumSize(1000, 700)
333
+
334
+ # 初始化模型和训练器
335
+ self.device = 'cpu'
336
+ self.model = Game2048Transformer()
337
+ self.trainer = PPOTrainer(self.model, lr=3e-4, device=self.device)
338
+
339
+ # 游戏实例(用于演示)
340
+ self.game = Game2048()
341
+
342
+ # 训练状态
343
+ self.is_training = False
344
+ self.training_thread = None
345
+
346
+ # 统计数据
347
+ self.training_scores = []
348
+ self.demo_situational_scores = []
349
+
350
+ # 演示模式状态
351
+ self.ai_mode = False
352
+ self.auto_step = False
353
+ self.auto_timer = QTimer()
354
+ self.auto_timer.timeout.connect(self.ai_step)
355
+
356
+ # 当前模型路径
357
+ self.current_model_path = None
358
+
359
+ self.init_ui()
360
+ self.update_display()
361
+
362
+ def init_ui(self):
363
+ """初始化UI"""
364
+ # 主窗口样式
365
+ self.setStyleSheet("""
366
+ QMainWindow {
367
+ background-color: #faf8ef;
368
+ }
369
+ QLabel {
370
+ color: #776e65;
371
+ }
372
+ QPushButton {
373
+ background-color: #8f7a66;
374
+ color: white;
375
+ border: none;
376
+ border-radius: 3px;
377
+ padding: 10px 20px;
378
+ font-size: 14px;
379
+ font-weight: bold;
380
+ min-width: 80px;
381
+ }
382
+ QPushButton:hover {
383
+ background-color: #9f8b77;
384
+ }
385
+ QPushButton:pressed {
386
+ background-color: #7f6a57;
387
+ }
388
+ QPushButton:disabled {
389
+ background-color: #ccc;
390
+ color: #999;
391
+ }
392
+ QComboBox {
393
+ background-color: #8f7a66;
394
+ color: white;
395
+ border: none;
396
+ border-radius: 3px;
397
+ padding: 5px 10px;
398
+ min-width: 100px;
399
+ }
400
+ QComboBox::drop-down {
401
+ border: none;
402
+ }
403
+ QGroupBox {
404
+ border: 2px solid #bbada0;
405
+ border-radius: 5px;
406
+ margin-top: 10px;
407
+ padding-top: 10px;
408
+ font-weight: bold;
409
+ color: #776e65;
410
+ }
411
+ QGroupBox::title {
412
+ subcontrol-origin: margin;
413
+ left: 10px;
414
+ padding: 0 5px;
415
+ }
416
+ QSpinBox {
417
+ background-color: white;
418
+ border: 1px solid #bbada0;
419
+ border-radius: 3px;
420
+ padding: 5px;
421
+ min-width: 80px;
422
+ }
423
+ """)
424
+
425
+ # 中央widget
426
+ central_widget = QWidget()
427
+ self.setCentralWidget(central_widget)
428
+
429
+ # 主布局
430
+ main_layout = QHBoxLayout(central_widget)
431
+
432
+ # 左侧面板(游戏区)
433
+ left_panel = self.create_left_panel()
434
+ main_layout.addWidget(left_panel)
435
+
436
+ # 右侧面板(统计和控制)
437
+ right_panel = self.create_right_panel()
438
+ main_layout.addWidget(right_panel)
439
+
440
+ # 设置拉伸比例
441
+ main_layout.setStretch(0, 1)
442
+ main_layout.setStretch(1, 2)
443
+
444
+ # 状态栏
445
+ self.statusBar = QStatusBar()
446
+ self.setStatusBar(self.statusBar)
447
+ self.statusBar.showMessage('Ready')
448
+
449
+ def create_left_panel(self) -> QWidget:
450
+ """创建左侧面板"""
451
+ panel = QWidget()
452
+ layout = QVBoxLayout(panel)
453
+ layout.setAlignment(Qt.AlignCenter)
454
+
455
+ # 模式选择
456
+ mode_layout = QHBoxLayout()
457
+ mode_label = QLabel('Mode:')
458
+ mode_label.setFont(QFont('Arial', 12, QFont.Bold))
459
+
460
+ self.mode_combo = QComboBox()
461
+ self.mode_combo.addItems(['Training Mode', 'Demo Mode'])
462
+ self.mode_combo.currentIndexChanged.connect(self.switch_mode)
463
+
464
+ mode_layout.addWidget(mode_label)
465
+ mode_layout.addWidget(self.mode_combo)
466
+ mode_layout.addStretch()
467
+ layout.addLayout(mode_layout)
468
+
469
+ # 分数显示
470
+ scores_layout = QHBoxLayout()
471
+ self.accumulated_score_widget = ScoreWidget('Score')
472
+ self.situational_score_widget = ScoreWidget('Situational')
473
+ self.max_tile_widget = ScoreWidget('Max Tile')
474
+
475
+ scores_layout.addWidget(self.accumulated_score_widget)
476
+ scores_layout.addWidget(self.situational_score_widget)
477
+ scores_layout.addWidget(self.max_tile_widget)
478
+ layout.addLayout(scores_layout)
479
+
480
+ # 游戏面板
481
+ self.game_board = GameBoardWidget()
482
+ layout.addWidget(self.game_board, alignment=Qt.AlignCenter)
483
+
484
+ # 演示模式控制
485
+ self.demo_controls = QWidget()
486
+ demo_layout = QHBoxLayout(self.demo_controls)
487
+
488
+ self.ai_btn = QPushButton('AI Mode')
489
+ self.ai_btn.clicked.connect(self.toggle_ai_mode)
490
+
491
+ self.step_btn = QPushButton('Step')
492
+ self.step_btn.clicked.connect(self.ai_step)
493
+
494
+ self.auto_btn = QPushButton('Auto')
495
+ self.auto_btn.clicked.connect(self.toggle_auto)
496
+
497
+ self.reset_btn = QPushButton('Reset')
498
+ self.reset_btn.clicked.connect(self.reset_game)
499
+
500
+ demo_layout.addWidget(self.ai_btn)
501
+ demo_layout.addWidget(self.step_btn)
502
+ demo_layout.addWidget(self.auto_btn)
503
+ demo_layout.addWidget(self.reset_btn)
504
+
505
+ self.demo_controls.setVisible(False)
506
+ layout.addWidget(self.demo_controls)
507
+
508
+ layout.addStretch()
509
+ return panel
510
+
511
+ def create_right_panel(self) -> QWidget:
512
+ """创建右侧面板"""
513
+ panel = QWidget()
514
+ layout = QVBoxLayout(panel)
515
+
516
+ # 训练控制
517
+ control_group = QGroupBox('Training Control')
518
+ control_layout = QVBoxLayout(control_group)
519
+
520
+ # 训练局数设置
521
+ games_layout = QHBoxLayout()
522
+ games_label = QLabel('Games:')
523
+ self.games_spinbox = QSpinBox()
524
+ self.games_spinbox.setRange(10, 100000)
525
+ self.games_spinbox.setValue(500)
526
+ games_layout.addWidget(games_label)
527
+ games_layout.addWidget(self.games_spinbox)
528
+ games_layout.addStretch()
529
+ control_layout.addLayout(games_layout)
530
+
531
+ # 按钮
532
+ btn_layout = QHBoxLayout()
533
+ self.start_btn = QPushButton('Start Training')
534
+ self.start_btn.clicked.connect(self.start_training)
535
+
536
+ self.stop_btn = QPushButton('Stop Training')
537
+ self.stop_btn.clicked.connect(self.stop_training)
538
+ self.stop_btn.setEnabled(False)
539
+
540
+ btn_layout.addWidget(self.start_btn)
541
+ btn_layout.addWidget(self.stop_btn)
542
+ control_layout.addLayout(btn_layout)
543
+
544
+ # 模型文件操作
545
+ model_layout = QHBoxLayout()
546
+ self.load_btn = QPushButton('Load Model')
547
+ self.load_btn.clicked.connect(self.load_model)
548
+
549
+ self.save_btn = QPushButton('Save Model')
550
+ self.save_btn.clicked.connect(self.save_model)
551
+
552
+ model_layout.addWidget(self.load_btn)
553
+ model_layout.addWidget(self.save_btn)
554
+ control_layout.addLayout(model_layout)
555
+
556
+ # 训练参数显示
557
+ param_layout = QGridLayout()
558
+
559
+ self.games_label = QLabel('Games: 0')
560
+ self.avg_score_label = QLabel('Avg Score: 0')
561
+ self.best_score_label = QLabel('Best Score: 0')
562
+ self.speed_label = QLabel('Speed: 0 games/s')
563
+
564
+ param_layout.addWidget(self.games_label, 0, 0)
565
+ param_layout.addWidget(self.avg_score_label, 0, 1)
566
+ param_layout.addWidget(self.best_score_label, 1, 0)
567
+ param_layout.addWidget(self.speed_label, 1, 1)
568
+
569
+ control_layout.addLayout(param_layout)
570
+ layout.addWidget(control_group)
571
+
572
+ # 分数曲线
573
+ plot_group = QGroupBox('Score Chart')
574
+ plot_layout = QVBoxLayout(plot_group)
575
+
576
+ self.plot_canvas = PlotCanvas(self, width=6, height=4, dpi=100)
577
+ plot_layout.addWidget(self.plot_canvas)
578
+
579
+ layout.addWidget(plot_group)
580
+
581
+ # 实时统计
582
+ stats_group = QGroupBox('Training Stats')
583
+ stats_layout = QVBoxLayout(stats_group)
584
+
585
+ self.stats_text = QLabel('Waiting for training...')
586
+ self.stats_text.setStyleSheet('font-family: monospace;')
587
+ stats_layout.addWidget(self.stats_text)
588
+
589
+ layout.addWidget(stats_group)
590
+
591
+ return panel
592
+
593
+ def switch_mode(self, index):
594
+ """切换模式"""
595
+ if index == 0: # 训练模式
596
+ self.demo_controls.setVisible(False)
597
+ self.demo_situational_scores = []
598
+ self.plot_canvas.plot_training_scores(self.training_scores)
599
+ else: # 演示模式
600
+ self.demo_controls.setVisible(True)
601
+ self.reset_game()
602
+ self.demo_situational_scores = []
603
+ self.plot_canvas.plot_demo_scores([])
604
+
605
+ def keyPressEvent(self, event):
606
+ """键盘事件"""
607
+ if self.mode_combo.currentIndex() == 1 and not self.ai_mode:
608
+ # 演示模式且非AI托管
609
+ key_map = {
610
+ Qt.Key_Up: 0,
611
+ Qt.Key_Down: 1,
612
+ Qt.Key_Left: 2,
613
+ Qt.Key_Right: 3,
614
+ }
615
+
616
+ if event.key() in key_map:
617
+ direction = key_map[event.key()]
618
+ self.game.move(direction)
619
+ self.demo_situational_scores.append(self.game.situational_score)
620
+ self.update_display()
621
+ self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
622
+
623
+ if self.game.game_over:
624
+ self.statusBar.showMessage('Game Over!')
625
+
626
+ def toggle_ai_mode(self):
627
+ """切换AI托管模式"""
628
+ self.ai_mode = not self.ai_mode
629
+ if self.ai_mode:
630
+ self.ai_btn.setText('Manual')
631
+ self.step_btn.setEnabled(False)
632
+ self.auto_btn.setEnabled(True)
633
+ else:
634
+ self.ai_btn.setText('AI Mode')
635
+ self.step_btn.setEnabled(True)
636
+ self.auto_btn.setEnabled(False)
637
+ self.auto_timer.stop()
638
+ self.auto_btn.setText('Auto')
639
+
640
+ def ai_step(self):
641
+ """AI单步执行"""
642
+ if self.game.game_over:
643
+ self.reset_game()
644
+ return
645
+
646
+ state = self.game.get_state()
647
+ scores = np.array([
648
+ min(self.game.accumulated_score / 50000, 1.0),
649
+ min(self.game.situational_score / 200, 1.0)
650
+ ], dtype=np.float32)
651
+ valid_actions = self.game.get_valid_actions()
652
+
653
+ state_t = torch.FloatTensor(state).unsqueeze(0)
654
+ scores_t = torch.FloatTensor(scores).unsqueeze(0)
655
+ valid_t = torch.BoolTensor(valid_actions).unsqueeze(0)
656
+
657
+ action, _, _ = self.model.get_action(state_t, scores_t, valid_t, deterministic=True)
658
+
659
+ self.game.move(action)
660
+ self.demo_situational_scores.append(self.game.situational_score)
661
+ self.update_display()
662
+ self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
663
+
664
+ if self.game.game_over:
665
+ self.statusBar.showMessage(f'Game Over! Final Score: {self.game.accumulated_score}')
666
+ if self.auto_timer.isActive():
667
+ self.auto_timer.stop()
668
+ self.auto_btn.setText('Auto')
669
+
670
+ def toggle_auto(self):
671
+ """切换自动执行"""
672
+ if self.auto_timer.isActive():
673
+ self.auto_timer.stop()
674
+ self.auto_btn.setText('Auto')
675
+ else:
676
+ self.auto_timer.start(100) # 100ms间隔
677
+ self.auto_btn.setText('Stop')
678
+
679
+ def reset_game(self):
680
+ """重置游戏"""
681
+ self.game.reset()
682
+ self.demo_situational_scores = [self.game.situational_score]
683
+ self.update_display()
684
+ self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
685
+ self.statusBar.showMessage('Game Reset')
686
+
687
+ def update_display(self):
688
+ """更新显示"""
689
+ self.game_board.set_board(self.game.board)
690
+ self.accumulated_score_widget.set_value(self.game.accumulated_score)
691
+ self.situational_score_widget.set_value(self.game.situational_score)
692
+ self.max_tile_widget.set_value(self.game.get_max_tile())
693
+
694
+ def start_training(self):
695
+ """开始训练"""
696
+ self.is_training = True
697
+ self.start_btn.setEnabled(False)
698
+ self.stop_btn.setEnabled(True)
699
+ self.mode_combo.setEnabled(False)
700
+ self.games_spinbox.setEnabled(False)
701
+ self.load_btn.setEnabled(False)
702
+ self.save_btn.setEnabled(False)
703
+
704
+ # 重置统计
705
+ self.training_scores = []
706
+
707
+ # 创建训练线程
708
+ num_games = self.games_spinbox.value()
709
+ self.training_thread = SimpleTrainingThread(
710
+ self.model, self.trainer, num_games
711
+ )
712
+ self.training_thread.game_end_signal.connect(self.on_game_end)
713
+ self.training_thread.finished_signal.connect(self.on_training_finished)
714
+ self.training_thread.start()
715
+
716
+ self.statusBar.showMessage('Training started...')
717
+
718
+ def stop_training(self):
719
+ """停止训练"""
720
+ if self.training_thread:
721
+ self.training_thread.stop()
722
+ self.training_thread.wait()
723
+ self.training_thread = None
724
+
725
+ self.on_training_finished()
726
+
727
+ def on_game_end(self, stats):
728
+ """游戏结束回调"""
729
+ self.training_scores.append(stats['score'])
730
+
731
+ # 更新统计显示
732
+ games = stats['game_idx']
733
+ avg_score = stats['avg_score']
734
+ best_score = stats['best_score']
735
+ elapsed = stats['elapsed']
736
+ speed = games / elapsed if elapsed > 0 else 0
737
+
738
+ self.games_label.setText(f'Games: {games}')
739
+ self.avg_score_label.setText(f'Avg Score: {avg_score:.0f}')
740
+ self.best_score_label.setText(f'Best Score: {best_score}')
741
+ self.speed_label.setText(f'Speed: {speed:.2f} games/s')
742
+
743
+ # 更新曲线
744
+ if games % 5 == 0:
745
+ self.plot_canvas.plot_training_scores(self.training_scores)
746
+
747
+ # 更新统计文本
748
+ self.stats_text.setText(
749
+ f"Games: {games}\n"
750
+ f"Avg Score: {avg_score:.0f}\n"
751
+ f"Best Score: {best_score}\n"
752
+ f"Last Score: {stats['score']}"
753
+ )
754
+
755
+ def on_training_finished(self):
756
+ """训练完成回调"""
757
+ self.is_training = False
758
+ self.start_btn.setEnabled(True)
759
+ self.stop_btn.setEnabled(False)
760
+ self.mode_combo.setEnabled(True)
761
+ self.games_spinbox.setEnabled(True)
762
+ self.load_btn.setEnabled(True)
763
+ self.save_btn.setEnabled(True)
764
+
765
+ # 自动保存模型
766
+ save_path = os.path.join(os.path.dirname(__file__), 'checkpoints', 'model.pt')
767
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
768
+ torch.save({
769
+ 'model_state_dict': self.model.state_dict(),
770
+ 'training_scores': self.training_scores,
771
+ 'best_score': max(self.training_scores) if self.training_scores else 0
772
+ }, save_path)
773
+
774
+ self.statusBar.showMessage(f'Training finished! Model saved to {save_path}')
775
+
776
+ def load_model(self):
777
+ """加载模型"""
778
+ file_path, _ = QFileDialog.getOpenFileName(
779
+ self, 'Load Model',
780
+ os.path.join(os.path.dirname(__file__), 'checkpoints'),
781
+ 'PyTorch Model (*.pt);;All Files (*)'
782
+ )
783
+
784
+ if file_path:
785
+ try:
786
+ checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
787
+ self.model.load_state_dict(checkpoint['model_state_dict'])
788
+ self.current_model_path = file_path
789
+
790
+ if 'training_scores' in checkpoint:
791
+ self.training_scores = checkpoint['training_scores']
792
+ self.plot_canvas.plot_training_scores(self.training_scores)
793
+
794
+ self.statusBar.showMessage(f'Model loaded: {os.path.basename(file_path)}')
795
+ except Exception as e:
796
+ QMessageBox.warning(self, 'Error', f'Failed to load model:\n{str(e)}')
797
+
798
+ def save_model(self):
799
+ """保存模型"""
800
+ file_path, _ = QFileDialog.getSaveFileName(
801
+ self, 'Save Model',
802
+ os.path.join(os.path.dirname(__file__), 'checkpoints', 'model.pt'),
803
+ 'PyTorch Model (*.pt);;All Files (*)'
804
+ )
805
+
806
+ if file_path:
807
+ try:
808
+ torch.save({
809
+ 'model_state_dict': self.model.state_dict(),
810
+ 'training_scores': self.training_scores,
811
+ 'best_score': max(self.training_scores) if self.training_scores else 0
812
+ }, file_path)
813
+ self.current_model_path = file_path
814
+ self.statusBar.showMessage(f'Model saved: {os.path.basename(file_path)}')
815
+ except Exception as e:
816
+ QMessageBox.warning(self, 'Error', f'Failed to save model:\n{str(e)}')
817
+
818
+
819
+ def main():
820
+ """主函数"""
821
+ app = QApplication(sys.argv)
822
+ app.setStyle('Fusion')
823
+
824
+ # 设置字体
825
+ font = QFont('Arial', 10)
826
+ app.setFont(font)
827
+
828
+ window = MainWindow()
829
+ window.show()
830
+
831
+ sys.exit(app.exec_())
832
+
833
+
834
+ if __name__ == '__main__':
835
+ main()
main.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2048 AI Trainer - 主入口
3
+ 基于Transformer的2048游戏AI训练器
4
+
5
+ 使用方法:
6
+ python main.py # 启动GUI界面
7
+ python main.py --train # 命令行训练模式
8
+ python main.py --demo # 演示模式(加载已有模型)
9
+ """
10
+ import sys
11
+ import os
12
+ import argparse
13
+ import torch
14
+ import time
15
+
16
+ # 添加当前目录到路径
17
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
18
+
19
+ from game import Game2048
20
+ from model import Game2048Transformer, count_parameters
21
+ from trainer import PPOTrainer, RolloutBuffer, TrainingStats
22
+ from parallel import TrainingLoop, ParallelGameEnv, TrainingWorker
23
+ from utils import (
24
+ set_seed, get_device, print_model_info, save_checkpoint,
25
+ load_checkpoint, EarlyStopping, format_time, format_number
26
+ )
27
+
28
+
29
+ def run_gui():
30
+ """运行GUI界面"""
31
+ from gui import main
32
+ main()
33
+
34
+
35
+ def run_training(
36
+ num_games: int = 10000,
37
+ num_envs: int = 4,
38
+ save_interval: int = 100,
39
+ checkpoint_dir: str = "checkpoints",
40
+ seed: int = 42
41
+ ):
42
+ """
43
+ 命令行训练模式
44
+
45
+ Args:
46
+ num_games: 总游戏局数
47
+ num_envs: 并行环境数
48
+ save_interval: 保存间隔
49
+ checkpoint_dir: 检查点目录
50
+ seed: 随机种子
51
+ """
52
+ print("=" * 50)
53
+ print("2048 AI Training")
54
+ print("=" * 50)
55
+
56
+ # 设置随机种子
57
+ set_seed(seed)
58
+
59
+ # 获取设备
60
+ device = get_device()
61
+ print(f"Device: {device}")
62
+
63
+ # 创建模型
64
+ model = Game2048Transformer()
65
+ print_model_info(model)
66
+ model.to(device)
67
+
68
+ # 创建训练器
69
+ trainer = PPOTrainer(model, lr=1e-4, device=device)
70
+
71
+ # 创建训练循环
72
+ training_loop = TrainingLoop(
73
+ model=model,
74
+ trainer=trainer,
75
+ num_envs=num_envs,
76
+ device=device,
77
+ steps_per_update=256,
78
+ checkpoint_dir=checkpoint_dir
79
+ )
80
+
81
+ # 训练统计
82
+ stats = TrainingStats()
83
+ start_time = time.time()
84
+ best_score = 0
85
+
86
+ # 创建检查点目录
87
+ os.makedirs(checkpoint_dir, exist_ok=True)
88
+
89
+ print(f"\nStarting training for {num_games} games...")
90
+ print(f"Parallel environments: {num_envs}")
91
+ print("-" * 50)
92
+
93
+ try:
94
+ games_completed = 0
95
+
96
+ def on_game_end(game_stats):
97
+ nonlocal games_completed, best_score
98
+
99
+ stats.record_game(
100
+ score=game_stats['score'],
101
+ situational_score=game_stats['situational_score'],
102
+ max_tile=game_stats['max_tile'],
103
+ steps=game_stats['moves']
104
+ )
105
+
106
+ games_completed += 1
107
+
108
+ if game_stats['score'] > best_score:
109
+ best_score = game_stats['score']
110
+
111
+ # 定期打印统计
112
+ if games_completed % 10 == 0:
113
+ elapsed = time.time() - start_time
114
+ avg_stats = stats.get_avg_stats(window=100)
115
+
116
+ print(
117
+ f"Games: {games_completed} | "
118
+ f"Avg Score: {avg_stats['avg_score']:.0f} | "
119
+ f"Best: {best_score} | "
120
+ f"Max Tile: {avg_stats['avg_max_tile']:.0f} | "
121
+ f"Speed: {games_completed/elapsed:.2f} games/s"
122
+ )
123
+
124
+ # 保存检查点
125
+ if games_completed % save_interval == 0:
126
+ checkpoint_path = os.path.join(
127
+ checkpoint_dir,
128
+ f"checkpoint_{games_completed}.pt"
129
+ )
130
+ save_checkpoint(
131
+ model, trainer.optimizer, games_completed,
132
+ avg_stats, checkpoint_path
133
+ )
134
+ print(f"Checkpoint saved: {checkpoint_path}")
135
+
136
+ training_loop.on_game_end_callback = on_game_end
137
+ training_loop.train(total_games=num_games, stop_threshold=200)
138
+
139
+ except KeyboardInterrupt:
140
+ print("\nTraining interrupted by user.")
141
+
142
+ # 训练结束统计
143
+ elapsed = time.time() - start_time
144
+ final_stats = stats.get_avg_stats()
145
+
146
+ print("\n" + "=" * 50)
147
+ print("Training Complete!")
148
+ print("=" * 50)
149
+ print(f"Total games: {format_number(final_stats['games_played'])}")
150
+ print(f"Total time: {format_time(elapsed)}")
151
+ print(f"Average score: {final_stats['avg_score']:.0f}")
152
+ print(f"Best score: {final_stats['best_score']}")
153
+ print(f"Best max tile: {final_stats['best_max_tile']}")
154
+
155
+ # 保存最终模型
156
+ final_path = os.path.join(checkpoint_dir, "final_model.pt")
157
+ save_checkpoint(
158
+ model, trainer.optimizer, final_stats['games_played'],
159
+ final_stats, final_path
160
+ )
161
+ print(f"Final model saved: {final_path}")
162
+
163
+
164
+ def run_demo(model_path: str = None, num_games: int = 5):
165
+ """
166
+ 演示模式
167
+
168
+ Args:
169
+ model_path: 模型路径
170
+ num_games: 演示游戏数
171
+ """
172
+ print("=" * 50)
173
+ print("2048 AI Demo Mode")
174
+ print("=" * 50)
175
+
176
+ device = get_device()
177
+ print(f"Device: {device}")
178
+
179
+ # 创建模型
180
+ model = Game2048Transformer()
181
+
182
+ # 加载模型
183
+ if model_path and os.path.exists(model_path):
184
+ print(f"Loading model from: {model_path}")
185
+ checkpoint = load_checkpoint(model_path, model, device=device)
186
+ print(f"Loaded checkpoint: {checkpoint.get('epoch', 'unknown')} games")
187
+ else:
188
+ print("No model loaded, using random weights.")
189
+
190
+ model.to(device)
191
+ model.eval()
192
+ print_model_info(model)
193
+
194
+ # 运行演示游戏
195
+ print(f"\nRunning {num_games} demo games...")
196
+ print("-" * 50)
197
+
198
+ total_scores = []
199
+
200
+ for game_idx in range(num_games):
201
+ game = Game2048()
202
+ game.reset()
203
+
204
+ steps = 0
205
+ while not game.game_over and steps < 10000:
206
+ # 获取状态
207
+ state = game.get_state()
208
+ scores = game.get_state_with_scores()[-2:]
209
+ valid_actions = game.get_valid_actions()
210
+
211
+ # 转换为张量
212
+ state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
213
+ scores_t = torch.FloatTensor(scores).unsqueeze(0).to(device)
214
+ valid_t = torch.BoolTensor(valid_actions).unsqueeze(0).to(device)
215
+
216
+ # 获取动作
217
+ with torch.no_grad():
218
+ action, _, _ = model.get_action(state_t, scores_t, valid_t, deterministic=True)
219
+
220
+ # 执行动作
221
+ game.move(action)
222
+ steps += 1
223
+
224
+ total_scores.append(game.accumulated_score)
225
+
226
+ print(
227
+ f"Game {game_idx + 1}: "
228
+ f"Score = {game.accumulated_score}, "
229
+ f"Max Tile = {game.get_max_tile()}, "
230
+ f"Steps = {steps}"
231
+ )
232
+
233
+ # 统计
234
+ print("\n" + "-" * 50)
235
+ print(f"Average score: {sum(total_scores) / len(total_scores):.0f}")
236
+ print(f"Best score: {max(total_scores)}")
237
+ print(f"Worst score: {min(total_scores)}")
238
+
239
+
240
+ def main():
241
+ """主函数"""
242
+ parser = argparse.ArgumentParser(
243
+ description='2048 AI Trainer - 基于Transformer的2048游戏AI'
244
+ )
245
+
246
+ parser.add_argument(
247
+ '--train',
248
+ action='store_true',
249
+ help='运行命令行训练模式'
250
+ )
251
+
252
+ parser.add_argument(
253
+ '--demo',
254
+ action='store_true',
255
+ help='运行演示模式'
256
+ )
257
+
258
+ parser.add_argument(
259
+ '--model',
260
+ type=str,
261
+ default=None,
262
+ help='模型路径(用于演示模式或继续训练)'
263
+ )
264
+
265
+ parser.add_argument(
266
+ '--games',
267
+ type=int,
268
+ default=10000,
269
+ help='训练游戏数(默认10000)'
270
+ )
271
+
272
+ parser.add_argument(
273
+ '--envs',
274
+ type=int,
275
+ default=4,
276
+ help='并行环境数(默认4)'
277
+ )
278
+
279
+ parser.add_argument(
280
+ '--seed',
281
+ type=int,
282
+ default=42,
283
+ help='随机种子(默认42)'
284
+ )
285
+
286
+ args = parser.parse_args()
287
+
288
+ if args.train:
289
+ run_training(
290
+ num_games=args.games,
291
+ num_envs=args.envs,
292
+ seed=args.seed
293
+ )
294
+ elif args.demo:
295
+ run_demo(model_path=args.model, num_games=5)
296
+ else:
297
+ # 默认启动GUI
298
+ run_gui()
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()
model.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformer模型用于2048游戏决策
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import numpy as np
8
+ from typing import Tuple
9
+
10
+
11
+ class Game2048Transformer(nn.Module):
12
+ """
13
+ 小型Transformer模型用于2048游戏
14
+
15
+ 输入:
16
+ - 棋盘状态 (batch, 4, 4) 或 (batch, 16)
17
+ - 可选: 分数特征 (batch, 2)
18
+
19
+ 输出:
20
+ - policy: (batch, 4) 动作概率
21
+ - value: (batch, 1) 状态价值
22
+ """
23
+
24
+ def __init__(
25
+ self,
26
+ d_model: int = 64,
27
+ nhead: int = 4,
28
+ num_layers: int = 2,
29
+ dim_feedforward: int = 128,
30
+ dropout: float = 0.1
31
+ ):
32
+ super().__init__()
33
+
34
+ self.d_model = d_model
35
+
36
+ # 棋盘位置embedding
37
+ # 每个格子: 0表示空,1-15表示log2(value)
38
+ self.position_embedding = nn.Embedding(16, d_model)
39
+
40
+ # 空间位置编码(4x4棋盘的行列位置)
41
+ self.row_embedding = nn.Embedding(4, d_model // 2)
42
+ self.col_embedding = nn.Embedding(4, d_model // 2)
43
+
44
+ # 分数特征embedding
45
+ self.score_embedding = nn.Linear(2, d_model)
46
+
47
+ # Transformer编码器
48
+ encoder_layer = nn.TransformerEncoderLayer(
49
+ d_model=d_model,
50
+ nhead=nhead,
51
+ dim_feedforward=dim_feedforward,
52
+ dropout=dropout,
53
+ batch_first=True,
54
+ activation='gelu'
55
+ )
56
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
57
+
58
+ # 策略头(输出4个动作的概率)
59
+ self.policy_head = nn.Sequential(
60
+ nn.Linear(d_model, d_model),
61
+ nn.GELU(),
62
+ nn.Linear(d_model, 4)
63
+ )
64
+
65
+ # 价值头(输出状态价值)
66
+ self.value_head = nn.Sequential(
67
+ nn.Linear(d_model, d_model),
68
+ nn.GELU(),
69
+ nn.Linear(d_model, 1),
70
+ nn.Tanh()
71
+ )
72
+
73
+ # 初始化权重
74
+ self._init_weights()
75
+
76
+ def _init_weights(self):
77
+ """初始化权重"""
78
+ for module in self.modules():
79
+ if isinstance(module, nn.Linear):
80
+ nn.init.xavier_uniform_(module.weight)
81
+ if module.bias is not None:
82
+ nn.init.zeros_(module.bias)
83
+ elif isinstance(module, nn.Embedding):
84
+ nn.init.normal_(module.weight, mean=0, std=0.02)
85
+
86
+ def forward(
87
+ self,
88
+ board: torch.Tensor,
89
+ scores: torch.Tensor = None,
90
+ mask: torch.Tensor = None
91
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
92
+ """
93
+ 前向传播
94
+
95
+ Args:
96
+ board: (batch, 4, 4) 或 (batch, 16) 棋盘状态,值为log2(value)
97
+ scores: (batch, 2) 可选的分数特征 [累积分数, 局面分数]
98
+ mask: (batch, 4) 可选的动作mask,True表示有效动作
99
+
100
+ Returns:
101
+ policy: (batch, 4) 动作logits
102
+ value: (batch, 1) 状态价值
103
+ """
104
+ batch_size = board.shape[0]
105
+
106
+ # 展平棋盘
107
+ if board.dim() == 3:
108
+ board_flat = board.view(batch_size, -1) # (batch, 16)
109
+ else:
110
+ board_flat = board
111
+
112
+ # 将棋盘值转换为embedding索引(clamp到有效范围)
113
+ board_indices = torch.clamp(board_flat.long(), 0, 15)
114
+
115
+ # 位置embedding
116
+ pos_embeddings = self.position_embedding(board_indices) # (batch, 16, d_model)
117
+
118
+ # 添加空间位置编码
119
+ row_indices = torch.arange(4, device=board.device).repeat(4)
120
+ col_indices = torch.arange(4, device=board.device).repeat_interleave(4)
121
+ row_emb = self.row_embedding(row_indices) # (16, d_model//2)
122
+ col_emb = self.col_embedding(col_indices) # (16, d_model//2)
123
+ spatial_emb = torch.cat([row_emb, col_emb], dim=-1) # (16, d_model)
124
+
125
+ # 合并embedding
126
+ x = pos_embeddings + spatial_emb.unsqueeze(0) # (batch, 16, d_model)
127
+
128
+ # 如果提供分数特征,作为第17个token
129
+ if scores is not None:
130
+ score_emb = self.score_embedding(scores).unsqueeze(1) # (batch, 1, d_model)
131
+ x = torch.cat([x, score_emb], dim=1) # (batch, 17, d_model)
132
+
133
+ # Transformer编码
134
+ x = self.transformer(x) # (batch, 17 or 16, d_model)
135
+
136
+ # 全局池化
137
+ x = x.mean(dim=1) # (batch, d_model)
138
+
139
+ # 输出头
140
+ policy_logits = self.policy_head(x) # (batch, 4)
141
+ value = self.value_head(x) # (batch, 1)
142
+
143
+ # 应用动作mask
144
+ if mask is not None:
145
+ # 无效动作设为很小的值
146
+ policy_logits = policy_logits.masked_fill(~mask, -1e9)
147
+
148
+ return policy_logits, value
149
+
150
+ def get_action(
151
+ self,
152
+ board: torch.Tensor,
153
+ scores: torch.Tensor = None,
154
+ mask: torch.Tensor = None,
155
+ deterministic: bool = False
156
+ ) -> Tuple[int, torch.Tensor, torch.Tensor]:
157
+ """
158
+ 选择动作
159
+
160
+ Args:
161
+ board: (1, 4, 4) 或 (4, 4) 棋盘状态
162
+ scores: (1, 2) 或 (2,) 分数特征
163
+ mask: (1, 4) 或 (4,) 动作mask
164
+ deterministic: 是否确定性选择
165
+
166
+ Returns:
167
+ action: 选择的动作
168
+ log_prob: 动作的log概率
169
+ value: 状态价值
170
+ """
171
+ # 确保维度正确
172
+ if board.dim() == 2:
173
+ board = board.unsqueeze(0)
174
+ if scores is not None and scores.dim() == 1:
175
+ scores = scores.unsqueeze(0)
176
+ if mask is not None and mask.dim() == 1:
177
+ mask = mask.unsqueeze(0)
178
+
179
+ with torch.no_grad():
180
+ policy_logits, value = self.forward(board, scores, mask)
181
+ probs = F.softmax(policy_logits, dim=-1)
182
+
183
+ if deterministic:
184
+ action = torch.argmax(probs, dim=-1).item()
185
+ else:
186
+ # 从概率分布采样
187
+ dist = torch.distributions.Categorical(probs)
188
+ action = dist.sample().item()
189
+
190
+ log_prob = F.log_softmax(policy_logits, dim=-1)[0, action].item()
191
+
192
+ return action, log_prob, value.item()
193
+
194
+ def evaluate_actions(
195
+ self,
196
+ board: torch.Tensor,
197
+ actions: torch.Tensor,
198
+ scores: torch.Tensor = None,
199
+ mask: torch.Tensor = None
200
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
201
+ """
202
+ 评估动作(用于训练)
203
+
204
+ Args:
205
+ board: (batch, 4, 4) 棋盘状态
206
+ actions: (batch,) 采取的动作
207
+ scores: (batch, 2) 分数特征
208
+ mask: (batch, 4) 动作mask
209
+
210
+ Returns:
211
+ log_probs: (batch,) 动作log概率
212
+ values: (batch, 1) 状态价值
213
+ entropy: (batch,) 策略熵
214
+ """
215
+ policy_logits, values = self.forward(board, scores, mask)
216
+
217
+ probs = F.softmax(policy_logits, dim=-1)
218
+ log_probs = F.log_softmax(policy_logits, dim=-1)
219
+
220
+ # 选择动作的log概率
221
+ action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
222
+
223
+ # 计算熵
224
+ entropy = -(probs * log_probs).sum(dim=-1)
225
+
226
+ return action_log_probs, values, entropy
227
+
228
+
229
+ class ResidualBlock(nn.Module):
230
+ """残差块"""
231
+
232
+ def __init__(self, channels: int):
233
+ super().__init__()
234
+ self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
235
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
236
+ self.bn1 = nn.BatchNorm2d(channels)
237
+ self.bn2 = nn.BatchNorm2d(channels)
238
+
239
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
240
+ residual = x
241
+ x = F.relu(self.bn1(self.conv1(x)))
242
+ x = self.bn2(self.conv2(x))
243
+ return F.relu(x + residual)
244
+
245
+
246
+ class Game2048CNN(nn.Module):
247
+ """
248
+ CNN版本的2048模型(作为备选)
249
+ 更简单,可能更快
250
+ """
251
+
252
+ def __init__(self, channels: int = 64):
253
+ super().__init__()
254
+
255
+ # 输入: (batch, 1, 4, 4)
256
+ self.conv1 = nn.Conv2d(1, channels, kernel_size=3, padding=1)
257
+ self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
258
+ self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
259
+
260
+ # 分数处理
261
+ self.score_fc = nn.Linear(2, channels)
262
+
263
+ # 输出头
264
+ self.policy_head = nn.Sequential(
265
+ nn.Linear(channels * 16 + channels, 256),
266
+ nn.ReLU(),
267
+ nn.Linear(256, 4)
268
+ )
269
+
270
+ self.value_head = nn.Sequential(
271
+ nn.Linear(channels * 16 + channels, 256),
272
+ nn.ReLU(),
273
+ nn.Linear(256, 1),
274
+ nn.Tanh()
275
+ )
276
+
277
+ def forward(
278
+ self,
279
+ board: torch.Tensor,
280
+ scores: torch.Tensor = None,
281
+ mask: torch.Tensor = None
282
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
283
+ """前向传播"""
284
+ # 添加channel维度
285
+ x = board.unsqueeze(1) # (batch, 1, 4, 4)
286
+
287
+ # CNN特征提取
288
+ x = F.relu(self.conv1(x))
289
+ x = F.relu(self.conv2(x))
290
+ x = F.relu(self.conv3(x))
291
+
292
+ # 展平
293
+ x = x.view(x.size(0), -1) # (batch, channels*16)
294
+
295
+ # 合并分数
296
+ if scores is not None:
297
+ score_feat = F.relu(self.score_fc(scores))
298
+ x = torch.cat([x, score_feat], dim=-1)
299
+ else:
300
+ x = torch.cat([x, torch.zeros(x.size(0), 64, device=x.device)], dim=-1)
301
+
302
+ # 输出
303
+ policy_logits = self.policy_head(x)
304
+ value = self.value_head(x)
305
+
306
+ if mask is not None:
307
+ policy_logits = policy_logits.masked_fill(~mask, -1e9)
308
+
309
+ return policy_logits, value
310
+
311
+
312
+ def count_parameters(model: nn.Module) -> int:
313
+ """计算模型参数数量"""
314
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
315
+
316
+
317
+ if __name__ == "__main__":
318
+ # 测试模型
319
+ device = torch.device("cpu")
320
+
321
+ # Transformer模型
322
+ model = Game2048Transformer().to(device)
323
+ print(f"Transformer参数量: {count_parameters(model):,}")
324
+
325
+ # 测试前向传播
326
+ batch_size = 4
327
+ board = torch.randint(0, 12, (batch_size, 4, 4), dtype=torch.float32).to(device)
328
+ scores = torch.rand(batch_size, 2).to(device)
329
+ mask = torch.ones(batch_size, 4, dtype=torch.bool).to(device)
330
+
331
+ policy_logits, value = model(board, scores, mask)
332
+ print(f"Policy logits shape: {policy_logits.shape}")
333
+ print(f"Value shape: {value.shape}")
334
+
335
+ # 测试动作选择
336
+ action, log_prob, val = model.get_action(board[0], scores[0], mask[0])
337
+ print(f"Action: {action}, Log prob: {log_prob:.4f}, Value: {val:.4f}")
338
+
339
+ # CNN模型
340
+ cnn_model = Game2048CNN().to(device)
341
+ print(f"\nCNN参数量: {count_parameters(cnn_model):,}")
parallel.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 多进程并行训练环境
3
+ """
4
+ import numpy as np
5
+ import torch
6
+ from typing import List, Tuple, Optional
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ import threading
9
+ from game import Game2048
10
+ from trainer import Transition, TrainingStats
11
+ import time
12
+
13
+
14
+ class ParallelGameEnv:
15
+ """
16
+ 并行游戏环境
17
+ 使用多线程同时运行多个游戏实例
18
+ """
19
+
20
+ def __init__(self, num_envs: int = 4):
21
+ """
22
+ 初始化并行环境
23
+
24
+ Args:
25
+ num_envs: 并行游戏数量
26
+ """
27
+ self.num_envs = num_envs
28
+ self.envs = [Game2048() for _ in range(num_envs)]
29
+ self.states = [env.get_state() for env in self.envs]
30
+ self.scores = [env.accumulated_score for env in self.envs]
31
+ self.situational_scores = [env.situational_score for env in self.envs]
32
+
33
+ # 游戏结束回调
34
+ self.on_game_end = None
35
+
36
+ # 已完成游戏计数
37
+ self.games_completed = 0
38
+
39
+ # 使用线程池
40
+ self.executor = ThreadPoolExecutor(max_workers=num_envs)
41
+
42
+ def reset(self, indices: Optional[List[int]] = None) -> Tuple[np.ndarray, np.ndarray]:
43
+ """
44
+ 重置指定环境
45
+
46
+ Args:
47
+ indices: 要重置的环境索引,None表示全部重置
48
+
49
+ Returns:
50
+ states: (num_envs, 4, 4) 状态数组
51
+ score_features: (num_envs, 2) 分数特征数组
52
+ """
53
+ if indices is None:
54
+ indices = range(self.num_envs)
55
+
56
+ for i in indices:
57
+ self.states[i] = self.envs[i].reset()
58
+ self.scores[i] = self.envs[i].accumulated_score
59
+ self.situational_scores[i] = self.envs[i].situational_score
60
+
61
+ return self._get_batch_state()
62
+
63
+ def reset_single(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
64
+ """重置单个环境"""
65
+ self.states[idx] = self.envs[idx].reset()
66
+ self.scores[idx] = self.envs[idx].accumulated_score
67
+ self.situational_scores[idx] = self.envs[idx].situational_score
68
+ return self.states[idx], self._get_score_features(idx)
69
+
70
+ def step(self, actions: List[int]) -> List[Transition]:
71
+ """
72
+ 并行执行动作
73
+
74
+ Args:
75
+ actions: 每个环境要执行的动作列表
76
+
77
+ Returns:
78
+ transitions: 状态转移列表
79
+ """
80
+ transitions = []
81
+
82
+ for i, action in enumerate(actions):
83
+ old_state = self.states[i].copy()
84
+ old_scores = self._get_score_features(i)
85
+ old_situational = self.situational_scores[i]
86
+
87
+ # 执行动作
88
+ new_state, reward, moved, done = self.envs[i].move(action)
89
+
90
+ # 更新状态
91
+ self.states[i] = new_state
92
+ self.scores[i] = self.envs[i].accumulated_score
93
+ self.situational_scores[i] = self.envs[i].situational_score
94
+
95
+ # 创建转移记录
96
+ transition = Transition(
97
+ state=old_state,
98
+ scores=old_scores,
99
+ action=action,
100
+ reward=reward,
101
+ next_state=new_state.copy(),
102
+ next_scores=self._get_score_features(i),
103
+ done=done,
104
+ log_prob=0.0, # 需要在外部填充
105
+ value=0.0, # 需要在外部填充
106
+ valid_actions=self.envs[i].get_valid_actions()
107
+ )
108
+ transitions.append(transition)
109
+
110
+ # 如果游戏结束,记录统计并重置
111
+ if done:
112
+ # 记录游戏统计
113
+ game_stats = {
114
+ 'score': self.scores[i],
115
+ 'situational_score': self.situational_scores[i],
116
+ 'max_tile': self.envs[i].get_max_tile(),
117
+ 'moves': self.envs[i].moves_count
118
+ }
119
+ self.games_completed += 1
120
+
121
+ # 调用回调
122
+ if self.on_game_end:
123
+ self.on_game_end(game_stats)
124
+
125
+ self.reset_single(i)
126
+
127
+ return transitions
128
+
129
+ def _get_batch_state(self) -> Tuple[np.ndarray, np.ndarray]:
130
+ """获取批量状态"""
131
+ states = np.array(self.states, dtype=np.float32)
132
+ score_features = np.array([
133
+ self._get_score_features(i) for i in range(self.num_envs)
134
+ ], dtype=np.float32)
135
+ return states, score_features
136
+
137
+ def _get_score_features(self, idx: int) -> np.ndarray:
138
+ """获取单个环境的分数特征"""
139
+ max_accumulated = 50000
140
+ max_situational = 200
141
+ return np.array([
142
+ min(self.scores[idx] / max_accumulated, 1.0),
143
+ min(self.situational_scores[idx] / max_situational, 1.0)
144
+ ], dtype=np.float32)
145
+
146
+ def get_valid_actions(self) -> np.ndarray:
147
+ """获取所有环境的有效动作"""
148
+ return np.array([env.get_valid_actions() for env in self.envs])
149
+
150
+ def get_game_stats(self) -> List[dict]:
151
+ """获取所有游戏的统计信息"""
152
+ return [
153
+ {
154
+ 'score': env.accumulated_score,
155
+ 'situational_score': env.situational_score,
156
+ 'max_tile': env.get_max_tile(),
157
+ 'moves': env.moves_count,
158
+ 'game_over': env.game_over
159
+ }
160
+ for env in self.envs
161
+ ]
162
+
163
+ def close(self):
164
+ """关闭环境"""
165
+ self.executor.shutdown(wait=False)
166
+
167
+
168
+ class TrainingWorker:
169
+ """
170
+ 训练工作器
171
+ 负责收集轨迹数据
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ model,
177
+ env: ParallelGameEnv,
178
+ device: str = "cpu"
179
+ ):
180
+ self.model = model.to(device)
181
+ self.model.eval()
182
+ self.env = env
183
+ self.device = device
184
+
185
+ self.stats = TrainingStats()
186
+
187
+ def collect_trajectories(
188
+ self,
189
+ num_steps: int = 256,
190
+ deterministic: bool = False
191
+ ) -> List[Transition]:
192
+ """
193
+ 收集轨迹数据
194
+
195
+ Args:
196
+ num_steps: 每个环境收集的步数
197
+ deterministic: 是否确定性选择动作
198
+
199
+ Returns:
200
+ transitions: 收集的转移数据
201
+ """
202
+ all_transitions = []
203
+
204
+ for _ in range(num_steps):
205
+ # 获取当前状态
206
+ states = np.array(self.env.states, dtype=np.float32)
207
+ score_features = np.array([
208
+ self.env._get_score_features(i)
209
+ for i in range(self.env.num_envs)
210
+ ], dtype=np.float32)
211
+ valid_actions = self.env.get_valid_actions()
212
+
213
+ # 转换为张量
214
+ states_t = torch.FloatTensor(states).to(self.device)
215
+ scores_t = torch.FloatTensor(score_features).to(self.device)
216
+ valid_t = torch.BoolTensor(valid_actions).to(self.device)
217
+
218
+ # 选择动作
219
+ actions = []
220
+ log_probs = []
221
+ values = []
222
+
223
+ with torch.no_grad():
224
+ for i in range(self.env.num_envs):
225
+ action, log_prob, value = self.model.get_action(
226
+ states_t[i:i+1],
227
+ scores_t[i:i+1],
228
+ valid_t[i:i+1],
229
+ deterministic=deterministic
230
+ )
231
+ actions.append(action)
232
+ log_probs.append(log_prob)
233
+ values.append(value)
234
+
235
+ # 执行动作
236
+ transitions = self.env.step(actions)
237
+
238
+ # 填充log_prob和value
239
+ for i, t in enumerate(transitions):
240
+ t.log_prob = log_probs[i]
241
+ t.value = values[i]
242
+ all_transitions.append(t)
243
+
244
+ return all_transitions
245
+
246
+ def run_episode(
247
+ self,
248
+ deterministic: bool = True,
249
+ max_steps: int = 10000
250
+ ) -> dict:
251
+ """
252
+ 运行一局演示游戏
253
+
254
+ Args:
255
+ deterministic: 是否确定性选择
256
+ max_steps: 最大步数
257
+
258
+ Returns:
259
+ 游戏统计信息
260
+ """
261
+ # 重置环境
262
+ env = Game2048()
263
+ state = env.reset()
264
+
265
+ total_reward = 0
266
+ steps = 0
267
+
268
+ while not env.game_over and steps < max_steps:
269
+ # 获取分数特征
270
+ max_accumulated = 50000
271
+ max_situational = 200
272
+ scores = np.array([
273
+ min(env.accumulated_score / max_accumulated, 1.0),
274
+ min(env.situational_score / max_situational, 1.0)
275
+ ], dtype=np.float32)
276
+
277
+ # 获取有效动作
278
+ valid_actions = env.get_valid_actions()
279
+
280
+ # 转换为张量
281
+ state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
282
+ scores_t = torch.FloatTensor(scores).unsqueeze(0).to(self.device)
283
+ valid_t = torch.BoolTensor(valid_actions).unsqueeze(0).to(self.device)
284
+
285
+ # 选择动作
286
+ with torch.no_grad():
287
+ action, _, _ = self.model.get_action(
288
+ state_t, scores_t, valid_t, deterministic=deterministic
289
+ )
290
+
291
+ # 执行动作
292
+ state, reward, moved, done = env.move(action)
293
+ total_reward += reward
294
+ steps += 1
295
+
296
+ return {
297
+ 'score': env.accumulated_score,
298
+ 'situational_score': env.situational_score,
299
+ 'max_tile': env.get_max_tile(),
300
+ 'steps': steps,
301
+ 'total_reward': total_reward
302
+ }
303
+
304
+
305
+ class TrainingLoop:
306
+ """
307
+ 完整的训练循环
308
+ """
309
+
310
+ def __init__(
311
+ self,
312
+ model,
313
+ trainer,
314
+ num_envs: int = 4,
315
+ device: str = "cpu",
316
+ steps_per_update: int = 256,
317
+ save_interval: int = 100,
318
+ checkpoint_dir: str = "checkpoints"
319
+ ):
320
+ self.model = model
321
+ self.trainer = trainer
322
+ self.num_envs = num_envs
323
+ self.device = device
324
+ self.steps_per_update = steps_per_update
325
+ self.save_interval = save_interval
326
+ self.checkpoint_dir = checkpoint_dir
327
+
328
+ # 初始化环境和工作器
329
+ self.env = ParallelGameEnv(num_envs=num_envs)
330
+ self.worker = TrainingWorker(model, self.env, device)
331
+
332
+ # 训练状态
333
+ self.training = False
334
+ self.paused = False
335
+ self.stats = TrainingStats()
336
+
337
+ # 回调函数
338
+ self.on_update_callback = None
339
+ self.on_game_end_callback = None
340
+
341
+ def train(
342
+ self,
343
+ total_games: int = 10000,
344
+ stop_threshold: int = 100,
345
+ min_improvement: float = 0.01
346
+ ) -> None:
347
+ """
348
+ 训练循环
349
+
350
+ Args:
351
+ total_games: 总游戏局数
352
+ stop_threshold: 无提升停止阈值(局数)
353
+ min_improvement: 最小提升比例
354
+ """
355
+ self.training = True
356
+
357
+ games_since_improvement = 0
358
+ best_avg_score = 0
359
+
360
+ # 设置游戏结束回调
361
+ def on_game_end(game_stats):
362
+ self.stats.record_game(
363
+ score=game_stats['score'],
364
+ situational_score=game_stats['situational_score'],
365
+ max_tile=game_stats['max_tile'],
366
+ steps=game_stats['moves']
367
+ )
368
+ if self.on_game_end_callback:
369
+ self.on_game_end_callback(game_stats)
370
+
371
+ self.env.on_game_end = on_game_end
372
+
373
+ try:
374
+ while self.training and self.env.games_completed < total_games:
375
+ if self.paused:
376
+ time.sleep(0.1)
377
+ continue
378
+
379
+ # 收集轨迹
380
+ transitions = self.worker.collect_trajectories(
381
+ num_steps=self.steps_per_update // self.num_envs,
382
+ deterministic=False
383
+ )
384
+
385
+ # 更新模型
386
+ from trainer import RolloutBuffer
387
+ buffer = RolloutBuffer(capacity=len(transitions))
388
+ buffer.push_batch(transitions)
389
+
390
+ update_stats = self.trainer.update(buffer)
391
+
392
+ if self.on_update_callback:
393
+ self.on_update_callback(update_stats)
394
+
395
+ # 检查停止条件(每10次更新检查一次)
396
+ if self.env.games_completed % 10 == 0 and self.env.games_completed > 0:
397
+ current_avg = self.stats.get_avg_stats(window=100)['avg_score']
398
+ if current_avg > best_avg_score * (1 + min_improvement):
399
+ best_avg_score = current_avg
400
+ games_since_improvement = 0
401
+ else:
402
+ games_since_improvement = self.env.games_completed - int(best_avg_score / 100 * 100) if best_avg_score > 0 else 0
403
+
404
+ if games_since_improvement >= stop_threshold:
405
+ print(f"No improvement for {stop_threshold} updates, stopping.")
406
+ break
407
+
408
+ except KeyboardInterrupt:
409
+ print("Training interrupted by user.")
410
+ finally:
411
+ self.training = False
412
+ self.env.close()
413
+
414
+ def stop(self) -> None:
415
+ """停止训练"""
416
+ self.training = False
417
+
418
+ def pause(self) -> None:
419
+ """暂停训练"""
420
+ self.paused = True
421
+
422
+ def resume(self) -> None:
423
+ """恢复训练"""
424
+ self.paused = False
425
+
426
+ def get_stats(self) -> dict:
427
+ """获取当前统计信息"""
428
+ return self.stats.get_avg_stats()
429
+
430
+ def save_checkpoint(self, path: str) -> None:
431
+ """保存模型检查点"""
432
+ torch.save({
433
+ 'model_state_dict': self.model.state_dict(),
434
+ 'optimizer_state_dict': self.trainer.optimizer.state_dict(),
435
+ 'stats': self.stats.get_avg_stats(),
436
+ 'games_played': self.stats.games_played
437
+ }, path)
438
+
439
+ def load_checkpoint(self, path: str) -> None:
440
+ """加载模型检查点"""
441
+ checkpoint = torch.load(path, map_location=self.device)
442
+ self.model.load_state_dict(checkpoint['model_state_dict'])
443
+ self.trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
444
+
445
+
446
+ if __name__ == "__main__":
447
+ from model import Game2048Transformer
448
+ from trainer import PPOTrainer
449
+
450
+ # 测试并行环境
451
+ env = ParallelGameEnv(num_envs=4)
452
+ env.reset()
453
+
454
+ print("Testing parallel environment...")
455
+ for i in range(10):
456
+ actions = [np.random.randint(0, 4) for _ in range(env.num_envs)]
457
+ transitions = env.step(actions)
458
+ print(f"Step {i}: collected {len(transitions)} transitions")
459
+
460
+ env.close()
461
+ print("Parallel environment test passed!")
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 2048 AI Trainer 依赖
2
+
3
+ # 深度学习框架
4
+ torch>=2.0.0
5
+
6
+ # 数值计算 (需要 <2 以兼容 torch)
7
+ numpy<2
8
+
9
+ # GUI
10
+ PyQt5>=5.15.0
11
+
12
+ # 绘图
13
+ matplotlib>=3.7.0
train_simple.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 简单训练脚本 - 直接训练并保存模型
3
+ """
4
+ import os
5
+ import sys
6
+ import time
7
+ import torch
8
+ import numpy as np
9
+
10
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
11
+
12
+ from game import Game2048
13
+ from model import Game2048Transformer
14
+ from trainer import PPOTrainer, RolloutBuffer
15
+
16
+ def train_simple(
17
+ num_games: int = 1000,
18
+ save_path: str = "checkpoints/model.pt",
19
+ print_interval: int = 10
20
+ ):
21
+ """简单训练"""
22
+ print("=" * 50)
23
+ print("2048 AI Simple Training")
24
+ print("=" * 50)
25
+
26
+ device = "cpu"
27
+ model = Game2048Transformer().to(device)
28
+ trainer = PPOTrainer(model, lr=3e-4, device=device)
29
+
30
+ print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
31
+ print(f"Training for {num_games} games...")
32
+ print("-" * 50)
33
+
34
+ # 统计
35
+ scores = []
36
+ max_tiles = []
37
+ best_score = 0
38
+
39
+ start_time = time.time()
40
+
41
+ for game_idx in range(num_games):
42
+ # 运行一局游戏
43
+ game = Game2048()
44
+ game.reset()
45
+
46
+ buffer = RolloutBuffer(capacity=10000)
47
+
48
+ while not game.game_over:
49
+ state = game.get_state()
50
+ scores_feat = np.array([
51
+ min(game.accumulated_score / 50000, 1.0),
52
+ min(game.situational_score / 200, 1.0)
53
+ ], dtype=np.float32)
54
+ valid = game.get_valid_actions()
55
+
56
+ # 转换张量
57
+ state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
58
+ scores_t = torch.FloatTensor(scores_feat).unsqueeze(0).to(device)
59
+ valid_t = torch.BoolTensor(valid).unsqueeze(0).to(device)
60
+
61
+ # 选择动作
62
+ with torch.no_grad():
63
+ action, log_prob, value = model.get_action(state_t, scores_t, valid_t, deterministic=False)
64
+
65
+ # 执行动作
66
+ old_state = state.copy()
67
+ old_scores = scores_feat.copy()
68
+
69
+ next_state, reward, moved, done = game.move(action)
70
+
71
+ # 存储转移
72
+ from trainer import Transition
73
+ transition = Transition(
74
+ state=old_state,
75
+ scores=old_scores,
76
+ action=action,
77
+ reward=reward,
78
+ next_state=next_state.copy(),
79
+ next_scores=np.array([
80
+ min(game.accumulated_score / 50000, 1.0),
81
+ min(game.situational_score / 200, 1.0)
82
+ ], dtype=np.float32),
83
+ done=done,
84
+ log_prob=log_prob,
85
+ value=value,
86
+ valid_actions=valid
87
+ )
88
+ buffer.push(transition)
89
+
90
+ # 每步更新
91
+ if len(buffer) >= 64:
92
+ trainer.update(buffer)
93
+ buffer.clear()
94
+
95
+ # 记录结果
96
+ scores.append(game.accumulated_score)
97
+ max_tiles.append(game.get_max_tile())
98
+
99
+ if game.accumulated_score > best_score:
100
+ best_score = game.accumulated_score
101
+
102
+ # 打印进度
103
+ if (game_idx + 1) % print_interval == 0:
104
+ elapsed = time.time() - start_time
105
+ avg_score = np.mean(scores[-print_interval:])
106
+ avg_max_tile = np.mean(max_tiles[-print_interval:])
107
+ speed = (game_idx + 1) / elapsed
108
+
109
+ print(
110
+ f"Game {game_idx + 1}/{num_games} | "
111
+ f"Avg Score: {avg_score:.0f} | "
112
+ f"Best: {best_score} | "
113
+ f"Max Tile: {avg_max_tile:.0f} | "
114
+ f"Speed: {speed:.2f} games/s"
115
+ )
116
+
117
+ # 保存模型
118
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
119
+ torch.save({
120
+ 'model_state_dict': model.state_dict(),
121
+ 'game_idx': game_idx,
122
+ 'best_score': best_score,
123
+ 'avg_score': avg_score
124
+ }, save_path)
125
+
126
+ # 最终保存
127
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
128
+ torch.save({
129
+ 'model_state_dict': model.state_dict(),
130
+ 'game_idx': num_games,
131
+ 'best_score': best_score,
132
+ 'avg_score': np.mean(scores[-100:])
133
+ }, save_path)
134
+
135
+ elapsed = time.time() - start_time
136
+ print("\n" + "=" * 50)
137
+ print("Training Complete!")
138
+ print("=" * 50)
139
+ print(f"Total games: {num_games}")
140
+ print(f"Total time: {elapsed:.1f}s")
141
+ print(f"Average score (last 100): {np.mean(scores[-100:]):.0f}")
142
+ print(f"Best score: {best_score}")
143
+ print(f"Best max tile: {max(max_tiles)}")
144
+ print(f"Model saved to: {save_path}")
145
+
146
+ return model
147
+
148
+ if __name__ == "__main__":
149
+ train_simple(num_games=500)
trainer.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PPO训练器
3
+ """
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ import numpy as np
9
+ from typing import List, Tuple, Optional
10
+ from dataclasses import dataclass
11
+ from collections import deque
12
+ import random
13
+
14
+
15
+ @dataclass
16
+ class Transition:
17
+ """状态转移数据"""
18
+ state: np.ndarray # 棋盘状态 (4, 4)
19
+ scores: np.ndarray # 分数特征 (2,)
20
+ action: int # 采取的动作
21
+ reward: float # 奖励
22
+ next_state: np.ndarray # 下一状态
23
+ next_scores: np.ndarray # 下一分数
24
+ done: bool # 是否结束
25
+ log_prob: float # 动作的log概率
26
+ value: float # 状态价值
27
+ valid_actions: np.ndarray # 有效动作mask
28
+
29
+
30
+ class RolloutBuffer:
31
+ """存储轨迹数据的缓冲区"""
32
+
33
+ def __init__(self, capacity: int = 10000):
34
+ self.capacity = capacity
35
+ self.buffer: List[Transition] = []
36
+ self.position = 0
37
+
38
+ def push(self, transition: Transition) -> None:
39
+ """添加一个转移"""
40
+ if len(self.buffer) < self.capacity:
41
+ self.buffer.append(transition)
42
+ else:
43
+ self.buffer[self.position] = transition
44
+ self.position = (self.position + 1) % self.capacity
45
+
46
+ def push_batch(self, transitions: List[Transition]) -> None:
47
+ """批量添加转移"""
48
+ for t in transitions:
49
+ self.push(t)
50
+
51
+ def get_all(self) -> List[Transition]:
52
+ """获取所有数据"""
53
+ return self.buffer.copy()
54
+
55
+ def clear(self) -> None:
56
+ """清空缓冲区"""
57
+ self.buffer = []
58
+ self.position = 0
59
+
60
+ def __len__(self) -> int:
61
+ return len(self.buffer)
62
+
63
+
64
+ class PPOTrainer:
65
+ """PPO训练器"""
66
+
67
+ def __init__(
68
+ self,
69
+ model,
70
+ lr: float = 1e-4,
71
+ gamma: float = 0.99,
72
+ gae_lambda: float = 0.95,
73
+ clip_ratio: float = 0.2,
74
+ value_coef: float = 0.5,
75
+ entropy_coef: float = 0.01,
76
+ max_grad_norm: float = 0.5,
77
+ update_epochs: int = 4,
78
+ batch_size: int = 64,
79
+ device: str = "cpu"
80
+ ):
81
+ self.model = model.to(device)
82
+ self.device = device
83
+
84
+ self.gamma = gamma
85
+ self.gae_lambda = gae_lambda
86
+ self.clip_ratio = clip_ratio
87
+ self.value_coef = value_coef
88
+ self.entropy_coef = entropy_coef
89
+ self.max_grad_norm = max_grad_norm
90
+ self.update_epochs = update_epochs
91
+ self.batch_size = batch_size
92
+
93
+ self.optimizer = optim.Adam(model.parameters(), lr=lr)
94
+
95
+ # 训练统计
96
+ self.stats = {
97
+ 'policy_loss': deque(maxlen=100),
98
+ 'value_loss': deque(maxlen=100),
99
+ 'entropy': deque(maxlen=100),
100
+ 'total_loss': deque(maxlen=100)
101
+ }
102
+
103
+ def compute_gae(
104
+ self,
105
+ rewards: np.ndarray,
106
+ values: np.ndarray,
107
+ dones: np.ndarray,
108
+ next_value: float = 0.0
109
+ ) -> Tuple[np.ndarray, np.ndarray]:
110
+ """
111
+ 计算Generalized Advantage Estimation (GAE)
112
+
113
+ Args:
114
+ rewards: 奖励序列 (T,)
115
+ values: 价值序列 (T,)
116
+ dones: 结束标志序列 (T,)
117
+ next_value: 最后状态的下一个价值
118
+
119
+ Returns:
120
+ returns: 回报 (T,)
121
+ advantages: 优势 (T,)
122
+ """
123
+ T = len(rewards)
124
+ advantages = np.zeros(T, dtype=np.float32)
125
+ returns = np.zeros(T, dtype=np.float32)
126
+
127
+ last_gae = 0
128
+ last_return = next_value
129
+
130
+ for t in reversed(range(T)):
131
+ if dones[t]:
132
+ next_value_t = 0
133
+ last_gae = 0
134
+ else:
135
+ next_value_t = values[t + 1] if t + 1 < T else next_value
136
+
137
+ delta = rewards[t] + self.gamma * next_value_t - values[t]
138
+ last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
139
+ advantages[t] = last_gae
140
+
141
+ last_return = rewards[t] + self.gamma * (1 - dones[t]) * last_return
142
+ returns[t] = last_return
143
+
144
+ return returns, advantages
145
+
146
+ def update(self, buffer: RolloutBuffer) -> dict:
147
+ """
148
+ 使用PPO更新模型
149
+
150
+ Args:
151
+ buffer: 存储轨迹数据的缓冲区
152
+
153
+ Returns:
154
+ 训练统计信息
155
+ """
156
+ if len(buffer) < self.batch_size:
157
+ return {}
158
+
159
+ # 获取所有数据
160
+ transitions = buffer.get_all()
161
+
162
+ # 转换为数组
163
+ states = np.array([t.state for t in transitions])
164
+ scores = np.array([t.scores for t in transitions])
165
+ actions = np.array([t.action for t in transitions])
166
+ rewards = np.array([t.reward for t in transitions])
167
+ dones = np.array([t.done for t in transitions])
168
+ old_log_probs = np.array([t.log_prob for t in transitions])
169
+ old_values = np.array([t.value for t in transitions])
170
+ valid_actions = np.array([t.valid_actions for t in transitions])
171
+
172
+ # 计算优势和回报
173
+ returns, advantages = self.compute_gae(rewards, old_values, dones)
174
+
175
+ # 标准化优势
176
+ advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
177
+
178
+ # 转换为张量
179
+ states_t = torch.FloatTensor(states).to(self.device)
180
+ scores_t = torch.FloatTensor(scores).to(self.device)
181
+ actions_t = torch.LongTensor(actions).to(self.device)
182
+ old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
183
+ returns_t = torch.FloatTensor(returns).to(self.device)
184
+ advantages_t = torch.FloatTensor(advantages).to(self.device)
185
+ valid_actions_t = torch.BoolTensor(valid_actions).to(self.device)
186
+
187
+ # PPO更新
188
+ total_policy_loss = 0
189
+ total_value_loss = 0
190
+ total_entropy = 0
191
+ num_updates = 0
192
+
193
+ dataset_size = len(transitions)
194
+ indices = np.arange(dataset_size)
195
+
196
+ for _ in range(self.update_epochs):
197
+ np.random.shuffle(indices)
198
+
199
+ for start in range(0, dataset_size, self.batch_size):
200
+ end = start + self.batch_size
201
+ batch_indices = indices[start:end]
202
+
203
+ # 获取批次数据
204
+ batch_states = states_t[batch_indices]
205
+ batch_scores = scores_t[batch_indices]
206
+ batch_actions = actions_t[batch_indices]
207
+ batch_old_log_probs = old_log_probs_t[batch_indices]
208
+ batch_returns = returns_t[batch_indices]
209
+ batch_advantages = advantages_t[batch_indices]
210
+ batch_valid = valid_actions_t[batch_indices]
211
+
212
+ # 前向传播
213
+ log_probs, values, entropy = self.model.evaluate_actions(
214
+ batch_states, batch_actions, batch_scores, batch_valid
215
+ )
216
+
217
+ # 策略损失 (PPO Clip)
218
+ ratio = torch.exp(log_probs - batch_old_log_probs)
219
+ surr1 = ratio * batch_advantages
220
+ surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
221
+ policy_loss = -torch.min(surr1, surr2).mean()
222
+
223
+ # 价值损失
224
+ value_loss = F.mse_loss(values.squeeze(), batch_returns)
225
+
226
+ # 总损失
227
+ loss = (
228
+ policy_loss +
229
+ self.value_coef * value_loss -
230
+ self.entropy_coef * entropy.mean()
231
+ )
232
+
233
+ # 反向传播
234
+ self.optimizer.zero_grad()
235
+ loss.backward()
236
+ nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
237
+ self.optimizer.step()
238
+
239
+ total_policy_loss += policy_loss.item()
240
+ total_value_loss += value_loss.item()
241
+ total_entropy += entropy.mean().item()
242
+ num_updates += 1
243
+
244
+ # 记录统计
245
+ stats = {
246
+ 'policy_loss': total_policy_loss / num_updates,
247
+ 'value_loss': total_value_loss / num_updates,
248
+ 'entropy': total_entropy / num_updates
249
+ }
250
+
251
+ for key, value in stats.items():
252
+ self.stats[key].append(value)
253
+
254
+ return stats
255
+
256
+ def get_recent_stats(self) -> dict:
257
+ """获取最近的训练统计"""
258
+ return {key: np.mean(values) for key, values in self.stats.items() if values}
259
+
260
+
261
+ class TrainingStats:
262
+ """训练统计记录器"""
263
+
264
+ def __init__(self):
265
+ self.games_played = 0
266
+ self.total_steps = 0
267
+ self.scores = [] # 每局累积分数
268
+ self.situational_scores = [] # 每局平均局面分数
269
+ self.max_tiles = [] # 每局最大砖块
270
+ self.game_lengths = [] # 每局步数
271
+
272
+ # 历史记录用于绘图
273
+ self.score_history = []
274
+ self.situational_history = []
275
+ self.max_tile_history = []
276
+ self.steps_history = []
277
+
278
+ # 最佳记录
279
+ self.best_score = 0
280
+ self.best_max_tile = 0
281
+
282
+ def record_game(
283
+ self,
284
+ score: int,
285
+ situational_score: float,
286
+ max_tile: int,
287
+ steps: int
288
+ ) -> None:
289
+ """记录一局游戏"""
290
+ self.games_played += 1
291
+ self.total_steps += steps
292
+
293
+ self.scores.append(score)
294
+ self.situational_scores.append(situational_score)
295
+ self.max_tiles.append(max_tile)
296
+ self.game_lengths.append(steps)
297
+
298
+ self.score_history.append(score)
299
+ self.situational_history.append(situational_score)
300
+ self.max_tile_history.append(max_tile)
301
+ self.steps_history.append(steps)
302
+
303
+ if score > self.best_score:
304
+ self.best_score = score
305
+ if max_tile > self.best_max_tile:
306
+ self.best_max_tile = max_tile
307
+
308
+ def get_avg_stats(self, window: int = 100) -> dict:
309
+ """获取平均统计"""
310
+ def avg(lst):
311
+ if not lst:
312
+ return 0
313
+ recent = lst[-window:]
314
+ return sum(recent) / len(recent)
315
+
316
+ return {
317
+ 'games_played': self.games_played,
318
+ 'total_steps': self.total_steps,
319
+ 'avg_score': avg(self.scores),
320
+ 'avg_situational': avg(self.situational_scores),
321
+ 'avg_max_tile': avg(self.max_tiles),
322
+ 'avg_game_length': avg(self.game_lengths),
323
+ 'best_score': self.best_score,
324
+ 'best_max_tile': self.best_max_tile,
325
+ 'recent_scores': self.scores[-10:] if self.scores else [],
326
+ 'recent_max_tiles': self.max_tiles[-10:] if self.max_tiles else []
327
+ }
328
+
329
+
330
+ if __name__ == "__main__":
331
+ from model import Game2048Transformer
332
+
333
+ # 测试PPO训练器
334
+ device = torch.device("cpu")
335
+ model = Game2048Transformer().to(device)
336
+ trainer = PPOTrainer(model, device=device)
337
+
338
+ # 创建测试数据
339
+ buffer = RolloutBuffer(capacity=1000)
340
+
341
+ for _ in range(100):
342
+ t = Transition(
343
+ state=np.random.randn(4, 4).astype(np.float32),
344
+ scores=np.random.rand(2).astype(np.float32),
345
+ action=np.random.randint(0, 4),
346
+ reward=np.random.randn(),
347
+ next_state=np.random.randn(4, 4).astype(np.float32),
348
+ next_scores=np.random.rand(2).astype(np.float32),
349
+ done=np.random.rand() < 0.1,
350
+ log_prob=np.random.randn(),
351
+ value=np.random.randn(),
352
+ valid_actions=np.ones(4, dtype=bool)
353
+ )
354
+ buffer.push(t)
355
+
356
+ # 测试更新
357
+ stats = trainer.update(buffer)
358
+ print(f"Training stats: {stats}")
359
+
360
+ # 测试统计
361
+ training_stats = TrainingStats()
362
+ for i in range(10):
363
+ training_stats.record_game(
364
+ score=1000 * (i + 1),
365
+ situational_score=50.0 + i * 5,
366
+ max_tile=2 ** (i + 5),
367
+ steps=100 + i * 10
368
+ )
369
+
370
+ print(f"Average stats: {training_stats.get_avg_stats()}")
utils.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 工具函数
3
+ """
4
+ import os
5
+ import json
6
+ import numpy as np
7
+ import torch
8
+ from datetime import datetime
9
+ from typing import Dict, Any, Optional
10
+ import shutil
11
+
12
+
13
+ def ensure_dir(path: str) -> str:
14
+ """确保目录存在,不存在则创建"""
15
+ if not os.path.exists(path):
16
+ os.makedirs(path)
17
+ return path
18
+
19
+
20
+ def save_checkpoint(
21
+ model: torch.nn.Module,
22
+ optimizer: torch.optim.Optimizer,
23
+ epoch: int,
24
+ stats: Dict[str, Any],
25
+ path: str
26
+ ) -> None:
27
+ """保存训练检查点"""
28
+ ensure_dir(os.path.dirname(path))
29
+ torch.save({
30
+ 'epoch': epoch,
31
+ 'model_state_dict': model.state_dict(),
32
+ 'optimizer_state_dict': optimizer.state_dict(),
33
+ 'stats': stats,
34
+ 'timestamp': datetime.now().isoformat()
35
+ }, path)
36
+
37
+
38
+ def load_checkpoint(
39
+ path: str,
40
+ model: torch.nn.Module,
41
+ optimizer: Optional[torch.optim.Optimizer] = None,
42
+ device: str = 'cpu'
43
+ ) -> Dict[str, Any]:
44
+ """加载训练检查点"""
45
+ checkpoint = torch.load(path, map_location=device, weights_only=False)
46
+ model.load_state_dict(checkpoint['model_state_dict'])
47
+
48
+ if optimizer is not None and 'optimizer_state_dict' in checkpoint:
49
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
50
+
51
+ return checkpoint
52
+
53
+
54
+ def save_training_log(log_data: Dict[str, Any], path: str) -> None:
55
+ """保存训练日志"""
56
+ ensure_dir(os.path.dirname(path))
57
+
58
+ # 读取现有日志
59
+ if os.path.exists(path):
60
+ with open(path, 'r', encoding='utf-8') as f:
61
+ try:
62
+ logs = json.load(f)
63
+ except json.JSONDecodeError:
64
+ logs = []
65
+ else:
66
+ logs = []
67
+
68
+ # 添加新记录
69
+ log_data['timestamp'] = datetime.now().isoformat()
70
+ logs.append(log_data)
71
+
72
+ # 保存
73
+ with open(path, 'w', encoding='utf-8') as f:
74
+ json.dump(logs, f, indent=2, ensure_ascii=False)
75
+
76
+
77
+ def format_time(seconds: float) -> str:
78
+ """格式化时间"""
79
+ if seconds < 60:
80
+ return f'{seconds:.1f}s'
81
+ elif seconds < 3600:
82
+ minutes = seconds / 60
83
+ return f'{minutes:.1f}m'
84
+ else:
85
+ hours = seconds / 3600
86
+ return f'{hours:.1f}h'
87
+
88
+
89
+ def format_number(num: int) -> str:
90
+ """格式化数字(添加逗号分隔)"""
91
+ return f'{num:,}'
92
+
93
+
94
+ def calculate_ema(values: list, alpha: float = 0.1) -> list:
95
+ """计算指数移动平均"""
96
+ if not values:
97
+ return []
98
+
99
+ ema = [values[0]]
100
+ for value in values[1:]:
101
+ ema.append(alpha * value + (1 - alpha) * ema[-1])
102
+ return ema
103
+
104
+
105
+ def get_tile_color(value: int) -> str:
106
+ """获取砖块颜色"""
107
+ colors = {
108
+ 0: '#cdc1b4',
109
+ 2: '#eee4da',
110
+ 4: '#ede0c8',
111
+ 8: '#f2b179',
112
+ 16: '#f59563',
113
+ 32: '#f67c5f',
114
+ 64: '#f65e3b',
115
+ 128: '#edcf72',
116
+ 256: '#edcc61',
117
+ 512: '#edc850',
118
+ 1024: '#edc53f',
119
+ 2048: '#edc22e',
120
+ }
121
+ return colors.get(value, '#3c3a32')
122
+
123
+
124
+ def get_text_color(value: int) -> str:
125
+ """获取文字颜色"""
126
+ if value <= 4:
127
+ return '#776e65'
128
+ return '#f9f6f2'
129
+
130
+
131
+ class EarlyStopping:
132
+ """早停机制"""
133
+
134
+ def __init__(
135
+ self,
136
+ patience: int = 100,
137
+ min_delta: float = 0.01,
138
+ mode: str = 'max'
139
+ ):
140
+ """
141
+ Args:
142
+ patience: 容忍的epoch数
143
+ min_delta: 最小改进
144
+ mode: 'max' 或 'min'
145
+ """
146
+ self.patience = patience
147
+ self.min_delta = min_delta
148
+ self.mode = mode
149
+ self.counter = 0
150
+ self.best_value = None
151
+ self.should_stop = False
152
+
153
+ def __call__(self, value: float) -> bool:
154
+ """
155
+ 检查是否应该停止
156
+
157
+ Args:
158
+ value: 当前值
159
+
160
+ Returns:
161
+ 是否应该停止
162
+ """
163
+ if self.best_value is None:
164
+ self.best_value = value
165
+ return False
166
+
167
+ if self.mode == 'max':
168
+ improved = value > self.best_value + self.min_delta
169
+ else:
170
+ improved = value < self.best_value - self.min_delta
171
+
172
+ if improved:
173
+ self.best_value = value
174
+ self.counter = 0
175
+ else:
176
+ self.counter += 1
177
+ if self.counter >= self.patience:
178
+ self.should_stop = True
179
+
180
+ return self.should_stop
181
+
182
+
183
+ class MetricTracker:
184
+ """指标跟踪器"""
185
+
186
+ def __init__(self, window_size: int = 100):
187
+ self.window_size = window_size
188
+ self.metrics = {}
189
+
190
+ def update(self, name: str, value: float) -> None:
191
+ """更新指标"""
192
+ if name not in self.metrics:
193
+ self.metrics[name] = []
194
+ self.metrics[name].append(value)
195
+
196
+ # 保持窗口大小
197
+ if len(self.metrics[name]) > self.window_size:
198
+ self.metrics[name] = self.metrics[name][-self.window_size:]
199
+
200
+ def get_mean(self, name: str) -> float:
201
+ """获取平均值"""
202
+ if name not in self.metrics or not self.metrics[name]:
203
+ return 0.0
204
+ return np.mean(self.metrics[name])
205
+
206
+ def get_std(self, name: str) -> float:
207
+ """获取标准差"""
208
+ if name not in self.metrics or len(self.metrics[name]) < 2:
209
+ return 0.0
210
+ return np.std(self.metrics[name])
211
+
212
+ def get_all_means(self) -> Dict[str, float]:
213
+ """获取所有指标的平均值"""
214
+ return {name: self.get_mean(name) for name in self.metrics}
215
+
216
+
217
+ def set_seed(seed: int) -> None:
218
+ """设置随机种子"""
219
+ import random
220
+ random.seed(seed)
221
+ np.random.seed(seed)
222
+ torch.manual_seed(seed)
223
+ if torch.cuda.is_available():
224
+ torch.cuda.manual_seed(seed)
225
+ torch.cuda.manual_seed_all(seed)
226
+
227
+
228
+ def get_device() -> str:
229
+ """获取可用设备"""
230
+ if torch.cuda.is_available():
231
+ return 'cuda'
232
+ return 'cpu'
233
+
234
+
235
+ def count_parameters(model: torch.nn.Module) -> int:
236
+ """计算模型参数数量"""
237
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
238
+
239
+
240
+ def print_model_info(model: torch.nn.Module) -> None:
241
+ """打印模型信息"""
242
+ total_params = count_parameters(model)
243
+ print(f"模型参数数量: {format_number(total_params)}")
244
+ print(f"模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (float32)")
245
+
246
+
247
+ def export_to_onnx(
248
+ model: torch.nn.Module,
249
+ path: str,
250
+ input_size: tuple = (1, 4, 4)
251
+ ) -> None:
252
+ """导出模型到ONNX格式"""
253
+ model.eval()
254
+ dummy_input = torch.randn(*input_size)
255
+ dummy_scores = torch.randn(1, 2)
256
+ dummy_mask = torch.ones(1, 4, dtype=torch.bool)
257
+
258
+ ensure_dir(os.path.dirname(path))
259
+ torch.onnx.export(
260
+ model,
261
+ (dummy_input, dummy_scores, dummy_mask),
262
+ path,
263
+ input_names=['board', 'scores', 'mask'],
264
+ output_names=['policy', 'value'],
265
+ dynamic_axes={
266
+ 'board': {0: 'batch_size'},
267
+ 'scores': {0: 'batch_size'},
268
+ 'mask': {0: 'batch_size'}
269
+ }
270
+ )
271
+ print(f"模型已导出到: {path}")
272
+
273
+
274
+ if __name__ == "__main__":
275
+ # 测试工具函数
276
+ print("Testing utility functions...")
277
+
278
+ # 测试时间格式化
279
+ print(f"Format time: {format_time(45.5)}, {format_time(125.3)}, {format_time(3661)}")
280
+
281
+ # 测试数字格式化
282
+ print(f"Format number: {format_number(1234567)}")
283
+
284
+ # 测试EMA
285
+ values = [1, 2, 3, 4, 5]
286
+ print(f"EMA: {calculate_ema(values)}")
287
+
288
+ # 测试早停
289
+ early_stop = EarlyStopping(patience=3, min_delta=0.1)
290
+ scores = [10, 11, 12, 12, 12, 12, 12]
291
+ for i, score in enumerate(scores):
292
+ stop = early_stop(score)
293
+ print(f"Epoch {i}: score={score}, stop={stop}")
294
+
295
+ print("All tests passed!")