Spaces:
Sleeping
Sleeping
File size: 5,698 Bytes
b570cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import numpy as np
import torch
import heapq
import os
from model.Cube import TARGET_STATE
from model.DNN import DNN
# 设置矩阵乘法精度,与train.py保持一致
torch.set_float32_matmul_precision('medium')
def load_model(model_path, device):
"""
加载预训练模型
参数:
model_path: 模型权重文件路径
device: 运行设备 (cpu 或 cuda)
use_half_precision: 是否使用半精度浮点数进行推理
"""
input_dim = 54 * 6 # 魔方54个贴纸,每个6种颜色,使用one-hot编码
model = DNN(input_dim, num_residual_blocks=4)
try:
checkpoint = torch.load(model_path, map_location=device, weights_only=True)
# 处理模型权重,移除_orig_mod.前缀
model_weights = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items() if k.startswith('_orig_mod.')}
if not model_weights:
model_weights = checkpoint
model.load_state_dict(model_weights)
except Exception as e:
print(f"加载模型失败: {str(e)}")
raise
model = model.to(device)
# 启用推理模式
model.eval()
# 可选:使用TorchScript优化模型执行
try:
# 准备一个示例输入以进行追踪
example_input = torch.randn(1, input_dim).to(device)
model = torch.jit.trace(model, example_input)
print("已使用TorchScript优化模型")
except Exception as e:
print(f"TorchScript优化失败: {str(e)}")
pass
return model
def state_to_one_hot(state):
"""
将魔方状态转换为one-hot编码
"""
one_hot = np.zeros(54 * 6)
for i, color in enumerate(state):
one_hot[i * 6 + color] = 1
return one_hot
def h(state, model):
"""
启发函数,使用模型预测当前状态到目标状态的距离
"""
with torch.no_grad():
if len(state.shape) == 2:
prediction = model(state)
return prediction.squeeze().tolist()
else:
state = state.unsqueeze(0)
prediction = model(state)
return prediction.item()
def a_star_search(initial_state, model, cube, lam=0.6, max_iterations=200, N=1000):
"""
A*搜索算法求解魔方
"""
# 检查初始状态是否为目标状态
if np.array_equal(initial_state, TARGET_STATE):
return [], [initial_state]
open_set = []
closed_set = set()
initial_state_tensor = torch.tensor(initial_state, device=next(model.parameters()).device).long()
initial_state_tensor = torch.nn.functional.one_hot(initial_state_tensor, num_classes=6).float().view(-1)
g_score = {tuple(initial_state): 0}
h_score = {tuple(initial_state): h(initial_state_tensor, model)}
f_score = {tuple(initial_state): lam * g_score[tuple(initial_state)] + h_score[tuple(initial_state)]}
heapq.heappush(open_set, (f_score[tuple(initial_state)], tuple(initial_state)))
came_from = {}
iterations = 0
while open_set and iterations < max_iterations:
iterations += 1
#print(f"当前迭代: {iterations}, 开放集大小: {len(open_set)}")
current_states = []
for _ in range(min(N, len(open_set))):
_, state_tuple = heapq.heappop(open_set)
current_states.append(state_tuple)
# 收集所有邻居状态
neighbor_states = []
for current_state_tuple in current_states:
current_state = np.array(current_state_tuple)
if np.array_equal(current_state, TARGET_STATE):
action_path = []
state_path = [current_state]
while current_state_tuple in came_from:
current_state_tuple, action = came_from[current_state_tuple]
action_path.append(action)
state_path.append(current_state_tuple)
return action_path[::-1], state_path[::-1]
if current_state_tuple in closed_set:
continue
closed_set.add(current_state_tuple)
for action in cube.moves.keys():
next_state = cube.apply_action(current_state, action)
next_state_tuple = tuple(next_state)
tentative_g_score = g_score[current_state_tuple] + 1
if next_state_tuple in closed_set:
if tentative_g_score < g_score.get(next_state_tuple, float('inf')):
closed_set.remove(next_state_tuple)
else:
continue
if next_state_tuple not in g_score or tentative_g_score < g_score[next_state_tuple]:
came_from[next_state_tuple] = (current_state_tuple, action)
g_score[next_state_tuple] = tentative_g_score
neighbor_states.append(next_state)
if neighbor_states:
neighbor_states = np.stack(neighbor_states)
neighbor_states = np.unique(neighbor_states, axis=0)
neighbor_states_tensor = torch.tensor(neighbor_states, device=next(model.parameters()).device).long()
neighbor_states_tensor = torch.nn.functional.one_hot(neighbor_states_tensor, num_classes=6).float().view(-1, 324)
neighbor_h_scores = h(neighbor_states_tensor, model)
for i, state in enumerate(neighbor_states):
state_tuple = tuple(state)
h_score[state_tuple] = neighbor_h_scores[i]
f_score[state_tuple] = lam * g_score[state_tuple] + neighbor_h_scores[i]
heapq.heappush(open_set, (f_score[state_tuple], state_tuple))
return None, None # 未找到解决方案 |