Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.optim as optim | |
| from pytorch_lightning import LightningModule | |
| from model.DNN import DNN | |
| from model.Cube import TARGET_STATE_ONE_HOT | |
| class RelativeMSELoss(nn.Module): | |
| def forward(self, pred, target): | |
| return torch.mean(((pred - target) / (target + 1e-8)) ** 2) | |
| # 判断A中是否有和b相同的张量 | |
| def row_allclose_mask(A, b, rtol=1e-4, atol=1e-6): | |
| # 计算逐元素误差 | |
| diff = torch.abs(A - b) # (B, D) | |
| tol = atol + rtol * torch.abs(b) # (D,), 广播自动扩展到 (B, D) | |
| # 满足误差条件的元素掩码 | |
| mask_elements = diff <= tol # (B, D), bool | |
| # 判断每行是否所有元素都满足条件 | |
| mask_rows = mask_elements.all(dim=1) # (B,) | |
| return mask_rows | |
| class DeepcubeA(LightningModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.learning_rate = config.learning_rate | |
| self.weight_decay = config.weight_decay | |
| self.convergence_threshold = config.convergence_threshold | |
| self.chunk_size = config.chunk_size | |
| self.converged_checkpoint_dir = config.converged_checkpoint_dir | |
| self.compile = config.compile | |
| # 输入维度(54个贴纸,每个有6种可能的颜色,使用one-hot编码) | |
| self.input_dim = 54 * 6 | |
| self.model_theta = DNN(self.input_dim, num_residual_blocks=4) # 训练模型 | |
| self.model_theta_e = DNN(self.input_dim, num_residual_blocks=4).eval() # 监督模型 | |
| self.target_state = torch.tensor(TARGET_STATE_ONE_HOT, dtype=torch.float32).reshape(1, -1) | |
| if self.compile: | |
| self.model_theta = torch.compile(self.model_theta) | |
| self.model_theta_e = torch.compile(self.model_theta_e) | |
| self.K = 1 | |
| # 损失函数 | |
| self.criterion = nn.MSELoss() | |
| # 保存超参数 | |
| self.save_hyperparameters(config) | |
| def transfer_batch_to_tensor(self, batch): | |
| """ | |
| 批量将batch中的数据转移到tensor并移动到正确的设备上 | |
| 参数: | |
| batch: 输入的batch数据 | |
| 返回: | |
| 处理后的batch字典,包含tensor格式的数据 | |
| """ | |
| batch_dict = {} | |
| for key, value in batch.items(): | |
| if isinstance(value, torch.Tensor): | |
| batch_dict[key] = value.to(self.device) | |
| else: | |
| batch_dict[key] = torch.tensor(value, device=self.device) | |
| return batch_dict | |
| def forward(self, x): | |
| return self.model_theta(x) | |
| def model_step(self, batch): | |
| # 从batch中获取状态和邻居 | |
| batch_dict = self.transfer_batch_to_tensor(batch) | |
| states = batch_dict['state'] | |
| neighbor_states = batch_dict['neighbors'] | |
| B, N, D = neighbor_states.shape | |
| states = F.one_hot(states.long(), num_classes=6).float().view(B, -1) | |
| neighbor_states = F.one_hot(neighbor_states.long(), num_classes=6).float().view(B*N, -1) | |
| # 分块预测以避免显存不足 | |
| num_chunks = (B * N + self.chunk_size - 1) // self.chunk_size | |
| chunked_neighbors = torch.chunk(neighbor_states, num_chunks, dim=0) | |
| with torch.no_grad(): | |
| neighbor_costs = [] | |
| for chunk in chunked_neighbors: | |
| mask = row_allclose_mask(chunk, self.target_state.to(chunk.device)) | |
| cost = self.model_theta_e(chunk) | |
| cost[mask] = 0.0 | |
| neighbor_costs.append(cost) | |
| # 聚合结果 | |
| neighbor_costs = torch.cat(neighbor_costs, dim=0) | |
| neighbor_costs = neighbor_costs.view(B, N) | |
| # 计算min[J_theta_e(A(x_i, a)) + 1] | |
| min_neighbor_cost = neighbor_costs.abs().min(dim=1)[0] + 1 | |
| # 使用model_theta预测当前状态的cost | |
| current_cost = self.model_theta(states) | |
| # 总是计算损失 | |
| loss = self.criterion(current_cost.squeeze(), min_neighbor_cost) | |
| return loss, current_cost | |
| def training_step(self, batch, batch_idx): | |
| # 调用model_step获取损失 | |
| loss, _ = self.model_step(batch) | |
| # 记录指标 | |
| self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True) | |
| return loss | |
| def on_validation_epoch_end(self): | |
| # 获取验证损失 | |
| val_loss = self.trainer.callback_metrics.get('val_loss') | |
| if val_loss is not None and val_loss < self.convergence_threshold: | |
| self.log('converged', True) | |
| # 保存模型参数到专门的收敛模型目录 | |
| import os | |
| os.makedirs(self.converged_checkpoint_dir, exist_ok=True) | |
| checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"converged_model_K_{self.K}.pth") | |
| torch.save(self.model_theta.state_dict(), checkpoint_path) | |
| print(f'模型已保存到 {checkpoint_path}') | |
| # 如果收敛,更新model_theta_e | |
| self.model_theta_e.load_state_dict(self.model_theta.state_dict()) | |
| # 原文中没有找到上一轮训练的模型下一轮是否要继承参数,这里选择完全继承上一轮的参数,因为从头训练开销太大 | |
| # self.model_theta = DNN(self.input_dim, num_residual_blocks=4) | |
| # if self.compile: | |
| # self.model_theta = torch.compile(self.model_theta) | |
| # 停止训练 | |
| self.trainer.should_stop = True | |
| def on_train_end(self): | |
| # 检查训练是否正常结束(非early stopping) | |
| # 只有当训练不是因为converged而停止时,才执行保存操作 | |
| if not self.trainer.callback_metrics.get('converged', False): | |
| # 获取最后一个epoch的验证损失 | |
| val_loss = self.trainer.callback_metrics.get('val_loss') | |
| # 保存模型参数到专门的收敛模型目录 | |
| import os | |
| os.makedirs(self.converged_checkpoint_dir, exist_ok=True) | |
| checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"final_model_K_{self.K}.pth") | |
| torch.save(self.model_theta.state_dict(), checkpoint_path) | |
| print(f'训练结束,模型已保存到 {checkpoint_path}') | |
| # 更新model_theta_e | |
| self.model_theta_e.load_state_dict(self.model_theta.state_dict()) | |
| def validation_step(self, batch, batch_idx): | |
| # 计算验证损失 | |
| loss, current_cost = self.model_step(batch) | |
| self.log('val_loss', loss, on_epoch=True, prog_bar=True) | |
| self.log('val_cost', current_cost.mean(), on_epoch=True) | |
| return loss | |
| def configure_optimizers(self): | |
| optimizer = optim.AdamW( | |
| self.model_theta.parameters(), | |
| lr=self.learning_rate, | |
| weight_decay=self.weight_decay | |
| ) | |
| return {'optimizer': optimizer} | |
| def load_state_dict_theta_e(self, checkpoint_path): | |
| state_dict = torch.load(checkpoint_path) | |
| self.model_theta_e.load_state_dict(state_dict) | |
| self.model_theta_e.zero_output = False |