Deepcube / model /DeepcubeA_module.py
hanxiaofeng
first commit
b570cf2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from model.DNN import DNN
from model.Cube import TARGET_STATE_ONE_HOT
class RelativeMSELoss(nn.Module):
def forward(self, pred, target):
return torch.mean(((pred - target) / (target + 1e-8)) ** 2)
# 判断A中是否有和b相同的张量
def row_allclose_mask(A, b, rtol=1e-4, atol=1e-6):
# 计算逐元素误差
diff = torch.abs(A - b) # (B, D)
tol = atol + rtol * torch.abs(b) # (D,), 广播自动扩展到 (B, D)
# 满足误差条件的元素掩码
mask_elements = diff <= tol # (B, D), bool
# 判断每行是否所有元素都满足条件
mask_rows = mask_elements.all(dim=1) # (B,)
return mask_rows
class DeepcubeA(LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.learning_rate = config.learning_rate
self.weight_decay = config.weight_decay
self.convergence_threshold = config.convergence_threshold
self.chunk_size = config.chunk_size
self.converged_checkpoint_dir = config.converged_checkpoint_dir
self.compile = config.compile
# 输入维度(54个贴纸,每个有6种可能的颜色,使用one-hot编码)
self.input_dim = 54 * 6
self.model_theta = DNN(self.input_dim, num_residual_blocks=4) # 训练模型
self.model_theta_e = DNN(self.input_dim, num_residual_blocks=4).eval() # 监督模型
self.target_state = torch.tensor(TARGET_STATE_ONE_HOT, dtype=torch.float32).reshape(1, -1)
if self.compile:
self.model_theta = torch.compile(self.model_theta)
self.model_theta_e = torch.compile(self.model_theta_e)
self.K = 1
# 损失函数
self.criterion = nn.MSELoss()
# 保存超参数
self.save_hyperparameters(config)
def transfer_batch_to_tensor(self, batch):
"""
批量将batch中的数据转移到tensor并移动到正确的设备上
参数:
batch: 输入的batch数据
返回:
处理后的batch字典,包含tensor格式的数据
"""
batch_dict = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch_dict[key] = value.to(self.device)
else:
batch_dict[key] = torch.tensor(value, device=self.device)
return batch_dict
def forward(self, x):
return self.model_theta(x)
def model_step(self, batch):
# 从batch中获取状态和邻居
batch_dict = self.transfer_batch_to_tensor(batch)
states = batch_dict['state']
neighbor_states = batch_dict['neighbors']
B, N, D = neighbor_states.shape
states = F.one_hot(states.long(), num_classes=6).float().view(B, -1)
neighbor_states = F.one_hot(neighbor_states.long(), num_classes=6).float().view(B*N, -1)
# 分块预测以避免显存不足
num_chunks = (B * N + self.chunk_size - 1) // self.chunk_size
chunked_neighbors = torch.chunk(neighbor_states, num_chunks, dim=0)
with torch.no_grad():
neighbor_costs = []
for chunk in chunked_neighbors:
mask = row_allclose_mask(chunk, self.target_state.to(chunk.device))
cost = self.model_theta_e(chunk)
cost[mask] = 0.0
neighbor_costs.append(cost)
# 聚合结果
neighbor_costs = torch.cat(neighbor_costs, dim=0)
neighbor_costs = neighbor_costs.view(B, N)
# 计算min[J_theta_e(A(x_i, a)) + 1]
min_neighbor_cost = neighbor_costs.abs().min(dim=1)[0] + 1
# 使用model_theta预测当前状态的cost
current_cost = self.model_theta(states)
# 总是计算损失
loss = self.criterion(current_cost.squeeze(), min_neighbor_cost)
return loss, current_cost
def training_step(self, batch, batch_idx):
# 调用model_step获取损失
loss, _ = self.model_step(batch)
# 记录指标
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def on_validation_epoch_end(self):
# 获取验证损失
val_loss = self.trainer.callback_metrics.get('val_loss')
if val_loss is not None and val_loss < self.convergence_threshold:
self.log('converged', True)
# 保存模型参数到专门的收敛模型目录
import os
os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"converged_model_K_{self.K}.pth")
torch.save(self.model_theta.state_dict(), checkpoint_path)
print(f'模型已保存到 {checkpoint_path}')
# 如果收敛,更新model_theta_e
self.model_theta_e.load_state_dict(self.model_theta.state_dict())
# 原文中没有找到上一轮训练的模型下一轮是否要继承参数,这里选择完全继承上一轮的参数,因为从头训练开销太大
# self.model_theta = DNN(self.input_dim, num_residual_blocks=4)
# if self.compile:
# self.model_theta = torch.compile(self.model_theta)
# 停止训练
self.trainer.should_stop = True
def on_train_end(self):
# 检查训练是否正常结束(非early stopping)
# 只有当训练不是因为converged而停止时,才执行保存操作
if not self.trainer.callback_metrics.get('converged', False):
# 获取最后一个epoch的验证损失
val_loss = self.trainer.callback_metrics.get('val_loss')
# 保存模型参数到专门的收敛模型目录
import os
os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"final_model_K_{self.K}.pth")
torch.save(self.model_theta.state_dict(), checkpoint_path)
print(f'训练结束,模型已保存到 {checkpoint_path}')
# 更新model_theta_e
self.model_theta_e.load_state_dict(self.model_theta.state_dict())
def validation_step(self, batch, batch_idx):
# 计算验证损失
loss, current_cost = self.model_step(batch)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
self.log('val_cost', current_cost.mean(), on_epoch=True)
return loss
def configure_optimizers(self):
optimizer = optim.AdamW(
self.model_theta.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
return {'optimizer': optimizer}
def load_state_dict_theta_e(self, checkpoint_path):
state_dict = torch.load(checkpoint_path)
self.model_theta_e.load_state_dict(state_dict)
self.model_theta_e.zero_output = False