Upload 13 files
Browse files- PLAN.md +230 -0
- README.md +318 -0
- TASK.md +24 -0
- checkpoints/model_1500.pt +3 -0
- game.py +388 -0
- gui.py +835 -0
- main.py +302 -0
- model.py +341 -0
- parallel.py +461 -0
- requirements.txt +13 -0
- train_simple.py +149 -0
- trainer.py +370 -0
- utils.py +295 -0
PLAN.md
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 项目计划:基于Transformer的2048游戏AI
|
| 2 |
+
|
| 3 |
+
## 技术方案概述
|
| 4 |
+
|
| 5 |
+
### 硬件约束
|
| 6 |
+
- AMD Ryzen 5 PRO 4650U (6核12线程) CPU
|
| 7 |
+
- 8GB RAM
|
| 8 |
+
- 无NVIDIA GPU,纯CPU训练
|
| 9 |
+
- 需要小型高效的模型架构
|
| 10 |
+
|
| 11 |
+
### 整体架构
|
| 12 |
+
|
| 13 |
+
```
|
| 14 |
+
┌─────────────────────────────────────────────────────────┐
|
| 15 |
+
│ GUI主窗口 (PyQt5) │
|
| 16 |
+
├─────────────────────────────────────────────────────────┤
|
| 17 |
+
│ ┌──────────────┐ ┌──────────────────────────────────┐ │
|
| 18 |
+
│ │ 2048游戏面板 │ │ 训练状态面板 │ │
|
| 19 |
+
│ │ (4x4网格) │ │ - 当前分数/局面分数 │ │
|
| 20 |
+
│ │ │ │ - 训练速度 (games/sec) │ │
|
| 21 |
+
│ │ │ │ - 累积分数变化曲线 │ │
|
| 22 |
+
│ │ │ │ - 局面分数变化曲线 │ │
|
| 23 |
+
│ └──────────────┘ └──────────────────────────────────┘ │
|
| 24 |
+
├─────────────────────────────────────────────────────────┤
|
| 25 |
+
│ 控制面板: [训练模式] [演示模式] [开始/停止] [AI托管] │
|
| 26 |
+
└─────────────────────────────────────────────────────────┘
|
| 27 |
+
```
|
| 28 |
+
|
| 29 |
+
## 模块设计
|
| 30 |
+
|
| 31 |
+
### 1. 游戏核心模块 (`game.py`)
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
class Game2048:
|
| 35 |
+
"""2048游戏核心逻辑"""
|
| 36 |
+
|
| 37 |
+
def __init__(self):
|
| 38 |
+
self.board: np.ndarray # 4x4棋盘
|
| 39 |
+
self.accumulated_score: int # 累积分数
|
| 40 |
+
self.situational_score: float # 局面分数
|
| 41 |
+
|
| 42 |
+
def reset(self) -> None: ...
|
| 43 |
+
def move(self, direction: int) -> tuple[bool, bool]: ... # (moved, game_over)
|
| 44 |
+
def get_state(self) -> np.ndarray: ... # 返回当前局面
|
| 45 |
+
def calculate_situational_score(self) -> float: ...
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
**局面分数计算公式:**
|
| 49 |
+
```
|
| 50 |
+
situation_score = (
|
| 51 |
+
empty_cells * 10 + # 空格越多越好
|
| 52 |
+
max_consecutive_adjacent * 15 + # 连续相邻数字越多越好
|
| 53 |
+
log2(max_tile) * 5 - # 最高数字的对数
|
| 54 |
+
monotonicity_penalty # 单调性惩罚(避免混乱)
|
| 55 |
+
)
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### 2. Transformer模型 (`model.py`)
|
| 59 |
+
|
| 60 |
+
考虑到CPU训练的限制,采用小型Transformer:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
class Game2048Transformer(nn.Module):
|
| 64 |
+
"""小型Transformer用于2048决策"""
|
| 65 |
+
|
| 66 |
+
def __init__(self):
|
| 67 |
+
# 输入: 4x4棋盘 + 2个分数特征
|
| 68 |
+
# 将棋盘展平为16个token,每个token代表一个格子的状态
|
| 69 |
+
|
| 70 |
+
self.embedding = nn.Embedding(16, 64) # 0-15 表示 log2(value),16表示空
|
| 71 |
+
self.score_embedding = nn.Linear(2, 64) # 两种分数的embedding
|
| 72 |
+
|
| 73 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 74 |
+
d_model=64,
|
| 75 |
+
nhead=4,
|
| 76 |
+
dim_feedforward=128,
|
| 77 |
+
dropout=0.1,
|
| 78 |
+
batch_first=True
|
| 79 |
+
)
|
| 80 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=2)
|
| 81 |
+
|
| 82 |
+
self.policy_head = nn.Linear(64, 4) # 输出4个动作的概率
|
| 83 |
+
self.value_head = nn.Linear(64, 1) # 输出状态价值
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
**模型大小估算:**
|
| 87 |
+
- Embedding: 17 * 64 = 1,088 参数
|
| 88 |
+
- Transformer (2层): ~50,000 参数
|
| 89 |
+
- 输出头: ~300 参数
|
| 90 |
+
- **总计: ~52,000 参数** - 非常小,适合CPU训练
|
| 91 |
+
|
| 92 |
+
### 3. 训练模块 (`trainer.py`)
|
| 93 |
+
|
| 94 |
+
采用 **Actor-Critic + PPO** 策略:
|
| 95 |
+
|
| 96 |
+
```python
|
| 97 |
+
class PPOTrainer:
|
| 98 |
+
"""PPO训练器"""
|
| 99 |
+
|
| 100 |
+
def __init__(self, model, lr=1e-4):
|
| 101 |
+
self.model = model
|
| 102 |
+
self.optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
| 103 |
+
|
| 104 |
+
def compute_advantage(self, rewards, values, dones):
|
| 105 |
+
# 计算GAE (Generalized Advantage Estimation)
|
| 106 |
+
...
|
| 107 |
+
|
| 108 |
+
def update(self, trajectories):
|
| 109 |
+
# PPO更新逻辑
|
| 110 |
+
...
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
**奖励设计:**
|
| 114 |
+
```python
|
| 115 |
+
reward = (
|
| 116 |
+
accumulated_score_delta * 0.3 + # 累积分数增量(权重低)
|
| 117 |
+
situational_score * 0.7 + # 局面分数(权重高)
|
| 118 |
+
game_over_penalty * (-100) # 游戏结束惩罚
|
| 119 |
+
)
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
### 4. 多进程训练 (`parallel.py`)
|
| 123 |
+
|
| 124 |
+
利用6核CPU,同时运行多个游戏实例:
|
| 125 |
+
|
| 126 |
+
```python
|
| 127 |
+
class ParallelGameEnv:
|
| 128 |
+
"""并行游戏环境"""
|
| 129 |
+
|
| 130 |
+
def __init__(self, num_envs=4):
|
| 131 |
+
self.num_envs = num_envs
|
| 132 |
+
self.envs = [Game2048() for _ in range(num_envs)]
|
| 133 |
+
|
| 134 |
+
def step(self, actions: list[int]) -> list[Transition]:
|
| 135 |
+
# 并行执行动作,返回状态转移
|
| 136 |
+
...
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
### 5. GUI模块 (`gui.py`)
|
| 140 |
+
|
| 141 |
+
使用 PyQt5 构建界面:
|
| 142 |
+
|
| 143 |
+
```python
|
| 144 |
+
class MainWindow(QMainWindow):
|
| 145 |
+
"""主窗口"""
|
| 146 |
+
|
| 147 |
+
def __init__(self):
|
| 148 |
+
self.game_widget = GameBoardWidget()
|
| 149 |
+
self.stats_widget = StatsWidget()
|
| 150 |
+
self.control_widget = ControlWidget()
|
| 151 |
+
|
| 152 |
+
# 训练线程
|
| 153 |
+
self.training_thread = TrainingThread()
|
| 154 |
+
|
| 155 |
+
def switch_mode(self, mode: str): ...
|
| 156 |
+
def update_display(self): ...
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
## 文件结构
|
| 160 |
+
|
| 161 |
+
```
|
| 162 |
+
game2048/
|
| 163 |
+
├── TASK.md # 任务描述
|
| 164 |
+
├── PLAN.md # 本文件
|
| 165 |
+
├── main.py # 入口文件
|
| 166 |
+
├── game.py # 游戏核心逻辑
|
| 167 |
+
├── model.py # Transformer模型定义
|
| 168 |
+
├── trainer.py # PPO训练器
|
| 169 |
+
├── parallel.py # 多进程训练
|
| 170 |
+
├── gui.py # GUI界面
|
| 171 |
+
├── utils.py # 工具函数
|
| 172 |
+
├── requirements.txt # 依赖
|
| 173 |
+
└── checkpoints/ # 模型保存目录
|
| 174 |
+
```
|
| 175 |
+
|
| 176 |
+
## 实现步骤
|
| 177 |
+
|
| 178 |
+
### 阶段1: 核心游戏逻辑
|
| 179 |
+
1. 实现 `game.py` - 2048游戏规则
|
| 180 |
+
2. 实现局面分数计算
|
| 181 |
+
3. 编写游戏逻辑单元测试
|
| 182 |
+
|
| 183 |
+
### 阶段2: 模型与训练
|
| 184 |
+
4. 实现 `model.py` - Transformer模型
|
| 185 |
+
5. 实现 `trainer.py` - PPO训练器
|
| 186 |
+
6. 实现 `parallel.py` - 多进程环境
|
| 187 |
+
7. 验证训练流程可以运行
|
| 188 |
+
|
| 189 |
+
### 阶段3: GUI界面
|
| 190 |
+
8. 实现 `gui.py` - 主窗口和游戏面板
|
| 191 |
+
9. 实现训练状态可视化(分数曲线)
|
| 192 |
+
10. 实现模式切换(训练/演示)
|
| 193 |
+
|
| 194 |
+
### 阶段4: 整合与优化
|
| 195 |
+
11. 整合所有模块
|
| 196 |
+
12. 性能优化
|
| 197 |
+
13. 模型保存/加载功能
|
| 198 |
+
|
| 199 |
+
## 依赖
|
| 200 |
+
|
| 201 |
+
```
|
| 202 |
+
torch>=2.0.0
|
| 203 |
+
numpy>=1.24.0
|
| 204 |
+
PyQt5>=5.15.0
|
| 205 |
+
matplotlib>=3.7.0
|
| 206 |
+
```
|
| 207 |
+
|
| 208 |
+
## 训练策略细节
|
| 209 |
+
|
| 210 |
+
### 状态表示
|
| 211 |
+
- 棋盘状态:将每个格子的值转换为 log2(value),空格为0
|
| 212 |
+
- 分数归一化:累积分数和局面分数归一化到 [0, 1]
|
| 213 |
+
|
| 214 |
+
### 动作空间
|
| 215 |
+
- 0: 上
|
| 216 |
+
- 1: 下
|
| 217 |
+
- 2: 左
|
| 218 |
+
- 3: 右
|
| 219 |
+
|
| 220 |
+
### 训练超参数
|
| 221 |
+
- Learning rate: 1e-4
|
| 222 |
+
- Batch size: 64
|
| 223 |
+
- PPO clip ratio: 0.2
|
| 224 |
+
- GAE lambda: 0.95
|
| 225 |
+
- Discount factor (gamma): 0.99
|
| 226 |
+
- 并行环境数: 4 (根据CPU核心数调整)
|
| 227 |
+
|
| 228 |
+
### 停止条件
|
| 229 |
+
- 连续100局游戏平均分数无提升
|
| 230 |
+
- 用户手动停止
|
README.md
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 2048 AI Trainer
|
| 2 |
+
|
| 3 |
+
基于 Transformer 的 2048 游戏人工智能训练器,使用 PPO(Proximal Policy Optimization)强化学习算法,让 AI 学会玩 2048 游戏。
|
| 4 |
+
|
| 5 |
+
## 项目简介
|
| 6 |
+
|
| 7 |
+
本项目实现了一个完整的 2048 游戏 AI 训练系统,包括:
|
| 8 |
+
|
| 9 |
+
- **游戏引擎**: 完整的 2048 游戏逻辑实现
|
| 10 |
+
- **深度学习模型**: 基于 Transformer 架构的策略网络
|
| 11 |
+
- **强化学习训练**: PPO 算法实现
|
| 12 |
+
- **可视化界面**: PyQt5 图形界面,支持训练监控和演示
|
| 13 |
+
- **命令行工具**: 支持无 GUI 的训练和演示模式
|
| 14 |
+
|
| 15 |
+
## 功能特点
|
| 16 |
+
|
| 17 |
+
### 1. Transformer 模型
|
| 18 |
+
|
| 19 |
+
采用小型 Transformer 架构,专为 CPU 训练优化:
|
| 20 |
+
|
| 21 |
+
- **参数量**: 约 77,000 个参数(~300KB)
|
| 22 |
+
- **输入处理**:
|
| 23 |
+
- 棋盘状态编码为 16 个 token(每个格子对应一个 token)
|
| 24 |
+
- 分数特征(累积分数、局面分数)作为额外输入
|
| 25 |
+
- 位置编码:行/列位置嵌入
|
| 26 |
+
- **网络结构**:
|
| 27 |
+
- 2 层 Transformer Encoder
|
| 28 |
+
- 4 个注意力头
|
| 29 |
+
- 隐藏维度 64
|
| 30 |
+
- 前馈网络维度 128
|
| 31 |
+
- **输出**:
|
| 32 |
+
- 策略头:4 个动作(上/下/左/右)的概率分布
|
| 33 |
+
- 价值头:当前状态的价值评估
|
| 34 |
+
|
| 35 |
+
### 2. 双评分机制
|
| 36 |
+
|
| 37 |
+
#### 累积分数(Accumulated Score)
|
| 38 |
+
传统 2048 计分方式,每次合成砖块获得合成后砖块的数值作为分数。
|
| 39 |
+
|
| 40 |
+
#### 局面分数(Situational Score)
|
| 41 |
+
综合评估当前局面的质量,鼓励 AI 保持良好局面:
|
| 42 |
+
|
| 43 |
+
```
|
| 44 |
+
局面分数 = 空格数 × 10 + 最大连续相邻数 × 15 + log₂(最大砖块) × 5 + 单调性奖励
|
| 45 |
+
```
|
| 46 |
+
|
| 47 |
+
- **空格数**: 空格越多,操作空间越大
|
| 48 |
+
- **连续相邻数**: 如 512-1024-2048 连续排列,便于后续合并
|
| 49 |
+
- **单调性**: 鼓励数字按方向有序排列
|
| 50 |
+
|
| 51 |
+
### 3. PPO 训练算法
|
| 52 |
+
|
| 53 |
+
使用 Proximal Policy Optimization 算法进行训练:
|
| 54 |
+
|
| 55 |
+
- **优势估计**: GAE(Generalized Advantage Estimation)
|
| 56 |
+
- **策略裁剪**: 防止策略更新过大
|
| 57 |
+
- **价值函数**: 辅助训练,提供状态价值估计
|
| 58 |
+
- **熵正则化**: 鼓励探索
|
| 59 |
+
|
| 60 |
+
### 4. GUI 界面
|
| 61 |
+
|
| 62 |
+
基于 PyQt5 的图形界面:
|
| 63 |
+
|
| 64 |
+
- **训练模式**:
|
| 65 |
+
- 设置训练局数
|
| 66 |
+
- 实时显示训练进度
|
| 67 |
+
- 分数曲线可视化
|
| 68 |
+
- 训练完成后自动保存模型
|
| 69 |
+
|
| 70 |
+
- **演示模式**:
|
| 71 |
+
- 键盘手动操作
|
| 72 |
+
- AI 托管模式
|
| 73 |
+
- 单步执行
|
| 74 |
+
- 自动连续执行
|
| 75 |
+
- 实时局面分数曲线
|
| 76 |
+
|
| 77 |
+
## 安装
|
| 78 |
+
|
| 79 |
+
### 环境要求
|
| 80 |
+
|
| 81 |
+
- Python 3.8+
|
| 82 |
+
- Windows / Linux / macOS
|
| 83 |
+
|
| 84 |
+
### 安装依赖
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
cd game2048
|
| 88 |
+
pip install -r requirements.txt
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
### 依赖列表
|
| 92 |
+
|
| 93 |
+
```
|
| 94 |
+
torch>=2.0.0 # 深度学习框架
|
| 95 |
+
numpy<2 # 数值计算
|
| 96 |
+
PyQt5>=5.15.0 # GUI 框架
|
| 97 |
+
matplotlib>=3.7.0 # 绘图库
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## 使用方法
|
| 101 |
+
|
| 102 |
+
### 1. GUI 模式
|
| 103 |
+
|
| 104 |
+
```bash
|
| 105 |
+
python main.py
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
启动图形界面后:
|
| 109 |
+
|
| 110 |
+
**训练模式**:
|
| 111 |
+
1. 选择 "Training Mode"
|
| 112 |
+
2. 设置训练局数(默认 500)
|
| 113 |
+
3. 点击 "Start Training" 开始训练
|
| 114 |
+
4. 训练完成后自动保存到 `checkpoints/model.pt`
|
| 115 |
+
5. 可随时点击 "Stop Training" 停止
|
| 116 |
+
|
| 117 |
+
**演示模式**:
|
| 118 |
+
1. 选择 "Demo Mode"
|
| 119 |
+
2. 点击 "Load Model" 加载已训练模型
|
| 120 |
+
3. 使用方式:
|
| 121 |
+
- 键盘方向键:手动操作
|
| 122 |
+
- "AI Mode":切换 AI 托管
|
| 123 |
+
- "Step":AI 单步执行
|
| 124 |
+
- "Auto":AI 自动连续执行
|
| 125 |
+
- "Reset":重新开始游戏
|
| 126 |
+
|
| 127 |
+
### 2. 命令行训练
|
| 128 |
+
|
| 129 |
+
```bash
|
| 130 |
+
# 训练 1000 局
|
| 131 |
+
python main.py --train --games 1000
|
| 132 |
+
|
| 133 |
+
# 使用 4 个并行环境
|
| 134 |
+
python main.py --train --games 1000 --envs 4
|
| 135 |
+
|
| 136 |
+
# 设置随机种子
|
| 137 |
+
python main.py --train --games 1000 --seed 42
|
| 138 |
+
```
|
| 139 |
+
|
| 140 |
+
### 3. 演示模式
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
# 加载模型并演示 5 局
|
| 144 |
+
python main.py --demo --model checkpoints/model.pt --games 5
|
| 145 |
+
|
| 146 |
+
# 不加载模型(随机权重)
|
| 147 |
+
python main.py --demo --games 3
|
| 148 |
+
```
|
| 149 |
+
|
| 150 |
+
### 4. 简单训练脚本
|
| 151 |
+
|
| 152 |
+
```bash
|
| 153 |
+
python train_simple.py
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
修改脚本末尾可调整训练参数:
|
| 157 |
+
|
| 158 |
+
```python
|
| 159 |
+
train_simple(num_games=500, save_path="checkpoints/model.pt")
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
## 项目结构
|
| 163 |
+
|
| 164 |
+
```
|
| 165 |
+
game2048/
|
| 166 |
+
├── TASK.md # 任务需求文档
|
| 167 |
+
├── PLAN.md # 项目计划文档
|
| 168 |
+
├── README.md # 本文件
|
| 169 |
+
├── main.py # 程序入口
|
| 170 |
+
├── game.py # 2048 游戏核心逻辑
|
| 171 |
+
│ ├── Game2048 # 游戏类
|
| 172 |
+
│ ├── move() # 移动操作
|
| 173 |
+
│ ├── get_state() # 获取状态
|
| 174 |
+
│ └── calculate_situational_score() # 计算局面分数
|
| 175 |
+
│
|
| 176 |
+
├── model.py # Transformer 模型
|
| 177 |
+
│ ├── Game2048Transformer # Transformer 模型
|
| 178 |
+
│ ├── Game2048CNN # CNN 备选模型
|
| 179 |
+
│ └── get_action() # 动作选择
|
| 180 |
+
│
|
| 181 |
+
├── trainer.py # PPO 训练器
|
| 182 |
+
│ ├── PPOTrainer # PPO 训练类
|
| 183 |
+
│ ├── RolloutBuffer # 经验缓冲区
|
| 184 |
+
│ ├── Transition # 状态转移数据结构
|
| 185 |
+
│ └── TrainingStats # 训练统计
|
| 186 |
+
│
|
| 187 |
+
├── parallel.py # 并行训练环境
|
| 188 |
+
│ ├── ParallelGameEnv # 并行游戏环境
|
| 189 |
+
│ ├── TrainingWorker # 训练工作器
|
| 190 |
+
│ └── TrainingLoop # 训练循环
|
| 191 |
+
│
|
| 192 |
+
├── gui.py # GUI 界面
|
| 193 |
+
│ ├── MainWindow # 主窗口
|
| 194 |
+
│ ├── GameBoardWidget # 游戏面板
|
| 195 |
+
│ ├── ScoreWidget # 分数显示
|
| 196 |
+
│ ├── PlotCanvas # 曲线绑图
|
| 197 |
+
│ └── SimpleTrainingThread # 训练线程
|
| 198 |
+
│
|
| 199 |
+
├── train_simple.py # 简化训练脚本
|
| 200 |
+
├── utils.py # 工具函数
|
| 201 |
+
├── requirements.txt # 依赖列表
|
| 202 |
+
└── checkpoints/ # 模型保存目录
|
| 203 |
+
└── model.pt # 训练好的模型
|
| 204 |
+
```
|
| 205 |
+
|
| 206 |
+
## 模型架构详解
|
| 207 |
+
|
| 208 |
+
### 输入表示
|
| 209 |
+
|
| 210 |
+
```python
|
| 211 |
+
# 棋盘状态 (4, 4)
|
| 212 |
+
# 每个格子值转换为 log₂(value),空格为 0
|
| 213 |
+
state = [[0, 1, 2, 0], # 对应 [空, 2, 4, 空]
|
| 214 |
+
[1, 2, 3, 1], # 对应 [2, 4, 8, 2]
|
| 215 |
+
...]
|
| 216 |
+
|
| 217 |
+
# 分数特征 (2,)
|
| 218 |
+
# [归一化累积分数, 归一化局面分数]
|
| 219 |
+
scores = [0.05, 0.85]
|
| 220 |
+
```
|
| 221 |
+
|
| 222 |
+
### 网络结构
|
| 223 |
+
|
| 224 |
+
```
|
| 225 |
+
Input: (batch, 4, 4) board + (batch, 2) scores
|
| 226 |
+
↓
|
| 227 |
+
Position Embedding: (batch, 16, 64)
|
| 228 |
+
+ Spatial Embedding: (batch, 16, 64)
|
| 229 |
+
+ Score Embedding: (batch, 1, 64)
|
| 230 |
+
↓
|
| 231 |
+
Transformer Encoder (2 layers)
|
| 232 |
+
- Multi-Head Attention (4 heads)
|
| 233 |
+
- Feed-Forward Network (dim=128)
|
| 234 |
+
↓
|
| 235 |
+
Global Mean Pooling: (batch, 64)
|
| 236 |
+
↓
|
| 237 |
+
├── Policy Head → (batch, 4) # 动作概率
|
| 238 |
+
└── Value Head → (batch, 1) # 状态价值
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
## 训练策略
|
| 242 |
+
|
| 243 |
+
### 奖励设计
|
| 244 |
+
|
| 245 |
+
```python
|
| 246 |
+
reward = 局面分数变化 × 0.7 + 累积分数增量 × 0.003
|
| 247 |
+
|
| 248 |
+
# 游戏结束惩罚
|
| 249 |
+
if game_over:
|
| 250 |
+
reward -= 10.0
|
| 251 |
+
```
|
| 252 |
+
|
| 253 |
+
### 超参数
|
| 254 |
+
|
| 255 |
+
| 参数 | 值 |
|
| 256 |
+
|------|-----|
|
| 257 |
+
| Learning Rate | 3e-4 |
|
| 258 |
+
| Batch Size | 64 |
|
| 259 |
+
| PPO Clip Ratio | 0.2 |
|
| 260 |
+
| GAE Lambda | 0.95 |
|
| 261 |
+
| Discount Factor (γ) | 0.99 |
|
| 262 |
+
| Entropy Coefficient | 0.01 |
|
| 263 |
+
|
| 264 |
+
## 训练结果
|
| 265 |
+
|
| 266 |
+
### 500 局训练后
|
| 267 |
+
|
| 268 |
+
| 指标 | 数值 |
|
| 269 |
+
|------|------|
|
| 270 |
+
| 平均分数 | ~2500 |
|
| 271 |
+
| 最高分数 | 6812 |
|
| 272 |
+
| 最大砖块 | 512 |
|
| 273 |
+
| 训练时间 | ~9 分钟 |
|
| 274 |
+
|
| 275 |
+
### 分数分布
|
| 276 |
+
|
| 277 |
+
```
|
| 278 |
+
随机权重: 平均 ~800, 最高 ~2000
|
| 279 |
+
训练 500 局: 平均 ~2500, 最高 ~6800
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
## 开发说明
|
| 283 |
+
|
| 284 |
+
### 添加新功能
|
| 285 |
+
|
| 286 |
+
1. **修改局面评分**: 编辑 `game.py` 中的 `calculate_situational_score()`
|
| 287 |
+
2. **调整模型**: 修改 `model.py` 中的网络结构
|
| 288 |
+
3. **优化训练**: 调整 `trainer.py` 中的超参数
|
| 289 |
+
|
| 290 |
+
### 调试模式
|
| 291 |
+
|
| 292 |
+
```python
|
| 293 |
+
# 在 game.py 中测试游戏逻辑
|
| 294 |
+
if __name__ == "__main__":
|
| 295 |
+
game = Game2048()
|
| 296 |
+
print(game)
|
| 297 |
+
game.move(Game2048.LEFT)
|
| 298 |
+
print(game)
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
## 已知问题
|
| 302 |
+
|
| 303 |
+
- Windows 下 PyTorch 可能需要特定版本以避免 DLL 加载问题
|
| 304 |
+
- NumPy 2.x 与 PyTorch 存在兼容性问题,建议使用 NumPy < 2
|
| 305 |
+
|
| 306 |
+
## 参考资料
|
| 307 |
+
|
| 308 |
+
- [PPO 论文](https://arxiv.org/abs/1707.06347)
|
| 309 |
+
- [Transformer 论文](https://arxiv.org/abs/1706.03762)
|
| 310 |
+
- [2048 游戏](https://play2048.co/)
|
| 311 |
+
|
| 312 |
+
## 许可证
|
| 313 |
+
|
| 314 |
+
MIT License
|
| 315 |
+
|
| 316 |
+
---
|
| 317 |
+
|
| 318 |
+
*本项目由 GLM-5 开发实现*
|
TASK.md
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 需求
|
| 2 |
+
|
| 3 |
+
用Python制作一个基于Transformer的2048游戏AI
|
| 4 |
+
|
| 5 |
+
可以创建一个简单的窗口让我看到实时的训练状态,包括训练时AI的当前分数、训练速度、历史已经结束的游戏的*累积分数*变化、当前游戏的*局面分数*变化,窗口可切换演示模式和训练模式
|
| 6 |
+
|
| 7 |
+
2048的一部分规定补充:
|
| 8 |
+
|
| 9 |
+
- 开局会在随机位置生成一个2,后续在空格生成的砖块在随机位置且值为2或4
|
| 10 |
+
- 同时存在两个分数机制:累积分数和局面分数
|
| 11 |
+
- 累积分数机制:每合成一个增加合成后的砖块的数值
|
| 12 |
+
- 局面分数机制:根据*空格数量*、*相邻连续数字数量(相邻砖块拥有相邻数字的数量,取最大值,不计算重复,如512 1024 2048为3,256 512 512 2为2)*、*最高数字的对数*设计以恶合理的函数,局面分数越高越好
|
| 13 |
+
|
| 14 |
+
## 第一步:训练模型
|
| 15 |
+
|
| 16 |
+
为模型提供当前的局面状态和两种分数(累积分数的权重低于其它,目的是局面分数保持更高),输出只有上下左右
|
| 17 |
+
|
| 18 |
+
游戏结束后再开始新的游戏继续训练,直到训练没有长进或我手动点击停止
|
| 19 |
+
|
| 20 |
+
性能允许的话可以多开2048游戏进行训练
|
| 21 |
+
|
| 22 |
+
## 第二步:演示模式
|
| 23 |
+
|
| 24 |
+
在窗口中可以手动操作或AI托管,AI托管可以单步执行也可以持续下一步,可以看到两种分数及局面分数变化
|
checkpoints/model_1500.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:11ff8939d4676c1ad76ca9c5b0ec2616bfc48ed10788c36ff2b0e50cd7fbcb35
|
| 3 |
+
size 415458
|
game.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2048游戏核心逻辑
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
from typing import Tuple, Optional
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Game2048:
|
| 10 |
+
"""2048游戏核心类"""
|
| 11 |
+
|
| 12 |
+
# 动作定义
|
| 13 |
+
UP = 0
|
| 14 |
+
DOWN = 1
|
| 15 |
+
LEFT = 2
|
| 16 |
+
RIGHT = 3
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.board: np.ndarray = np.zeros((4, 4), dtype=np.int64)
|
| 20 |
+
self.accumulated_score: int = 0
|
| 21 |
+
self.situational_score: float = 0.0
|
| 22 |
+
self.game_over: bool = False
|
| 23 |
+
self.moves_count: int = 0
|
| 24 |
+
self.reset()
|
| 25 |
+
|
| 26 |
+
def reset(self) -> np.ndarray:
|
| 27 |
+
"""重置游戏,返回初始状态"""
|
| 28 |
+
self.board = np.zeros((4, 4), dtype=np.int64)
|
| 29 |
+
self.accumulated_score = 0
|
| 30 |
+
self.situational_score = 0.0
|
| 31 |
+
self.game_over = False
|
| 32 |
+
self.moves_count = 0
|
| 33 |
+
|
| 34 |
+
# 开局生成一个2
|
| 35 |
+
self._spawn_tile(value=2)
|
| 36 |
+
self._update_situational_score()
|
| 37 |
+
return self.get_state()
|
| 38 |
+
|
| 39 |
+
def _spawn_tile(self, value: Optional[int] = None) -> bool:
|
| 40 |
+
"""
|
| 41 |
+
在空格生成新砖块
|
| 42 |
+
开局时value=2,后续随机2或4
|
| 43 |
+
返回是否成功生成
|
| 44 |
+
"""
|
| 45 |
+
empty_cells = list(zip(*np.where(self.board == 0)))
|
| 46 |
+
if not empty_cells:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
row, col = random.choice(empty_cells)
|
| 50 |
+
if value is None:
|
| 51 |
+
# 90%概率生成2,10%概率生成4
|
| 52 |
+
value = 2 if random.random() < 0.9 else 4
|
| 53 |
+
self.board[row, col] = value
|
| 54 |
+
return True
|
| 55 |
+
|
| 56 |
+
def _compress(self, line: np.ndarray) -> Tuple[np.ndarray, int]:
|
| 57 |
+
"""
|
| 58 |
+
压缩一行/列,将非零元素移到一端
|
| 59 |
+
返回压缩后的行和合并得分
|
| 60 |
+
"""
|
| 61 |
+
# 移除零,填充到末尾
|
| 62 |
+
non_zero = line[line != 0]
|
| 63 |
+
new_line = np.zeros_like(line)
|
| 64 |
+
score = 0
|
| 65 |
+
|
| 66 |
+
pos = 0
|
| 67 |
+
i = 0
|
| 68 |
+
while i < len(non_zero):
|
| 69 |
+
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i + 1]:
|
| 70 |
+
# 合并
|
| 71 |
+
new_line[pos] = non_zero[i] * 2
|
| 72 |
+
score += new_line[pos]
|
| 73 |
+
i += 2
|
| 74 |
+
else:
|
| 75 |
+
new_line[pos] = non_zero[i]
|
| 76 |
+
i += 1
|
| 77 |
+
pos += 1
|
| 78 |
+
|
| 79 |
+
return new_line, score
|
| 80 |
+
|
| 81 |
+
def _move_left(self) -> Tuple[bool, int]:
|
| 82 |
+
"""向左移动,返回(是否移动, 得分)"""
|
| 83 |
+
moved = False
|
| 84 |
+
total_score = 0
|
| 85 |
+
|
| 86 |
+
for i in range(4):
|
| 87 |
+
original = self.board[i].copy()
|
| 88 |
+
new_line, score = self._compress(self.board[i])
|
| 89 |
+
self.board[i] = new_line
|
| 90 |
+
total_score += score
|
| 91 |
+
if not np.array_equal(original, new_line):
|
| 92 |
+
moved = True
|
| 93 |
+
|
| 94 |
+
return moved, total_score
|
| 95 |
+
|
| 96 |
+
def _move_right(self) -> Tuple[bool, int]:
|
| 97 |
+
"""向右移动"""
|
| 98 |
+
moved = False
|
| 99 |
+
total_score = 0
|
| 100 |
+
|
| 101 |
+
for i in range(4):
|
| 102 |
+
original = self.board[i].copy()
|
| 103 |
+
new_line, score = self._compress(self.board[i][::-1])
|
| 104 |
+
self.board[i] = new_line[::-1]
|
| 105 |
+
total_score += score
|
| 106 |
+
if not np.array_equal(original, self.board[i]):
|
| 107 |
+
moved = True
|
| 108 |
+
|
| 109 |
+
return moved, total_score
|
| 110 |
+
|
| 111 |
+
def _move_up(self) -> Tuple[bool, int]:
|
| 112 |
+
"""向上移动"""
|
| 113 |
+
moved = False
|
| 114 |
+
total_score = 0
|
| 115 |
+
|
| 116 |
+
for j in range(4):
|
| 117 |
+
original = self.board[:, j].copy()
|
| 118 |
+
new_line, score = self._compress(self.board[:, j])
|
| 119 |
+
self.board[:, j] = new_line
|
| 120 |
+
total_score += score
|
| 121 |
+
if not np.array_equal(original, new_line):
|
| 122 |
+
moved = True
|
| 123 |
+
|
| 124 |
+
return moved, total_score
|
| 125 |
+
|
| 126 |
+
def _move_down(self) -> Tuple[bool, int]:
|
| 127 |
+
"""向下移动"""
|
| 128 |
+
moved = False
|
| 129 |
+
total_score = 0
|
| 130 |
+
|
| 131 |
+
for j in range(4):
|
| 132 |
+
original = self.board[:, j].copy()
|
| 133 |
+
new_line, score = self._compress(self.board[:, j][::-1])
|
| 134 |
+
self.board[:, j] = new_line[::-1]
|
| 135 |
+
total_score += score
|
| 136 |
+
if not np.array_equal(original, self.board[:, j]):
|
| 137 |
+
moved = True
|
| 138 |
+
|
| 139 |
+
return moved, total_score
|
| 140 |
+
|
| 141 |
+
def move(self, direction: int) -> Tuple[np.ndarray, float, bool, bool]:
|
| 142 |
+
"""
|
| 143 |
+
执行移动
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
direction: 0=上, 1=下, 2=左, 3=右
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
state: 新状态
|
| 150 |
+
reward: 奖励(累积分数增量 + 局面分数变化)
|
| 151 |
+
moved: 是否成功移动
|
| 152 |
+
done: 游戏是否结束
|
| 153 |
+
"""
|
| 154 |
+
if self.game_over:
|
| 155 |
+
return self.get_state(), 0.0, False, True
|
| 156 |
+
|
| 157 |
+
old_accumulated = self.accumulated_score
|
| 158 |
+
old_situational = self.situational_score
|
| 159 |
+
|
| 160 |
+
# 执行移动
|
| 161 |
+
if direction == self.UP:
|
| 162 |
+
moved, score = self._move_up()
|
| 163 |
+
elif direction == self.DOWN:
|
| 164 |
+
moved, score = self._move_down()
|
| 165 |
+
elif direction == self.LEFT:
|
| 166 |
+
moved, score = self._move_left()
|
| 167 |
+
elif direction == self.RIGHT:
|
| 168 |
+
moved, score = self._move_right()
|
| 169 |
+
else:
|
| 170 |
+
raise ValueError(f"Invalid direction: {direction}")
|
| 171 |
+
|
| 172 |
+
if moved:
|
| 173 |
+
self.accumulated_score += score
|
| 174 |
+
self.moves_count += 1
|
| 175 |
+
self._spawn_tile()
|
| 176 |
+
self._update_situational_score()
|
| 177 |
+
|
| 178 |
+
# 检查游戏是否结束
|
| 179 |
+
self.game_over = self._check_game_over()
|
| 180 |
+
|
| 181 |
+
# 计算奖励
|
| 182 |
+
accumulated_delta = self.accumulated_score - old_accumulated
|
| 183 |
+
situational_delta = self.situational_score - old_situational
|
| 184 |
+
|
| 185 |
+
# 奖励 = 局面分数变化 * 0.7 + 累积分数增量 * 0.3 / 100 (归一化)
|
| 186 |
+
reward = situational_delta * 0.7 + accumulated_delta * 0.003
|
| 187 |
+
|
| 188 |
+
# 游戏结束惩罚
|
| 189 |
+
if self.game_over:
|
| 190 |
+
reward -= 10.0
|
| 191 |
+
|
| 192 |
+
return self.get_state(), reward, moved, self.game_over
|
| 193 |
+
|
| 194 |
+
def _check_game_over(self) -> bool:
|
| 195 |
+
"""检查游戏是否结束"""
|
| 196 |
+
# 还有空格
|
| 197 |
+
if np.any(self.board == 0):
|
| 198 |
+
return False
|
| 199 |
+
|
| 200 |
+
# 检查是否还能合并
|
| 201 |
+
for i in range(4):
|
| 202 |
+
for j in range(4):
|
| 203 |
+
if i < 3 and self.board[i, j] == self.board[i + 1, j]:
|
| 204 |
+
return False
|
| 205 |
+
if j < 3 and self.board[i, j] == self.board[i, j + 1]:
|
| 206 |
+
return False
|
| 207 |
+
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
def _update_situational_score(self) -> None:
|
| 211 |
+
"""
|
| 212 |
+
更新局面分数
|
| 213 |
+
局面分数 = 空格数 * 10 + 最大连续相邻数 * 15 + log2(最大数字) * 5 + 单调性奖励
|
| 214 |
+
"""
|
| 215 |
+
empty_cells = np.sum(self.board == 0)
|
| 216 |
+
|
| 217 |
+
# 计算最大连续相邻数字
|
| 218 |
+
max_consecutive = self._calculate_max_consecutive()
|
| 219 |
+
|
| 220 |
+
# 最高数字的对数
|
| 221 |
+
max_tile = np.max(self.board)
|
| 222 |
+
max_tile_log = np.log2(max_tile) if max_tile > 0 else 0
|
| 223 |
+
|
| 224 |
+
# 单调性评估(鼓励数字按方向排列)
|
| 225 |
+
monotonicity = self._calculate_monotonicity()
|
| 226 |
+
|
| 227 |
+
# 局面分数
|
| 228 |
+
self.situational_score = (
|
| 229 |
+
empty_cells * 10 +
|
| 230 |
+
max_consecutive * 15 +
|
| 231 |
+
max_tile_log * 5 +
|
| 232 |
+
monotonicity * 5
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def _calculate_max_consecutive(self) -> int:
|
| 236 |
+
"""
|
| 237 |
+
计算最大连续相邻数字数量
|
| 238 |
+
相邻砖块拥有相邻数字,如512 1024 2048为3
|
| 239 |
+
"""
|
| 240 |
+
max_count = 0
|
| 241 |
+
|
| 242 |
+
# 检查所有行
|
| 243 |
+
for i in range(4):
|
| 244 |
+
count = self._count_consecutive_in_line(self.board[i])
|
| 245 |
+
max_count = max(max_count, count)
|
| 246 |
+
|
| 247 |
+
# 检查所有列
|
| 248 |
+
for j in range(4):
|
| 249 |
+
count = self._count_consecutive_in_line(self.board[:, j])
|
| 250 |
+
max_count = max(max_count, count)
|
| 251 |
+
|
| 252 |
+
return max_count
|
| 253 |
+
|
| 254 |
+
def _count_consecutive_in_line(self, line: np.ndarray) -> int:
|
| 255 |
+
"""计算一行/列中的最大连续相邻数字"""
|
| 256 |
+
non_zero = line[line != 0]
|
| 257 |
+
if len(non_zero) < 2:
|
| 258 |
+
return 0
|
| 259 |
+
|
| 260 |
+
max_count = 1
|
| 261 |
+
current_count = 1
|
| 262 |
+
|
| 263 |
+
for i in range(1, len(non_zero)):
|
| 264 |
+
# 相邻数字:2的幂次相邻
|
| 265 |
+
if abs(np.log2(non_zero[i]) - np.log2(non_zero[i-1])) == 1:
|
| 266 |
+
current_count += 1
|
| 267 |
+
max_count = max(max_count, current_count)
|
| 268 |
+
else:
|
| 269 |
+
current_count = 1
|
| 270 |
+
|
| 271 |
+
return max_count
|
| 272 |
+
|
| 273 |
+
def _calculate_monotonicity(self) -> float:
|
| 274 |
+
"""
|
| 275 |
+
计算单调性
|
| 276 |
+
鼓励数字在行/列上递增或递减
|
| 277 |
+
"""
|
| 278 |
+
score = 0.0
|
| 279 |
+
|
| 280 |
+
# 行单调性
|
| 281 |
+
for i in range(4):
|
| 282 |
+
row = self.board[i]
|
| 283 |
+
row = row[row != 0]
|
| 284 |
+
if len(row) >= 2:
|
| 285 |
+
# 检查递增
|
| 286 |
+
if all(row[i] <= row[i+1] for i in range(len(row)-1)):
|
| 287 |
+
score += len(row) - 1
|
| 288 |
+
# 检查递减
|
| 289 |
+
elif all(row[i] >= row[i+1] for i in range(len(row)-1)):
|
| 290 |
+
score += len(row) - 1
|
| 291 |
+
|
| 292 |
+
# 列单调性
|
| 293 |
+
for j in range(4):
|
| 294 |
+
col = self.board[:, j]
|
| 295 |
+
col = col[col != 0]
|
| 296 |
+
if len(col) >= 2:
|
| 297 |
+
if all(col[i] <= col[i+1] for i in range(len(col)-1)):
|
| 298 |
+
score += len(col) - 1
|
| 299 |
+
elif all(col[i] >= col[i+1] for i in range(len(col)-1)):
|
| 300 |
+
score += len(col) - 1
|
| 301 |
+
|
| 302 |
+
return score
|
| 303 |
+
|
| 304 |
+
def get_state(self) -> np.ndarray:
|
| 305 |
+
"""
|
| 306 |
+
获取当前状态表示
|
| 307 |
+
返回: (4, 4) 棋盘,值为log2(value),空格为0
|
| 308 |
+
"""
|
| 309 |
+
state = np.zeros((4, 4), dtype=np.float32)
|
| 310 |
+
non_zero_mask = self.board > 0
|
| 311 |
+
state[non_zero_mask] = np.log2(self.board[non_zero_mask])
|
| 312 |
+
return state
|
| 313 |
+
|
| 314 |
+
def get_state_with_scores(self) -> np.ndarray:
|
| 315 |
+
"""
|
| 316 |
+
获取带分数的状态表示
|
| 317 |
+
返回: (18,) 包含16个棋盘位置 + 2个分数
|
| 318 |
+
"""
|
| 319 |
+
board_state = self.get_state().flatten()
|
| 320 |
+
|
| 321 |
+
# 归一化分数
|
| 322 |
+
max_accumulated = 50000 # 假设最大累积分数
|
| 323 |
+
max_situational = 200 # 假设最大局面分数
|
| 324 |
+
|
| 325 |
+
normalized_accumulated = min(self.accumulated_score / max_accumulated, 1.0)
|
| 326 |
+
normalized_situational = min(self.situational_score / max_situational, 1.0)
|
| 327 |
+
|
| 328 |
+
return np.concatenate([
|
| 329 |
+
board_state / 15.0, # 归一化到 [0, 1],最大2048 = log2(2048) = 11
|
| 330 |
+
[normalized_accumulated, normalized_situational]
|
| 331 |
+
]).astype(np.float32)
|
| 332 |
+
|
| 333 |
+
def get_valid_actions(self) -> np.ndarray:
|
| 334 |
+
"""获取当前可执行的动作"""
|
| 335 |
+
valid = np.zeros(4, dtype=bool)
|
| 336 |
+
|
| 337 |
+
# 临时保存状态
|
| 338 |
+
old_board = self.board.copy()
|
| 339 |
+
old_accumulated = self.accumulated_score
|
| 340 |
+
|
| 341 |
+
for direction in range(4):
|
| 342 |
+
if direction == self.UP:
|
| 343 |
+
moved, _ = self._move_up()
|
| 344 |
+
elif direction == self.DOWN:
|
| 345 |
+
moved, _ = self._move_down()
|
| 346 |
+
elif direction == self.LEFT:
|
| 347 |
+
moved, _ = self._move_left()
|
| 348 |
+
else:
|
| 349 |
+
moved, _ = self._move_right()
|
| 350 |
+
|
| 351 |
+
valid[direction] = moved
|
| 352 |
+
self.board = old_board.copy()
|
| 353 |
+
|
| 354 |
+
self.accumulated_score = old_accumulated
|
| 355 |
+
return valid
|
| 356 |
+
|
| 357 |
+
def get_max_tile(self) -> int:
|
| 358 |
+
"""获取最大砖块值"""
|
| 359 |
+
return int(np.max(self.board))
|
| 360 |
+
|
| 361 |
+
def get_empty_cells_count(self) -> int:
|
| 362 |
+
"""获取空格数量"""
|
| 363 |
+
return int(np.sum(self.board == 0))
|
| 364 |
+
|
| 365 |
+
def __str__(self) -> str:
|
| 366 |
+
"""字符串表示"""
|
| 367 |
+
result = []
|
| 368 |
+
for row in self.board:
|
| 369 |
+
result.append(" | ".join(f"{int(x):4d}" if x > 0 else " ." for x in row))
|
| 370 |
+
return "\n".join(result)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if __name__ == "__main__":
|
| 374 |
+
# 测试游戏
|
| 375 |
+
game = Game2048()
|
| 376 |
+
print("Initial state:")
|
| 377 |
+
print(game)
|
| 378 |
+
print(f"Accumulated score: {game.accumulated_score}")
|
| 379 |
+
print(f"Situational score: {game.situational_score}")
|
| 380 |
+
|
| 381 |
+
# 测试一些移动
|
| 382 |
+
moves = ['UP', 'LEFT', 'DOWN', 'RIGHT']
|
| 383 |
+
for i in range(10):
|
| 384 |
+
direction = i % 4
|
| 385 |
+
state, reward, moved, done = game.move(direction)
|
| 386 |
+
print(f"\nMove {moves[direction]}: moved={moved}, done={done}")
|
| 387 |
+
print(game)
|
| 388 |
+
print(f"Reward: {reward:.2f}")
|
gui.py
ADDED
|
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GUI界面 - PyQt5实现
|
| 3 |
+
"""
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
import numpy as np
|
| 7 |
+
from PyQt5.QtWidgets import (
|
| 8 |
+
QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout,
|
| 9 |
+
QGridLayout, QLabel, QPushButton, QComboBox, QGroupBox,
|
| 10 |
+
QFrame, QSplitter, QStatusBar, QProgressBar, QCheckBox,
|
| 11 |
+
QFileDialog, QMessageBox, QSpinBox
|
| 12 |
+
)
|
| 13 |
+
from PyQt5.QtCore import Qt, QTimer, QThread, pyqtSignal
|
| 14 |
+
from PyQt5.QtGui import QFont, QColor, QPalette
|
| 15 |
+
import matplotlib
|
| 16 |
+
matplotlib.use('Qt5Agg')
|
| 17 |
+
from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
|
| 18 |
+
from matplotlib.figure import Figure
|
| 19 |
+
from collections import deque
|
| 20 |
+
import time
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from game import Game2048
|
| 24 |
+
from model import Game2048Transformer
|
| 25 |
+
from trainer import PPOTrainer, RolloutBuffer, Transition
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# 砖块颜色配置
|
| 29 |
+
TILE_COLORS = {
|
| 30 |
+
0: '#cdc1b4',
|
| 31 |
+
2: '#eee4da',
|
| 32 |
+
4: '#ede0c8',
|
| 33 |
+
8: '#f2b179',
|
| 34 |
+
16: '#f59563',
|
| 35 |
+
32: '#f67c5f',
|
| 36 |
+
64: '#f65e3b',
|
| 37 |
+
128: '#edcf72',
|
| 38 |
+
256: '#edcc61',
|
| 39 |
+
512: '#edc850',
|
| 40 |
+
1024: '#edc53f',
|
| 41 |
+
2048: '#edc22e',
|
| 42 |
+
4096: '#3c3a32',
|
| 43 |
+
8192: '#3c3a32',
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
TILE_TEXT_COLORS = {
|
| 47 |
+
2: '#776e65',
|
| 48 |
+
4: '#776e65',
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class GameBoardWidget(QWidget):
|
| 53 |
+
"""2048游戏面板组件"""
|
| 54 |
+
|
| 55 |
+
def __init__(self, parent=None):
|
| 56 |
+
super().__init__(parent)
|
| 57 |
+
self.board = np.zeros((4, 4), dtype=np.int64)
|
| 58 |
+
self.cell_size = 80
|
| 59 |
+
self.padding = 5
|
| 60 |
+
self.init_ui()
|
| 61 |
+
|
| 62 |
+
def init_ui(self):
|
| 63 |
+
"""初始化UI"""
|
| 64 |
+
self.setFixedSize(
|
| 65 |
+
self.cell_size * 4 + self.padding * 5,
|
| 66 |
+
self.cell_size * 4 + self.padding * 5
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
def set_board(self, board: np.ndarray):
|
| 70 |
+
"""设置棋盘状态"""
|
| 71 |
+
self.board = board.copy()
|
| 72 |
+
self.update()
|
| 73 |
+
|
| 74 |
+
def paintEvent(self, event):
|
| 75 |
+
"""绘制棋盘"""
|
| 76 |
+
from PyQt5.QtGui import QPainter, QColor, QPen, QBrush, QFont
|
| 77 |
+
|
| 78 |
+
painter = QPainter(self)
|
| 79 |
+
painter.setRenderHint(QPainter.Antialiasing)
|
| 80 |
+
|
| 81 |
+
# 背景
|
| 82 |
+
painter.fillRect(self.rect(), QColor('#bbada0'))
|
| 83 |
+
|
| 84 |
+
# 绘制每个格子
|
| 85 |
+
for i in range(4):
|
| 86 |
+
for j in range(4):
|
| 87 |
+
x = self.padding + j * (self.cell_size + self.padding)
|
| 88 |
+
y = self.padding + i * (self.cell_size + self.padding)
|
| 89 |
+
value = self.board[i, j]
|
| 90 |
+
|
| 91 |
+
# 格子颜色
|
| 92 |
+
color = TILE_COLORS.get(value, '#3c3a32')
|
| 93 |
+
painter.fillRect(x, y, self.cell_size, self.cell_size, QColor(color))
|
| 94 |
+
|
| 95 |
+
# 数字
|
| 96 |
+
if value > 0:
|
| 97 |
+
# 文字颜色
|
| 98 |
+
text_color = TILE_TEXT_COLORS.get(value, '#f9f6f2')
|
| 99 |
+
painter.setPen(QColor(text_color))
|
| 100 |
+
|
| 101 |
+
# 字体大小根据数字位数调整
|
| 102 |
+
if value < 100:
|
| 103 |
+
font_size = 32
|
| 104 |
+
elif value < 1000:
|
| 105 |
+
font_size = 28
|
| 106 |
+
else:
|
| 107 |
+
font_size = 22
|
| 108 |
+
|
| 109 |
+
font = QFont('Arial', font_size, QFont.Bold)
|
| 110 |
+
painter.setFont(font)
|
| 111 |
+
|
| 112 |
+
# 居中绘制
|
| 113 |
+
text = str(int(value))
|
| 114 |
+
painter.drawText(x, y, self.cell_size, self.cell_size,
|
| 115 |
+
Qt.AlignCenter, text)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class ScoreWidget(QWidget):
|
| 119 |
+
"""分数显示组件"""
|
| 120 |
+
|
| 121 |
+
def __init__(self, title: str, parent=None):
|
| 122 |
+
super().__init__(parent)
|
| 123 |
+
self.title = title
|
| 124 |
+
self.value = 0
|
| 125 |
+
self.init_ui()
|
| 126 |
+
|
| 127 |
+
def init_ui(self):
|
| 128 |
+
"""初始化UI"""
|
| 129 |
+
layout = QVBoxLayout(self)
|
| 130 |
+
layout.setContentsMargins(5, 5, 5, 5)
|
| 131 |
+
|
| 132 |
+
# 标题
|
| 133 |
+
self.title_label = QLabel(self.title)
|
| 134 |
+
self.title_label.setAlignment(Qt.AlignCenter)
|
| 135 |
+
self.title_label.setStyleSheet("""
|
| 136 |
+
QLabel {
|
| 137 |
+
background-color: #bbada0;
|
| 138 |
+
color: #eee4da;
|
| 139 |
+
font-size: 12px;
|
| 140 |
+
font-weight: bold;
|
| 141 |
+
border-radius: 3px;
|
| 142 |
+
padding: 5px;
|
| 143 |
+
}
|
| 144 |
+
""")
|
| 145 |
+
layout.addWidget(self.title_label)
|
| 146 |
+
|
| 147 |
+
# 数值
|
| 148 |
+
self.value_label = QLabel('0')
|
| 149 |
+
self.value_label.setAlignment(Qt.AlignCenter)
|
| 150 |
+
self.value_label.setStyleSheet("""
|
| 151 |
+
QLabel {
|
| 152 |
+
background-color: #8f7a66;
|
| 153 |
+
color: white;
|
| 154 |
+
font-size: 20px;
|
| 155 |
+
font-weight: bold;
|
| 156 |
+
border-radius: 3px;
|
| 157 |
+
padding: 10px;
|
| 158 |
+
min-width: 80px;
|
| 159 |
+
}
|
| 160 |
+
""")
|
| 161 |
+
layout.addWidget(self.value_label)
|
| 162 |
+
|
| 163 |
+
def set_value(self, value):
|
| 164 |
+
"""设置分数值"""
|
| 165 |
+
self.value = value
|
| 166 |
+
if isinstance(value, float):
|
| 167 |
+
self.value_label.setText(f'{value:.1f}')
|
| 168 |
+
else:
|
| 169 |
+
self.value_label.setText(str(int(value)))
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class PlotCanvas(FigureCanvas):
|
| 173 |
+
"""matplotlib绑定的画布"""
|
| 174 |
+
|
| 175 |
+
def __init__(self, parent=None, width=5, height=4, dpi=100):
|
| 176 |
+
self.fig = Figure(figsize=(width, height), dpi=dpi)
|
| 177 |
+
self.axes = self.fig.add_subplot(111)
|
| 178 |
+
super().__init__(self.fig)
|
| 179 |
+
self.setParent(parent)
|
| 180 |
+
|
| 181 |
+
self.fig.patch.set_facecolor('#faf8ef')
|
| 182 |
+
self.axes.set_facecolor('#faf8ef')
|
| 183 |
+
|
| 184 |
+
# 设置中文字体
|
| 185 |
+
self.axes.set_prop_cycle(color=['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728'])
|
| 186 |
+
|
| 187 |
+
self.fig.tight_layout()
|
| 188 |
+
|
| 189 |
+
def plot_training_scores(self, scores, title="Training Scores"):
|
| 190 |
+
"""绘制训练曲线 - 只显示累积分数"""
|
| 191 |
+
self.axes.clear()
|
| 192 |
+
|
| 193 |
+
if scores:
|
| 194 |
+
x = range(1, len(scores) + 1)
|
| 195 |
+
self.axes.plot(x, scores, '#1f77b4', linewidth=1.5, alpha=0.8)
|
| 196 |
+
|
| 197 |
+
# 添加移动平均线
|
| 198 |
+
if len(scores) >= 10:
|
| 199 |
+
window = min(50, len(scores) // 5)
|
| 200 |
+
if window >= 5:
|
| 201 |
+
ma = np.convolve(scores, np.ones(window)/window, mode='valid')
|
| 202 |
+
ma_x = range(window, len(scores) + 1)
|
| 203 |
+
self.axes.plot(ma_x, ma, '#d62728', linewidth=2, alpha=0.8, label=f'MA({window})')
|
| 204 |
+
|
| 205 |
+
self.axes.set_xlabel('Games', fontsize=10)
|
| 206 |
+
self.axes.set_ylabel('Score', fontsize=10)
|
| 207 |
+
self.axes.set_title(title, fontsize=12)
|
| 208 |
+
self.axes.grid(True, alpha=0.3)
|
| 209 |
+
|
| 210 |
+
self.fig.tight_layout()
|
| 211 |
+
self.draw()
|
| 212 |
+
|
| 213 |
+
def plot_demo_scores(self, scores, title="Situational Score"):
|
| 214 |
+
"""绘制演示曲线 - 只显示局面分数"""
|
| 215 |
+
self.axes.clear()
|
| 216 |
+
|
| 217 |
+
if scores:
|
| 218 |
+
x = range(1, len(scores) + 1)
|
| 219 |
+
self.axes.plot(x, scores, '#2ca02c', linewidth=1.5, alpha=0.8)
|
| 220 |
+
|
| 221 |
+
self.axes.set_xlabel('Steps', fontsize=10)
|
| 222 |
+
self.axes.set_ylabel('Situational Score', fontsize=10)
|
| 223 |
+
self.axes.set_title(title, fontsize=12)
|
| 224 |
+
self.axes.grid(True, alpha=0.3)
|
| 225 |
+
|
| 226 |
+
self.fig.tight_layout()
|
| 227 |
+
self.draw()
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class SimpleTrainingThread(QThread):
|
| 231 |
+
"""简化训练线程"""
|
| 232 |
+
|
| 233 |
+
game_end_signal = pyqtSignal(dict)
|
| 234 |
+
progress_signal = pyqtSignal(dict)
|
| 235 |
+
finished_signal = pyqtSignal()
|
| 236 |
+
|
| 237 |
+
def __init__(self, model, trainer, num_games=1000):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.model = model
|
| 240 |
+
self.trainer = trainer
|
| 241 |
+
self.num_games = num_games
|
| 242 |
+
self.running = True
|
| 243 |
+
self.device = "cpu"
|
| 244 |
+
|
| 245 |
+
def run(self):
|
| 246 |
+
"""运行训练"""
|
| 247 |
+
scores = []
|
| 248 |
+
max_tiles = []
|
| 249 |
+
best_score = 0
|
| 250 |
+
start_time = time.time()
|
| 251 |
+
|
| 252 |
+
for game_idx in range(self.num_games):
|
| 253 |
+
if not self.running:
|
| 254 |
+
break
|
| 255 |
+
|
| 256 |
+
game = Game2048()
|
| 257 |
+
game.reset()
|
| 258 |
+
buffer = RolloutBuffer(capacity=10000)
|
| 259 |
+
|
| 260 |
+
while not game.game_over and self.running:
|
| 261 |
+
state = game.get_state()
|
| 262 |
+
scores_feat = np.array([
|
| 263 |
+
min(game.accumulated_score / 50000, 1.0),
|
| 264 |
+
min(game.situational_score / 200, 1.0)
|
| 265 |
+
], dtype=np.float32)
|
| 266 |
+
valid = game.get_valid_actions()
|
| 267 |
+
|
| 268 |
+
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 269 |
+
scores_t = torch.FloatTensor(scores_feat).unsqueeze(0).to(self.device)
|
| 270 |
+
valid_t = torch.BoolTensor(valid).unsqueeze(0).to(self.device)
|
| 271 |
+
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
action, log_prob, value = self.model.get_action(
|
| 274 |
+
state_t, scores_t, valid_t, deterministic=False
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
old_state = state.copy()
|
| 278 |
+
old_scores = scores_feat.copy()
|
| 279 |
+
next_state, reward, moved, done = game.move(action)
|
| 280 |
+
|
| 281 |
+
transition = Transition(
|
| 282 |
+
state=old_state,
|
| 283 |
+
scores=old_scores,
|
| 284 |
+
action=action,
|
| 285 |
+
reward=reward,
|
| 286 |
+
next_state=next_state.copy(),
|
| 287 |
+
next_scores=np.array([
|
| 288 |
+
min(game.accumulated_score / 50000, 1.0),
|
| 289 |
+
min(game.situational_score / 200, 1.0)
|
| 290 |
+
], dtype=np.float32),
|
| 291 |
+
done=done,
|
| 292 |
+
log_prob=log_prob,
|
| 293 |
+
value=value,
|
| 294 |
+
valid_actions=valid
|
| 295 |
+
)
|
| 296 |
+
buffer.push(transition)
|
| 297 |
+
|
| 298 |
+
if len(buffer) >= 64:
|
| 299 |
+
self.trainer.update(buffer)
|
| 300 |
+
buffer.clear()
|
| 301 |
+
|
| 302 |
+
scores.append(game.accumulated_score)
|
| 303 |
+
max_tiles.append(game.get_max_tile())
|
| 304 |
+
|
| 305 |
+
if game.accumulated_score > best_score:
|
| 306 |
+
best_score = game.accumulated_score
|
| 307 |
+
|
| 308 |
+
# 发送游戏结束信号
|
| 309 |
+
elapsed = time.time() - start_time
|
| 310 |
+
self.game_end_signal.emit({
|
| 311 |
+
'score': game.accumulated_score,
|
| 312 |
+
'max_tile': game.get_max_tile(),
|
| 313 |
+
'game_idx': game_idx + 1,
|
| 314 |
+
'best_score': best_score,
|
| 315 |
+
'elapsed': elapsed,
|
| 316 |
+
'avg_score': np.mean(scores[-100:]) if scores else 0
|
| 317 |
+
})
|
| 318 |
+
|
| 319 |
+
self.finished_signal.emit()
|
| 320 |
+
|
| 321 |
+
def stop(self):
|
| 322 |
+
"""停止训练"""
|
| 323 |
+
self.running = False
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
class MainWindow(QMainWindow):
|
| 327 |
+
"""主窗口"""
|
| 328 |
+
|
| 329 |
+
def __init__(self):
|
| 330 |
+
super().__init__()
|
| 331 |
+
self.setWindowTitle('2048 AI Trainer')
|
| 332 |
+
self.setMinimumSize(1000, 700)
|
| 333 |
+
|
| 334 |
+
# 初始化模型和训练器
|
| 335 |
+
self.device = 'cpu'
|
| 336 |
+
self.model = Game2048Transformer()
|
| 337 |
+
self.trainer = PPOTrainer(self.model, lr=3e-4, device=self.device)
|
| 338 |
+
|
| 339 |
+
# 游戏实例(用于演示)
|
| 340 |
+
self.game = Game2048()
|
| 341 |
+
|
| 342 |
+
# 训练状态
|
| 343 |
+
self.is_training = False
|
| 344 |
+
self.training_thread = None
|
| 345 |
+
|
| 346 |
+
# 统计数据
|
| 347 |
+
self.training_scores = []
|
| 348 |
+
self.demo_situational_scores = []
|
| 349 |
+
|
| 350 |
+
# 演示模式状态
|
| 351 |
+
self.ai_mode = False
|
| 352 |
+
self.auto_step = False
|
| 353 |
+
self.auto_timer = QTimer()
|
| 354 |
+
self.auto_timer.timeout.connect(self.ai_step)
|
| 355 |
+
|
| 356 |
+
# 当前模型路径
|
| 357 |
+
self.current_model_path = None
|
| 358 |
+
|
| 359 |
+
self.init_ui()
|
| 360 |
+
self.update_display()
|
| 361 |
+
|
| 362 |
+
def init_ui(self):
|
| 363 |
+
"""初始化UI"""
|
| 364 |
+
# 主窗口样式
|
| 365 |
+
self.setStyleSheet("""
|
| 366 |
+
QMainWindow {
|
| 367 |
+
background-color: #faf8ef;
|
| 368 |
+
}
|
| 369 |
+
QLabel {
|
| 370 |
+
color: #776e65;
|
| 371 |
+
}
|
| 372 |
+
QPushButton {
|
| 373 |
+
background-color: #8f7a66;
|
| 374 |
+
color: white;
|
| 375 |
+
border: none;
|
| 376 |
+
border-radius: 3px;
|
| 377 |
+
padding: 10px 20px;
|
| 378 |
+
font-size: 14px;
|
| 379 |
+
font-weight: bold;
|
| 380 |
+
min-width: 80px;
|
| 381 |
+
}
|
| 382 |
+
QPushButton:hover {
|
| 383 |
+
background-color: #9f8b77;
|
| 384 |
+
}
|
| 385 |
+
QPushButton:pressed {
|
| 386 |
+
background-color: #7f6a57;
|
| 387 |
+
}
|
| 388 |
+
QPushButton:disabled {
|
| 389 |
+
background-color: #ccc;
|
| 390 |
+
color: #999;
|
| 391 |
+
}
|
| 392 |
+
QComboBox {
|
| 393 |
+
background-color: #8f7a66;
|
| 394 |
+
color: white;
|
| 395 |
+
border: none;
|
| 396 |
+
border-radius: 3px;
|
| 397 |
+
padding: 5px 10px;
|
| 398 |
+
min-width: 100px;
|
| 399 |
+
}
|
| 400 |
+
QComboBox::drop-down {
|
| 401 |
+
border: none;
|
| 402 |
+
}
|
| 403 |
+
QGroupBox {
|
| 404 |
+
border: 2px solid #bbada0;
|
| 405 |
+
border-radius: 5px;
|
| 406 |
+
margin-top: 10px;
|
| 407 |
+
padding-top: 10px;
|
| 408 |
+
font-weight: bold;
|
| 409 |
+
color: #776e65;
|
| 410 |
+
}
|
| 411 |
+
QGroupBox::title {
|
| 412 |
+
subcontrol-origin: margin;
|
| 413 |
+
left: 10px;
|
| 414 |
+
padding: 0 5px;
|
| 415 |
+
}
|
| 416 |
+
QSpinBox {
|
| 417 |
+
background-color: white;
|
| 418 |
+
border: 1px solid #bbada0;
|
| 419 |
+
border-radius: 3px;
|
| 420 |
+
padding: 5px;
|
| 421 |
+
min-width: 80px;
|
| 422 |
+
}
|
| 423 |
+
""")
|
| 424 |
+
|
| 425 |
+
# 中央widget
|
| 426 |
+
central_widget = QWidget()
|
| 427 |
+
self.setCentralWidget(central_widget)
|
| 428 |
+
|
| 429 |
+
# 主布局
|
| 430 |
+
main_layout = QHBoxLayout(central_widget)
|
| 431 |
+
|
| 432 |
+
# 左侧面板(游戏区)
|
| 433 |
+
left_panel = self.create_left_panel()
|
| 434 |
+
main_layout.addWidget(left_panel)
|
| 435 |
+
|
| 436 |
+
# 右侧面板(统计和控制)
|
| 437 |
+
right_panel = self.create_right_panel()
|
| 438 |
+
main_layout.addWidget(right_panel)
|
| 439 |
+
|
| 440 |
+
# 设置拉伸比例
|
| 441 |
+
main_layout.setStretch(0, 1)
|
| 442 |
+
main_layout.setStretch(1, 2)
|
| 443 |
+
|
| 444 |
+
# 状态栏
|
| 445 |
+
self.statusBar = QStatusBar()
|
| 446 |
+
self.setStatusBar(self.statusBar)
|
| 447 |
+
self.statusBar.showMessage('Ready')
|
| 448 |
+
|
| 449 |
+
def create_left_panel(self) -> QWidget:
|
| 450 |
+
"""创建左侧面板"""
|
| 451 |
+
panel = QWidget()
|
| 452 |
+
layout = QVBoxLayout(panel)
|
| 453 |
+
layout.setAlignment(Qt.AlignCenter)
|
| 454 |
+
|
| 455 |
+
# 模式选择
|
| 456 |
+
mode_layout = QHBoxLayout()
|
| 457 |
+
mode_label = QLabel('Mode:')
|
| 458 |
+
mode_label.setFont(QFont('Arial', 12, QFont.Bold))
|
| 459 |
+
|
| 460 |
+
self.mode_combo = QComboBox()
|
| 461 |
+
self.mode_combo.addItems(['Training Mode', 'Demo Mode'])
|
| 462 |
+
self.mode_combo.currentIndexChanged.connect(self.switch_mode)
|
| 463 |
+
|
| 464 |
+
mode_layout.addWidget(mode_label)
|
| 465 |
+
mode_layout.addWidget(self.mode_combo)
|
| 466 |
+
mode_layout.addStretch()
|
| 467 |
+
layout.addLayout(mode_layout)
|
| 468 |
+
|
| 469 |
+
# 分数显示
|
| 470 |
+
scores_layout = QHBoxLayout()
|
| 471 |
+
self.accumulated_score_widget = ScoreWidget('Score')
|
| 472 |
+
self.situational_score_widget = ScoreWidget('Situational')
|
| 473 |
+
self.max_tile_widget = ScoreWidget('Max Tile')
|
| 474 |
+
|
| 475 |
+
scores_layout.addWidget(self.accumulated_score_widget)
|
| 476 |
+
scores_layout.addWidget(self.situational_score_widget)
|
| 477 |
+
scores_layout.addWidget(self.max_tile_widget)
|
| 478 |
+
layout.addLayout(scores_layout)
|
| 479 |
+
|
| 480 |
+
# 游戏面板
|
| 481 |
+
self.game_board = GameBoardWidget()
|
| 482 |
+
layout.addWidget(self.game_board, alignment=Qt.AlignCenter)
|
| 483 |
+
|
| 484 |
+
# 演示模式控制
|
| 485 |
+
self.demo_controls = QWidget()
|
| 486 |
+
demo_layout = QHBoxLayout(self.demo_controls)
|
| 487 |
+
|
| 488 |
+
self.ai_btn = QPushButton('AI Mode')
|
| 489 |
+
self.ai_btn.clicked.connect(self.toggle_ai_mode)
|
| 490 |
+
|
| 491 |
+
self.step_btn = QPushButton('Step')
|
| 492 |
+
self.step_btn.clicked.connect(self.ai_step)
|
| 493 |
+
|
| 494 |
+
self.auto_btn = QPushButton('Auto')
|
| 495 |
+
self.auto_btn.clicked.connect(self.toggle_auto)
|
| 496 |
+
|
| 497 |
+
self.reset_btn = QPushButton('Reset')
|
| 498 |
+
self.reset_btn.clicked.connect(self.reset_game)
|
| 499 |
+
|
| 500 |
+
demo_layout.addWidget(self.ai_btn)
|
| 501 |
+
demo_layout.addWidget(self.step_btn)
|
| 502 |
+
demo_layout.addWidget(self.auto_btn)
|
| 503 |
+
demo_layout.addWidget(self.reset_btn)
|
| 504 |
+
|
| 505 |
+
self.demo_controls.setVisible(False)
|
| 506 |
+
layout.addWidget(self.demo_controls)
|
| 507 |
+
|
| 508 |
+
layout.addStretch()
|
| 509 |
+
return panel
|
| 510 |
+
|
| 511 |
+
def create_right_panel(self) -> QWidget:
|
| 512 |
+
"""创建右侧面板"""
|
| 513 |
+
panel = QWidget()
|
| 514 |
+
layout = QVBoxLayout(panel)
|
| 515 |
+
|
| 516 |
+
# 训练控制
|
| 517 |
+
control_group = QGroupBox('Training Control')
|
| 518 |
+
control_layout = QVBoxLayout(control_group)
|
| 519 |
+
|
| 520 |
+
# 训练局数设置
|
| 521 |
+
games_layout = QHBoxLayout()
|
| 522 |
+
games_label = QLabel('Games:')
|
| 523 |
+
self.games_spinbox = QSpinBox()
|
| 524 |
+
self.games_spinbox.setRange(10, 100000)
|
| 525 |
+
self.games_spinbox.setValue(500)
|
| 526 |
+
games_layout.addWidget(games_label)
|
| 527 |
+
games_layout.addWidget(self.games_spinbox)
|
| 528 |
+
games_layout.addStretch()
|
| 529 |
+
control_layout.addLayout(games_layout)
|
| 530 |
+
|
| 531 |
+
# 按钮
|
| 532 |
+
btn_layout = QHBoxLayout()
|
| 533 |
+
self.start_btn = QPushButton('Start Training')
|
| 534 |
+
self.start_btn.clicked.connect(self.start_training)
|
| 535 |
+
|
| 536 |
+
self.stop_btn = QPushButton('Stop Training')
|
| 537 |
+
self.stop_btn.clicked.connect(self.stop_training)
|
| 538 |
+
self.stop_btn.setEnabled(False)
|
| 539 |
+
|
| 540 |
+
btn_layout.addWidget(self.start_btn)
|
| 541 |
+
btn_layout.addWidget(self.stop_btn)
|
| 542 |
+
control_layout.addLayout(btn_layout)
|
| 543 |
+
|
| 544 |
+
# 模型文件操作
|
| 545 |
+
model_layout = QHBoxLayout()
|
| 546 |
+
self.load_btn = QPushButton('Load Model')
|
| 547 |
+
self.load_btn.clicked.connect(self.load_model)
|
| 548 |
+
|
| 549 |
+
self.save_btn = QPushButton('Save Model')
|
| 550 |
+
self.save_btn.clicked.connect(self.save_model)
|
| 551 |
+
|
| 552 |
+
model_layout.addWidget(self.load_btn)
|
| 553 |
+
model_layout.addWidget(self.save_btn)
|
| 554 |
+
control_layout.addLayout(model_layout)
|
| 555 |
+
|
| 556 |
+
# 训练参数显示
|
| 557 |
+
param_layout = QGridLayout()
|
| 558 |
+
|
| 559 |
+
self.games_label = QLabel('Games: 0')
|
| 560 |
+
self.avg_score_label = QLabel('Avg Score: 0')
|
| 561 |
+
self.best_score_label = QLabel('Best Score: 0')
|
| 562 |
+
self.speed_label = QLabel('Speed: 0 games/s')
|
| 563 |
+
|
| 564 |
+
param_layout.addWidget(self.games_label, 0, 0)
|
| 565 |
+
param_layout.addWidget(self.avg_score_label, 0, 1)
|
| 566 |
+
param_layout.addWidget(self.best_score_label, 1, 0)
|
| 567 |
+
param_layout.addWidget(self.speed_label, 1, 1)
|
| 568 |
+
|
| 569 |
+
control_layout.addLayout(param_layout)
|
| 570 |
+
layout.addWidget(control_group)
|
| 571 |
+
|
| 572 |
+
# 分数曲线
|
| 573 |
+
plot_group = QGroupBox('Score Chart')
|
| 574 |
+
plot_layout = QVBoxLayout(plot_group)
|
| 575 |
+
|
| 576 |
+
self.plot_canvas = PlotCanvas(self, width=6, height=4, dpi=100)
|
| 577 |
+
plot_layout.addWidget(self.plot_canvas)
|
| 578 |
+
|
| 579 |
+
layout.addWidget(plot_group)
|
| 580 |
+
|
| 581 |
+
# 实时统计
|
| 582 |
+
stats_group = QGroupBox('Training Stats')
|
| 583 |
+
stats_layout = QVBoxLayout(stats_group)
|
| 584 |
+
|
| 585 |
+
self.stats_text = QLabel('Waiting for training...')
|
| 586 |
+
self.stats_text.setStyleSheet('font-family: monospace;')
|
| 587 |
+
stats_layout.addWidget(self.stats_text)
|
| 588 |
+
|
| 589 |
+
layout.addWidget(stats_group)
|
| 590 |
+
|
| 591 |
+
return panel
|
| 592 |
+
|
| 593 |
+
def switch_mode(self, index):
|
| 594 |
+
"""切换模式"""
|
| 595 |
+
if index == 0: # 训练模式
|
| 596 |
+
self.demo_controls.setVisible(False)
|
| 597 |
+
self.demo_situational_scores = []
|
| 598 |
+
self.plot_canvas.plot_training_scores(self.training_scores)
|
| 599 |
+
else: # 演示模式
|
| 600 |
+
self.demo_controls.setVisible(True)
|
| 601 |
+
self.reset_game()
|
| 602 |
+
self.demo_situational_scores = []
|
| 603 |
+
self.plot_canvas.plot_demo_scores([])
|
| 604 |
+
|
| 605 |
+
def keyPressEvent(self, event):
|
| 606 |
+
"""键盘事件"""
|
| 607 |
+
if self.mode_combo.currentIndex() == 1 and not self.ai_mode:
|
| 608 |
+
# 演示模式且非AI托管
|
| 609 |
+
key_map = {
|
| 610 |
+
Qt.Key_Up: 0,
|
| 611 |
+
Qt.Key_Down: 1,
|
| 612 |
+
Qt.Key_Left: 2,
|
| 613 |
+
Qt.Key_Right: 3,
|
| 614 |
+
}
|
| 615 |
+
|
| 616 |
+
if event.key() in key_map:
|
| 617 |
+
direction = key_map[event.key()]
|
| 618 |
+
self.game.move(direction)
|
| 619 |
+
self.demo_situational_scores.append(self.game.situational_score)
|
| 620 |
+
self.update_display()
|
| 621 |
+
self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
|
| 622 |
+
|
| 623 |
+
if self.game.game_over:
|
| 624 |
+
self.statusBar.showMessage('Game Over!')
|
| 625 |
+
|
| 626 |
+
def toggle_ai_mode(self):
|
| 627 |
+
"""切换AI托管模式"""
|
| 628 |
+
self.ai_mode = not self.ai_mode
|
| 629 |
+
if self.ai_mode:
|
| 630 |
+
self.ai_btn.setText('Manual')
|
| 631 |
+
self.step_btn.setEnabled(False)
|
| 632 |
+
self.auto_btn.setEnabled(True)
|
| 633 |
+
else:
|
| 634 |
+
self.ai_btn.setText('AI Mode')
|
| 635 |
+
self.step_btn.setEnabled(True)
|
| 636 |
+
self.auto_btn.setEnabled(False)
|
| 637 |
+
self.auto_timer.stop()
|
| 638 |
+
self.auto_btn.setText('Auto')
|
| 639 |
+
|
| 640 |
+
def ai_step(self):
|
| 641 |
+
"""AI单步执行"""
|
| 642 |
+
if self.game.game_over:
|
| 643 |
+
self.reset_game()
|
| 644 |
+
return
|
| 645 |
+
|
| 646 |
+
state = self.game.get_state()
|
| 647 |
+
scores = np.array([
|
| 648 |
+
min(self.game.accumulated_score / 50000, 1.0),
|
| 649 |
+
min(self.game.situational_score / 200, 1.0)
|
| 650 |
+
], dtype=np.float32)
|
| 651 |
+
valid_actions = self.game.get_valid_actions()
|
| 652 |
+
|
| 653 |
+
state_t = torch.FloatTensor(state).unsqueeze(0)
|
| 654 |
+
scores_t = torch.FloatTensor(scores).unsqueeze(0)
|
| 655 |
+
valid_t = torch.BoolTensor(valid_actions).unsqueeze(0)
|
| 656 |
+
|
| 657 |
+
action, _, _ = self.model.get_action(state_t, scores_t, valid_t, deterministic=True)
|
| 658 |
+
|
| 659 |
+
self.game.move(action)
|
| 660 |
+
self.demo_situational_scores.append(self.game.situational_score)
|
| 661 |
+
self.update_display()
|
| 662 |
+
self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
|
| 663 |
+
|
| 664 |
+
if self.game.game_over:
|
| 665 |
+
self.statusBar.showMessage(f'Game Over! Final Score: {self.game.accumulated_score}')
|
| 666 |
+
if self.auto_timer.isActive():
|
| 667 |
+
self.auto_timer.stop()
|
| 668 |
+
self.auto_btn.setText('Auto')
|
| 669 |
+
|
| 670 |
+
def toggle_auto(self):
|
| 671 |
+
"""切换自动执行"""
|
| 672 |
+
if self.auto_timer.isActive():
|
| 673 |
+
self.auto_timer.stop()
|
| 674 |
+
self.auto_btn.setText('Auto')
|
| 675 |
+
else:
|
| 676 |
+
self.auto_timer.start(100) # 100ms间隔
|
| 677 |
+
self.auto_btn.setText('Stop')
|
| 678 |
+
|
| 679 |
+
def reset_game(self):
|
| 680 |
+
"""重置游戏"""
|
| 681 |
+
self.game.reset()
|
| 682 |
+
self.demo_situational_scores = [self.game.situational_score]
|
| 683 |
+
self.update_display()
|
| 684 |
+
self.plot_canvas.plot_demo_scores(self.demo_situational_scores)
|
| 685 |
+
self.statusBar.showMessage('Game Reset')
|
| 686 |
+
|
| 687 |
+
def update_display(self):
|
| 688 |
+
"""更新显示"""
|
| 689 |
+
self.game_board.set_board(self.game.board)
|
| 690 |
+
self.accumulated_score_widget.set_value(self.game.accumulated_score)
|
| 691 |
+
self.situational_score_widget.set_value(self.game.situational_score)
|
| 692 |
+
self.max_tile_widget.set_value(self.game.get_max_tile())
|
| 693 |
+
|
| 694 |
+
def start_training(self):
|
| 695 |
+
"""开始训练"""
|
| 696 |
+
self.is_training = True
|
| 697 |
+
self.start_btn.setEnabled(False)
|
| 698 |
+
self.stop_btn.setEnabled(True)
|
| 699 |
+
self.mode_combo.setEnabled(False)
|
| 700 |
+
self.games_spinbox.setEnabled(False)
|
| 701 |
+
self.load_btn.setEnabled(False)
|
| 702 |
+
self.save_btn.setEnabled(False)
|
| 703 |
+
|
| 704 |
+
# 重置统计
|
| 705 |
+
self.training_scores = []
|
| 706 |
+
|
| 707 |
+
# 创建训练线程
|
| 708 |
+
num_games = self.games_spinbox.value()
|
| 709 |
+
self.training_thread = SimpleTrainingThread(
|
| 710 |
+
self.model, self.trainer, num_games
|
| 711 |
+
)
|
| 712 |
+
self.training_thread.game_end_signal.connect(self.on_game_end)
|
| 713 |
+
self.training_thread.finished_signal.connect(self.on_training_finished)
|
| 714 |
+
self.training_thread.start()
|
| 715 |
+
|
| 716 |
+
self.statusBar.showMessage('Training started...')
|
| 717 |
+
|
| 718 |
+
def stop_training(self):
|
| 719 |
+
"""停止训练"""
|
| 720 |
+
if self.training_thread:
|
| 721 |
+
self.training_thread.stop()
|
| 722 |
+
self.training_thread.wait()
|
| 723 |
+
self.training_thread = None
|
| 724 |
+
|
| 725 |
+
self.on_training_finished()
|
| 726 |
+
|
| 727 |
+
def on_game_end(self, stats):
|
| 728 |
+
"""游戏结束回调"""
|
| 729 |
+
self.training_scores.append(stats['score'])
|
| 730 |
+
|
| 731 |
+
# 更新统计显示
|
| 732 |
+
games = stats['game_idx']
|
| 733 |
+
avg_score = stats['avg_score']
|
| 734 |
+
best_score = stats['best_score']
|
| 735 |
+
elapsed = stats['elapsed']
|
| 736 |
+
speed = games / elapsed if elapsed > 0 else 0
|
| 737 |
+
|
| 738 |
+
self.games_label.setText(f'Games: {games}')
|
| 739 |
+
self.avg_score_label.setText(f'Avg Score: {avg_score:.0f}')
|
| 740 |
+
self.best_score_label.setText(f'Best Score: {best_score}')
|
| 741 |
+
self.speed_label.setText(f'Speed: {speed:.2f} games/s')
|
| 742 |
+
|
| 743 |
+
# 更新曲线
|
| 744 |
+
if games % 5 == 0:
|
| 745 |
+
self.plot_canvas.plot_training_scores(self.training_scores)
|
| 746 |
+
|
| 747 |
+
# 更新统计文本
|
| 748 |
+
self.stats_text.setText(
|
| 749 |
+
f"Games: {games}\n"
|
| 750 |
+
f"Avg Score: {avg_score:.0f}\n"
|
| 751 |
+
f"Best Score: {best_score}\n"
|
| 752 |
+
f"Last Score: {stats['score']}"
|
| 753 |
+
)
|
| 754 |
+
|
| 755 |
+
def on_training_finished(self):
|
| 756 |
+
"""训练完成回调"""
|
| 757 |
+
self.is_training = False
|
| 758 |
+
self.start_btn.setEnabled(True)
|
| 759 |
+
self.stop_btn.setEnabled(False)
|
| 760 |
+
self.mode_combo.setEnabled(True)
|
| 761 |
+
self.games_spinbox.setEnabled(True)
|
| 762 |
+
self.load_btn.setEnabled(True)
|
| 763 |
+
self.save_btn.setEnabled(True)
|
| 764 |
+
|
| 765 |
+
# 自动保存模型
|
| 766 |
+
save_path = os.path.join(os.path.dirname(__file__), 'checkpoints', 'model.pt')
|
| 767 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 768 |
+
torch.save({
|
| 769 |
+
'model_state_dict': self.model.state_dict(),
|
| 770 |
+
'training_scores': self.training_scores,
|
| 771 |
+
'best_score': max(self.training_scores) if self.training_scores else 0
|
| 772 |
+
}, save_path)
|
| 773 |
+
|
| 774 |
+
self.statusBar.showMessage(f'Training finished! Model saved to {save_path}')
|
| 775 |
+
|
| 776 |
+
def load_model(self):
|
| 777 |
+
"""加载模型"""
|
| 778 |
+
file_path, _ = QFileDialog.getOpenFileName(
|
| 779 |
+
self, 'Load Model',
|
| 780 |
+
os.path.join(os.path.dirname(__file__), 'checkpoints'),
|
| 781 |
+
'PyTorch Model (*.pt);;All Files (*)'
|
| 782 |
+
)
|
| 783 |
+
|
| 784 |
+
if file_path:
|
| 785 |
+
try:
|
| 786 |
+
checkpoint = torch.load(file_path, map_location=self.device, weights_only=False)
|
| 787 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 788 |
+
self.current_model_path = file_path
|
| 789 |
+
|
| 790 |
+
if 'training_scores' in checkpoint:
|
| 791 |
+
self.training_scores = checkpoint['training_scores']
|
| 792 |
+
self.plot_canvas.plot_training_scores(self.training_scores)
|
| 793 |
+
|
| 794 |
+
self.statusBar.showMessage(f'Model loaded: {os.path.basename(file_path)}')
|
| 795 |
+
except Exception as e:
|
| 796 |
+
QMessageBox.warning(self, 'Error', f'Failed to load model:\n{str(e)}')
|
| 797 |
+
|
| 798 |
+
def save_model(self):
|
| 799 |
+
"""保存模型"""
|
| 800 |
+
file_path, _ = QFileDialog.getSaveFileName(
|
| 801 |
+
self, 'Save Model',
|
| 802 |
+
os.path.join(os.path.dirname(__file__), 'checkpoints', 'model.pt'),
|
| 803 |
+
'PyTorch Model (*.pt);;All Files (*)'
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
if file_path:
|
| 807 |
+
try:
|
| 808 |
+
torch.save({
|
| 809 |
+
'model_state_dict': self.model.state_dict(),
|
| 810 |
+
'training_scores': self.training_scores,
|
| 811 |
+
'best_score': max(self.training_scores) if self.training_scores else 0
|
| 812 |
+
}, file_path)
|
| 813 |
+
self.current_model_path = file_path
|
| 814 |
+
self.statusBar.showMessage(f'Model saved: {os.path.basename(file_path)}')
|
| 815 |
+
except Exception as e:
|
| 816 |
+
QMessageBox.warning(self, 'Error', f'Failed to save model:\n{str(e)}')
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def main():
|
| 820 |
+
"""主函数"""
|
| 821 |
+
app = QApplication(sys.argv)
|
| 822 |
+
app.setStyle('Fusion')
|
| 823 |
+
|
| 824 |
+
# 设置字体
|
| 825 |
+
font = QFont('Arial', 10)
|
| 826 |
+
app.setFont(font)
|
| 827 |
+
|
| 828 |
+
window = MainWindow()
|
| 829 |
+
window.show()
|
| 830 |
+
|
| 831 |
+
sys.exit(app.exec_())
|
| 832 |
+
|
| 833 |
+
|
| 834 |
+
if __name__ == '__main__':
|
| 835 |
+
main()
|
main.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
2048 AI Trainer - 主入口
|
| 3 |
+
基于Transformer的2048游戏AI训练器
|
| 4 |
+
|
| 5 |
+
使用方法:
|
| 6 |
+
python main.py # 启动GUI界面
|
| 7 |
+
python main.py --train # 命令行训练模式
|
| 8 |
+
python main.py --demo # 演示模式(加载已有模型)
|
| 9 |
+
"""
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
import argparse
|
| 13 |
+
import torch
|
| 14 |
+
import time
|
| 15 |
+
|
| 16 |
+
# 添加当前目录到路径
|
| 17 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 18 |
+
|
| 19 |
+
from game import Game2048
|
| 20 |
+
from model import Game2048Transformer, count_parameters
|
| 21 |
+
from trainer import PPOTrainer, RolloutBuffer, TrainingStats
|
| 22 |
+
from parallel import TrainingLoop, ParallelGameEnv, TrainingWorker
|
| 23 |
+
from utils import (
|
| 24 |
+
set_seed, get_device, print_model_info, save_checkpoint,
|
| 25 |
+
load_checkpoint, EarlyStopping, format_time, format_number
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def run_gui():
|
| 30 |
+
"""运行GUI界面"""
|
| 31 |
+
from gui import main
|
| 32 |
+
main()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_training(
|
| 36 |
+
num_games: int = 10000,
|
| 37 |
+
num_envs: int = 4,
|
| 38 |
+
save_interval: int = 100,
|
| 39 |
+
checkpoint_dir: str = "checkpoints",
|
| 40 |
+
seed: int = 42
|
| 41 |
+
):
|
| 42 |
+
"""
|
| 43 |
+
命令行训练模式
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
num_games: 总游戏局数
|
| 47 |
+
num_envs: 并行环境数
|
| 48 |
+
save_interval: 保存间隔
|
| 49 |
+
checkpoint_dir: 检查点目录
|
| 50 |
+
seed: 随机种子
|
| 51 |
+
"""
|
| 52 |
+
print("=" * 50)
|
| 53 |
+
print("2048 AI Training")
|
| 54 |
+
print("=" * 50)
|
| 55 |
+
|
| 56 |
+
# 设置随机种子
|
| 57 |
+
set_seed(seed)
|
| 58 |
+
|
| 59 |
+
# 获取设备
|
| 60 |
+
device = get_device()
|
| 61 |
+
print(f"Device: {device}")
|
| 62 |
+
|
| 63 |
+
# 创建模型
|
| 64 |
+
model = Game2048Transformer()
|
| 65 |
+
print_model_info(model)
|
| 66 |
+
model.to(device)
|
| 67 |
+
|
| 68 |
+
# 创建训练器
|
| 69 |
+
trainer = PPOTrainer(model, lr=1e-4, device=device)
|
| 70 |
+
|
| 71 |
+
# 创建训练循环
|
| 72 |
+
training_loop = TrainingLoop(
|
| 73 |
+
model=model,
|
| 74 |
+
trainer=trainer,
|
| 75 |
+
num_envs=num_envs,
|
| 76 |
+
device=device,
|
| 77 |
+
steps_per_update=256,
|
| 78 |
+
checkpoint_dir=checkpoint_dir
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# 训练统计
|
| 82 |
+
stats = TrainingStats()
|
| 83 |
+
start_time = time.time()
|
| 84 |
+
best_score = 0
|
| 85 |
+
|
| 86 |
+
# 创建检查点目录
|
| 87 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 88 |
+
|
| 89 |
+
print(f"\nStarting training for {num_games} games...")
|
| 90 |
+
print(f"Parallel environments: {num_envs}")
|
| 91 |
+
print("-" * 50)
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
games_completed = 0
|
| 95 |
+
|
| 96 |
+
def on_game_end(game_stats):
|
| 97 |
+
nonlocal games_completed, best_score
|
| 98 |
+
|
| 99 |
+
stats.record_game(
|
| 100 |
+
score=game_stats['score'],
|
| 101 |
+
situational_score=game_stats['situational_score'],
|
| 102 |
+
max_tile=game_stats['max_tile'],
|
| 103 |
+
steps=game_stats['moves']
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
games_completed += 1
|
| 107 |
+
|
| 108 |
+
if game_stats['score'] > best_score:
|
| 109 |
+
best_score = game_stats['score']
|
| 110 |
+
|
| 111 |
+
# 定期打印统计
|
| 112 |
+
if games_completed % 10 == 0:
|
| 113 |
+
elapsed = time.time() - start_time
|
| 114 |
+
avg_stats = stats.get_avg_stats(window=100)
|
| 115 |
+
|
| 116 |
+
print(
|
| 117 |
+
f"Games: {games_completed} | "
|
| 118 |
+
f"Avg Score: {avg_stats['avg_score']:.0f} | "
|
| 119 |
+
f"Best: {best_score} | "
|
| 120 |
+
f"Max Tile: {avg_stats['avg_max_tile']:.0f} | "
|
| 121 |
+
f"Speed: {games_completed/elapsed:.2f} games/s"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# 保存检查点
|
| 125 |
+
if games_completed % save_interval == 0:
|
| 126 |
+
checkpoint_path = os.path.join(
|
| 127 |
+
checkpoint_dir,
|
| 128 |
+
f"checkpoint_{games_completed}.pt"
|
| 129 |
+
)
|
| 130 |
+
save_checkpoint(
|
| 131 |
+
model, trainer.optimizer, games_completed,
|
| 132 |
+
avg_stats, checkpoint_path
|
| 133 |
+
)
|
| 134 |
+
print(f"Checkpoint saved: {checkpoint_path}")
|
| 135 |
+
|
| 136 |
+
training_loop.on_game_end_callback = on_game_end
|
| 137 |
+
training_loop.train(total_games=num_games, stop_threshold=200)
|
| 138 |
+
|
| 139 |
+
except KeyboardInterrupt:
|
| 140 |
+
print("\nTraining interrupted by user.")
|
| 141 |
+
|
| 142 |
+
# 训练结束统计
|
| 143 |
+
elapsed = time.time() - start_time
|
| 144 |
+
final_stats = stats.get_avg_stats()
|
| 145 |
+
|
| 146 |
+
print("\n" + "=" * 50)
|
| 147 |
+
print("Training Complete!")
|
| 148 |
+
print("=" * 50)
|
| 149 |
+
print(f"Total games: {format_number(final_stats['games_played'])}")
|
| 150 |
+
print(f"Total time: {format_time(elapsed)}")
|
| 151 |
+
print(f"Average score: {final_stats['avg_score']:.0f}")
|
| 152 |
+
print(f"Best score: {final_stats['best_score']}")
|
| 153 |
+
print(f"Best max tile: {final_stats['best_max_tile']}")
|
| 154 |
+
|
| 155 |
+
# 保存最终模型
|
| 156 |
+
final_path = os.path.join(checkpoint_dir, "final_model.pt")
|
| 157 |
+
save_checkpoint(
|
| 158 |
+
model, trainer.optimizer, final_stats['games_played'],
|
| 159 |
+
final_stats, final_path
|
| 160 |
+
)
|
| 161 |
+
print(f"Final model saved: {final_path}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def run_demo(model_path: str = None, num_games: int = 5):
|
| 165 |
+
"""
|
| 166 |
+
演示模式
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
model_path: 模型路径
|
| 170 |
+
num_games: 演示游戏数
|
| 171 |
+
"""
|
| 172 |
+
print("=" * 50)
|
| 173 |
+
print("2048 AI Demo Mode")
|
| 174 |
+
print("=" * 50)
|
| 175 |
+
|
| 176 |
+
device = get_device()
|
| 177 |
+
print(f"Device: {device}")
|
| 178 |
+
|
| 179 |
+
# 创建模型
|
| 180 |
+
model = Game2048Transformer()
|
| 181 |
+
|
| 182 |
+
# 加载模型
|
| 183 |
+
if model_path and os.path.exists(model_path):
|
| 184 |
+
print(f"Loading model from: {model_path}")
|
| 185 |
+
checkpoint = load_checkpoint(model_path, model, device=device)
|
| 186 |
+
print(f"Loaded checkpoint: {checkpoint.get('epoch', 'unknown')} games")
|
| 187 |
+
else:
|
| 188 |
+
print("No model loaded, using random weights.")
|
| 189 |
+
|
| 190 |
+
model.to(device)
|
| 191 |
+
model.eval()
|
| 192 |
+
print_model_info(model)
|
| 193 |
+
|
| 194 |
+
# 运行演示游戏
|
| 195 |
+
print(f"\nRunning {num_games} demo games...")
|
| 196 |
+
print("-" * 50)
|
| 197 |
+
|
| 198 |
+
total_scores = []
|
| 199 |
+
|
| 200 |
+
for game_idx in range(num_games):
|
| 201 |
+
game = Game2048()
|
| 202 |
+
game.reset()
|
| 203 |
+
|
| 204 |
+
steps = 0
|
| 205 |
+
while not game.game_over and steps < 10000:
|
| 206 |
+
# 获取状态
|
| 207 |
+
state = game.get_state()
|
| 208 |
+
scores = game.get_state_with_scores()[-2:]
|
| 209 |
+
valid_actions = game.get_valid_actions()
|
| 210 |
+
|
| 211 |
+
# 转换为张量
|
| 212 |
+
state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
|
| 213 |
+
scores_t = torch.FloatTensor(scores).unsqueeze(0).to(device)
|
| 214 |
+
valid_t = torch.BoolTensor(valid_actions).unsqueeze(0).to(device)
|
| 215 |
+
|
| 216 |
+
# 获取动作
|
| 217 |
+
with torch.no_grad():
|
| 218 |
+
action, _, _ = model.get_action(state_t, scores_t, valid_t, deterministic=True)
|
| 219 |
+
|
| 220 |
+
# 执行动作
|
| 221 |
+
game.move(action)
|
| 222 |
+
steps += 1
|
| 223 |
+
|
| 224 |
+
total_scores.append(game.accumulated_score)
|
| 225 |
+
|
| 226 |
+
print(
|
| 227 |
+
f"Game {game_idx + 1}: "
|
| 228 |
+
f"Score = {game.accumulated_score}, "
|
| 229 |
+
f"Max Tile = {game.get_max_tile()}, "
|
| 230 |
+
f"Steps = {steps}"
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# 统计
|
| 234 |
+
print("\n" + "-" * 50)
|
| 235 |
+
print(f"Average score: {sum(total_scores) / len(total_scores):.0f}")
|
| 236 |
+
print(f"Best score: {max(total_scores)}")
|
| 237 |
+
print(f"Worst score: {min(total_scores)}")
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def main():
|
| 241 |
+
"""主函数"""
|
| 242 |
+
parser = argparse.ArgumentParser(
|
| 243 |
+
description='2048 AI Trainer - 基于Transformer的2048游戏AI'
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
parser.add_argument(
|
| 247 |
+
'--train',
|
| 248 |
+
action='store_true',
|
| 249 |
+
help='运行命令行训练模式'
|
| 250 |
+
)
|
| 251 |
+
|
| 252 |
+
parser.add_argument(
|
| 253 |
+
'--demo',
|
| 254 |
+
action='store_true',
|
| 255 |
+
help='运行演示模式'
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
parser.add_argument(
|
| 259 |
+
'--model',
|
| 260 |
+
type=str,
|
| 261 |
+
default=None,
|
| 262 |
+
help='模型路径(用于演示模式或继续训练)'
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
'--games',
|
| 267 |
+
type=int,
|
| 268 |
+
default=10000,
|
| 269 |
+
help='训练游戏数(默认10000)'
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
parser.add_argument(
|
| 273 |
+
'--envs',
|
| 274 |
+
type=int,
|
| 275 |
+
default=4,
|
| 276 |
+
help='并行环境数(默认4)'
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
parser.add_argument(
|
| 280 |
+
'--seed',
|
| 281 |
+
type=int,
|
| 282 |
+
default=42,
|
| 283 |
+
help='随机种子(默认42)'
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
args = parser.parse_args()
|
| 287 |
+
|
| 288 |
+
if args.train:
|
| 289 |
+
run_training(
|
| 290 |
+
num_games=args.games,
|
| 291 |
+
num_envs=args.envs,
|
| 292 |
+
seed=args.seed
|
| 293 |
+
)
|
| 294 |
+
elif args.demo:
|
| 295 |
+
run_demo(model_path=args.model, num_games=5)
|
| 296 |
+
else:
|
| 297 |
+
# 默认启动GUI
|
| 298 |
+
run_gui()
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
main()
|
model.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Transformer模型用于2048游戏决策
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import numpy as np
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class Game2048Transformer(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
小型Transformer模型用于2048游戏
|
| 14 |
+
|
| 15 |
+
输入:
|
| 16 |
+
- 棋盘状态 (batch, 4, 4) 或 (batch, 16)
|
| 17 |
+
- 可选: 分数特征 (batch, 2)
|
| 18 |
+
|
| 19 |
+
输出:
|
| 20 |
+
- policy: (batch, 4) 动作概率
|
| 21 |
+
- value: (batch, 1) 状态价值
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
d_model: int = 64,
|
| 27 |
+
nhead: int = 4,
|
| 28 |
+
num_layers: int = 2,
|
| 29 |
+
dim_feedforward: int = 128,
|
| 30 |
+
dropout: float = 0.1
|
| 31 |
+
):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
self.d_model = d_model
|
| 35 |
+
|
| 36 |
+
# 棋盘位置embedding
|
| 37 |
+
# 每个格子: 0表示空,1-15表示log2(value)
|
| 38 |
+
self.position_embedding = nn.Embedding(16, d_model)
|
| 39 |
+
|
| 40 |
+
# 空间位置编码(4x4棋盘的行列位置)
|
| 41 |
+
self.row_embedding = nn.Embedding(4, d_model // 2)
|
| 42 |
+
self.col_embedding = nn.Embedding(4, d_model // 2)
|
| 43 |
+
|
| 44 |
+
# 分数特征embedding
|
| 45 |
+
self.score_embedding = nn.Linear(2, d_model)
|
| 46 |
+
|
| 47 |
+
# Transformer编码器
|
| 48 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 49 |
+
d_model=d_model,
|
| 50 |
+
nhead=nhead,
|
| 51 |
+
dim_feedforward=dim_feedforward,
|
| 52 |
+
dropout=dropout,
|
| 53 |
+
batch_first=True,
|
| 54 |
+
activation='gelu'
|
| 55 |
+
)
|
| 56 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
| 57 |
+
|
| 58 |
+
# 策略头(输出4个动作的概率)
|
| 59 |
+
self.policy_head = nn.Sequential(
|
| 60 |
+
nn.Linear(d_model, d_model),
|
| 61 |
+
nn.GELU(),
|
| 62 |
+
nn.Linear(d_model, 4)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# 价值头(输出状态价值)
|
| 66 |
+
self.value_head = nn.Sequential(
|
| 67 |
+
nn.Linear(d_model, d_model),
|
| 68 |
+
nn.GELU(),
|
| 69 |
+
nn.Linear(d_model, 1),
|
| 70 |
+
nn.Tanh()
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# 初始化权重
|
| 74 |
+
self._init_weights()
|
| 75 |
+
|
| 76 |
+
def _init_weights(self):
|
| 77 |
+
"""初始化权重"""
|
| 78 |
+
for module in self.modules():
|
| 79 |
+
if isinstance(module, nn.Linear):
|
| 80 |
+
nn.init.xavier_uniform_(module.weight)
|
| 81 |
+
if module.bias is not None:
|
| 82 |
+
nn.init.zeros_(module.bias)
|
| 83 |
+
elif isinstance(module, nn.Embedding):
|
| 84 |
+
nn.init.normal_(module.weight, mean=0, std=0.02)
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
board: torch.Tensor,
|
| 89 |
+
scores: torch.Tensor = None,
|
| 90 |
+
mask: torch.Tensor = None
|
| 91 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 92 |
+
"""
|
| 93 |
+
前向传播
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
board: (batch, 4, 4) 或 (batch, 16) 棋盘状态,值为log2(value)
|
| 97 |
+
scores: (batch, 2) 可选的分数特征 [累积分数, 局面分数]
|
| 98 |
+
mask: (batch, 4) 可选的动作mask,True表示有效动作
|
| 99 |
+
|
| 100 |
+
Returns:
|
| 101 |
+
policy: (batch, 4) 动作logits
|
| 102 |
+
value: (batch, 1) 状态价值
|
| 103 |
+
"""
|
| 104 |
+
batch_size = board.shape[0]
|
| 105 |
+
|
| 106 |
+
# 展平棋盘
|
| 107 |
+
if board.dim() == 3:
|
| 108 |
+
board_flat = board.view(batch_size, -1) # (batch, 16)
|
| 109 |
+
else:
|
| 110 |
+
board_flat = board
|
| 111 |
+
|
| 112 |
+
# 将棋盘值转换为embedding索引(clamp到有效范围)
|
| 113 |
+
board_indices = torch.clamp(board_flat.long(), 0, 15)
|
| 114 |
+
|
| 115 |
+
# 位置embedding
|
| 116 |
+
pos_embeddings = self.position_embedding(board_indices) # (batch, 16, d_model)
|
| 117 |
+
|
| 118 |
+
# 添加空间位置编码
|
| 119 |
+
row_indices = torch.arange(4, device=board.device).repeat(4)
|
| 120 |
+
col_indices = torch.arange(4, device=board.device).repeat_interleave(4)
|
| 121 |
+
row_emb = self.row_embedding(row_indices) # (16, d_model//2)
|
| 122 |
+
col_emb = self.col_embedding(col_indices) # (16, d_model//2)
|
| 123 |
+
spatial_emb = torch.cat([row_emb, col_emb], dim=-1) # (16, d_model)
|
| 124 |
+
|
| 125 |
+
# 合并embedding
|
| 126 |
+
x = pos_embeddings + spatial_emb.unsqueeze(0) # (batch, 16, d_model)
|
| 127 |
+
|
| 128 |
+
# 如果提供分数特征,作为第17个token
|
| 129 |
+
if scores is not None:
|
| 130 |
+
score_emb = self.score_embedding(scores).unsqueeze(1) # (batch, 1, d_model)
|
| 131 |
+
x = torch.cat([x, score_emb], dim=1) # (batch, 17, d_model)
|
| 132 |
+
|
| 133 |
+
# Transformer编码
|
| 134 |
+
x = self.transformer(x) # (batch, 17 or 16, d_model)
|
| 135 |
+
|
| 136 |
+
# 全局池化
|
| 137 |
+
x = x.mean(dim=1) # (batch, d_model)
|
| 138 |
+
|
| 139 |
+
# 输出头
|
| 140 |
+
policy_logits = self.policy_head(x) # (batch, 4)
|
| 141 |
+
value = self.value_head(x) # (batch, 1)
|
| 142 |
+
|
| 143 |
+
# 应用动作mask
|
| 144 |
+
if mask is not None:
|
| 145 |
+
# 无效动作设为很小的值
|
| 146 |
+
policy_logits = policy_logits.masked_fill(~mask, -1e9)
|
| 147 |
+
|
| 148 |
+
return policy_logits, value
|
| 149 |
+
|
| 150 |
+
def get_action(
|
| 151 |
+
self,
|
| 152 |
+
board: torch.Tensor,
|
| 153 |
+
scores: torch.Tensor = None,
|
| 154 |
+
mask: torch.Tensor = None,
|
| 155 |
+
deterministic: bool = False
|
| 156 |
+
) -> Tuple[int, torch.Tensor, torch.Tensor]:
|
| 157 |
+
"""
|
| 158 |
+
选择动作
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
board: (1, 4, 4) 或 (4, 4) 棋盘状态
|
| 162 |
+
scores: (1, 2) 或 (2,) 分数特征
|
| 163 |
+
mask: (1, 4) 或 (4,) 动作mask
|
| 164 |
+
deterministic: 是否确定性选择
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
action: 选择的动作
|
| 168 |
+
log_prob: 动作的log概率
|
| 169 |
+
value: 状态价值
|
| 170 |
+
"""
|
| 171 |
+
# 确保维度正确
|
| 172 |
+
if board.dim() == 2:
|
| 173 |
+
board = board.unsqueeze(0)
|
| 174 |
+
if scores is not None and scores.dim() == 1:
|
| 175 |
+
scores = scores.unsqueeze(0)
|
| 176 |
+
if mask is not None and mask.dim() == 1:
|
| 177 |
+
mask = mask.unsqueeze(0)
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
policy_logits, value = self.forward(board, scores, mask)
|
| 181 |
+
probs = F.softmax(policy_logits, dim=-1)
|
| 182 |
+
|
| 183 |
+
if deterministic:
|
| 184 |
+
action = torch.argmax(probs, dim=-1).item()
|
| 185 |
+
else:
|
| 186 |
+
# 从概率分布采样
|
| 187 |
+
dist = torch.distributions.Categorical(probs)
|
| 188 |
+
action = dist.sample().item()
|
| 189 |
+
|
| 190 |
+
log_prob = F.log_softmax(policy_logits, dim=-1)[0, action].item()
|
| 191 |
+
|
| 192 |
+
return action, log_prob, value.item()
|
| 193 |
+
|
| 194 |
+
def evaluate_actions(
|
| 195 |
+
self,
|
| 196 |
+
board: torch.Tensor,
|
| 197 |
+
actions: torch.Tensor,
|
| 198 |
+
scores: torch.Tensor = None,
|
| 199 |
+
mask: torch.Tensor = None
|
| 200 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 201 |
+
"""
|
| 202 |
+
评估动作(用于训练)
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
board: (batch, 4, 4) 棋盘状态
|
| 206 |
+
actions: (batch,) 采取的动作
|
| 207 |
+
scores: (batch, 2) 分数特征
|
| 208 |
+
mask: (batch, 4) 动作mask
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
log_probs: (batch,) 动作log概率
|
| 212 |
+
values: (batch, 1) 状态价值
|
| 213 |
+
entropy: (batch,) 策略熵
|
| 214 |
+
"""
|
| 215 |
+
policy_logits, values = self.forward(board, scores, mask)
|
| 216 |
+
|
| 217 |
+
probs = F.softmax(policy_logits, dim=-1)
|
| 218 |
+
log_probs = F.log_softmax(policy_logits, dim=-1)
|
| 219 |
+
|
| 220 |
+
# 选择动作的log概率
|
| 221 |
+
action_log_probs = log_probs.gather(1, actions.unsqueeze(1)).squeeze(1)
|
| 222 |
+
|
| 223 |
+
# 计算熵
|
| 224 |
+
entropy = -(probs * log_probs).sum(dim=-1)
|
| 225 |
+
|
| 226 |
+
return action_log_probs, values, entropy
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
class ResidualBlock(nn.Module):
|
| 230 |
+
"""残差块"""
|
| 231 |
+
|
| 232 |
+
def __init__(self, channels: int):
|
| 233 |
+
super().__init__()
|
| 234 |
+
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 235 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 236 |
+
self.bn1 = nn.BatchNorm2d(channels)
|
| 237 |
+
self.bn2 = nn.BatchNorm2d(channels)
|
| 238 |
+
|
| 239 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 240 |
+
residual = x
|
| 241 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
| 242 |
+
x = self.bn2(self.conv2(x))
|
| 243 |
+
return F.relu(x + residual)
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
class Game2048CNN(nn.Module):
|
| 247 |
+
"""
|
| 248 |
+
CNN版本的2048模型(作为备选)
|
| 249 |
+
更简单,可能更快
|
| 250 |
+
"""
|
| 251 |
+
|
| 252 |
+
def __init__(self, channels: int = 64):
|
| 253 |
+
super().__init__()
|
| 254 |
+
|
| 255 |
+
# 输入: (batch, 1, 4, 4)
|
| 256 |
+
self.conv1 = nn.Conv2d(1, channels, kernel_size=3, padding=1)
|
| 257 |
+
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 258 |
+
self.conv3 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
|
| 259 |
+
|
| 260 |
+
# 分数处理
|
| 261 |
+
self.score_fc = nn.Linear(2, channels)
|
| 262 |
+
|
| 263 |
+
# 输出头
|
| 264 |
+
self.policy_head = nn.Sequential(
|
| 265 |
+
nn.Linear(channels * 16 + channels, 256),
|
| 266 |
+
nn.ReLU(),
|
| 267 |
+
nn.Linear(256, 4)
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
self.value_head = nn.Sequential(
|
| 271 |
+
nn.Linear(channels * 16 + channels, 256),
|
| 272 |
+
nn.ReLU(),
|
| 273 |
+
nn.Linear(256, 1),
|
| 274 |
+
nn.Tanh()
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def forward(
|
| 278 |
+
self,
|
| 279 |
+
board: torch.Tensor,
|
| 280 |
+
scores: torch.Tensor = None,
|
| 281 |
+
mask: torch.Tensor = None
|
| 282 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 283 |
+
"""前向传播"""
|
| 284 |
+
# 添加channel维度
|
| 285 |
+
x = board.unsqueeze(1) # (batch, 1, 4, 4)
|
| 286 |
+
|
| 287 |
+
# CNN特征提取
|
| 288 |
+
x = F.relu(self.conv1(x))
|
| 289 |
+
x = F.relu(self.conv2(x))
|
| 290 |
+
x = F.relu(self.conv3(x))
|
| 291 |
+
|
| 292 |
+
# 展平
|
| 293 |
+
x = x.view(x.size(0), -1) # (batch, channels*16)
|
| 294 |
+
|
| 295 |
+
# 合并分数
|
| 296 |
+
if scores is not None:
|
| 297 |
+
score_feat = F.relu(self.score_fc(scores))
|
| 298 |
+
x = torch.cat([x, score_feat], dim=-1)
|
| 299 |
+
else:
|
| 300 |
+
x = torch.cat([x, torch.zeros(x.size(0), 64, device=x.device)], dim=-1)
|
| 301 |
+
|
| 302 |
+
# 输出
|
| 303 |
+
policy_logits = self.policy_head(x)
|
| 304 |
+
value = self.value_head(x)
|
| 305 |
+
|
| 306 |
+
if mask is not None:
|
| 307 |
+
policy_logits = policy_logits.masked_fill(~mask, -1e9)
|
| 308 |
+
|
| 309 |
+
return policy_logits, value
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
def count_parameters(model: nn.Module) -> int:
|
| 313 |
+
"""计算模型参数数量"""
|
| 314 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
if __name__ == "__main__":
|
| 318 |
+
# 测试模型
|
| 319 |
+
device = torch.device("cpu")
|
| 320 |
+
|
| 321 |
+
# Transformer模型
|
| 322 |
+
model = Game2048Transformer().to(device)
|
| 323 |
+
print(f"Transformer参数量: {count_parameters(model):,}")
|
| 324 |
+
|
| 325 |
+
# 测试前向传播
|
| 326 |
+
batch_size = 4
|
| 327 |
+
board = torch.randint(0, 12, (batch_size, 4, 4), dtype=torch.float32).to(device)
|
| 328 |
+
scores = torch.rand(batch_size, 2).to(device)
|
| 329 |
+
mask = torch.ones(batch_size, 4, dtype=torch.bool).to(device)
|
| 330 |
+
|
| 331 |
+
policy_logits, value = model(board, scores, mask)
|
| 332 |
+
print(f"Policy logits shape: {policy_logits.shape}")
|
| 333 |
+
print(f"Value shape: {value.shape}")
|
| 334 |
+
|
| 335 |
+
# 测试动作选择
|
| 336 |
+
action, log_prob, val = model.get_action(board[0], scores[0], mask[0])
|
| 337 |
+
print(f"Action: {action}, Log prob: {log_prob:.4f}, Value: {val:.4f}")
|
| 338 |
+
|
| 339 |
+
# CNN模型
|
| 340 |
+
cnn_model = Game2048CNN().to(device)
|
| 341 |
+
print(f"\nCNN参数量: {count_parameters(cnn_model):,}")
|
parallel.py
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
多进程并行训练环境
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from typing import List, Tuple, Optional
|
| 7 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 8 |
+
import threading
|
| 9 |
+
from game import Game2048
|
| 10 |
+
from trainer import Transition, TrainingStats
|
| 11 |
+
import time
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ParallelGameEnv:
|
| 15 |
+
"""
|
| 16 |
+
并行游戏环境
|
| 17 |
+
使用多线程同时运行多个游戏实例
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
def __init__(self, num_envs: int = 4):
|
| 21 |
+
"""
|
| 22 |
+
初始化并行环境
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
num_envs: 并行游戏数量
|
| 26 |
+
"""
|
| 27 |
+
self.num_envs = num_envs
|
| 28 |
+
self.envs = [Game2048() for _ in range(num_envs)]
|
| 29 |
+
self.states = [env.get_state() for env in self.envs]
|
| 30 |
+
self.scores = [env.accumulated_score for env in self.envs]
|
| 31 |
+
self.situational_scores = [env.situational_score for env in self.envs]
|
| 32 |
+
|
| 33 |
+
# 游戏结束回调
|
| 34 |
+
self.on_game_end = None
|
| 35 |
+
|
| 36 |
+
# 已完成游戏计数
|
| 37 |
+
self.games_completed = 0
|
| 38 |
+
|
| 39 |
+
# 使用线程池
|
| 40 |
+
self.executor = ThreadPoolExecutor(max_workers=num_envs)
|
| 41 |
+
|
| 42 |
+
def reset(self, indices: Optional[List[int]] = None) -> Tuple[np.ndarray, np.ndarray]:
|
| 43 |
+
"""
|
| 44 |
+
重置指定环境
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
indices: 要重置的环境索引,None表示全部重置
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
states: (num_envs, 4, 4) 状态数组
|
| 51 |
+
score_features: (num_envs, 2) 分数特征数组
|
| 52 |
+
"""
|
| 53 |
+
if indices is None:
|
| 54 |
+
indices = range(self.num_envs)
|
| 55 |
+
|
| 56 |
+
for i in indices:
|
| 57 |
+
self.states[i] = self.envs[i].reset()
|
| 58 |
+
self.scores[i] = self.envs[i].accumulated_score
|
| 59 |
+
self.situational_scores[i] = self.envs[i].situational_score
|
| 60 |
+
|
| 61 |
+
return self._get_batch_state()
|
| 62 |
+
|
| 63 |
+
def reset_single(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
|
| 64 |
+
"""重置单个环境"""
|
| 65 |
+
self.states[idx] = self.envs[idx].reset()
|
| 66 |
+
self.scores[idx] = self.envs[idx].accumulated_score
|
| 67 |
+
self.situational_scores[idx] = self.envs[idx].situational_score
|
| 68 |
+
return self.states[idx], self._get_score_features(idx)
|
| 69 |
+
|
| 70 |
+
def step(self, actions: List[int]) -> List[Transition]:
|
| 71 |
+
"""
|
| 72 |
+
并行执行动作
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
actions: 每个环境要执行的动作列表
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
transitions: 状态转移列表
|
| 79 |
+
"""
|
| 80 |
+
transitions = []
|
| 81 |
+
|
| 82 |
+
for i, action in enumerate(actions):
|
| 83 |
+
old_state = self.states[i].copy()
|
| 84 |
+
old_scores = self._get_score_features(i)
|
| 85 |
+
old_situational = self.situational_scores[i]
|
| 86 |
+
|
| 87 |
+
# 执行动作
|
| 88 |
+
new_state, reward, moved, done = self.envs[i].move(action)
|
| 89 |
+
|
| 90 |
+
# 更新状态
|
| 91 |
+
self.states[i] = new_state
|
| 92 |
+
self.scores[i] = self.envs[i].accumulated_score
|
| 93 |
+
self.situational_scores[i] = self.envs[i].situational_score
|
| 94 |
+
|
| 95 |
+
# 创建转移记录
|
| 96 |
+
transition = Transition(
|
| 97 |
+
state=old_state,
|
| 98 |
+
scores=old_scores,
|
| 99 |
+
action=action,
|
| 100 |
+
reward=reward,
|
| 101 |
+
next_state=new_state.copy(),
|
| 102 |
+
next_scores=self._get_score_features(i),
|
| 103 |
+
done=done,
|
| 104 |
+
log_prob=0.0, # 需要在外部填充
|
| 105 |
+
value=0.0, # 需要在外部填充
|
| 106 |
+
valid_actions=self.envs[i].get_valid_actions()
|
| 107 |
+
)
|
| 108 |
+
transitions.append(transition)
|
| 109 |
+
|
| 110 |
+
# 如果游戏结束,记录统计并重置
|
| 111 |
+
if done:
|
| 112 |
+
# 记录游戏统计
|
| 113 |
+
game_stats = {
|
| 114 |
+
'score': self.scores[i],
|
| 115 |
+
'situational_score': self.situational_scores[i],
|
| 116 |
+
'max_tile': self.envs[i].get_max_tile(),
|
| 117 |
+
'moves': self.envs[i].moves_count
|
| 118 |
+
}
|
| 119 |
+
self.games_completed += 1
|
| 120 |
+
|
| 121 |
+
# 调用回调
|
| 122 |
+
if self.on_game_end:
|
| 123 |
+
self.on_game_end(game_stats)
|
| 124 |
+
|
| 125 |
+
self.reset_single(i)
|
| 126 |
+
|
| 127 |
+
return transitions
|
| 128 |
+
|
| 129 |
+
def _get_batch_state(self) -> Tuple[np.ndarray, np.ndarray]:
|
| 130 |
+
"""获取批量状态"""
|
| 131 |
+
states = np.array(self.states, dtype=np.float32)
|
| 132 |
+
score_features = np.array([
|
| 133 |
+
self._get_score_features(i) for i in range(self.num_envs)
|
| 134 |
+
], dtype=np.float32)
|
| 135 |
+
return states, score_features
|
| 136 |
+
|
| 137 |
+
def _get_score_features(self, idx: int) -> np.ndarray:
|
| 138 |
+
"""获取单个环境的分数特征"""
|
| 139 |
+
max_accumulated = 50000
|
| 140 |
+
max_situational = 200
|
| 141 |
+
return np.array([
|
| 142 |
+
min(self.scores[idx] / max_accumulated, 1.0),
|
| 143 |
+
min(self.situational_scores[idx] / max_situational, 1.0)
|
| 144 |
+
], dtype=np.float32)
|
| 145 |
+
|
| 146 |
+
def get_valid_actions(self) -> np.ndarray:
|
| 147 |
+
"""获取所有环境的有效动作"""
|
| 148 |
+
return np.array([env.get_valid_actions() for env in self.envs])
|
| 149 |
+
|
| 150 |
+
def get_game_stats(self) -> List[dict]:
|
| 151 |
+
"""获取所有游戏的统计信息"""
|
| 152 |
+
return [
|
| 153 |
+
{
|
| 154 |
+
'score': env.accumulated_score,
|
| 155 |
+
'situational_score': env.situational_score,
|
| 156 |
+
'max_tile': env.get_max_tile(),
|
| 157 |
+
'moves': env.moves_count,
|
| 158 |
+
'game_over': env.game_over
|
| 159 |
+
}
|
| 160 |
+
for env in self.envs
|
| 161 |
+
]
|
| 162 |
+
|
| 163 |
+
def close(self):
|
| 164 |
+
"""关闭环境"""
|
| 165 |
+
self.executor.shutdown(wait=False)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class TrainingWorker:
|
| 169 |
+
"""
|
| 170 |
+
训练工作器
|
| 171 |
+
负责收集轨迹数据
|
| 172 |
+
"""
|
| 173 |
+
|
| 174 |
+
def __init__(
|
| 175 |
+
self,
|
| 176 |
+
model,
|
| 177 |
+
env: ParallelGameEnv,
|
| 178 |
+
device: str = "cpu"
|
| 179 |
+
):
|
| 180 |
+
self.model = model.to(device)
|
| 181 |
+
self.model.eval()
|
| 182 |
+
self.env = env
|
| 183 |
+
self.device = device
|
| 184 |
+
|
| 185 |
+
self.stats = TrainingStats()
|
| 186 |
+
|
| 187 |
+
def collect_trajectories(
|
| 188 |
+
self,
|
| 189 |
+
num_steps: int = 256,
|
| 190 |
+
deterministic: bool = False
|
| 191 |
+
) -> List[Transition]:
|
| 192 |
+
"""
|
| 193 |
+
收集轨迹数据
|
| 194 |
+
|
| 195 |
+
Args:
|
| 196 |
+
num_steps: 每个环境收集的步数
|
| 197 |
+
deterministic: 是否确定性选择动作
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
transitions: 收集的转移数据
|
| 201 |
+
"""
|
| 202 |
+
all_transitions = []
|
| 203 |
+
|
| 204 |
+
for _ in range(num_steps):
|
| 205 |
+
# 获取当前状态
|
| 206 |
+
states = np.array(self.env.states, dtype=np.float32)
|
| 207 |
+
score_features = np.array([
|
| 208 |
+
self.env._get_score_features(i)
|
| 209 |
+
for i in range(self.env.num_envs)
|
| 210 |
+
], dtype=np.float32)
|
| 211 |
+
valid_actions = self.env.get_valid_actions()
|
| 212 |
+
|
| 213 |
+
# 转换为张量
|
| 214 |
+
states_t = torch.FloatTensor(states).to(self.device)
|
| 215 |
+
scores_t = torch.FloatTensor(score_features).to(self.device)
|
| 216 |
+
valid_t = torch.BoolTensor(valid_actions).to(self.device)
|
| 217 |
+
|
| 218 |
+
# 选择动作
|
| 219 |
+
actions = []
|
| 220 |
+
log_probs = []
|
| 221 |
+
values = []
|
| 222 |
+
|
| 223 |
+
with torch.no_grad():
|
| 224 |
+
for i in range(self.env.num_envs):
|
| 225 |
+
action, log_prob, value = self.model.get_action(
|
| 226 |
+
states_t[i:i+1],
|
| 227 |
+
scores_t[i:i+1],
|
| 228 |
+
valid_t[i:i+1],
|
| 229 |
+
deterministic=deterministic
|
| 230 |
+
)
|
| 231 |
+
actions.append(action)
|
| 232 |
+
log_probs.append(log_prob)
|
| 233 |
+
values.append(value)
|
| 234 |
+
|
| 235 |
+
# 执行动作
|
| 236 |
+
transitions = self.env.step(actions)
|
| 237 |
+
|
| 238 |
+
# 填充log_prob和value
|
| 239 |
+
for i, t in enumerate(transitions):
|
| 240 |
+
t.log_prob = log_probs[i]
|
| 241 |
+
t.value = values[i]
|
| 242 |
+
all_transitions.append(t)
|
| 243 |
+
|
| 244 |
+
return all_transitions
|
| 245 |
+
|
| 246 |
+
def run_episode(
|
| 247 |
+
self,
|
| 248 |
+
deterministic: bool = True,
|
| 249 |
+
max_steps: int = 10000
|
| 250 |
+
) -> dict:
|
| 251 |
+
"""
|
| 252 |
+
运行一局演示游戏
|
| 253 |
+
|
| 254 |
+
Args:
|
| 255 |
+
deterministic: 是否确定性选择
|
| 256 |
+
max_steps: 最大步数
|
| 257 |
+
|
| 258 |
+
Returns:
|
| 259 |
+
游戏统计信息
|
| 260 |
+
"""
|
| 261 |
+
# 重置环境
|
| 262 |
+
env = Game2048()
|
| 263 |
+
state = env.reset()
|
| 264 |
+
|
| 265 |
+
total_reward = 0
|
| 266 |
+
steps = 0
|
| 267 |
+
|
| 268 |
+
while not env.game_over and steps < max_steps:
|
| 269 |
+
# 获取分数特征
|
| 270 |
+
max_accumulated = 50000
|
| 271 |
+
max_situational = 200
|
| 272 |
+
scores = np.array([
|
| 273 |
+
min(env.accumulated_score / max_accumulated, 1.0),
|
| 274 |
+
min(env.situational_score / max_situational, 1.0)
|
| 275 |
+
], dtype=np.float32)
|
| 276 |
+
|
| 277 |
+
# 获取有效动作
|
| 278 |
+
valid_actions = env.get_valid_actions()
|
| 279 |
+
|
| 280 |
+
# 转换为张量
|
| 281 |
+
state_t = torch.FloatTensor(state).unsqueeze(0).to(self.device)
|
| 282 |
+
scores_t = torch.FloatTensor(scores).unsqueeze(0).to(self.device)
|
| 283 |
+
valid_t = torch.BoolTensor(valid_actions).unsqueeze(0).to(self.device)
|
| 284 |
+
|
| 285 |
+
# 选择动作
|
| 286 |
+
with torch.no_grad():
|
| 287 |
+
action, _, _ = self.model.get_action(
|
| 288 |
+
state_t, scores_t, valid_t, deterministic=deterministic
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# 执行动作
|
| 292 |
+
state, reward, moved, done = env.move(action)
|
| 293 |
+
total_reward += reward
|
| 294 |
+
steps += 1
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
'score': env.accumulated_score,
|
| 298 |
+
'situational_score': env.situational_score,
|
| 299 |
+
'max_tile': env.get_max_tile(),
|
| 300 |
+
'steps': steps,
|
| 301 |
+
'total_reward': total_reward
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class TrainingLoop:
|
| 306 |
+
"""
|
| 307 |
+
完整的训练循环
|
| 308 |
+
"""
|
| 309 |
+
|
| 310 |
+
def __init__(
|
| 311 |
+
self,
|
| 312 |
+
model,
|
| 313 |
+
trainer,
|
| 314 |
+
num_envs: int = 4,
|
| 315 |
+
device: str = "cpu",
|
| 316 |
+
steps_per_update: int = 256,
|
| 317 |
+
save_interval: int = 100,
|
| 318 |
+
checkpoint_dir: str = "checkpoints"
|
| 319 |
+
):
|
| 320 |
+
self.model = model
|
| 321 |
+
self.trainer = trainer
|
| 322 |
+
self.num_envs = num_envs
|
| 323 |
+
self.device = device
|
| 324 |
+
self.steps_per_update = steps_per_update
|
| 325 |
+
self.save_interval = save_interval
|
| 326 |
+
self.checkpoint_dir = checkpoint_dir
|
| 327 |
+
|
| 328 |
+
# 初始化环境和工作器
|
| 329 |
+
self.env = ParallelGameEnv(num_envs=num_envs)
|
| 330 |
+
self.worker = TrainingWorker(model, self.env, device)
|
| 331 |
+
|
| 332 |
+
# 训练状态
|
| 333 |
+
self.training = False
|
| 334 |
+
self.paused = False
|
| 335 |
+
self.stats = TrainingStats()
|
| 336 |
+
|
| 337 |
+
# 回调函数
|
| 338 |
+
self.on_update_callback = None
|
| 339 |
+
self.on_game_end_callback = None
|
| 340 |
+
|
| 341 |
+
def train(
|
| 342 |
+
self,
|
| 343 |
+
total_games: int = 10000,
|
| 344 |
+
stop_threshold: int = 100,
|
| 345 |
+
min_improvement: float = 0.01
|
| 346 |
+
) -> None:
|
| 347 |
+
"""
|
| 348 |
+
训练循环
|
| 349 |
+
|
| 350 |
+
Args:
|
| 351 |
+
total_games: 总游戏局数
|
| 352 |
+
stop_threshold: 无提升停止阈值(局数)
|
| 353 |
+
min_improvement: 最小提升比例
|
| 354 |
+
"""
|
| 355 |
+
self.training = True
|
| 356 |
+
|
| 357 |
+
games_since_improvement = 0
|
| 358 |
+
best_avg_score = 0
|
| 359 |
+
|
| 360 |
+
# 设置游戏结束回调
|
| 361 |
+
def on_game_end(game_stats):
|
| 362 |
+
self.stats.record_game(
|
| 363 |
+
score=game_stats['score'],
|
| 364 |
+
situational_score=game_stats['situational_score'],
|
| 365 |
+
max_tile=game_stats['max_tile'],
|
| 366 |
+
steps=game_stats['moves']
|
| 367 |
+
)
|
| 368 |
+
if self.on_game_end_callback:
|
| 369 |
+
self.on_game_end_callback(game_stats)
|
| 370 |
+
|
| 371 |
+
self.env.on_game_end = on_game_end
|
| 372 |
+
|
| 373 |
+
try:
|
| 374 |
+
while self.training and self.env.games_completed < total_games:
|
| 375 |
+
if self.paused:
|
| 376 |
+
time.sleep(0.1)
|
| 377 |
+
continue
|
| 378 |
+
|
| 379 |
+
# 收集轨迹
|
| 380 |
+
transitions = self.worker.collect_trajectories(
|
| 381 |
+
num_steps=self.steps_per_update // self.num_envs,
|
| 382 |
+
deterministic=False
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
# 更新模型
|
| 386 |
+
from trainer import RolloutBuffer
|
| 387 |
+
buffer = RolloutBuffer(capacity=len(transitions))
|
| 388 |
+
buffer.push_batch(transitions)
|
| 389 |
+
|
| 390 |
+
update_stats = self.trainer.update(buffer)
|
| 391 |
+
|
| 392 |
+
if self.on_update_callback:
|
| 393 |
+
self.on_update_callback(update_stats)
|
| 394 |
+
|
| 395 |
+
# 检查停止条件(每10次更新检查一次)
|
| 396 |
+
if self.env.games_completed % 10 == 0 and self.env.games_completed > 0:
|
| 397 |
+
current_avg = self.stats.get_avg_stats(window=100)['avg_score']
|
| 398 |
+
if current_avg > best_avg_score * (1 + min_improvement):
|
| 399 |
+
best_avg_score = current_avg
|
| 400 |
+
games_since_improvement = 0
|
| 401 |
+
else:
|
| 402 |
+
games_since_improvement = self.env.games_completed - int(best_avg_score / 100 * 100) if best_avg_score > 0 else 0
|
| 403 |
+
|
| 404 |
+
if games_since_improvement >= stop_threshold:
|
| 405 |
+
print(f"No improvement for {stop_threshold} updates, stopping.")
|
| 406 |
+
break
|
| 407 |
+
|
| 408 |
+
except KeyboardInterrupt:
|
| 409 |
+
print("Training interrupted by user.")
|
| 410 |
+
finally:
|
| 411 |
+
self.training = False
|
| 412 |
+
self.env.close()
|
| 413 |
+
|
| 414 |
+
def stop(self) -> None:
|
| 415 |
+
"""停止训练"""
|
| 416 |
+
self.training = False
|
| 417 |
+
|
| 418 |
+
def pause(self) -> None:
|
| 419 |
+
"""暂停训练"""
|
| 420 |
+
self.paused = True
|
| 421 |
+
|
| 422 |
+
def resume(self) -> None:
|
| 423 |
+
"""恢复训练"""
|
| 424 |
+
self.paused = False
|
| 425 |
+
|
| 426 |
+
def get_stats(self) -> dict:
|
| 427 |
+
"""获取当前统计信息"""
|
| 428 |
+
return self.stats.get_avg_stats()
|
| 429 |
+
|
| 430 |
+
def save_checkpoint(self, path: str) -> None:
|
| 431 |
+
"""保存模型检查点"""
|
| 432 |
+
torch.save({
|
| 433 |
+
'model_state_dict': self.model.state_dict(),
|
| 434 |
+
'optimizer_state_dict': self.trainer.optimizer.state_dict(),
|
| 435 |
+
'stats': self.stats.get_avg_stats(),
|
| 436 |
+
'games_played': self.stats.games_played
|
| 437 |
+
}, path)
|
| 438 |
+
|
| 439 |
+
def load_checkpoint(self, path: str) -> None:
|
| 440 |
+
"""加载模型检查点"""
|
| 441 |
+
checkpoint = torch.load(path, map_location=self.device)
|
| 442 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 443 |
+
self.trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
from model import Game2048Transformer
|
| 448 |
+
from trainer import PPOTrainer
|
| 449 |
+
|
| 450 |
+
# 测试并行环境
|
| 451 |
+
env = ParallelGameEnv(num_envs=4)
|
| 452 |
+
env.reset()
|
| 453 |
+
|
| 454 |
+
print("Testing parallel environment...")
|
| 455 |
+
for i in range(10):
|
| 456 |
+
actions = [np.random.randint(0, 4) for _ in range(env.num_envs)]
|
| 457 |
+
transitions = env.step(actions)
|
| 458 |
+
print(f"Step {i}: collected {len(transitions)} transitions")
|
| 459 |
+
|
| 460 |
+
env.close()
|
| 461 |
+
print("Parallel environment test passed!")
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 2048 AI Trainer 依赖
|
| 2 |
+
|
| 3 |
+
# 深度学习框架
|
| 4 |
+
torch>=2.0.0
|
| 5 |
+
|
| 6 |
+
# 数值计算 (需要 <2 以兼容 torch)
|
| 7 |
+
numpy<2
|
| 8 |
+
|
| 9 |
+
# GUI
|
| 10 |
+
PyQt5>=5.15.0
|
| 11 |
+
|
| 12 |
+
# 绘图
|
| 13 |
+
matplotlib>=3.7.0
|
train_simple.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
简单训练脚本 - 直接训练并保存模型
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
import time
|
| 7 |
+
import torch
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 11 |
+
|
| 12 |
+
from game import Game2048
|
| 13 |
+
from model import Game2048Transformer
|
| 14 |
+
from trainer import PPOTrainer, RolloutBuffer
|
| 15 |
+
|
| 16 |
+
def train_simple(
|
| 17 |
+
num_games: int = 1000,
|
| 18 |
+
save_path: str = "checkpoints/model.pt",
|
| 19 |
+
print_interval: int = 10
|
| 20 |
+
):
|
| 21 |
+
"""简单训练"""
|
| 22 |
+
print("=" * 50)
|
| 23 |
+
print("2048 AI Simple Training")
|
| 24 |
+
print("=" * 50)
|
| 25 |
+
|
| 26 |
+
device = "cpu"
|
| 27 |
+
model = Game2048Transformer().to(device)
|
| 28 |
+
trainer = PPOTrainer(model, lr=3e-4, device=device)
|
| 29 |
+
|
| 30 |
+
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
|
| 31 |
+
print(f"Training for {num_games} games...")
|
| 32 |
+
print("-" * 50)
|
| 33 |
+
|
| 34 |
+
# 统计
|
| 35 |
+
scores = []
|
| 36 |
+
max_tiles = []
|
| 37 |
+
best_score = 0
|
| 38 |
+
|
| 39 |
+
start_time = time.time()
|
| 40 |
+
|
| 41 |
+
for game_idx in range(num_games):
|
| 42 |
+
# 运行一局游戏
|
| 43 |
+
game = Game2048()
|
| 44 |
+
game.reset()
|
| 45 |
+
|
| 46 |
+
buffer = RolloutBuffer(capacity=10000)
|
| 47 |
+
|
| 48 |
+
while not game.game_over:
|
| 49 |
+
state = game.get_state()
|
| 50 |
+
scores_feat = np.array([
|
| 51 |
+
min(game.accumulated_score / 50000, 1.0),
|
| 52 |
+
min(game.situational_score / 200, 1.0)
|
| 53 |
+
], dtype=np.float32)
|
| 54 |
+
valid = game.get_valid_actions()
|
| 55 |
+
|
| 56 |
+
# 转换张量
|
| 57 |
+
state_t = torch.FloatTensor(state).unsqueeze(0).to(device)
|
| 58 |
+
scores_t = torch.FloatTensor(scores_feat).unsqueeze(0).to(device)
|
| 59 |
+
valid_t = torch.BoolTensor(valid).unsqueeze(0).to(device)
|
| 60 |
+
|
| 61 |
+
# 选择动作
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
action, log_prob, value = model.get_action(state_t, scores_t, valid_t, deterministic=False)
|
| 64 |
+
|
| 65 |
+
# 执行动作
|
| 66 |
+
old_state = state.copy()
|
| 67 |
+
old_scores = scores_feat.copy()
|
| 68 |
+
|
| 69 |
+
next_state, reward, moved, done = game.move(action)
|
| 70 |
+
|
| 71 |
+
# 存储转移
|
| 72 |
+
from trainer import Transition
|
| 73 |
+
transition = Transition(
|
| 74 |
+
state=old_state,
|
| 75 |
+
scores=old_scores,
|
| 76 |
+
action=action,
|
| 77 |
+
reward=reward,
|
| 78 |
+
next_state=next_state.copy(),
|
| 79 |
+
next_scores=np.array([
|
| 80 |
+
min(game.accumulated_score / 50000, 1.0),
|
| 81 |
+
min(game.situational_score / 200, 1.0)
|
| 82 |
+
], dtype=np.float32),
|
| 83 |
+
done=done,
|
| 84 |
+
log_prob=log_prob,
|
| 85 |
+
value=value,
|
| 86 |
+
valid_actions=valid
|
| 87 |
+
)
|
| 88 |
+
buffer.push(transition)
|
| 89 |
+
|
| 90 |
+
# 每步更新
|
| 91 |
+
if len(buffer) >= 64:
|
| 92 |
+
trainer.update(buffer)
|
| 93 |
+
buffer.clear()
|
| 94 |
+
|
| 95 |
+
# 记录结果
|
| 96 |
+
scores.append(game.accumulated_score)
|
| 97 |
+
max_tiles.append(game.get_max_tile())
|
| 98 |
+
|
| 99 |
+
if game.accumulated_score > best_score:
|
| 100 |
+
best_score = game.accumulated_score
|
| 101 |
+
|
| 102 |
+
# 打印进度
|
| 103 |
+
if (game_idx + 1) % print_interval == 0:
|
| 104 |
+
elapsed = time.time() - start_time
|
| 105 |
+
avg_score = np.mean(scores[-print_interval:])
|
| 106 |
+
avg_max_tile = np.mean(max_tiles[-print_interval:])
|
| 107 |
+
speed = (game_idx + 1) / elapsed
|
| 108 |
+
|
| 109 |
+
print(
|
| 110 |
+
f"Game {game_idx + 1}/{num_games} | "
|
| 111 |
+
f"Avg Score: {avg_score:.0f} | "
|
| 112 |
+
f"Best: {best_score} | "
|
| 113 |
+
f"Max Tile: {avg_max_tile:.0f} | "
|
| 114 |
+
f"Speed: {speed:.2f} games/s"
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
# 保存模型
|
| 118 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 119 |
+
torch.save({
|
| 120 |
+
'model_state_dict': model.state_dict(),
|
| 121 |
+
'game_idx': game_idx,
|
| 122 |
+
'best_score': best_score,
|
| 123 |
+
'avg_score': avg_score
|
| 124 |
+
}, save_path)
|
| 125 |
+
|
| 126 |
+
# 最终保存
|
| 127 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 128 |
+
torch.save({
|
| 129 |
+
'model_state_dict': model.state_dict(),
|
| 130 |
+
'game_idx': num_games,
|
| 131 |
+
'best_score': best_score,
|
| 132 |
+
'avg_score': np.mean(scores[-100:])
|
| 133 |
+
}, save_path)
|
| 134 |
+
|
| 135 |
+
elapsed = time.time() - start_time
|
| 136 |
+
print("\n" + "=" * 50)
|
| 137 |
+
print("Training Complete!")
|
| 138 |
+
print("=" * 50)
|
| 139 |
+
print(f"Total games: {num_games}")
|
| 140 |
+
print(f"Total time: {elapsed:.1f}s")
|
| 141 |
+
print(f"Average score (last 100): {np.mean(scores[-100:]):.0f}")
|
| 142 |
+
print(f"Best score: {best_score}")
|
| 143 |
+
print(f"Best max tile: {max(max_tiles)}")
|
| 144 |
+
print(f"Model saved to: {save_path}")
|
| 145 |
+
|
| 146 |
+
return model
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
train_simple(num_games=500)
|
trainer.py
ADDED
|
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
PPO训练器
|
| 3 |
+
"""
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
import numpy as np
|
| 9 |
+
from typing import List, Tuple, Optional
|
| 10 |
+
from dataclasses import dataclass
|
| 11 |
+
from collections import deque
|
| 12 |
+
import random
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class Transition:
|
| 17 |
+
"""状态转移数据"""
|
| 18 |
+
state: np.ndarray # 棋盘状态 (4, 4)
|
| 19 |
+
scores: np.ndarray # 分数特征 (2,)
|
| 20 |
+
action: int # 采取的动作
|
| 21 |
+
reward: float # 奖励
|
| 22 |
+
next_state: np.ndarray # 下一状态
|
| 23 |
+
next_scores: np.ndarray # 下一分数
|
| 24 |
+
done: bool # 是否结束
|
| 25 |
+
log_prob: float # 动作的log概率
|
| 26 |
+
value: float # 状态价值
|
| 27 |
+
valid_actions: np.ndarray # 有效动作mask
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class RolloutBuffer:
|
| 31 |
+
"""存储轨迹数据的缓冲区"""
|
| 32 |
+
|
| 33 |
+
def __init__(self, capacity: int = 10000):
|
| 34 |
+
self.capacity = capacity
|
| 35 |
+
self.buffer: List[Transition] = []
|
| 36 |
+
self.position = 0
|
| 37 |
+
|
| 38 |
+
def push(self, transition: Transition) -> None:
|
| 39 |
+
"""添加一个转移"""
|
| 40 |
+
if len(self.buffer) < self.capacity:
|
| 41 |
+
self.buffer.append(transition)
|
| 42 |
+
else:
|
| 43 |
+
self.buffer[self.position] = transition
|
| 44 |
+
self.position = (self.position + 1) % self.capacity
|
| 45 |
+
|
| 46 |
+
def push_batch(self, transitions: List[Transition]) -> None:
|
| 47 |
+
"""批量添加转移"""
|
| 48 |
+
for t in transitions:
|
| 49 |
+
self.push(t)
|
| 50 |
+
|
| 51 |
+
def get_all(self) -> List[Transition]:
|
| 52 |
+
"""获取所有数据"""
|
| 53 |
+
return self.buffer.copy()
|
| 54 |
+
|
| 55 |
+
def clear(self) -> None:
|
| 56 |
+
"""清空缓冲区"""
|
| 57 |
+
self.buffer = []
|
| 58 |
+
self.position = 0
|
| 59 |
+
|
| 60 |
+
def __len__(self) -> int:
|
| 61 |
+
return len(self.buffer)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class PPOTrainer:
|
| 65 |
+
"""PPO训练器"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
model,
|
| 70 |
+
lr: float = 1e-4,
|
| 71 |
+
gamma: float = 0.99,
|
| 72 |
+
gae_lambda: float = 0.95,
|
| 73 |
+
clip_ratio: float = 0.2,
|
| 74 |
+
value_coef: float = 0.5,
|
| 75 |
+
entropy_coef: float = 0.01,
|
| 76 |
+
max_grad_norm: float = 0.5,
|
| 77 |
+
update_epochs: int = 4,
|
| 78 |
+
batch_size: int = 64,
|
| 79 |
+
device: str = "cpu"
|
| 80 |
+
):
|
| 81 |
+
self.model = model.to(device)
|
| 82 |
+
self.device = device
|
| 83 |
+
|
| 84 |
+
self.gamma = gamma
|
| 85 |
+
self.gae_lambda = gae_lambda
|
| 86 |
+
self.clip_ratio = clip_ratio
|
| 87 |
+
self.value_coef = value_coef
|
| 88 |
+
self.entropy_coef = entropy_coef
|
| 89 |
+
self.max_grad_norm = max_grad_norm
|
| 90 |
+
self.update_epochs = update_epochs
|
| 91 |
+
self.batch_size = batch_size
|
| 92 |
+
|
| 93 |
+
self.optimizer = optim.Adam(model.parameters(), lr=lr)
|
| 94 |
+
|
| 95 |
+
# 训练统计
|
| 96 |
+
self.stats = {
|
| 97 |
+
'policy_loss': deque(maxlen=100),
|
| 98 |
+
'value_loss': deque(maxlen=100),
|
| 99 |
+
'entropy': deque(maxlen=100),
|
| 100 |
+
'total_loss': deque(maxlen=100)
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
def compute_gae(
|
| 104 |
+
self,
|
| 105 |
+
rewards: np.ndarray,
|
| 106 |
+
values: np.ndarray,
|
| 107 |
+
dones: np.ndarray,
|
| 108 |
+
next_value: float = 0.0
|
| 109 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
| 110 |
+
"""
|
| 111 |
+
计算Generalized Advantage Estimation (GAE)
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
rewards: 奖励序列 (T,)
|
| 115 |
+
values: 价值序列 (T,)
|
| 116 |
+
dones: 结束标志序列 (T,)
|
| 117 |
+
next_value: 最后状态的下一个价值
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
returns: 回报 (T,)
|
| 121 |
+
advantages: 优势 (T,)
|
| 122 |
+
"""
|
| 123 |
+
T = len(rewards)
|
| 124 |
+
advantages = np.zeros(T, dtype=np.float32)
|
| 125 |
+
returns = np.zeros(T, dtype=np.float32)
|
| 126 |
+
|
| 127 |
+
last_gae = 0
|
| 128 |
+
last_return = next_value
|
| 129 |
+
|
| 130 |
+
for t in reversed(range(T)):
|
| 131 |
+
if dones[t]:
|
| 132 |
+
next_value_t = 0
|
| 133 |
+
last_gae = 0
|
| 134 |
+
else:
|
| 135 |
+
next_value_t = values[t + 1] if t + 1 < T else next_value
|
| 136 |
+
|
| 137 |
+
delta = rewards[t] + self.gamma * next_value_t - values[t]
|
| 138 |
+
last_gae = delta + self.gamma * self.gae_lambda * (1 - dones[t]) * last_gae
|
| 139 |
+
advantages[t] = last_gae
|
| 140 |
+
|
| 141 |
+
last_return = rewards[t] + self.gamma * (1 - dones[t]) * last_return
|
| 142 |
+
returns[t] = last_return
|
| 143 |
+
|
| 144 |
+
return returns, advantages
|
| 145 |
+
|
| 146 |
+
def update(self, buffer: RolloutBuffer) -> dict:
|
| 147 |
+
"""
|
| 148 |
+
使用PPO更新模型
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
buffer: 存储轨迹数据的缓冲区
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
训练统计信息
|
| 155 |
+
"""
|
| 156 |
+
if len(buffer) < self.batch_size:
|
| 157 |
+
return {}
|
| 158 |
+
|
| 159 |
+
# 获取所有数据
|
| 160 |
+
transitions = buffer.get_all()
|
| 161 |
+
|
| 162 |
+
# 转换为数组
|
| 163 |
+
states = np.array([t.state for t in transitions])
|
| 164 |
+
scores = np.array([t.scores for t in transitions])
|
| 165 |
+
actions = np.array([t.action for t in transitions])
|
| 166 |
+
rewards = np.array([t.reward for t in transitions])
|
| 167 |
+
dones = np.array([t.done for t in transitions])
|
| 168 |
+
old_log_probs = np.array([t.log_prob for t in transitions])
|
| 169 |
+
old_values = np.array([t.value for t in transitions])
|
| 170 |
+
valid_actions = np.array([t.valid_actions for t in transitions])
|
| 171 |
+
|
| 172 |
+
# 计算优势和回报
|
| 173 |
+
returns, advantages = self.compute_gae(rewards, old_values, dones)
|
| 174 |
+
|
| 175 |
+
# 标准化优势
|
| 176 |
+
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
|
| 177 |
+
|
| 178 |
+
# 转换为张量
|
| 179 |
+
states_t = torch.FloatTensor(states).to(self.device)
|
| 180 |
+
scores_t = torch.FloatTensor(scores).to(self.device)
|
| 181 |
+
actions_t = torch.LongTensor(actions).to(self.device)
|
| 182 |
+
old_log_probs_t = torch.FloatTensor(old_log_probs).to(self.device)
|
| 183 |
+
returns_t = torch.FloatTensor(returns).to(self.device)
|
| 184 |
+
advantages_t = torch.FloatTensor(advantages).to(self.device)
|
| 185 |
+
valid_actions_t = torch.BoolTensor(valid_actions).to(self.device)
|
| 186 |
+
|
| 187 |
+
# PPO更新
|
| 188 |
+
total_policy_loss = 0
|
| 189 |
+
total_value_loss = 0
|
| 190 |
+
total_entropy = 0
|
| 191 |
+
num_updates = 0
|
| 192 |
+
|
| 193 |
+
dataset_size = len(transitions)
|
| 194 |
+
indices = np.arange(dataset_size)
|
| 195 |
+
|
| 196 |
+
for _ in range(self.update_epochs):
|
| 197 |
+
np.random.shuffle(indices)
|
| 198 |
+
|
| 199 |
+
for start in range(0, dataset_size, self.batch_size):
|
| 200 |
+
end = start + self.batch_size
|
| 201 |
+
batch_indices = indices[start:end]
|
| 202 |
+
|
| 203 |
+
# 获取批次数据
|
| 204 |
+
batch_states = states_t[batch_indices]
|
| 205 |
+
batch_scores = scores_t[batch_indices]
|
| 206 |
+
batch_actions = actions_t[batch_indices]
|
| 207 |
+
batch_old_log_probs = old_log_probs_t[batch_indices]
|
| 208 |
+
batch_returns = returns_t[batch_indices]
|
| 209 |
+
batch_advantages = advantages_t[batch_indices]
|
| 210 |
+
batch_valid = valid_actions_t[batch_indices]
|
| 211 |
+
|
| 212 |
+
# 前向传播
|
| 213 |
+
log_probs, values, entropy = self.model.evaluate_actions(
|
| 214 |
+
batch_states, batch_actions, batch_scores, batch_valid
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# 策略损失 (PPO Clip)
|
| 218 |
+
ratio = torch.exp(log_probs - batch_old_log_probs)
|
| 219 |
+
surr1 = ratio * batch_advantages
|
| 220 |
+
surr2 = torch.clamp(ratio, 1 - self.clip_ratio, 1 + self.clip_ratio) * batch_advantages
|
| 221 |
+
policy_loss = -torch.min(surr1, surr2).mean()
|
| 222 |
+
|
| 223 |
+
# 价值损失
|
| 224 |
+
value_loss = F.mse_loss(values.squeeze(), batch_returns)
|
| 225 |
+
|
| 226 |
+
# 总损失
|
| 227 |
+
loss = (
|
| 228 |
+
policy_loss +
|
| 229 |
+
self.value_coef * value_loss -
|
| 230 |
+
self.entropy_coef * entropy.mean()
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
# 反向传播
|
| 234 |
+
self.optimizer.zero_grad()
|
| 235 |
+
loss.backward()
|
| 236 |
+
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
|
| 237 |
+
self.optimizer.step()
|
| 238 |
+
|
| 239 |
+
total_policy_loss += policy_loss.item()
|
| 240 |
+
total_value_loss += value_loss.item()
|
| 241 |
+
total_entropy += entropy.mean().item()
|
| 242 |
+
num_updates += 1
|
| 243 |
+
|
| 244 |
+
# 记录统计
|
| 245 |
+
stats = {
|
| 246 |
+
'policy_loss': total_policy_loss / num_updates,
|
| 247 |
+
'value_loss': total_value_loss / num_updates,
|
| 248 |
+
'entropy': total_entropy / num_updates
|
| 249 |
+
}
|
| 250 |
+
|
| 251 |
+
for key, value in stats.items():
|
| 252 |
+
self.stats[key].append(value)
|
| 253 |
+
|
| 254 |
+
return stats
|
| 255 |
+
|
| 256 |
+
def get_recent_stats(self) -> dict:
|
| 257 |
+
"""获取最近的训练统计"""
|
| 258 |
+
return {key: np.mean(values) for key, values in self.stats.items() if values}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
class TrainingStats:
|
| 262 |
+
"""训练统计记录器"""
|
| 263 |
+
|
| 264 |
+
def __init__(self):
|
| 265 |
+
self.games_played = 0
|
| 266 |
+
self.total_steps = 0
|
| 267 |
+
self.scores = [] # 每局累积分数
|
| 268 |
+
self.situational_scores = [] # 每局平均局面分数
|
| 269 |
+
self.max_tiles = [] # 每局最大砖块
|
| 270 |
+
self.game_lengths = [] # 每局步数
|
| 271 |
+
|
| 272 |
+
# 历史记录用于绘图
|
| 273 |
+
self.score_history = []
|
| 274 |
+
self.situational_history = []
|
| 275 |
+
self.max_tile_history = []
|
| 276 |
+
self.steps_history = []
|
| 277 |
+
|
| 278 |
+
# 最佳记录
|
| 279 |
+
self.best_score = 0
|
| 280 |
+
self.best_max_tile = 0
|
| 281 |
+
|
| 282 |
+
def record_game(
|
| 283 |
+
self,
|
| 284 |
+
score: int,
|
| 285 |
+
situational_score: float,
|
| 286 |
+
max_tile: int,
|
| 287 |
+
steps: int
|
| 288 |
+
) -> None:
|
| 289 |
+
"""记录一局游戏"""
|
| 290 |
+
self.games_played += 1
|
| 291 |
+
self.total_steps += steps
|
| 292 |
+
|
| 293 |
+
self.scores.append(score)
|
| 294 |
+
self.situational_scores.append(situational_score)
|
| 295 |
+
self.max_tiles.append(max_tile)
|
| 296 |
+
self.game_lengths.append(steps)
|
| 297 |
+
|
| 298 |
+
self.score_history.append(score)
|
| 299 |
+
self.situational_history.append(situational_score)
|
| 300 |
+
self.max_tile_history.append(max_tile)
|
| 301 |
+
self.steps_history.append(steps)
|
| 302 |
+
|
| 303 |
+
if score > self.best_score:
|
| 304 |
+
self.best_score = score
|
| 305 |
+
if max_tile > self.best_max_tile:
|
| 306 |
+
self.best_max_tile = max_tile
|
| 307 |
+
|
| 308 |
+
def get_avg_stats(self, window: int = 100) -> dict:
|
| 309 |
+
"""获取平均统计"""
|
| 310 |
+
def avg(lst):
|
| 311 |
+
if not lst:
|
| 312 |
+
return 0
|
| 313 |
+
recent = lst[-window:]
|
| 314 |
+
return sum(recent) / len(recent)
|
| 315 |
+
|
| 316 |
+
return {
|
| 317 |
+
'games_played': self.games_played,
|
| 318 |
+
'total_steps': self.total_steps,
|
| 319 |
+
'avg_score': avg(self.scores),
|
| 320 |
+
'avg_situational': avg(self.situational_scores),
|
| 321 |
+
'avg_max_tile': avg(self.max_tiles),
|
| 322 |
+
'avg_game_length': avg(self.game_lengths),
|
| 323 |
+
'best_score': self.best_score,
|
| 324 |
+
'best_max_tile': self.best_max_tile,
|
| 325 |
+
'recent_scores': self.scores[-10:] if self.scores else [],
|
| 326 |
+
'recent_max_tiles': self.max_tiles[-10:] if self.max_tiles else []
|
| 327 |
+
}
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
if __name__ == "__main__":
|
| 331 |
+
from model import Game2048Transformer
|
| 332 |
+
|
| 333 |
+
# 测试PPO训练器
|
| 334 |
+
device = torch.device("cpu")
|
| 335 |
+
model = Game2048Transformer().to(device)
|
| 336 |
+
trainer = PPOTrainer(model, device=device)
|
| 337 |
+
|
| 338 |
+
# 创建测试数据
|
| 339 |
+
buffer = RolloutBuffer(capacity=1000)
|
| 340 |
+
|
| 341 |
+
for _ in range(100):
|
| 342 |
+
t = Transition(
|
| 343 |
+
state=np.random.randn(4, 4).astype(np.float32),
|
| 344 |
+
scores=np.random.rand(2).astype(np.float32),
|
| 345 |
+
action=np.random.randint(0, 4),
|
| 346 |
+
reward=np.random.randn(),
|
| 347 |
+
next_state=np.random.randn(4, 4).astype(np.float32),
|
| 348 |
+
next_scores=np.random.rand(2).astype(np.float32),
|
| 349 |
+
done=np.random.rand() < 0.1,
|
| 350 |
+
log_prob=np.random.randn(),
|
| 351 |
+
value=np.random.randn(),
|
| 352 |
+
valid_actions=np.ones(4, dtype=bool)
|
| 353 |
+
)
|
| 354 |
+
buffer.push(t)
|
| 355 |
+
|
| 356 |
+
# 测试更新
|
| 357 |
+
stats = trainer.update(buffer)
|
| 358 |
+
print(f"Training stats: {stats}")
|
| 359 |
+
|
| 360 |
+
# 测试统计
|
| 361 |
+
training_stats = TrainingStats()
|
| 362 |
+
for i in range(10):
|
| 363 |
+
training_stats.record_game(
|
| 364 |
+
score=1000 * (i + 1),
|
| 365 |
+
situational_score=50.0 + i * 5,
|
| 366 |
+
max_tile=2 ** (i + 5),
|
| 367 |
+
steps=100 + i * 10
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
print(f"Average stats: {training_stats.get_avg_stats()}")
|
utils.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
工具函数
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import json
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
from typing import Dict, Any, Optional
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def ensure_dir(path: str) -> str:
|
| 14 |
+
"""确保目录存在,不存在则创建"""
|
| 15 |
+
if not os.path.exists(path):
|
| 16 |
+
os.makedirs(path)
|
| 17 |
+
return path
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def save_checkpoint(
|
| 21 |
+
model: torch.nn.Module,
|
| 22 |
+
optimizer: torch.optim.Optimizer,
|
| 23 |
+
epoch: int,
|
| 24 |
+
stats: Dict[str, Any],
|
| 25 |
+
path: str
|
| 26 |
+
) -> None:
|
| 27 |
+
"""保存训练检查点"""
|
| 28 |
+
ensure_dir(os.path.dirname(path))
|
| 29 |
+
torch.save({
|
| 30 |
+
'epoch': epoch,
|
| 31 |
+
'model_state_dict': model.state_dict(),
|
| 32 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 33 |
+
'stats': stats,
|
| 34 |
+
'timestamp': datetime.now().isoformat()
|
| 35 |
+
}, path)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_checkpoint(
|
| 39 |
+
path: str,
|
| 40 |
+
model: torch.nn.Module,
|
| 41 |
+
optimizer: Optional[torch.optim.Optimizer] = None,
|
| 42 |
+
device: str = 'cpu'
|
| 43 |
+
) -> Dict[str, Any]:
|
| 44 |
+
"""加载训练检查点"""
|
| 45 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 46 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 47 |
+
|
| 48 |
+
if optimizer is not None and 'optimizer_state_dict' in checkpoint:
|
| 49 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 50 |
+
|
| 51 |
+
return checkpoint
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def save_training_log(log_data: Dict[str, Any], path: str) -> None:
|
| 55 |
+
"""保存训练日志"""
|
| 56 |
+
ensure_dir(os.path.dirname(path))
|
| 57 |
+
|
| 58 |
+
# 读取现有日志
|
| 59 |
+
if os.path.exists(path):
|
| 60 |
+
with open(path, 'r', encoding='utf-8') as f:
|
| 61 |
+
try:
|
| 62 |
+
logs = json.load(f)
|
| 63 |
+
except json.JSONDecodeError:
|
| 64 |
+
logs = []
|
| 65 |
+
else:
|
| 66 |
+
logs = []
|
| 67 |
+
|
| 68 |
+
# 添加新记录
|
| 69 |
+
log_data['timestamp'] = datetime.now().isoformat()
|
| 70 |
+
logs.append(log_data)
|
| 71 |
+
|
| 72 |
+
# 保存
|
| 73 |
+
with open(path, 'w', encoding='utf-8') as f:
|
| 74 |
+
json.dump(logs, f, indent=2, ensure_ascii=False)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def format_time(seconds: float) -> str:
|
| 78 |
+
"""格式化时间"""
|
| 79 |
+
if seconds < 60:
|
| 80 |
+
return f'{seconds:.1f}s'
|
| 81 |
+
elif seconds < 3600:
|
| 82 |
+
minutes = seconds / 60
|
| 83 |
+
return f'{minutes:.1f}m'
|
| 84 |
+
else:
|
| 85 |
+
hours = seconds / 3600
|
| 86 |
+
return f'{hours:.1f}h'
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def format_number(num: int) -> str:
|
| 90 |
+
"""格式化数字(添加逗号分隔)"""
|
| 91 |
+
return f'{num:,}'
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def calculate_ema(values: list, alpha: float = 0.1) -> list:
|
| 95 |
+
"""计算指数移动平均"""
|
| 96 |
+
if not values:
|
| 97 |
+
return []
|
| 98 |
+
|
| 99 |
+
ema = [values[0]]
|
| 100 |
+
for value in values[1:]:
|
| 101 |
+
ema.append(alpha * value + (1 - alpha) * ema[-1])
|
| 102 |
+
return ema
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def get_tile_color(value: int) -> str:
|
| 106 |
+
"""获取砖块颜色"""
|
| 107 |
+
colors = {
|
| 108 |
+
0: '#cdc1b4',
|
| 109 |
+
2: '#eee4da',
|
| 110 |
+
4: '#ede0c8',
|
| 111 |
+
8: '#f2b179',
|
| 112 |
+
16: '#f59563',
|
| 113 |
+
32: '#f67c5f',
|
| 114 |
+
64: '#f65e3b',
|
| 115 |
+
128: '#edcf72',
|
| 116 |
+
256: '#edcc61',
|
| 117 |
+
512: '#edc850',
|
| 118 |
+
1024: '#edc53f',
|
| 119 |
+
2048: '#edc22e',
|
| 120 |
+
}
|
| 121 |
+
return colors.get(value, '#3c3a32')
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_text_color(value: int) -> str:
|
| 125 |
+
"""获取文字颜色"""
|
| 126 |
+
if value <= 4:
|
| 127 |
+
return '#776e65'
|
| 128 |
+
return '#f9f6f2'
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
class EarlyStopping:
|
| 132 |
+
"""早停机制"""
|
| 133 |
+
|
| 134 |
+
def __init__(
|
| 135 |
+
self,
|
| 136 |
+
patience: int = 100,
|
| 137 |
+
min_delta: float = 0.01,
|
| 138 |
+
mode: str = 'max'
|
| 139 |
+
):
|
| 140 |
+
"""
|
| 141 |
+
Args:
|
| 142 |
+
patience: 容忍的epoch数
|
| 143 |
+
min_delta: 最小改进
|
| 144 |
+
mode: 'max' 或 'min'
|
| 145 |
+
"""
|
| 146 |
+
self.patience = patience
|
| 147 |
+
self.min_delta = min_delta
|
| 148 |
+
self.mode = mode
|
| 149 |
+
self.counter = 0
|
| 150 |
+
self.best_value = None
|
| 151 |
+
self.should_stop = False
|
| 152 |
+
|
| 153 |
+
def __call__(self, value: float) -> bool:
|
| 154 |
+
"""
|
| 155 |
+
检查是否应该停止
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
value: 当前值
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
是否应该停止
|
| 162 |
+
"""
|
| 163 |
+
if self.best_value is None:
|
| 164 |
+
self.best_value = value
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
if self.mode == 'max':
|
| 168 |
+
improved = value > self.best_value + self.min_delta
|
| 169 |
+
else:
|
| 170 |
+
improved = value < self.best_value - self.min_delta
|
| 171 |
+
|
| 172 |
+
if improved:
|
| 173 |
+
self.best_value = value
|
| 174 |
+
self.counter = 0
|
| 175 |
+
else:
|
| 176 |
+
self.counter += 1
|
| 177 |
+
if self.counter >= self.patience:
|
| 178 |
+
self.should_stop = True
|
| 179 |
+
|
| 180 |
+
return self.should_stop
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class MetricTracker:
|
| 184 |
+
"""指标跟踪器"""
|
| 185 |
+
|
| 186 |
+
def __init__(self, window_size: int = 100):
|
| 187 |
+
self.window_size = window_size
|
| 188 |
+
self.metrics = {}
|
| 189 |
+
|
| 190 |
+
def update(self, name: str, value: float) -> None:
|
| 191 |
+
"""更新指标"""
|
| 192 |
+
if name not in self.metrics:
|
| 193 |
+
self.metrics[name] = []
|
| 194 |
+
self.metrics[name].append(value)
|
| 195 |
+
|
| 196 |
+
# 保持窗口大小
|
| 197 |
+
if len(self.metrics[name]) > self.window_size:
|
| 198 |
+
self.metrics[name] = self.metrics[name][-self.window_size:]
|
| 199 |
+
|
| 200 |
+
def get_mean(self, name: str) -> float:
|
| 201 |
+
"""获取平均值"""
|
| 202 |
+
if name not in self.metrics or not self.metrics[name]:
|
| 203 |
+
return 0.0
|
| 204 |
+
return np.mean(self.metrics[name])
|
| 205 |
+
|
| 206 |
+
def get_std(self, name: str) -> float:
|
| 207 |
+
"""获取标准差"""
|
| 208 |
+
if name not in self.metrics or len(self.metrics[name]) < 2:
|
| 209 |
+
return 0.0
|
| 210 |
+
return np.std(self.metrics[name])
|
| 211 |
+
|
| 212 |
+
def get_all_means(self) -> Dict[str, float]:
|
| 213 |
+
"""获取所有指标的平均值"""
|
| 214 |
+
return {name: self.get_mean(name) for name in self.metrics}
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def set_seed(seed: int) -> None:
|
| 218 |
+
"""设置随机种子"""
|
| 219 |
+
import random
|
| 220 |
+
random.seed(seed)
|
| 221 |
+
np.random.seed(seed)
|
| 222 |
+
torch.manual_seed(seed)
|
| 223 |
+
if torch.cuda.is_available():
|
| 224 |
+
torch.cuda.manual_seed(seed)
|
| 225 |
+
torch.cuda.manual_seed_all(seed)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_device() -> str:
|
| 229 |
+
"""获取可用设备"""
|
| 230 |
+
if torch.cuda.is_available():
|
| 231 |
+
return 'cuda'
|
| 232 |
+
return 'cpu'
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def count_parameters(model: torch.nn.Module) -> int:
|
| 236 |
+
"""计算模型参数数量"""
|
| 237 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def print_model_info(model: torch.nn.Module) -> None:
|
| 241 |
+
"""打印模型信息"""
|
| 242 |
+
total_params = count_parameters(model)
|
| 243 |
+
print(f"模型参数数量: {format_number(total_params)}")
|
| 244 |
+
print(f"模型大小: {total_params * 4 / 1024 / 1024:.2f} MB (float32)")
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def export_to_onnx(
|
| 248 |
+
model: torch.nn.Module,
|
| 249 |
+
path: str,
|
| 250 |
+
input_size: tuple = (1, 4, 4)
|
| 251 |
+
) -> None:
|
| 252 |
+
"""导出模型到ONNX格式"""
|
| 253 |
+
model.eval()
|
| 254 |
+
dummy_input = torch.randn(*input_size)
|
| 255 |
+
dummy_scores = torch.randn(1, 2)
|
| 256 |
+
dummy_mask = torch.ones(1, 4, dtype=torch.bool)
|
| 257 |
+
|
| 258 |
+
ensure_dir(os.path.dirname(path))
|
| 259 |
+
torch.onnx.export(
|
| 260 |
+
model,
|
| 261 |
+
(dummy_input, dummy_scores, dummy_mask),
|
| 262 |
+
path,
|
| 263 |
+
input_names=['board', 'scores', 'mask'],
|
| 264 |
+
output_names=['policy', 'value'],
|
| 265 |
+
dynamic_axes={
|
| 266 |
+
'board': {0: 'batch_size'},
|
| 267 |
+
'scores': {0: 'batch_size'},
|
| 268 |
+
'mask': {0: 'batch_size'}
|
| 269 |
+
}
|
| 270 |
+
)
|
| 271 |
+
print(f"模型已导出到: {path}")
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
if __name__ == "__main__":
|
| 275 |
+
# 测试工具函数
|
| 276 |
+
print("Testing utility functions...")
|
| 277 |
+
|
| 278 |
+
# 测试时间格式化
|
| 279 |
+
print(f"Format time: {format_time(45.5)}, {format_time(125.3)}, {format_time(3661)}")
|
| 280 |
+
|
| 281 |
+
# 测试数字格式化
|
| 282 |
+
print(f"Format number: {format_number(1234567)}")
|
| 283 |
+
|
| 284 |
+
# 测试EMA
|
| 285 |
+
values = [1, 2, 3, 4, 5]
|
| 286 |
+
print(f"EMA: {calculate_ema(values)}")
|
| 287 |
+
|
| 288 |
+
# 测试早停
|
| 289 |
+
early_stop = EarlyStopping(patience=3, min_delta=0.1)
|
| 290 |
+
scores = [10, 11, 12, 12, 12, 12, 12]
|
| 291 |
+
for i, score in enumerate(scores):
|
| 292 |
+
stop = early_stop(score)
|
| 293 |
+
print(f"Epoch {i}: score={score}, stop={stop}")
|
| 294 |
+
|
| 295 |
+
print("All tests passed!")
|