Spaces:
Running
Running
| import gradio as gr | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from game2048 import Game2048 | |
| # 检测可用设备 | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| elif torch.xpu.is_available(): | |
| device = torch.device("xpu") | |
| else: | |
| device = torch.device("cpu") | |
| # 创建游戏实例 | |
| game = Game2048(size=4) | |
| # 方块颜色映射(根据数字值) | |
| TILE_COLORS = { | |
| 0: "#cdc1b4", # 空白格子 | |
| 2: "#eee4da", # 2 | |
| 4: "#ede0c8", # 4 | |
| 8: "#f2b179", # 8 | |
| 16: "#f59563", # 16 | |
| 32: "#f67c5f", # 32 | |
| 64: "#f65e3b", # 64 | |
| 128: "#edcf72", # 128 | |
| 256: "#edcc61", # 256 | |
| 512: "#edc850", # 512 | |
| 1024: "#edc53f", # 1024 | |
| 2048: "#edc22e", # 2048 | |
| 4096: "#3c3a32", # 4096+, | |
| 8192: "#3c3a32", # 8192+ | |
| 16384: "#3c3a32", # 16384+ | |
| } | |
| # 文本颜色映射(根据背景深浅) | |
| TEXT_COLORS = { | |
| 0: "#776e65", # 空白格子 | |
| 2: "#776e65", # 2 | |
| 4: "#776e65", # 4 | |
| 8: "#f9f6f2", # 8+ | |
| 16: "#f9f6f2", # 16+ | |
| 32: "#f9f6f2", # 32+ | |
| 64: "#f9f6f2", # 64+ | |
| 128: "#f9f6f2", # 128+ | |
| 256: "#f9f6f2", # 256+ | |
| 512: "#f9f6f2", # 512+ | |
| 1024: "#f9f6f2", # 1024+ | |
| 2048: "#f9f6f2", # 2048+ | |
| 4096: "#f9f6f2", # 4096+ | |
| 8192: "#f9f6f2", # 8192+ | |
| 16384: "#f9f6f2", # 16384+ | |
| } | |
| # 定义DQN网络结构(与训练时相同) | |
| class DQN(nn.Module): | |
| def __init__(self, input_channels, output_size): | |
| super(DQN, self).__init__() | |
| self.input_channels = input_channels | |
| # 卷积层 | |
| self.conv1 = nn.Conv2d(input_channels, 128, kernel_size=3, padding=1) | |
| self.conv2 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
| self.conv3 = nn.Conv2d(128, 128, kernel_size=3, padding=1) | |
| # Dueling DQN架构 | |
| # 价值流 | |
| self.value_conv = nn.Conv2d(128, 4, kernel_size=1) | |
| self.value_fc1 = nn.Linear(4 * 4 * 4, 128) | |
| self.value_fc2 = nn.Linear(128, 1) | |
| # 优势流 | |
| self.advantage_conv = nn.Conv2d(128, 16, kernel_size=1) | |
| self.advantage_fc1 = nn.Linear(16 * 4 * 4, 128) | |
| self.advantage_fc2 = nn.Linear(128, output_size) | |
| def forward(self, x): | |
| x = F.relu(self.conv1(x)) | |
| x = F.relu(self.conv2(x)) | |
| x = F.relu(self.conv3(x)) | |
| # 价值流 | |
| value = F.relu(self.value_conv(x)) | |
| value = value.view(value.size(0), -1) | |
| value = F.relu(self.value_fc1(value)) | |
| value = self.value_fc2(value) | |
| # 优势流 | |
| advantage = F.relu(self.advantage_conv(x)) | |
| advantage = advantage.view(advantage.size(0), -1) | |
| advantage = F.relu(self.advantage_fc1(advantage)) | |
| advantage = self.advantage_fc2(advantage) | |
| # 合并价值流和优势流 | |
| q_values = value + advantage - advantage.mean(dim=1, keepdim=True) | |
| return q_values | |
| # 加载模型(支持CUDA) | |
| def load_model(model_path, device): | |
| model = DQN(4, 4).to(device) # 将模型移动到指定设备 | |
| try: | |
| # 尝试加载模型 | |
| checkpoint = torch.load(model_path, map_location=device) | |
| # 检查检查点是否包含完整的模型状态 | |
| if 'policy_net_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['policy_net_state_dict']) | |
| else: | |
| # 如果检查点不包含policy_net_state_dict,尝试直接加载 | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print(f"模型成功加载到: {device}") | |
| return model | |
| except Exception as e: | |
| print(f"模型加载失败: {e}") | |
| # 尝试备选模型路径 | |
| alt_path = model_path.replace('_best_tile', '') | |
| try: | |
| checkpoint = torch.load(alt_path, map_location=device) | |
| if 'policy_net_state_dict' in checkpoint: | |
| model.load_state_dict(checkpoint['policy_net_state_dict']) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.eval() | |
| print(f"备选模型成功加载: {alt_path}") | |
| return model | |
| except Exception as e2: | |
| print(f"备选模型加载失败: {e2}") | |
| return None | |
| # 尝试加载模型 | |
| model_paths = [ | |
| "models/dqn_2048_best_tile.pth", | |
| "models/dqn_2048.pth", | |
| "dqn_2048_best_tile.pth", | |
| "dqn_2048.pth" | |
| ] | |
| model = None | |
| for path in model_paths: | |
| model = load_model(path, device) | |
| if model: | |
| break | |
| if not model: | |
| print("警告: 未加载任何模型,AI功能将不可用") | |
| def render_board(board): | |
| html = "<div style='background-color:#bbada0; padding:10px; border-radius:6px;'>" | |
| html += "<table style='border-spacing:10px; border-collapse:separate;'>" | |
| for i in range(game.size): | |
| html += "<tr>" | |
| for j in range(game.size): | |
| value = board[i][j] | |
| color = TILE_COLORS.get(value, "#3c3a32") # 默认深色 | |
| text_color = TEXT_COLORS.get(value, "#f9f6f2") # 默认浅色 | |
| font_size = "36px" if value < 100 else "30px" if value < 1000 else "24px" | |
| html += f""" | |
| <td style='background-color:{color}; | |
| width:80px; height:80px; | |
| border-radius:4px; | |
| text-align:center; | |
| font-weight:bold; | |
| font-size:{font_size}; | |
| color:{text_color};'> | |
| {value if value > 0 else ''} | |
| </td> | |
| """ | |
| html += "</tr>" | |
| html += "</table></div>" | |
| return html | |
| def make_move(direction): | |
| """执行移动操作并更新界面""" | |
| direction_names = ["上", "右", "下", "左"] | |
| # 执行移动 | |
| new_board, game_over = game.move(direction) | |
| # 渲染棋盘 | |
| board_html = render_board(new_board) | |
| # 更新状态信息 | |
| status = f"<b>移动方向:</b> {direction_names[direction]}" | |
| status += f"<br><b>当前分数:</b> {game.score}" | |
| status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
| if game.game_over: | |
| status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>" | |
| status += f"<br><b>最终分数:</b> {game.score}" | |
| return board_html, status | |
| def reset_game(): | |
| """重置游戏""" | |
| global game | |
| game = Game2048(size=4) | |
| board = game.reset() | |
| # 渲染棋盘 | |
| board_html = render_board(board) | |
| # 初始状态信息 | |
| status = "<b>游戏已重置!</b>" | |
| status += f"<br><b>当前分数:</b> {game.score}" | |
| status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
| return board_html, status | |
| def ai_move(): | |
| """使用AI模型进行一步移动""" | |
| if model is None: | |
| return render_board(game.board), "<b>错误:</b> 未加载AI模型" | |
| # 获取当前状态 | |
| state = game.get_state() | |
| # 获取有效移动 | |
| valid_moves = game.get_valid_moves() | |
| if not valid_moves: | |
| return render_board(game.board), "<b>游戏结束!</b> 没有有效移动" | |
| try: | |
| # 转换状态为模型输入并移动到设备 | |
| state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) | |
| # 模型预测 | |
| with torch.no_grad(): | |
| q_values = model(state_tensor).cpu().numpy().flatten() | |
| # 只考虑有效动作 | |
| valid_q_values = np.full(4, -np.inf) | |
| for move in valid_moves: | |
| valid_q_values[move] = q_values[move] | |
| # 选择最佳动作 | |
| action = np.argmax(valid_q_values) | |
| # 执行移动 | |
| direction_names = ["上", "右", "下", "左"] | |
| new_board, game_over = game.move(action) | |
| # 渲染棋盘 | |
| board_html = render_board(new_board) | |
| # 更新状态信息 | |
| status = f"<b>AI移动方向:</b> {direction_names[action]}" | |
| status += f"<br><b>当前分数:</b> {game.score}" | |
| status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
| status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}" | |
| if game.game_over: | |
| status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>" | |
| status += f"<br><b>最终分数:</b> {game.score}" | |
| return board_html, status | |
| except Exception as e: | |
| print(f"AI移动出错: {e}") | |
| return render_board(game.board), f"<b>错误:</b> AI移动失败 - {str(e)}" | |
| # 创建Gradio界面 | |
| with gr.Blocks(title="2048游戏", theme="soft") as demo: | |
| gr.Markdown("# 🎮 2048游戏") | |
| gr.Markdown(f"当前运行设备: **{'GPU' if device.type == 'cuda' else 'CPU'}**") | |
| gr.Markdown("使用方向键或下方的按钮移动方块,相同数字的方块相撞时会合并!") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| board_html = gr.HTML(render_board(game.board)) | |
| status_display = gr.HTML("<b>当前分数:</b> 0<br><b>最大方块:</b> 2") | |
| with gr.Column(): | |
| gr.Markdown("## 手动操作") | |
| with gr.Row(): | |
| up_btn = gr.Button("上 ↑", elem_id="up-btn") | |
| left_btn = gr.Button("左 ←", elem_id="left-btn") | |
| with gr.Row(): | |
| down_btn = gr.Button("下 ↓", elem_id="down-btn") | |
| right_btn = gr.Button("右 →", elem_id="right-btn") | |
| with gr.Row(): | |
| reset_btn = gr.Button("🔄 重置游戏", elem_id="reset-btn") | |
| gr.Markdown("## AI操作") | |
| with gr.Row(): | |
| ai_btn = gr.Button("🤖 AI移动一步", elem_id="ai-btn") | |
| auto_btn = gr.Button("🚀 连续AI模式", elem_id="auto-btn") | |
| # 连接按钮事件 | |
| up_btn.click(lambda: make_move(0), outputs=[board_html, status_display]) | |
| right_btn.click(lambda: make_move(1), outputs=[board_html, status_display]) | |
| down_btn.click(lambda: make_move(2), outputs=[board_html, status_display]) | |
| left_btn.click(lambda: make_move(3), outputs=[board_html, status_display]) | |
| reset_btn.click(reset_game, outputs=[board_html, status_display]) | |
| ai_btn.click(ai_move, outputs=[board_html, status_display]) | |
| # 连续AI模式 | |
| def auto_play(): | |
| """连续AI移动直到游戏结束""" | |
| if model is None: | |
| return render_board(game.board), "<b>错误:</b> 未加载AI模型" | |
| moves = 0 | |
| max_moves = 200 # 防止无限循环 | |
| while not game.game_over and moves < max_moves: | |
| # 获取当前状态 | |
| state = game.get_state() | |
| # 获取有效移动 | |
| valid_moves = game.get_valid_moves() | |
| if not valid_moves: | |
| break | |
| # 转换状态为模型输入并移动到设备 | |
| state_tensor = torch.tensor(state, dtype=torch.float).unsqueeze(0).to(device) | |
| # 模型预测 | |
| with torch.no_grad(): | |
| q_values = model(state_tensor).cpu().numpy().flatten() | |
| # 只考虑有效动作 | |
| valid_q_values = np.full(4, -np.inf) | |
| for move in valid_moves: | |
| valid_q_values[move] = q_values[move] | |
| # 选择最佳动作 | |
| action = np.argmax(valid_q_values) | |
| # 执行移动 | |
| game.move(action) | |
| moves += 1 | |
| # 渲染棋盘 | |
| board_html = render_board(game.board) | |
| # 更新状态信息 | |
| status = f"<b>连续AI完成!</b>" | |
| status += f"<br><b>移动次数:</b> {moves}" | |
| status += f"<br><b>当前分数:</b> {game.score}" | |
| status += f"<br><b>最大方块:</b> {np.max(game.board)}" | |
| status += f"<br><b>设备:</b> {'GPU' if device.type == 'cuda' else 'CPU'}" | |
| if game.game_over: | |
| status += "<br><br><div style='color:#ff0000; font-weight:bold;'>游戏结束!</div>" | |
| status += f"<br><b>最终分数:</b> {game.score}" | |
| return board_html, status | |
| auto_btn.click(auto_play, outputs=[board_html, status_display]) | |
| # 添加键盘快捷键支持 | |
| demo.load( | |
| fn=None, | |
| inputs=None, | |
| outputs=None, | |
| js="""() => { | |
| document.addEventListener('keydown', function(e) { | |
| if (e.key === 'ArrowUp') { | |
| document.getElementById('up-btn').click(); | |
| } else if (e.key === 'ArrowRight') { | |
| document.getElementById('right-btn').click(); | |
| } else if (e.key === 'ArrowDown') { | |
| document.getElementById('down-btn').click(); | |
| } else if (e.key === 'ArrowLeft') { | |
| document.getElementById('left-btn').click(); | |
| } else if (e.key === 'r' || e.key === 'R') { | |
| document.getElementById('reset-btn').click(); | |
| } else if (e.key === 'a' || e.key === 'A') { | |
| document.getElementById('ai-btn').click(); | |
| } else if (e.key === 's' || e.key === 'S') { | |
| document.getElementById('auto-btn').click(); | |
| } | |
| }); | |
| }""" | |
| ) | |
| gr.Markdown("### 📚 使用说明") | |
| gr.Markdown("1. 使用方向键或下方的按钮移动方块") | |
| gr.Markdown("2. 相同数字的方块相撞时会合并") | |
| gr.Markdown("3. **快捷键说明**:") | |
| gr.Markdown(" - ↑/↓/←/→: 移动方块") | |
| gr.Markdown(" - R: 重置游戏") | |
| gr.Markdown(" - A: AI移动一步") | |
| gr.Markdown(" - S: 连续AI模式(自动玩到游戏结束)") | |
| # 启动界面 | |
| if __name__ == "__main__": | |
| demo.launch() |