Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| import heapq | |
| import os | |
| from model.Cube import TARGET_STATE | |
| from model.DNN import DNN | |
| # 设置矩阵乘法精度,与train.py保持一致 | |
| torch.set_float32_matmul_precision('medium') | |
| def load_model(model_path, device): | |
| """ | |
| 加载预训练模型 | |
| 参数: | |
| model_path: 模型权重文件路径 | |
| device: 运行设备 (cpu 或 cuda) | |
| use_half_precision: 是否使用半精度浮点数进行推理 | |
| """ | |
| input_dim = 54 * 6 # 魔方54个贴纸,每个6种颜色,使用one-hot编码 | |
| model = DNN(input_dim, num_residual_blocks=4) | |
| try: | |
| checkpoint = torch.load(model_path, map_location=device, weights_only=True) | |
| # 处理模型权重,移除_orig_mod.前缀 | |
| model_weights = {k.replace('_orig_mod.', ''): v for k, v in checkpoint.items() if k.startswith('_orig_mod.')} | |
| if not model_weights: | |
| model_weights = checkpoint | |
| model.load_state_dict(model_weights) | |
| except Exception as e: | |
| print(f"加载模型失败: {str(e)}") | |
| raise | |
| model = model.to(device) | |
| # 启用推理模式 | |
| model.eval() | |
| # 可选:使用TorchScript优化模型执行 | |
| try: | |
| # 准备一个示例输入以进行追踪 | |
| example_input = torch.randn(1, input_dim).to(device) | |
| model = torch.jit.trace(model, example_input) | |
| print("已使用TorchScript优化模型") | |
| except Exception as e: | |
| print(f"TorchScript优化失败: {str(e)}") | |
| pass | |
| return model | |
| def state_to_one_hot(state): | |
| """ | |
| 将魔方状态转换为one-hot编码 | |
| """ | |
| one_hot = np.zeros(54 * 6) | |
| for i, color in enumerate(state): | |
| one_hot[i * 6 + color] = 1 | |
| return one_hot | |
| def h(state, model): | |
| """ | |
| 启发函数,使用模型预测当前状态到目标状态的距离 | |
| """ | |
| with torch.no_grad(): | |
| if len(state.shape) == 2: | |
| prediction = model(state) | |
| return prediction.squeeze().tolist() | |
| else: | |
| state = state.unsqueeze(0) | |
| prediction = model(state) | |
| return prediction.item() | |
| def a_star_search(initial_state, model, cube, lam=0.6, max_iterations=200, N=1000): | |
| """ | |
| A*搜索算法求解魔方 | |
| """ | |
| # 检查初始状态是否为目标状态 | |
| if np.array_equal(initial_state, TARGET_STATE): | |
| return [], [initial_state] | |
| open_set = [] | |
| closed_set = set() | |
| initial_state_tensor = torch.tensor(initial_state, device=next(model.parameters()).device).long() | |
| initial_state_tensor = torch.nn.functional.one_hot(initial_state_tensor, num_classes=6).float().view(-1) | |
| g_score = {tuple(initial_state): 0} | |
| h_score = {tuple(initial_state): h(initial_state_tensor, model)} | |
| f_score = {tuple(initial_state): lam * g_score[tuple(initial_state)] + h_score[tuple(initial_state)]} | |
| heapq.heappush(open_set, (f_score[tuple(initial_state)], tuple(initial_state))) | |
| came_from = {} | |
| iterations = 0 | |
| while open_set and iterations < max_iterations: | |
| iterations += 1 | |
| #print(f"当前迭代: {iterations}, 开放集大小: {len(open_set)}") | |
| current_states = [] | |
| for _ in range(min(N, len(open_set))): | |
| _, state_tuple = heapq.heappop(open_set) | |
| current_states.append(state_tuple) | |
| # 收集所有邻居状态 | |
| neighbor_states = [] | |
| for current_state_tuple in current_states: | |
| current_state = np.array(current_state_tuple) | |
| if np.array_equal(current_state, TARGET_STATE): | |
| action_path = [] | |
| state_path = [current_state] | |
| while current_state_tuple in came_from: | |
| current_state_tuple, action = came_from[current_state_tuple] | |
| action_path.append(action) | |
| state_path.append(current_state_tuple) | |
| return action_path[::-1], state_path[::-1] | |
| if current_state_tuple in closed_set: | |
| continue | |
| closed_set.add(current_state_tuple) | |
| for action in cube.moves.keys(): | |
| next_state = cube.apply_action(current_state, action) | |
| next_state_tuple = tuple(next_state) | |
| tentative_g_score = g_score[current_state_tuple] + 1 | |
| if next_state_tuple in closed_set: | |
| if tentative_g_score < g_score.get(next_state_tuple, float('inf')): | |
| closed_set.remove(next_state_tuple) | |
| else: | |
| continue | |
| if next_state_tuple not in g_score or tentative_g_score < g_score[next_state_tuple]: | |
| came_from[next_state_tuple] = (current_state_tuple, action) | |
| g_score[next_state_tuple] = tentative_g_score | |
| neighbor_states.append(next_state) | |
| if neighbor_states: | |
| neighbor_states = np.stack(neighbor_states) | |
| neighbor_states = np.unique(neighbor_states, axis=0) | |
| neighbor_states_tensor = torch.tensor(neighbor_states, device=next(model.parameters()).device).long() | |
| neighbor_states_tensor = torch.nn.functional.one_hot(neighbor_states_tensor, num_classes=6).float().view(-1, 324) | |
| neighbor_h_scores = h(neighbor_states_tensor, model) | |
| for i, state in enumerate(neighbor_states): | |
| state_tuple = tuple(state) | |
| h_score[state_tuple] = neighbor_h_scores[i] | |
| f_score[state_tuple] = lam * g_score[state_tuple] + neighbor_h_scores[i] | |
| heapq.heappush(open_set, (f_score[state_tuple], state_tuple)) | |
| return None, None # 未找到解决方案 |