File size: 4,191 Bytes
f5be61e
 
 
 
 
 
e6564ab
dbb2260
d95a8e3
f5be61e
 
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86543f5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
from huggingface_hub import hf_hub_download

# 下载模型文件(会自动缓存到 /root/.cache/huggingface/hub/)
model_path = hf_hub_download(
    repo_id="Hanxiaofeng123/Deepcube",  # 替换成你自己的仓库
    repo_type="space",
    filename="checkpoint/final_model_K_30.pth",
    cache_dir="/tmp/hf_cache"   # 指定在工作目录下缓存
)

import flask
from flask import request, jsonify
import torch
import numpy as np
import os
from config import Config
from model.DNN import DNN
from model.Cube import Cube, TARGET_STATE
from solver_utils import *

# 初始化Flask应用
app = flask.Flask(__name__, static_folder=None)
app.config['JSON_AS_ASCII'] = False
app.config['DEBUG'] = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 加载模型和创建Cube对象
model = load_model(model_path, device)
cube = Cube()

# 初始化状态接口
@app.route('/initState', methods=['POST'])
def init_state():
    # 初始状态设置为目标状态
    initial_state = TARGET_STATE.copy()

    # 生成旋转索引和状态映射
    rotateIdxs_old = {}
    rotateIdxs_new = {}
    for move_name in cube.moves.keys():
        # 使用Cube类中的实际移动映射
        move_mapping = cube.moves[move_name]
        # 构建old到new的映射
        rotateIdxs_old[move_name] = move_mapping.tolist()
        rotateIdxs_new[move_name] = list(range(54))

    # 定义状态到特征提取和反向的映射
    # 这里假设状态和特征提取使用相同的顺序
    stateToFE = list(range(54))
    FEToState = list(range(54))
    legalMoves = list(cube.moves.keys())

    response = {
        'state': initial_state.tolist(),
        'rotateIdxs_old': rotateIdxs_old,
        'rotateIdxs_new': rotateIdxs_new,
        'stateToFE': stateToFE,
        'FEToState': FEToState,
        'legalMoves': legalMoves
    }

    return jsonify(response)

# 求解魔方接口
@app.route('/solve', methods=['POST'])
def solve():
    try:
        data = request.json
        if not data or 'state' not in data:
            return jsonify({'error': '请求参数错误,缺少state字段'}), 400

        state = np.array(data['state'])
        if state.shape != (54,):
            return jsonify({'error': 'state参数格式错误,应为长度为54的数组'}), 400

        print("开始求解魔方...")
        action_path, solution_state_path = a_star_search(state, model, cube)

        if action_path is None:
            return jsonify({'error': '未能找到解决方案'}), 404

        # 生成反向动作路径
        solveMoves_rev = []
        for action in action_path:
            rev_action = action[:]
            # 反转动作方向
            if "inv" in rev_action:
                rev_action = rev_action[0]
            else:
                rev_action += "_inv"
            solveMoves_rev.append(rev_action)
            
        print(action_path)
        print(solveMoves_rev)

        response = {
            'moves': [action for action in action_path],
            'moves_rev': solveMoves_rev,
            'solve_text': action_path
        }

        return jsonify(response)
    except Exception as e:
        print(f"求解过程中发生错误: {str(e)}")
        return jsonify({'error': f'服务器内部错误: {str(e)}'}), 500

# 静态文件服务
@app.route('/static/<path:path>')
def send_static(path):
    print("Serving static file:", path)
    return flask.send_from_directory('web/deepcube.igb.uci.edu/static', path)

# 主页
@app.route('/')
def home():
    return flask.send_from_directory('web/deepcube.igb.uci.edu', 'index.html')

# 处理缺失的heapq模块
import heapq

if __name__ == '__main__':
    # 确保checkpoint目录存在
    if not os.path.exists('checkpoint'):
        os.makedirs('checkpoint')
        print("创建checkpoint目录,请将模型文件放入该目录")

    # 检查模型文件是否存在
    if not os.path.exists(model_path):
        print(f"警告:未找到模型文件 {model_path}")
        print("请确保模型文件存在于checkpoint目录中")

    # 启动服务器
    # 修改为仅监听本地主机
    app.run(host='0.0.0.0', port=7860)