Spaces:
Sleeping
Sleeping
File size: 7,332 Bytes
b570cf2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from model.DNN import DNN
from model.Cube import TARGET_STATE_ONE_HOT
class RelativeMSELoss(nn.Module):
def forward(self, pred, target):
return torch.mean(((pred - target) / (target + 1e-8)) ** 2)
# 判断A中是否有和b相同的张量
def row_allclose_mask(A, b, rtol=1e-4, atol=1e-6):
# 计算逐元素误差
diff = torch.abs(A - b) # (B, D)
tol = atol + rtol * torch.abs(b) # (D,), 广播自动扩展到 (B, D)
# 满足误差条件的元素掩码
mask_elements = diff <= tol # (B, D), bool
# 判断每行是否所有元素都满足条件
mask_rows = mask_elements.all(dim=1) # (B,)
return mask_rows
class DeepcubeA(LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.learning_rate = config.learning_rate
self.weight_decay = config.weight_decay
self.convergence_threshold = config.convergence_threshold
self.chunk_size = config.chunk_size
self.converged_checkpoint_dir = config.converged_checkpoint_dir
self.compile = config.compile
# 输入维度(54个贴纸,每个有6种可能的颜色,使用one-hot编码)
self.input_dim = 54 * 6
self.model_theta = DNN(self.input_dim, num_residual_blocks=4) # 训练模型
self.model_theta_e = DNN(self.input_dim, num_residual_blocks=4).eval() # 监督模型
self.target_state = torch.tensor(TARGET_STATE_ONE_HOT, dtype=torch.float32).reshape(1, -1)
if self.compile:
self.model_theta = torch.compile(self.model_theta)
self.model_theta_e = torch.compile(self.model_theta_e)
self.K = 1
# 损失函数
self.criterion = nn.MSELoss()
# 保存超参数
self.save_hyperparameters(config)
def transfer_batch_to_tensor(self, batch):
"""
批量将batch中的数据转移到tensor并移动到正确的设备上
参数:
batch: 输入的batch数据
返回:
处理后的batch字典,包含tensor格式的数据
"""
batch_dict = {}
for key, value in batch.items():
if isinstance(value, torch.Tensor):
batch_dict[key] = value.to(self.device)
else:
batch_dict[key] = torch.tensor(value, device=self.device)
return batch_dict
def forward(self, x):
return self.model_theta(x)
def model_step(self, batch):
# 从batch中获取状态和邻居
batch_dict = self.transfer_batch_to_tensor(batch)
states = batch_dict['state']
neighbor_states = batch_dict['neighbors']
B, N, D = neighbor_states.shape
states = F.one_hot(states.long(), num_classes=6).float().view(B, -1)
neighbor_states = F.one_hot(neighbor_states.long(), num_classes=6).float().view(B*N, -1)
# 分块预测以避免显存不足
num_chunks = (B * N + self.chunk_size - 1) // self.chunk_size
chunked_neighbors = torch.chunk(neighbor_states, num_chunks, dim=0)
with torch.no_grad():
neighbor_costs = []
for chunk in chunked_neighbors:
mask = row_allclose_mask(chunk, self.target_state.to(chunk.device))
cost = self.model_theta_e(chunk)
cost[mask] = 0.0
neighbor_costs.append(cost)
# 聚合结果
neighbor_costs = torch.cat(neighbor_costs, dim=0)
neighbor_costs = neighbor_costs.view(B, N)
# 计算min[J_theta_e(A(x_i, a)) + 1]
min_neighbor_cost = neighbor_costs.abs().min(dim=1)[0] + 1
# 使用model_theta预测当前状态的cost
current_cost = self.model_theta(states)
# 总是计算损失
loss = self.criterion(current_cost.squeeze(), min_neighbor_cost)
return loss, current_cost
def training_step(self, batch, batch_idx):
# 调用model_step获取损失
loss, _ = self.model_step(batch)
# 记录指标
self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
return loss
def on_validation_epoch_end(self):
# 获取验证损失
val_loss = self.trainer.callback_metrics.get('val_loss')
if val_loss is not None and val_loss < self.convergence_threshold:
self.log('converged', True)
# 保存模型参数到专门的收敛模型目录
import os
os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"converged_model_K_{self.K}.pth")
torch.save(self.model_theta.state_dict(), checkpoint_path)
print(f'模型已保存到 {checkpoint_path}')
# 如果收敛,更新model_theta_e
self.model_theta_e.load_state_dict(self.model_theta.state_dict())
# 原文中没有找到上一轮训练的模型下一轮是否要继承参数,这里选择完全继承上一轮的参数,因为从头训练开销太大
# self.model_theta = DNN(self.input_dim, num_residual_blocks=4)
# if self.compile:
# self.model_theta = torch.compile(self.model_theta)
# 停止训练
self.trainer.should_stop = True
def on_train_end(self):
# 检查训练是否正常结束(非early stopping)
# 只有当训练不是因为converged而停止时,才执行保存操作
if not self.trainer.callback_metrics.get('converged', False):
# 获取最后一个epoch的验证损失
val_loss = self.trainer.callback_metrics.get('val_loss')
# 保存模型参数到专门的收敛模型目录
import os
os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"final_model_K_{self.K}.pth")
torch.save(self.model_theta.state_dict(), checkpoint_path)
print(f'训练结束,模型已保存到 {checkpoint_path}')
# 更新model_theta_e
self.model_theta_e.load_state_dict(self.model_theta.state_dict())
def validation_step(self, batch, batch_idx):
# 计算验证损失
loss, current_cost = self.model_step(batch)
self.log('val_loss', loss, on_epoch=True, prog_bar=True)
self.log('val_cost', current_cost.mean(), on_epoch=True)
return loss
def configure_optimizers(self):
optimizer = optim.AdamW(
self.model_theta.parameters(),
lr=self.learning_rate,
weight_decay=self.weight_decay
)
return {'optimizer': optimizer}
def load_state_dict_theta_e(self, checkpoint_path):
state_dict = torch.load(checkpoint_path)
self.model_theta_e.load_state_dict(state_dict)
self.model_theta_e.zero_output = False |