Spaces:
Running
Running
Upload 3 files
Browse files- 2048的网页实现.py +293 -0
- game2048.py +200 -0
- main.py +837 -0
2048的网页实现.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from game2048 import Game2048
|
| 8 |
+
|
| 9 |
+
# 创建游戏实例
|
| 10 |
+
game = Game2048(size=4)
|
| 11 |
+
|
| 12 |
+
# 方块颜色映射(根据数字值)
|
| 13 |
+
TILE_COLORS = {
|
| 14 |
+
0: "#cdc1b4", # 空白格子
|
| 15 |
+
2: "#eee4da", # 2
|
| 16 |
+
4: "#ede0c8", # 4
|
| 17 |
+
8: "#f2b179", # 8
|
| 18 |
+
16: "#f59563", # 16
|
| 19 |
+
32: "#f67c5f", # 32
|
| 20 |
+
64: "#f65e3b", # 64
|
| 21 |
+
128: "#edcf72", # 128
|
| 22 |
+
256: "#edcc61", # 256
|
| 23 |
+
512: "#edc850", # 512
|
| 24 |
+
1024: "#edc53f", # 1024
|
| 25 |
+
2048: "#edc22e", # 2048
|
| 26 |
+
4096: "#3c3a32", # 4096+
|
| 27 |
+
}
|
| 28 |
+
|
| 29 |
+
# 文本颜色映射(根据背景深浅)
|
| 30 |
+
TEXT_COLORS = {
|
| 31 |
+
0: "#776e65", # 空白格子
|
| 32 |
+
2: "#776e65", # 2
|
| 33 |
+
4: "#776e65", # 4
|
| 34 |
+
8: "#f9f6f2", # 8+
|
| 35 |
+
16: "#f9f6f2", # 16+
|
| 36 |
+
32: "#f9f6f2", # 32+
|
| 37 |
+
64: "#f9f6f2", # 64+
|
| 38 |
+
128: "#f9f6f2", # 128+
|
| 39 |
+
256: "#f9f6f2", # 256+
|
| 40 |
+
512: "#f9f6f2", # 512+
|
| 41 |
+
1024: "#f9f6f2", # 1024+
|
| 42 |
+
2048: "#f9f6f2", # 2048+
|
| 43 |
+
4096: "#f9f6f2", # 4096+
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
# 定义DQN网络结构(与训练时相同)
|
| 47 |
+
class DQN(nn.Module):
|
| 48 |
+
def __init__(self, input_channels, output_size):
|
| 49 |
+
super(DQN, self).__init__()
|
| 50 |
+
self.input_channels = input_channels
|
| 51 |
+
|
| 52 |
+
# 卷积层
|
| 53 |
+
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
|
| 54 |
+
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
| 55 |
+
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
| 56 |
+
|
| 57 |
+
# Dueling DQN架构
|
| 58 |
+
# 价值流
|
| 59 |
+
self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
|
| 60 |
+
self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
|
| 61 |
+
self.value_fc2 = nn.Linear(128, 1)
|
| 62 |
+
|
| 63 |
+
# 优势流
|
| 64 |
+
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
|
| 65 |
+
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
|
| 66 |
+
self.advantage_fc2 = nn.Linear(128, output_size)
|
| 67 |
+
|
| 68 |
+
def forward(self, x):
|
| 69 |
+
x = F.relu(self.conv1(x))
|
| 70 |
+
x = F.relu(self.conv2(x))
|
| 71 |
+
x = F.relu(self.conv3(x))
|
| 72 |
+
|
| 73 |
+
# 价值流
|
| 74 |
+
value = F.relu(self.value_conv(x))
|
| 75 |
+
value = value.view(value.size(0), -1)
|
| 76 |
+
value = F.relu(self.value_fc1(value))
|
| 77 |
+
value = self.value_fc2(value)
|
| 78 |
+
|
| 79 |
+
# 优势流
|
| 80 |
+
advantage = F.relu(self.advantage_conv(x))
|
| 81 |
+
advantage = advantage.view(advantage.size(0), -1)
|
| 82 |
+
advantage = F.relu(self.advantage_fc1(advantage))
|
| 83 |
+
advantage = self.advantage_fc2(advantage)
|
| 84 |
+
|
| 85 |
+
# 合并价值流和优势流
|
| 86 |
+
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
| 87 |
+
return q_values
|
| 88 |
+
|
| 89 |
+
# 加载模型
|
| 90 |
+
def load_model(model_path):
|
| 91 |
+
model = DQN(4, 4) # 输入通道4,输出动作4个
|
| 92 |
+
try:
|
| 93 |
+
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))
|
| 94 |
+
model.load_state_dict(checkpoint['policy_net_state_dict'])
|
| 95 |
+
model.eval()
|
| 96 |
+
print("模型加载成功")
|
| 97 |
+
return model
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"模型加载失败: {e}")
|
| 100 |
+
return None
|
| 101 |
+
|
| 102 |
+
# 尝试加载模型
|
| 103 |
+
model_path = "models/dqn_2048_best_tile.pth"
|
| 104 |
+
model = load_model(model_path)
|
| 105 |
+
|
| 106 |
+
def render_board(board):
|
| 107 |
+
html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>"
|
| 108 |
+
html += "<table style='border-spacing:10px; border-collapse:separate;'>"
|
| 109 |
+
|
| 110 |
+
for i in range(game.size):
|
| 111 |
+
html += "<tr>"
|
| 112 |
+
for j in range(game.size):
|
| 113 |
+
value = board[i][j]
|
| 114 |
+
color = TILE_COLORS.get(value, "#3c3a32") # 默认深色
|
| 115 |
+
text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色
|
| 116 |
+
font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px"
|
| 117 |
+
|
| 118 |
+
html += f"""
|
| 119 |
+
<td style='background-color:{color};
|
| 120 |
+
width:80px; height:80px;
|
| 121 |
+
border-radius:4px;
|
| 122 |
+
text-align:center;
|
| 123 |
+
font-weight:bold;
|
| 124 |
+
font-size:{font_size};
|
| 125 |
+
color:{text_color};'>
|
| 126 |
+
{value if value > 0 else ''}
|
| 127 |
+
</td>
|
| 128 |
+
"""
|
| 129 |
+
html += "</tr>"
|
| 130 |
+
|
| 131 |
+
html += "</table></div>"
|
| 132 |
+
return html
|
| 133 |
+
|
| 134 |
+
def make_move(direction):
|
| 135 |
+
"""执行移动操作并更新界面"""
|
| 136 |
+
direction_names = ["上", "右", "下", "左"]
|
| 137 |
+
|
| 138 |
+
# 执行移动
|
| 139 |
+
new_board, game_over = game.move(direction)
|
| 140 |
+
|
| 141 |
+
# 渲染棋盘
|
| 142 |
+
board_html = render_board(new_board)
|
| 143 |
+
|
| 144 |
+
# 更新状态信息
|
| 145 |
+
status = f"<b>移动方向:</b> {direction_names[direction]}"
|
| 146 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
| 147 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
| 148 |
+
|
| 149 |
+
if game.game_over:
|
| 150 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
| 151 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
| 152 |
+
|
| 153 |
+
return board_html, status
|
| 154 |
+
|
| 155 |
+
def reset_game():
|
| 156 |
+
"""重置游戏"""
|
| 157 |
+
global game
|
| 158 |
+
game = Game2048(size=4)
|
| 159 |
+
board = game.reset()
|
| 160 |
+
|
| 161 |
+
# 渲染棋盘
|
| 162 |
+
board_html = render_board(board)
|
| 163 |
+
|
| 164 |
+
# 初始状态信息
|
| 165 |
+
status = "<b>游戏已重置!</b>"
|
| 166 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
| 167 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
| 168 |
+
|
| 169 |
+
return board_html, status
|
| 170 |
+
|
| 171 |
+
def ai_move():
|
| 172 |
+
"""使用AI模型进行一步移动"""
|
| 173 |
+
if model is None:
|
| 174 |
+
return render_board(game.board), "<b>错误:</b> 未加载AI模型"
|
| 175 |
+
|
| 176 |
+
# 获取当前状态
|
| 177 |
+
state = game.get_state()
|
| 178 |
+
|
| 179 |
+
# 获取有效移动
|
| 180 |
+
valid_moves = game.get_valid_moves()
|
| 181 |
+
if not valid_moves:
|
| 182 |
+
return render_board(game.board), "<b>游戏结束!</b> 没有有效移动"
|
| 183 |
+
|
| 184 |
+
# 转换状态为模型输入
|
| 185 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0)
|
| 186 |
+
|
| 187 |
+
# 模型预测
|
| 188 |
+
with torch.no_grad():
|
| 189 |
+
q_values = model(state_tensor).numpy().flatten()
|
| 190 |
+
|
| 191 |
+
# 只考虑有效动作
|
| 192 |
+
valid_q_values = np.full(4, -np.inf)
|
| 193 |
+
for move in valid_moves:
|
| 194 |
+
valid_q_values[move] = q_values[move]
|
| 195 |
+
|
| 196 |
+
# 选择最佳动作
|
| 197 |
+
action = np.argmax(valid_q_values)
|
| 198 |
+
|
| 199 |
+
# 执行移动
|
| 200 |
+
direction_names = ["上", "右", "下", "左"]
|
| 201 |
+
new_board, game_over = game.move(action)
|
| 202 |
+
|
| 203 |
+
# 渲染棋盘
|
| 204 |
+
board_html = render_board(new_board)
|
| 205 |
+
|
| 206 |
+
# 更新状态信息
|
| 207 |
+
status = f"<b>AI移动方向:</b> {direction_names[action]}"
|
| 208 |
+
status += f"<br><b>当前分数:</b> {game.score}"
|
| 209 |
+
status += f"<br><b>最大方块:</b> {np.max(game.board)}"
|
| 210 |
+
|
| 211 |
+
if game.game_over:
|
| 212 |
+
status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>"
|
| 213 |
+
status += f"<br><b>最终分数:</b> {game.score}"
|
| 214 |
+
|
| 215 |
+
return board_html, status
|
| 216 |
+
|
| 217 |
+
# 创建Gradio界面
|
| 218 |
+
with gr.Blocks(title="2048游戏", theme="soft") as demo:
|
| 219 |
+
gr.Markdown("# 🎮 2048游戏")
|
| 220 |
+
gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!")
|
| 221 |
+
with gr.Row():
|
| 222 |
+
with gr.Column(scale=2):
|
| 223 |
+
board_html = gr.HTML(render_board(game.board))
|
| 224 |
+
with gr.Row(visible=False):
|
| 225 |
+
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
|
| 226 |
+
with gr.Column():
|
| 227 |
+
gr.Markdown("## 手动操作")
|
| 228 |
+
with gr.Row():
|
| 229 |
+
gr.Button("上 ↑", elem_id="up-btn").click(
|
| 230 |
+
fn=lambda: make_move(0),
|
| 231 |
+
outputs=[board_html, status_display]
|
| 232 |
+
)
|
| 233 |
+
gr.Button("左 ←", elem_id="left-btn").click(
|
| 234 |
+
fn=lambda: make_move(3),
|
| 235 |
+
outputs=[board_html, status_display]
|
| 236 |
+
)
|
| 237 |
+
with gr.Row():
|
| 238 |
+
gr.Button("下 ↓", elem_id="down-btn").click(
|
| 239 |
+
fn=lambda: make_move(2),
|
| 240 |
+
outputs=[board_html, status_display]
|
| 241 |
+
)
|
| 242 |
+
gr.Button("右 →", elem_id="right-btn").click(
|
| 243 |
+
fn=lambda: make_move(1),
|
| 244 |
+
outputs=[board_html, status_display]
|
| 245 |
+
)
|
| 246 |
+
with gr.Row():
|
| 247 |
+
gr.Button("🔄 重置游戏", elem_id="reset-btn").click(
|
| 248 |
+
fn=reset_game,
|
| 249 |
+
outputs=[board_html, status_display]
|
| 250 |
+
)
|
| 251 |
+
with gr.Row():
|
| 252 |
+
status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2")
|
| 253 |
+
with gr.Column():
|
| 254 |
+
gr.Markdown("## AI操作")
|
| 255 |
+
gr.Button("🤖 AI移动一步", elem_id="ai-btn").click(
|
| 256 |
+
fn=ai_move,
|
| 257 |
+
outputs=[board_html, status_display]
|
| 258 |
+
)
|
| 259 |
+
gr.Markdown("基于DQN神经网络提供支持")
|
| 260 |
+
|
| 261 |
+
# 添加键盘快捷键支持
|
| 262 |
+
demo.load(
|
| 263 |
+
fn=None,
|
| 264 |
+
inputs=None,
|
| 265 |
+
outputs=None,
|
| 266 |
+
js="""() => {
|
| 267 |
+
document.addEventListener('keydown', function(e) {
|
| 268 |
+
if (e.key === 'ArrowUp') {
|
| 269 |
+
document.getElementById('up-btn').click();
|
| 270 |
+
} else if (e.key === 'ArrowRight') {
|
| 271 |
+
document.getElementById('right-btn').click();
|
| 272 |
+
} else if (e.key === 'ArrowDown') {
|
| 273 |
+
document.getElementById('down-btn').click();
|
| 274 |
+
} else if (e.key === 'ArrowLeft') {
|
| 275 |
+
document.getElementById('left-btn').click();
|
| 276 |
+
} else if (e.key === 'r' || e.key === 'R') {
|
| 277 |
+
document.getElementById('reset-btn').click();
|
| 278 |
+
} else if (e.key === 'a' || e.key === 'A') {
|
| 279 |
+
document.getElementById('ai-btn').click();
|
| 280 |
+
}
|
| 281 |
+
});
|
| 282 |
+
}"""
|
| 283 |
+
)
|
| 284 |
+
gr.Markdown("### 📚 使用说明")
|
| 285 |
+
gr.Markdown("1. 使用方向键或下方的按钮移动方块。")
|
| 286 |
+
gr.Markdown("2. 相同数字的方块相撞时会合并。")
|
| 287 |
+
gr.Markdown("3. 快捷键说明:上/下/左/右键移动方块,R键重置游戏,A键AI移动一步。")
|
| 288 |
+
gr.Markdown("4. 点击 '🤖 AI移动一步' 按钮可以使用AI模型进行一步移动。")
|
| 289 |
+
gr.Markdown("5. 游戏结束后,会显示最终分数和最大方块。")
|
| 290 |
+
|
| 291 |
+
# 启动界面
|
| 292 |
+
if __name__ == "__main__":
|
| 293 |
+
demo.launch(share=True)
|
game2048.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import random
|
| 3 |
+
|
| 4 |
+
class Game2048:
|
| 5 |
+
def __init__(self, size=4):
|
| 6 |
+
self.size = size
|
| 7 |
+
self.reset()
|
| 8 |
+
|
| 9 |
+
def reset(self):
|
| 10 |
+
"""重置游戏状态"""
|
| 11 |
+
self.board = np.zeros((self.size, self.size), dtype=np.int32)
|
| 12 |
+
self.score = 0
|
| 13 |
+
self.add_tile()
|
| 14 |
+
self.add_tile()
|
| 15 |
+
self.game_over = False
|
| 16 |
+
return self.board.copy()
|
| 17 |
+
|
| 18 |
+
def add_tile(self):
|
| 19 |
+
"""在随机空位置添加新方块(90%概率为2,10%概率为4)"""
|
| 20 |
+
empty_cells = []
|
| 21 |
+
for i in range(self.size):
|
| 22 |
+
for j in range(self.size):
|
| 23 |
+
if self.board[i][j] == 0:
|
| 24 |
+
empty_cells.append((i, j))
|
| 25 |
+
|
| 26 |
+
if empty_cells:
|
| 27 |
+
i, j = random.choice(empty_cells)
|
| 28 |
+
self.board[i][j] = 2 if random.random() < 0.9 else 4
|
| 29 |
+
|
| 30 |
+
def move(self, direction):
|
| 31 |
+
"""
|
| 32 |
+
执行移动操作
|
| 33 |
+
0: 上, 1: 右, 2: 下, 3: 左
|
| 34 |
+
返回: (新棋盘状态, 游戏是否结束)
|
| 35 |
+
"""
|
| 36 |
+
moved = False
|
| 37 |
+
# 根据方向执行移动
|
| 38 |
+
if direction == 0: # 上
|
| 39 |
+
for j in range(self.size):
|
| 40 |
+
column = self.board[:, j].copy()
|
| 41 |
+
new_column, moved_col = self.slide(column)
|
| 42 |
+
if moved_col:
|
| 43 |
+
moved = True
|
| 44 |
+
self.board[:, j] = new_column
|
| 45 |
+
|
| 46 |
+
elif direction == 1: # 右
|
| 47 |
+
for i in range(self.size):
|
| 48 |
+
row = self.board[i, :].copy()[::-1]
|
| 49 |
+
new_row, moved_row = self.slide(row)
|
| 50 |
+
if moved_row:
|
| 51 |
+
moved = True
|
| 52 |
+
self.board[i, :] = new_row[::-1]
|
| 53 |
+
|
| 54 |
+
elif direction == 2: # 下
|
| 55 |
+
for j in range(self.size):
|
| 56 |
+
column = self.board[::-1, j].copy()
|
| 57 |
+
new_column, moved_col = self.slide(column)
|
| 58 |
+
if moved_col:
|
| 59 |
+
moved = True
|
| 60 |
+
self.board[:, j] = new_column[::-1]
|
| 61 |
+
|
| 62 |
+
elif direction == 3: # 左
|
| 63 |
+
for i in range(self.size):
|
| 64 |
+
row = self.board[i, :].copy()
|
| 65 |
+
new_row, moved_row = self.slide(row)
|
| 66 |
+
if moved_row:
|
| 67 |
+
moved = True
|
| 68 |
+
self.board[i, :] = new_row
|
| 69 |
+
|
| 70 |
+
# 如果发生了移动,添加新方块并检查游戏结束
|
| 71 |
+
if moved:
|
| 72 |
+
self.add_tile()
|
| 73 |
+
self.check_game_over()
|
| 74 |
+
|
| 75 |
+
return self.board.copy(), self.game_over
|
| 76 |
+
|
| 77 |
+
def slide(self, line):
|
| 78 |
+
"""处理单行/列的移动和合并逻辑"""
|
| 79 |
+
non_zero = line[line != 0]
|
| 80 |
+
new_line = np.zeros_like(line)
|
| 81 |
+
idx = 0
|
| 82 |
+
score_inc = 0
|
| 83 |
+
moved = False
|
| 84 |
+
|
| 85 |
+
# 检查是否移动
|
| 86 |
+
if not np.array_equal(non_zero, line[:len(non_zero)]):
|
| 87 |
+
moved = True
|
| 88 |
+
|
| 89 |
+
# 合并相同数字
|
| 90 |
+
i = 0
|
| 91 |
+
while i < len(non_zero):
|
| 92 |
+
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
|
| 93 |
+
new_val = non_zero[i] * 2
|
| 94 |
+
new_line[idx] = new_val
|
| 95 |
+
score_inc += new_val
|
| 96 |
+
i += 2
|
| 97 |
+
idx += 1
|
| 98 |
+
else:
|
| 99 |
+
new_line[idx] = non_zero[i]
|
| 100 |
+
i += 1
|
| 101 |
+
idx += 1
|
| 102 |
+
|
| 103 |
+
self.score += score_inc
|
| 104 |
+
return new_line, moved or (score_inc > 0)
|
| 105 |
+
|
| 106 |
+
def check_game_over(self):
|
| 107 |
+
"""检查游戏是否结束"""
|
| 108 |
+
# 检查是否还有空格子
|
| 109 |
+
if np.any(self.board == 0):
|
| 110 |
+
self.game_over = False
|
| 111 |
+
return
|
| 112 |
+
|
| 113 |
+
# 检查水平和垂直方向是否有可合并的方块
|
| 114 |
+
for i in range(self.size):
|
| 115 |
+
for j in range(self.size - 1):
|
| 116 |
+
if self.board[i][j] == self.board[i][j+1]:
|
| 117 |
+
self.game_over = False
|
| 118 |
+
return
|
| 119 |
+
|
| 120 |
+
for j in range(self.size):
|
| 121 |
+
for i in range(self.size - 1):
|
| 122 |
+
if self.board[i][j] == self.board[i+1][j]:
|
| 123 |
+
self.game_over = False
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
self.game_over = True
|
| 127 |
+
|
| 128 |
+
def get_valid_moves(self):
|
| 129 |
+
"""获取当前所有有效移动方向"""
|
| 130 |
+
valid_moves = []
|
| 131 |
+
|
| 132 |
+
# 检查上移是否有效
|
| 133 |
+
for j in range(self.size):
|
| 134 |
+
column = self.board[:, j].copy()
|
| 135 |
+
new_column, _ = self.slide(column)
|
| 136 |
+
if not np.array_equal(new_column, self.board[:, j]):
|
| 137 |
+
valid_moves.append(0)
|
| 138 |
+
break
|
| 139 |
+
|
| 140 |
+
# 检查右移是否有效
|
| 141 |
+
for i in range(self.size):
|
| 142 |
+
row = self.board[i, :].copy()[::-1]
|
| 143 |
+
new_row, _ = self.slide(row)
|
| 144 |
+
if not np.array_equal(new_row[::-1], self.board[i, :]):
|
| 145 |
+
valid_moves.append(1)
|
| 146 |
+
break
|
| 147 |
+
|
| 148 |
+
# 检查下移是否有效
|
| 149 |
+
for j in range(self.size):
|
| 150 |
+
column = self.board[::-1, j].copy()
|
| 151 |
+
new_column, _ = self.slide(column)
|
| 152 |
+
if not np.array_equal(new_column[::-1], self.board[:, j]):
|
| 153 |
+
valid_moves.append(2)
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
# 检查左移是否有效
|
| 157 |
+
for i in range(self.size):
|
| 158 |
+
row = self.board[i, :].copy()
|
| 159 |
+
new_row, _ = self.slide(row)
|
| 160 |
+
if not np.array_equal(new_row, self.board[i, :]):
|
| 161 |
+
valid_moves.append(3)
|
| 162 |
+
break
|
| 163 |
+
|
| 164 |
+
return valid_moves
|
| 165 |
+
|
| 166 |
+
def get_state(self):
|
| 167 |
+
"""获取当前游戏状态表示(用于AI模型)"""
|
| 168 |
+
# 创建4个通道的状态表示
|
| 169 |
+
state = np.zeros((4, self.size, self.size), dtype=np.float32)
|
| 170 |
+
|
| 171 |
+
# 通道0: 当前方块值的对数(归一化)
|
| 172 |
+
for i in range(self.size):
|
| 173 |
+
for j in range(self.size):
|
| 174 |
+
if self.board[i][j] > 0:
|
| 175 |
+
state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
|
| 176 |
+
|
| 177 |
+
# 通道1: 空格子指示器
|
| 178 |
+
state[1] = (self.board == 0).astype(np.float32)
|
| 179 |
+
|
| 180 |
+
# 通道2: 可合并的邻居指示器
|
| 181 |
+
for i in range(self.size):
|
| 182 |
+
for j in range(self.size):
|
| 183 |
+
if self.board[i][j] > 0:
|
| 184 |
+
# 检查右侧
|
| 185 |
+
if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
|
| 186 |
+
state[2, i, j] = 1.0
|
| 187 |
+
state[2, i, j+1] = 1.0
|
| 188 |
+
# 检查下方
|
| 189 |
+
if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
|
| 190 |
+
state[2, i, j] = 1.0
|
| 191 |
+
state[2, i+1, j] = 1.0
|
| 192 |
+
|
| 193 |
+
# 通道3: 最大值位置(归一化)
|
| 194 |
+
max_value = np.max(self.board)
|
| 195 |
+
if max_value > 0:
|
| 196 |
+
max_positions = np.argwhere(self.board == max_value)
|
| 197 |
+
for pos in max_positions:
|
| 198 |
+
state[3, pos[0], pos[1]] = 1.0
|
| 199 |
+
|
| 200 |
+
return state
|
main.py
ADDED
|
@@ -0,0 +1,837 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import random
|
| 7 |
+
import os
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import warnings
|
| 11 |
+
|
| 12 |
+
if torch.cuda.is_available():
|
| 13 |
+
device=torch.device("cuda")
|
| 14 |
+
elif torch.xpu.is_available():
|
| 15 |
+
device=torch.device("xpu")
|
| 16 |
+
else:
|
| 17 |
+
device=torch.device("cpu")
|
| 18 |
+
print(f"Using device: {device}")
|
| 19 |
+
|
| 20 |
+
# 2048游戏环境(改进版)
|
| 21 |
+
class Game2048:
|
| 22 |
+
def __init__(self, size=4):
|
| 23 |
+
self.size = size
|
| 24 |
+
self.reset()
|
| 25 |
+
|
| 26 |
+
def reset(self):
|
| 27 |
+
self.board = np.zeros((self.size, self.size), dtype=np.int32)
|
| 28 |
+
self.score = 0
|
| 29 |
+
self.prev_score = 0
|
| 30 |
+
self.add_tile()
|
| 31 |
+
self.add_tile()
|
| 32 |
+
self.game_over = False
|
| 33 |
+
return self.get_state()
|
| 34 |
+
|
| 35 |
+
def add_tile(self):
|
| 36 |
+
empty_cells = []
|
| 37 |
+
for i in range(self.size):
|
| 38 |
+
for j in range(self.size):
|
| 39 |
+
if self.board[i][j] == 0:
|
| 40 |
+
empty_cells.append((i, j))
|
| 41 |
+
|
| 42 |
+
if empty_cells:
|
| 43 |
+
i, j = random.choice(empty_cells)
|
| 44 |
+
self.board[i][j] = 2 if random.random() < 0.9 else 4
|
| 45 |
+
|
| 46 |
+
def move(self, direction):
|
| 47 |
+
# 0: 上, 1: 右, 2: 下, 3: 左
|
| 48 |
+
moved = False
|
| 49 |
+
original_board = self.board.copy()
|
| 50 |
+
old_score = self.score
|
| 51 |
+
|
| 52 |
+
# 根据方向执行移动
|
| 53 |
+
if direction == 0: # 上
|
| 54 |
+
for j in range(self.size):
|
| 55 |
+
column = self.board[:, j].copy()
|
| 56 |
+
new_column, moved_col = self.slide(column)
|
| 57 |
+
if moved_col:
|
| 58 |
+
moved = True
|
| 59 |
+
self.board[:, j] = new_column
|
| 60 |
+
|
| 61 |
+
elif direction == 1: # 右
|
| 62 |
+
for i in range(self.size):
|
| 63 |
+
row = self.board[i, :].copy()[::-1]
|
| 64 |
+
new_row, moved_row = self.slide(row)
|
| 65 |
+
if moved_row:
|
| 66 |
+
moved = True
|
| 67 |
+
self.board[i, :] = new_row[::-1]
|
| 68 |
+
|
| 69 |
+
elif direction == 2: # 下
|
| 70 |
+
for j in range(self.size):
|
| 71 |
+
column = self.board[::-1, j].copy()
|
| 72 |
+
new_column, moved_col = self.slide(column)
|
| 73 |
+
if moved_col:
|
| 74 |
+
moved = True
|
| 75 |
+
self.board[:, j] = new_column[::-1]
|
| 76 |
+
|
| 77 |
+
elif direction == 3: # 左
|
| 78 |
+
for i in range(self.size):
|
| 79 |
+
row = self.board[i, :].copy()
|
| 80 |
+
new_row, moved_row = self.slide(row)
|
| 81 |
+
if moved_row:
|
| 82 |
+
moved = True
|
| 83 |
+
self.board[i, :] = new_row
|
| 84 |
+
|
| 85 |
+
# 如果发生了移动,添加新方块
|
| 86 |
+
if moved:
|
| 87 |
+
self.add_tile()
|
| 88 |
+
self.check_game_over()
|
| 89 |
+
|
| 90 |
+
reward = self.calculate_reward(old_score, original_board)
|
| 91 |
+
return self.get_state(), reward, self.game_over
|
| 92 |
+
|
| 93 |
+
def slide(self, line):
|
| 94 |
+
# 移除零并合并相同数字
|
| 95 |
+
non_zero = line[line != 0]
|
| 96 |
+
new_line = np.zeros_like(line)
|
| 97 |
+
idx = 0
|
| 98 |
+
score_inc = 0
|
| 99 |
+
moved = False
|
| 100 |
+
|
| 101 |
+
# 检查是否移动
|
| 102 |
+
if not np.array_equal(non_zero, line[:len(non_zero)]):
|
| 103 |
+
moved = True
|
| 104 |
+
|
| 105 |
+
# 合并相同数字
|
| 106 |
+
i = 0
|
| 107 |
+
while i < len(non_zero):
|
| 108 |
+
if i + 1 < len(non_zero) and non_zero[i] == non_zero[i+1]:
|
| 109 |
+
new_val = non_zero[i] * 2
|
| 110 |
+
new_line[idx] = new_val
|
| 111 |
+
score_inc += new_val
|
| 112 |
+
i += 2
|
| 113 |
+
idx += 1
|
| 114 |
+
else:
|
| 115 |
+
new_line[idx] = non_zero[i]
|
| 116 |
+
i += 1
|
| 117 |
+
idx += 1
|
| 118 |
+
|
| 119 |
+
self.score += score_inc
|
| 120 |
+
return new_line, moved or (score_inc > 0)
|
| 121 |
+
|
| 122 |
+
def calculate_reward(self, old_score, original_board):
|
| 123 |
+
"""改进的奖励函数"""
|
| 124 |
+
# 1. 基本分数奖励
|
| 125 |
+
score_reward = (self.score - old_score) * 0.1
|
| 126 |
+
|
| 127 |
+
# 2. 空格子数量变化奖励
|
| 128 |
+
empty_before = np.count_nonzero(original_board == 0)
|
| 129 |
+
empty_after = np.count_nonzero(self.board == 0)
|
| 130 |
+
empty_reward = (empty_after - empty_before) * 0.15
|
| 131 |
+
|
| 132 |
+
# 3. 最大方块奖励
|
| 133 |
+
max_before = np.max(original_board)
|
| 134 |
+
max_after = np.max(self.board)
|
| 135 |
+
max_tile_reward = 0
|
| 136 |
+
if max_after > max_before:
|
| 137 |
+
max_tile_reward = np.log2(max_after) * 0.2
|
| 138 |
+
|
| 139 |
+
# 4. 合并奖励(鼓励合并)
|
| 140 |
+
merge_reward = 0
|
| 141 |
+
if self.score - old_score > 0:
|
| 142 |
+
merge_reward = np.log2(self.score - old_score) * 0.1
|
| 143 |
+
|
| 144 |
+
# 5. 单调性惩罚(鼓励有序排列)
|
| 145 |
+
monotonicity_penalty = self.calculate_monotonicity_penalty() * 0.01
|
| 146 |
+
|
| 147 |
+
# 6. 游戏结束惩罚
|
| 148 |
+
game_over_penalty = 0
|
| 149 |
+
if self.game_over:
|
| 150 |
+
game_over_penalty = -10
|
| 151 |
+
|
| 152 |
+
# 7. 平滑度奖励(鼓励相邻方块值接近)
|
| 153 |
+
smoothness_reward = self.calculate_smoothness() * 0.01
|
| 154 |
+
|
| 155 |
+
# 总奖励
|
| 156 |
+
total_reward = (
|
| 157 |
+
score_reward +
|
| 158 |
+
empty_reward +
|
| 159 |
+
max_tile_reward +
|
| 160 |
+
merge_reward +
|
| 161 |
+
smoothness_reward +
|
| 162 |
+
monotonicity_penalty +
|
| 163 |
+
game_over_penalty
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return total_reward
|
| 167 |
+
|
| 168 |
+
def calculate_monotonicity_penalty(self):
|
| 169 |
+
"""计算单调性惩罚(值越低越好)"""
|
| 170 |
+
penalty = 0
|
| 171 |
+
for i in range(self.size):
|
| 172 |
+
for j in range(self.size - 1):
|
| 173 |
+
if self.board[i][j] > self.board[i][j+1]:
|
| 174 |
+
penalty += self.board[i][j] - self.board[i][j+1]
|
| 175 |
+
else:
|
| 176 |
+
penalty += self.board[i][j+1] - self.board[i][j]
|
| 177 |
+
return penalty
|
| 178 |
+
|
| 179 |
+
def calculate_smoothness(self):
|
| 180 |
+
"""计算平滑度(值越高越好)"""
|
| 181 |
+
smoothness = 0
|
| 182 |
+
for i in range(self.size):
|
| 183 |
+
for j in range(self.size):
|
| 184 |
+
if self.board[i][j] != 0:
|
| 185 |
+
value = np.log2(self.board[i][j])
|
| 186 |
+
# 检查右侧邻居
|
| 187 |
+
if j < self.size - 1 and self.board[i][j+1] != 0:
|
| 188 |
+
neighbor_value = np.log2(self.board[i][j+1])
|
| 189 |
+
smoothness -= abs(value - neighbor_value)
|
| 190 |
+
# 检查下方邻居
|
| 191 |
+
if i < self.size - 1 and self.board[i+1][j] != 0:
|
| 192 |
+
neighbor_value = np.log2(self.board[i+1][j])
|
| 193 |
+
smoothness -= abs(value - neighbor_value)
|
| 194 |
+
return smoothness
|
| 195 |
+
|
| 196 |
+
def check_game_over(self):
|
| 197 |
+
# 检查是否还有空格子
|
| 198 |
+
if np.any(self.board == 0):
|
| 199 |
+
self.game_over = False
|
| 200 |
+
return
|
| 201 |
+
|
| 202 |
+
# 检查水平和垂直方向是否有可合并的方块
|
| 203 |
+
for i in range(self.size):
|
| 204 |
+
for j in range(self.size - 1):
|
| 205 |
+
if self.board[i][j] == self.board[i][j+1]:
|
| 206 |
+
self.game_over = False
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
for j in range(self.size):
|
| 210 |
+
for i in range(self.size - 1):
|
| 211 |
+
if self.board[i][j] == self.board[i+1][j]:
|
| 212 |
+
self.game_over = False
|
| 213 |
+
return
|
| 214 |
+
|
| 215 |
+
self.game_over = True
|
| 216 |
+
|
| 217 |
+
def get_state(self):
|
| 218 |
+
"""改进的状态表示"""
|
| 219 |
+
# 创建4个通道的状态表示
|
| 220 |
+
state = np.zeros((4, self.size, self.size), dtype=np.float32)
|
| 221 |
+
|
| 222 |
+
# 通道0: 当前方块值的对数(归一化)
|
| 223 |
+
for i in range(self.size):
|
| 224 |
+
for j in range(self.size):
|
| 225 |
+
if self.board[i][j] > 0:
|
| 226 |
+
state[0, i, j] = np.log2(self.board[i][j]) / 16.0 # 支持到65536 (2^16)
|
| 227 |
+
|
| 228 |
+
# 通道1: 空格子指示器
|
| 229 |
+
state[1] = (self.board == 0).astype(np.float32)
|
| 230 |
+
|
| 231 |
+
# 通道2: 可合并的邻居指示器
|
| 232 |
+
for i in range(self.size):
|
| 233 |
+
for j in range(self.size):
|
| 234 |
+
if self.board[i][j] > 0:
|
| 235 |
+
# 检查右侧
|
| 236 |
+
if j < self.size - 1 and self.board[i][j] == self.board[i][j+1]:
|
| 237 |
+
state[2, i, j] = 1.0
|
| 238 |
+
state[2, i, j+1] = 1.0
|
| 239 |
+
# 检查下方
|
| 240 |
+
if i < self.size - 1 and self.board[i][j] == self.board[i+1][j]:
|
| 241 |
+
state[2, i, j] = 1.0
|
| 242 |
+
state[2, i+1, j] = 1.0
|
| 243 |
+
|
| 244 |
+
# 通道3: 最大值位置(归一化)
|
| 245 |
+
max_value = np.max(self.board)
|
| 246 |
+
if max_value > 0:
|
| 247 |
+
max_positions = np.argwhere(self.board == max_value)
|
| 248 |
+
for pos in max_positions:
|
| 249 |
+
state[3, pos[0], pos[1]] = 1.0
|
| 250 |
+
|
| 251 |
+
return state
|
| 252 |
+
|
| 253 |
+
def get_valid_moves(self):
|
| 254 |
+
"""更高效的有效移动检测"""
|
| 255 |
+
valid_moves = []
|
| 256 |
+
#test_board = np.zeros_like(self.board)
|
| 257 |
+
|
| 258 |
+
# 检查上移是否有效
|
| 259 |
+
for j in range(self.size):
|
| 260 |
+
column = self.board[:, j].copy()
|
| 261 |
+
new_column, _ = self.slide(column)
|
| 262 |
+
if not np.array_equal(new_column, self.board[:, j]):
|
| 263 |
+
valid_moves.append(0)
|
| 264 |
+
break
|
| 265 |
+
|
| 266 |
+
# 检查右移是否有效
|
| 267 |
+
for i in range(self.size):
|
| 268 |
+
row = self.board[i, :].copy()[::-1]
|
| 269 |
+
new_row, _ = self.slide(row)
|
| 270 |
+
if not np.array_equal(new_row[::-1], self.board[i, :]):
|
| 271 |
+
valid_moves.append(1)
|
| 272 |
+
break
|
| 273 |
+
|
| 274 |
+
# 检查下移是否有效
|
| 275 |
+
for j in range(self.size):
|
| 276 |
+
column = self.board[::-1, j].copy()
|
| 277 |
+
new_column, _ = self.slide(column)
|
| 278 |
+
if not np.array_equal(new_column[::-1], self.board[:, j]):
|
| 279 |
+
valid_moves.append(2)
|
| 280 |
+
break
|
| 281 |
+
|
| 282 |
+
# 检查左移是否有效
|
| 283 |
+
for i in range(self.size):
|
| 284 |
+
row = self.board[i, :].copy()
|
| 285 |
+
new_row, _ = self.slide(row)
|
| 286 |
+
if not np.array_equal(new_row, self.board[i, :]):
|
| 287 |
+
valid_moves.append(3)
|
| 288 |
+
break
|
| 289 |
+
|
| 290 |
+
return valid_moves
|
| 291 |
+
|
| 292 |
+
# 改进的深度Q网络(使用Dueling DQN架构)
|
| 293 |
+
class DQN(nn.Module):
|
| 294 |
+
def __init__(self, input_channels, output_size):
|
| 295 |
+
super(DQN, self).__init__()
|
| 296 |
+
self.input_channels = input_channels
|
| 297 |
+
|
| 298 |
+
# 卷积层
|
| 299 |
+
self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1)
|
| 300 |
+
self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
| 301 |
+
self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
|
| 302 |
+
|
| 303 |
+
# Dueling DQN架构
|
| 304 |
+
# 价值流
|
| 305 |
+
self.value_conv = nn.Conv2d(128, 4, kernel_size=1)
|
| 306 |
+
self.value_fc1 = nn.Linear(4 * 4 * 4, 128)
|
| 307 |
+
self.value_fc2 = nn.Linear(128, 1)
|
| 308 |
+
|
| 309 |
+
# 优势流
|
| 310 |
+
self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1)
|
| 311 |
+
self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128)
|
| 312 |
+
self.advantage_fc2 = nn.Linear(128, output_size)
|
| 313 |
+
|
| 314 |
+
def forward(self, x):
|
| 315 |
+
x = F.relu(self.conv1(x))
|
| 316 |
+
x = F.relu(self.conv2(x))
|
| 317 |
+
x = F.relu(self.conv3(x))
|
| 318 |
+
|
| 319 |
+
# 价值流
|
| 320 |
+
value = F.relu(self.value_conv(x))
|
| 321 |
+
value = value.view(value.size(0), -1)
|
| 322 |
+
value = F.relu(self.value_fc1(value))
|
| 323 |
+
value = self.value_fc2(value)
|
| 324 |
+
|
| 325 |
+
# 优势流
|
| 326 |
+
advantage = F.relu(self.advantage_conv(x))
|
| 327 |
+
advantage = advantage.view(advantage.size(0), -1)
|
| 328 |
+
advantage = F.relu(self.advantage_fc1(advantage))
|
| 329 |
+
advantage = self.advantage_fc2(advantage)
|
| 330 |
+
|
| 331 |
+
# 合并价值流和优势流
|
| 332 |
+
q_values = value + advantage - advantage.mean(dim=1, keepdim=True)
|
| 333 |
+
return q_values
|
| 334 |
+
|
| 335 |
+
# 经验回放缓冲区(带优先级)
|
| 336 |
+
class PrioritizedReplayBuffer:
|
| 337 |
+
def __init__(self, capacity, alpha=0.6):
|
| 338 |
+
self.capacity = capacity
|
| 339 |
+
self.alpha = alpha
|
| 340 |
+
self.buffer = []
|
| 341 |
+
self.priorities = np.zeros(capacity)
|
| 342 |
+
self.pos = 0
|
| 343 |
+
self.size = 0
|
| 344 |
+
|
| 345 |
+
def push(self, state, action, reward, next_state, done):
|
| 346 |
+
# 初始优先级设置为最大优先级
|
| 347 |
+
max_priority = self.priorities.max() if self.buffer else 1.0
|
| 348 |
+
|
| 349 |
+
if len(self.buffer) < self.capacity:
|
| 350 |
+
self.buffer.append((state, action, reward, next_state, done))
|
| 351 |
+
else:
|
| 352 |
+
self.buffer[self.pos] = (state, action, reward, next_state, done)
|
| 353 |
+
|
| 354 |
+
self.priorities[self.pos] = max_priority
|
| 355 |
+
self.pos = (self.pos + 1) % self.capacity
|
| 356 |
+
self.size = min(self.size + 1, self.capacity)
|
| 357 |
+
|
| 358 |
+
def sample(self, batch_size, beta=0.4):
|
| 359 |
+
if self.size == 0:
|
| 360 |
+
return None, None, None
|
| 361 |
+
|
| 362 |
+
priorities = self.priorities[:self.size]
|
| 363 |
+
probs = priorities ** self.alpha
|
| 364 |
+
probs /= probs.sum()
|
| 365 |
+
|
| 366 |
+
indices = np.random.choice(self.size, batch_size, p=probs)
|
| 367 |
+
samples = [self.buffer[idx] for idx in indices]
|
| 368 |
+
|
| 369 |
+
# 计算重要性采样权重
|
| 370 |
+
weights = (self.size * probs[indices]) ** (-beta)
|
| 371 |
+
weights /= weights.max()
|
| 372 |
+
weights = np.array(weights, dtype=np.float32)
|
| 373 |
+
|
| 374 |
+
states, actions, rewards, next_states, dones = zip(*samples)
|
| 375 |
+
return (
|
| 376 |
+
torch.tensor(np.array(states)),
|
| 377 |
+
torch.tensor(actions, dtype=torch.long),
|
| 378 |
+
torch.tensor(rewards, dtype=torch.float),
|
| 379 |
+
torch.tensor(np.array(next_states)),
|
| 380 |
+
torch.tensor(dones, dtype=torch.float),
|
| 381 |
+
indices,
|
| 382 |
+
torch.tensor(weights)
|
| 383 |
+
)
|
| 384 |
+
|
| 385 |
+
def update_priorities(self, indices, priorities):
|
| 386 |
+
# 确保 priorities 是一个数组
|
| 387 |
+
if isinstance(priorities, np.ndarray) and priorities.ndim == 1:
|
| 388 |
+
for idx, priority in zip(indices, priorities):
|
| 389 |
+
self.priorities[idx] = priority
|
| 390 |
+
else:
|
| 391 |
+
# 处理标量情况(虽然不应该发生)
|
| 392 |
+
if not isinstance(priorities, (list, np.ndarray)):
|
| 393 |
+
priorities = [priorities] * len(indices)
|
| 394 |
+
for idx, priority in zip(indices, priorities):
|
| 395 |
+
self.priorities[idx] = priority
|
| 396 |
+
|
| 397 |
+
def __len__(self):
|
| 398 |
+
return self.size
|
| 399 |
+
|
| 400 |
+
# 改进的DQN智能体
|
| 401 |
+
class DQNAgent:
|
| 402 |
+
def __init__(self, input_channels, action_size, lr=3e-4, gamma=0.99,
|
| 403 |
+
epsilon_start=1.0, epsilon_end=0.01, epsilon_decay=0.999,
|
| 404 |
+
target_update_freq=1000, batch_size=128):
|
| 405 |
+
self.input_channels = input_channels
|
| 406 |
+
self.action_size = action_size
|
| 407 |
+
self.gamma = gamma
|
| 408 |
+
self.epsilon = epsilon_start
|
| 409 |
+
self.epsilon_end = epsilon_end
|
| 410 |
+
self.epsilon_decay = epsilon_decay
|
| 411 |
+
self.batch_size = batch_size
|
| 412 |
+
self.target_update_freq = target_update_freq
|
| 413 |
+
|
| 414 |
+
# 主网络和目标网络
|
| 415 |
+
self.policy_net = DQN(input_channels, action_size).to(device)
|
| 416 |
+
self.target_net = DQN(input_channels, action_size).to(device)
|
| 417 |
+
self.target_net.load_state_dict(self.policy_net.state_dict())
|
| 418 |
+
self.target_net.eval()
|
| 419 |
+
|
| 420 |
+
self.optimizer = optim.Adam(self.policy_net.parameters(), lr=lr, weight_decay=1e-5)
|
| 421 |
+
self.memory = PrioritizedReplayBuffer(50000)
|
| 422 |
+
self.steps_done = 0
|
| 423 |
+
self.loss_fn = nn.SmoothL1Loss(reduction='none')
|
| 424 |
+
|
| 425 |
+
def select_action(self, state, valid_moves):
|
| 426 |
+
self.steps_done += 1
|
| 427 |
+
self.epsilon = max(self.epsilon_end, self.epsilon * self.epsilon_decay)
|
| 428 |
+
|
| 429 |
+
if random.random() < self.epsilon:
|
| 430 |
+
# 随机选择有效动作
|
| 431 |
+
return random.choice(valid_moves)
|
| 432 |
+
else:
|
| 433 |
+
# 使用策略网络选择动作
|
| 434 |
+
with torch.no_grad():
|
| 435 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
| 436 |
+
q_values = self.policy_net(state_tensor).cpu().numpy().flatten()
|
| 437 |
+
|
| 438 |
+
# 只考虑有效动作
|
| 439 |
+
valid_q_values = np.full(self.action_size, -np.inf)
|
| 440 |
+
for move in valid_moves:
|
| 441 |
+
valid_q_values[move] = q_values[move]
|
| 442 |
+
|
| 443 |
+
return np.argmax(valid_q_values)
|
| 444 |
+
|
| 445 |
+
def optimize_model(self, beta=0.4):
|
| 446 |
+
if len(self.memory) < self.batch_size:
|
| 447 |
+
return 0
|
| 448 |
+
|
| 449 |
+
# 从回放缓冲区采样
|
| 450 |
+
sample = self.memory.sample(self.batch_size, beta)
|
| 451 |
+
if sample is None:
|
| 452 |
+
return 0
|
| 453 |
+
|
| 454 |
+
states, actions, rewards, next_states, dones, indices, weights = sample
|
| 455 |
+
|
| 456 |
+
states = states.to(device)
|
| 457 |
+
actions = actions.to(device)
|
| 458 |
+
rewards = rewards.to(device)
|
| 459 |
+
next_states = next_states.to(device)
|
| 460 |
+
dones = dones.to(device)
|
| 461 |
+
weights = weights.to(device)
|
| 462 |
+
|
| 463 |
+
# 计算当前Q值
|
| 464 |
+
current_q = self.policy_net(states).gather(1, actions.unsqueeze(1)).squeeze()
|
| 465 |
+
|
| 466 |
+
# 计算目标Q值(Double DQN)
|
| 467 |
+
with torch.no_grad():
|
| 468 |
+
next_actions = self.policy_net(next_states).max(1)[1]
|
| 469 |
+
next_q = self.target_net(next_states).gather(1, next_actions.unsqueeze(1)).squeeze()
|
| 470 |
+
target_q = rewards + (1 - dones) * self.gamma * next_q
|
| 471 |
+
|
| 472 |
+
# 计算损失
|
| 473 |
+
losses = self.loss_fn(current_q, target_q)
|
| 474 |
+
loss = (losses * weights).mean()
|
| 475 |
+
|
| 476 |
+
# 更新优先级(使用每个样本的损失绝对值)
|
| 477 |
+
with torch.no_grad():
|
| 478 |
+
priorities = losses.abs().cpu().numpy() + 1e-5
|
| 479 |
+
self.memory.update_priorities(indices, priorities)
|
| 480 |
+
|
| 481 |
+
# 优化模型
|
| 482 |
+
self.optimizer.zero_grad()
|
| 483 |
+
loss.backward()
|
| 484 |
+
|
| 485 |
+
# 梯度裁剪
|
| 486 |
+
torch.nn.utils.clip_grad_norm_(self.policy_net.parameters(), 10)
|
| 487 |
+
|
| 488 |
+
self.optimizer.step()
|
| 489 |
+
|
| 490 |
+
return loss.item()
|
| 491 |
+
|
| 492 |
+
def update_target_network(self):
|
| 493 |
+
self.target_net.load_state_dict(self.policy_net.state_dict())
|
| 494 |
+
|
| 495 |
+
def save_model(self, path):
|
| 496 |
+
torch.save({
|
| 497 |
+
'policy_net_state_dict': self.policy_net.state_dict(),
|
| 498 |
+
'target_net_state_dict': self.target_net.state_dict(),
|
| 499 |
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
| 500 |
+
'epsilon': self.epsilon,
|
| 501 |
+
'steps_done': self.steps_done
|
| 502 |
+
}, path)
|
| 503 |
+
|
| 504 |
+
def load_model(self, path):
|
| 505 |
+
if not os.path.exists(path):
|
| 506 |
+
print(f"Model file not found: {path}")
|
| 507 |
+
return
|
| 508 |
+
|
| 509 |
+
try:
|
| 510 |
+
# 尝试使用 weights_only=False 加载模型
|
| 511 |
+
checkpoint = torch.load(path, map_location=device, weights_only=False)
|
| 512 |
+
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
| 513 |
+
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
| 514 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 515 |
+
self.epsilon = checkpoint['epsilon']
|
| 516 |
+
self.steps_done = checkpoint['steps_done']
|
| 517 |
+
self.policy_net.eval()
|
| 518 |
+
self.target_net.eval()
|
| 519 |
+
print(f"Model loaded successfully from {path}")
|
| 520 |
+
except Exception as e:
|
| 521 |
+
print(f"Error loading model: {e}")
|
| 522 |
+
# 尝试使用旧版加载方式作为备选
|
| 523 |
+
try:
|
| 524 |
+
warnings.warn("Trying legacy load method without weights_only")
|
| 525 |
+
checkpoint = torch.load(path, map_location=device)
|
| 526 |
+
self.policy_net.load_state_dict(checkpoint['policy_net_state_dict'])
|
| 527 |
+
self.target_net.load_state_dict(checkpoint['target_net_state_dict'])
|
| 528 |
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 529 |
+
self.epsilon = checkpoint['epsilon']
|
| 530 |
+
self.steps_done = checkpoint['steps_done']
|
| 531 |
+
self.policy_net.eval()
|
| 532 |
+
self.target_net.eval()
|
| 533 |
+
print(f"Model loaded successfully using legacy method")
|
| 534 |
+
except Exception as e2:
|
| 535 |
+
print(f"Failed to load model: {e2}")
|
| 536 |
+
# 训练函数(带进度记录)
|
| 537 |
+
def train_agent(agent, env, episodes=5000, save_path='models/dqn_2048.pth',
|
| 538 |
+
checkpoint_path='models/checkpoint.pth', resume=False, start_episode=0):
|
| 539 |
+
# 创建保存模型的目录
|
| 540 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
| 541 |
+
|
| 542 |
+
# 记录训练指标
|
| 543 |
+
scores = []
|
| 544 |
+
max_tiles = []
|
| 545 |
+
avg_scores = []
|
| 546 |
+
losses = []
|
| 547 |
+
best_score = 0
|
| 548 |
+
best_max_tile = 0
|
| 549 |
+
|
| 550 |
+
# 如果续训,加载训练状态
|
| 551 |
+
if resume and os.path.exists(checkpoint_path):
|
| 552 |
+
try:
|
| 553 |
+
# 使用 weights_only=False 加载检查点
|
| 554 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 555 |
+
scores = checkpoint['scores']
|
| 556 |
+
max_tiles = checkpoint['max_tiles']
|
| 557 |
+
avg_scores = checkpoint['avg_scores']
|
| 558 |
+
losses = checkpoint['losses']
|
| 559 |
+
best_score = checkpoint.get('best_score', 0)
|
| 560 |
+
best_max_tile = checkpoint.get('best_max_tile', 0)
|
| 561 |
+
print(f"Resuming training from episode {start_episode}...")
|
| 562 |
+
except Exception as e:
|
| 563 |
+
print(f"Error loading checkpoint: {e}")
|
| 564 |
+
print("Starting training from scratch...")
|
| 565 |
+
resume = False
|
| 566 |
+
|
| 567 |
+
if not resume:
|
| 568 |
+
start_episode = 0
|
| 569 |
+
|
| 570 |
+
# 使用tqdm显示进度条
|
| 571 |
+
progress_bar = tqdm(range(start_episode, episodes), desc="Training")
|
| 572 |
+
|
| 573 |
+
for episode in progress_bar:
|
| 574 |
+
state = env.reset()
|
| 575 |
+
total_reward = 0
|
| 576 |
+
done = False
|
| 577 |
+
steps = 0
|
| 578 |
+
episode_loss = 0
|
| 579 |
+
loss_count = 0
|
| 580 |
+
|
| 581 |
+
while not done:
|
| 582 |
+
valid_moves = env.get_valid_moves()
|
| 583 |
+
if not valid_moves:
|
| 584 |
+
done = True
|
| 585 |
+
continue
|
| 586 |
+
|
| 587 |
+
action = agent.select_action(state, valid_moves)
|
| 588 |
+
next_state, reward, done = env.move(action)
|
| 589 |
+
total_reward += reward
|
| 590 |
+
|
| 591 |
+
agent.memory.push(state, action, reward, next_state, done)
|
| 592 |
+
state = next_state
|
| 593 |
+
|
| 594 |
+
# 优化模型
|
| 595 |
+
loss = agent.optimize_model(beta=min(1.0, episode / 1000))
|
| 596 |
+
if loss > 0:
|
| 597 |
+
episode_loss += loss
|
| 598 |
+
loss_count += 1
|
| 599 |
+
|
| 600 |
+
# 定期更新目标网络
|
| 601 |
+
if agent.steps_done % agent.target_update_freq == 0:
|
| 602 |
+
agent.update_target_network()
|
| 603 |
+
|
| 604 |
+
steps += 1
|
| 605 |
+
|
| 606 |
+
# 记录分数和最大方块
|
| 607 |
+
score = env.score
|
| 608 |
+
max_tile = np.max(env.board)
|
| 609 |
+
scores.append(score)
|
| 610 |
+
max_tiles.append(max_tile)
|
| 611 |
+
|
| 612 |
+
# 计算平均损失
|
| 613 |
+
avg_loss = episode_loss / loss_count if loss_count > 0 else 0
|
| 614 |
+
losses.append(avg_loss)
|
| 615 |
+
|
| 616 |
+
# 更新最佳记录
|
| 617 |
+
if score > best_score:
|
| 618 |
+
best_score = score
|
| 619 |
+
agent.save_model(save_path.replace('.pth', '_best_score.pth'))
|
| 620 |
+
if max_tile > best_max_tile:
|
| 621 |
+
best_max_tile = max_tile
|
| 622 |
+
agent.save_model(save_path.replace('.pth', '_best_tile.pth'))
|
| 623 |
+
|
| 624 |
+
# 计算最近100轮平均分数
|
| 625 |
+
recent_scores = scores[-100:] if len(scores) >= 100 else scores
|
| 626 |
+
avg_score = np.mean(recent_scores)
|
| 627 |
+
avg_scores.append(avg_score)
|
| 628 |
+
|
| 629 |
+
# 更新进度条描述
|
| 630 |
+
progress_bar.set_description(
|
| 631 |
+
f"Ep {episode+1}/{episodes} | "
|
| 632 |
+
f"Score: {score} (Avg: {avg_score:.1f}) | "
|
| 633 |
+
f"Max Tile: {max_tile} | "
|
| 634 |
+
f"Loss: {avg_loss:.4f} | "
|
| 635 |
+
f"Epsilon: {agent.epsilon:.4f}"
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
# 定期保存模型和训练状态
|
| 639 |
+
if (episode + 1) % 100 == 0:
|
| 640 |
+
agent.save_model(save_path)
|
| 641 |
+
|
| 642 |
+
# 保存训练状态
|
| 643 |
+
checkpoint = {
|
| 644 |
+
'scores': scores,
|
| 645 |
+
'max_tiles': max_tiles,
|
| 646 |
+
'avg_scores': avg_scores,
|
| 647 |
+
'losses': losses,
|
| 648 |
+
'best_score': best_score,
|
| 649 |
+
'best_max_tile': best_max_tile,
|
| 650 |
+
'episode': episode + 1,
|
| 651 |
+
'steps_done': agent.steps_done,
|
| 652 |
+
'epsilon': agent.epsilon
|
| 653 |
+
}
|
| 654 |
+
try:
|
| 655 |
+
torch.save(checkpoint, checkpoint_path)
|
| 656 |
+
except Exception as e:
|
| 657 |
+
print(f"Error saving checkpoint: {e}")
|
| 658 |
+
|
| 659 |
+
# 绘制训练曲线
|
| 660 |
+
if episode > 100: # 确保有足够的数据
|
| 661 |
+
plt.figure(figsize=(12, 8))
|
| 662 |
+
|
| 663 |
+
# 分数曲线
|
| 664 |
+
plt.subplot(2, 2, 1)
|
| 665 |
+
plt.plot(scores, label='Score')
|
| 666 |
+
plt.plot(avg_scores, label='Avg Score (100 eps)')
|
| 667 |
+
plt.xlabel('Episode')
|
| 668 |
+
plt.ylabel('Score')
|
| 669 |
+
plt.title('Training Scores')
|
| 670 |
+
plt.legend()
|
| 671 |
+
|
| 672 |
+
# 最大方块曲线
|
| 673 |
+
plt.subplot(2, 2, 2)
|
| 674 |
+
plt.plot(max_tiles, 'g-')
|
| 675 |
+
plt.xlabel('Episode')
|
| 676 |
+
plt.ylabel('Max Tile')
|
| 677 |
+
plt.title('Max Tile Achieved')
|
| 678 |
+
|
| 679 |
+
# 损失曲线
|
| 680 |
+
plt.subplot(2, 2, 3)
|
| 681 |
+
plt.plot(losses, 'r-')
|
| 682 |
+
plt.xlabel('Episode')
|
| 683 |
+
plt.ylabel('Loss')
|
| 684 |
+
plt.title('Training Loss')
|
| 685 |
+
|
| 686 |
+
# 分数分布直方图
|
| 687 |
+
plt.subplot(2, 2, 4)
|
| 688 |
+
plt.hist(scores, bins=20, alpha=0.7)
|
| 689 |
+
plt.xlabel('Score')
|
| 690 |
+
plt.ylabel('Frequency')
|
| 691 |
+
plt.title('Score Distribution')
|
| 692 |
+
|
| 693 |
+
plt.tight_layout()
|
| 694 |
+
plt.savefig('training_progress.png')
|
| 695 |
+
plt.close()
|
| 696 |
+
|
| 697 |
+
# 保存最终模型
|
| 698 |
+
agent.save_model(save_path)
|
| 699 |
+
|
| 700 |
+
return scores, max_tiles, losses
|
| 701 |
+
# 推理函数(带可视化)
|
| 702 |
+
def play_with_model(agent, env, episodes=3):
|
| 703 |
+
agent.epsilon = 0.001 # 设置很小的epsilon值进行推理
|
| 704 |
+
|
| 705 |
+
for episode in range(episodes):
|
| 706 |
+
state = env.reset()
|
| 707 |
+
done = False
|
| 708 |
+
steps = 0
|
| 709 |
+
|
| 710 |
+
print(f"\nEpisode {episode+1}")
|
| 711 |
+
print("Initial Board:")
|
| 712 |
+
print(env.board)
|
| 713 |
+
|
| 714 |
+
while not done:
|
| 715 |
+
valid_moves = env.get_valid_moves()
|
| 716 |
+
if not valid_moves:
|
| 717 |
+
done = True
|
| 718 |
+
print("No valid moves left!")
|
| 719 |
+
continue
|
| 720 |
+
|
| 721 |
+
# 选择动作
|
| 722 |
+
with torch.no_grad():
|
| 723 |
+
state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device)
|
| 724 |
+
q_values = agent.policy_net(state_tensor).cpu().numpy().flatten()
|
| 725 |
+
|
| 726 |
+
# 只考虑有效动作
|
| 727 |
+
valid_q_values = np.full(env.size, -np.inf)
|
| 728 |
+
for move in valid_moves:
|
| 729 |
+
valid_q_values[move] = q_values[move]
|
| 730 |
+
|
| 731 |
+
action = np.argmax(valid_q_values)
|
| 732 |
+
|
| 733 |
+
# 执行动作
|
| 734 |
+
next_state, reward, done = env.move(action)
|
| 735 |
+
state = next_state
|
| 736 |
+
steps += 1
|
| 737 |
+
|
| 738 |
+
# 渲染游戏
|
| 739 |
+
print(f"\nStep {steps}: Action {['Up', 'Right', 'Down', 'Left'][action]}")
|
| 740 |
+
print(env.board)
|
| 741 |
+
print(f"Score: {env.score}, Max Tile: {np.max(env.board)}")
|
| 742 |
+
#同时将结果保存至result.txt文件中
|
| 743 |
+
with open("result.txt", "a") as f:
|
| 744 |
+
f.write(f"Episode {episode+1}, Step {steps}, Action {['Up', 'Right', 'Down', 'Left'][action]}, Score: {env.score}, Max Tile: {np.max(env.board)}\n{env.board}\n")
|
| 745 |
+
f.close()
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
print(f"\nGame Over! Final Score: {env.score}, Max Tile: {np.max(env.board)}")
|
| 749 |
+
|
| 750 |
+
# 主程序
|
| 751 |
+
if __name__ == "__main__":
|
| 752 |
+
args = {"train":0, "resume":0, "play":1, "episodes":50000}
|
| 753 |
+
env = Game2048(size=4)
|
| 754 |
+
input_channels = 4 # 状态表示的通道数
|
| 755 |
+
action_size = 4 # 上、右、下、左
|
| 756 |
+
|
| 757 |
+
agent = DQNAgent(
|
| 758 |
+
input_channels,
|
| 759 |
+
action_size,
|
| 760 |
+
lr=1e-4,
|
| 761 |
+
epsilon_decay=0.999, # 更慢的衰减
|
| 762 |
+
target_update_freq=1000,
|
| 763 |
+
batch_size=256
|
| 764 |
+
)
|
| 765 |
+
|
| 766 |
+
# 训练模型
|
| 767 |
+
if args.get('train') or args.get('resume'):
|
| 768 |
+
print("Starting training...")
|
| 769 |
+
|
| 770 |
+
# 如果续训,加载检查点
|
| 771 |
+
start_episode = 0
|
| 772 |
+
checkpoint_path = 'models/checkpoint.pth'
|
| 773 |
+
if args.get('resume') and os.path.exists(checkpoint_path):
|
| 774 |
+
try:
|
| 775 |
+
# 使用 weights_only=False 加载检查点
|
| 776 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 777 |
+
start_episode = checkpoint.get('episode', 0)
|
| 778 |
+
agent.steps_done = checkpoint.get('steps_done', 0)
|
| 779 |
+
agent.epsilon = checkpoint.get('epsilon', agent.epsilon)
|
| 780 |
+
except Exception as e:
|
| 781 |
+
print(f"Error loading checkpoint: {e}")
|
| 782 |
+
print("Starting training from scratch...")
|
| 783 |
+
start_episode = 0
|
| 784 |
+
|
| 785 |
+
scores, max_tiles, losses = train_agent(
|
| 786 |
+
agent,
|
| 787 |
+
env,
|
| 788 |
+
episodes=args.get('episodes'),
|
| 789 |
+
save_path='models/dqn_2048.pth',
|
| 790 |
+
checkpoint_path=checkpoint_path,
|
| 791 |
+
resume=args.get('resume'),
|
| 792 |
+
start_episode=start_episode
|
| 793 |
+
)
|
| 794 |
+
print("Training completed!")
|
| 795 |
+
|
| 796 |
+
# 绘制最终训练结果
|
| 797 |
+
plt.figure(figsize=(15, 10))
|
| 798 |
+
|
| 799 |
+
plt.subplot(3, 1, 1)
|
| 800 |
+
plt.plot(scores)
|
| 801 |
+
plt.title('Scores per Episode')
|
| 802 |
+
plt.xlabel('Episode')
|
| 803 |
+
plt.ylabel('Score')
|
| 804 |
+
|
| 805 |
+
plt.subplot(3, 1, 2)
|
| 806 |
+
plt.plot(max_tiles)
|
| 807 |
+
plt.title('Max Tile per Episode')
|
| 808 |
+
plt.xlabel('Episode')
|
| 809 |
+
plt.ylabel('Max Tile')
|
| 810 |
+
|
| 811 |
+
plt.subplot(3, 1, 3)
|
| 812 |
+
plt.plot(losses)
|
| 813 |
+
plt.title('Training Loss per Episode')
|
| 814 |
+
plt.xlabel('Episode')
|
| 815 |
+
plt.ylabel('Loss')
|
| 816 |
+
|
| 817 |
+
plt.tight_layout()
|
| 818 |
+
plt.savefig('final_training_results.png')
|
| 819 |
+
plt.close()
|
| 820 |
+
|
| 821 |
+
# 加载模型并推理
|
| 822 |
+
if args.get('play'):
|
| 823 |
+
model_path = 'models/dqn_2048_best_tile.pth'
|
| 824 |
+
if not os.path.exists(model_path):
|
| 825 |
+
model_path = 'models/dqn_2048.pth'
|
| 826 |
+
|
| 827 |
+
if os.path.exists(model_path):
|
| 828 |
+
agent.load_model(model_path)
|
| 829 |
+
print("Playing with trained model...")
|
| 830 |
+
if not os.path.exists("result.txt"):
|
| 831 |
+
play_with_model(agent, env, episodes=1)
|
| 832 |
+
else:
|
| 833 |
+
os.remove("result.txt") #删除之前记录
|
| 834 |
+
play_with_model(agent, env, episodes=1)
|
| 835 |
+
|
| 836 |
+
else:
|
| 837 |
+
print("No trained model found. Please train the model first.")
|