Spaces:
Sleeping
Sleeping
| import torch | |
| from huggingface_hub import hf_hub_download | |
| # 下载模型文件(会自动缓存到 /root/.cache/huggingface/hub/) | |
| model_path = hf_hub_download( | |
| repo_id="Hanxiaofeng123/Deepcube", # 替换成你自己的仓库 | |
| repo_type="space", | |
| filename="checkpoint/final_model_K_30.pth", | |
| cache_dir="/tmp/hf_cache" # 指定在工作目录下缓存 | |
| ) | |
| import flask | |
| from flask import request, jsonify | |
| import torch | |
| import numpy as np | |
| import os | |
| from config import Config | |
| from model.DNN import DNN | |
| from model.Cube import Cube, TARGET_STATE | |
| from solver_utils import * | |
| # 初始化Flask应用 | |
| app = flask.Flask(__name__, static_folder=None) | |
| app.config['JSON_AS_ASCII'] = False | |
| app.config['DEBUG'] = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # 加载模型和创建Cube对象 | |
| model = load_model(model_path, device) | |
| cube = Cube() | |
| # 初始化状态接口 | |
| def init_state(): | |
| # 初始状态设置为目标状态 | |
| initial_state = TARGET_STATE.copy() | |
| # 生成旋转索引和状态映射 | |
| rotateIdxs_old = {} | |
| rotateIdxs_new = {} | |
| for move_name in cube.moves.keys(): | |
| # 使用Cube类中的实际移动映射 | |
| move_mapping = cube.moves[move_name] | |
| # 构建old到new的映射 | |
| rotateIdxs_old[move_name] = move_mapping.tolist() | |
| rotateIdxs_new[move_name] = list(range(54)) | |
| # 定义状态到特征提取和反向的映射 | |
| # 这里假设状态和特征提取使用相同的顺序 | |
| stateToFE = list(range(54)) | |
| FEToState = list(range(54)) | |
| legalMoves = list(cube.moves.keys()) | |
| response = { | |
| 'state': initial_state.tolist(), | |
| 'rotateIdxs_old': rotateIdxs_old, | |
| 'rotateIdxs_new': rotateIdxs_new, | |
| 'stateToFE': stateToFE, | |
| 'FEToState': FEToState, | |
| 'legalMoves': legalMoves | |
| } | |
| return jsonify(response) | |
| # 求解魔方接口 | |
| def solve(): | |
| try: | |
| data = request.json | |
| if not data or 'state' not in data: | |
| return jsonify({'error': '请求参数错误,缺少state字段'}), 400 | |
| state = np.array(data['state']) | |
| if state.shape != (54,): | |
| return jsonify({'error': 'state参数格式错误,应为长度为54的数组'}), 400 | |
| print("开始求解魔方...") | |
| action_path, solution_state_path = a_star_search(state, model, cube) | |
| if action_path is None: | |
| return jsonify({'error': '未能找到解决方案'}), 404 | |
| # 生成反向动作路径 | |
| solveMoves_rev = [] | |
| for action in action_path: | |
| rev_action = action[:] | |
| # 反转动作方向 | |
| if "inv" in rev_action: | |
| rev_action = rev_action[0] | |
| else: | |
| rev_action += "_inv" | |
| solveMoves_rev.append(rev_action) | |
| print(action_path) | |
| print(solveMoves_rev) | |
| response = { | |
| 'moves': [action for action in action_path], | |
| 'moves_rev': solveMoves_rev, | |
| 'solve_text': action_path | |
| } | |
| return jsonify(response) | |
| except Exception as e: | |
| print(f"求解过程中发生错误: {str(e)}") | |
| return jsonify({'error': f'服务器内部错误: {str(e)}'}), 500 | |
| # 静态文件服务 | |
| def send_static(path): | |
| print("Serving static file:", path) | |
| return flask.send_from_directory('web/deepcube.igb.uci.edu/static', path) | |
| # 主页 | |
| def home(): | |
| return flask.send_from_directory('web/deepcube.igb.uci.edu', 'index.html') | |
| # 处理缺失的heapq模块 | |
| import heapq | |
| if __name__ == '__main__': | |
| # 确保checkpoint目录存在 | |
| if not os.path.exists('checkpoint'): | |
| os.makedirs('checkpoint') | |
| print("创建checkpoint目录,请将模型文件放入该目录") | |
| # 检查模型文件是否存在 | |
| if not os.path.exists(model_path): | |
| print(f"警告:未找到模型文件 {model_path}") | |
| print("请确保模型文件存在于checkpoint目录中") | |
| # 启动服务器 | |
| # 修改为仅监听本地主机 | |
| app.run(host='0.0.0.0', port=7860) |