File size: 5,698 Bytes
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import numpy as np
import torch
import heapq
import os
from model.Cube import TARGET_STATE
from model.DNN import DNN

# 设置矩阵乘法精度,与train.py保持一致
torch.set_float32_matmul_precision('medium')


def load_model(model_path, device):
    """
    加载预训练模型
    
    参数:
    model_path: 模型权重文件路径
    device: 运行设备 (cpu 或 cuda)
    use_half_precision: 是否使用半精度浮点数进行推理
    """
    input_dim = 54 * 6  # 魔方54个贴纸,每个6种颜色,使用one-hot编码
    model = DNN(input_dim, num_residual_blocks=4)

    try:
        checkpoint = torch.load(model_path, map_location=device, weights_only=True)
        # 处理模型权重,移除_orig_mod.前缀
        model_weights = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items() if k.startswith('_orig_mod.')}
        if not model_weights:
            model_weights = checkpoint
        model.load_state_dict(model_weights)
    except Exception as e:
        print(f"加载模型失败: {str(e)}")
        raise

    model = model.to(device)
    
    # 启用推理模式
    model.eval()
    
    # 可选:使用TorchScript优化模型执行
    try:
        # 准备一个示例输入以进行追踪
        example_input = torch.randn(1, input_dim).to(device)
        model = torch.jit.trace(model, example_input)
        print("已使用TorchScript优化模型")
    except Exception as e:
        print(f"TorchScript优化失败: {str(e)}")
        pass
    
    return model


def state_to_one_hot(state):
    """
    将魔方状态转换为one-hot编码
    """
    one_hot = np.zeros(54 * 6)
    for i, color in enumerate(state):
        one_hot[i * 6 + color] = 1
    return one_hot


def h(state, model):
    """
    启发函数,使用模型预测当前状态到目标状态的距离
    """
    with torch.no_grad():
        if len(state.shape) == 2:
            prediction = model(state)
            return prediction.squeeze().tolist()
        else:
            state = state.unsqueeze(0)
            prediction = model(state)
            return prediction.item()


def a_star_search(initial_state, model, cube, lam=0.6, max_iterations=200, N=1000):
    """
    A*搜索算法求解魔方
    """
    # 检查初始状态是否为目标状态
    if np.array_equal(initial_state, TARGET_STATE):
        return [], [initial_state]

    open_set = []
    closed_set = set()

    initial_state_tensor = torch.tensor(initial_state, device=next(model.parameters()).device).long()
    initial_state_tensor = torch.nn.functional.one_hot(initial_state_tensor, num_classes=6).float().view(-1)
    g_score = {tuple(initial_state): 0}
    h_score = {tuple(initial_state): h(initial_state_tensor, model)}
    f_score = {tuple(initial_state): lam * g_score[tuple(initial_state)] + h_score[tuple(initial_state)]}

    heapq.heappush(open_set, (f_score[tuple(initial_state)], tuple(initial_state)))

    came_from = {}
    iterations = 0

    while open_set and iterations < max_iterations:
        iterations += 1
        #print(f"当前迭代: {iterations}, 开放集大小: {len(open_set)}")

        current_states = []
        for _ in range(min(N, len(open_set))):
            _, state_tuple = heapq.heappop(open_set)
            current_states.append(state_tuple)

        # 收集所有邻居状态
        neighbor_states = []

        for current_state_tuple in current_states:
            current_state = np.array(current_state_tuple)

            if np.array_equal(current_state, TARGET_STATE):
                action_path = []
                state_path = [current_state]
                while current_state_tuple in came_from:
                    current_state_tuple, action = came_from[current_state_tuple]
                    action_path.append(action)
                    state_path.append(current_state_tuple)
                return action_path[::-1], state_path[::-1]

            if current_state_tuple in closed_set:
                continue

            closed_set.add(current_state_tuple)

            for action in cube.moves.keys():
                next_state = cube.apply_action(current_state, action)
                next_state_tuple = tuple(next_state)

                tentative_g_score = g_score[current_state_tuple] + 1

                if next_state_tuple in closed_set:
                    if tentative_g_score < g_score.get(next_state_tuple, float('inf')):
                        closed_set.remove(next_state_tuple)
                    else:
                        continue

                if next_state_tuple not in g_score or tentative_g_score < g_score[next_state_tuple]:
                    came_from[next_state_tuple] = (current_state_tuple, action)
                    g_score[next_state_tuple] = tentative_g_score
                    neighbor_states.append(next_state)

        if neighbor_states:
            neighbor_states = np.stack(neighbor_states)
            neighbor_states = np.unique(neighbor_states, axis=0)
            neighbor_states_tensor = torch.tensor(neighbor_states, device=next(model.parameters()).device).long()
            neighbor_states_tensor = torch.nn.functional.one_hot(neighbor_states_tensor, num_classes=6).float().view(-1, 324)

            neighbor_h_scores = h(neighbor_states_tensor, model)
            for i, state in enumerate(neighbor_states):
                state_tuple = tuple(state)
                h_score[state_tuple] = neighbor_h_scores[i]
                f_score[state_tuple] = lam * g_score[state_tuple] + neighbor_h_scores[i]
                heapq.heappush(open_set, (f_score[state_tuple], state_tuple))

    return None, None  # 未找到解决方案