File size: 7,332 Bytes
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from pytorch_lightning import LightningModule
from model.DNN import DNN
from model.Cube import TARGET_STATE_ONE_HOT


class RelativeMSELoss(nn.Module):
    def forward(self, pred, target):
        return torch.mean(((pred - target) / (target + 1e-8)) ** 2)

# 判断A中是否有和b相同的张量
def row_allclose_mask(A, b, rtol=1e-4, atol=1e-6):
    # 计算逐元素误差
    diff = torch.abs(A - b)  # (B, D)
    tol = atol + rtol * torch.abs(b)  # (D,), 广播自动扩展到 (B, D)
    
    # 满足误差条件的元素掩码
    mask_elements = diff <= tol  # (B, D), bool
    
    # 判断每行是否所有元素都满足条件
    mask_rows = mask_elements.all(dim=1)  # (B,)
    
    return mask_rows


class DeepcubeA(LightningModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.learning_rate = config.learning_rate
        self.weight_decay = config.weight_decay
        self.convergence_threshold = config.convergence_threshold
        self.chunk_size = config.chunk_size
        self.converged_checkpoint_dir = config.converged_checkpoint_dir
        self.compile = config.compile
                
        # 输入维度(54个贴纸,每个有6种可能的颜色,使用one-hot编码)
        self.input_dim = 54 * 6
        
        self.model_theta = DNN(self.input_dim, num_residual_blocks=4)  # 训练模型
        self.model_theta_e = DNN(self.input_dim, num_residual_blocks=4).eval()  # 监督模型

        self.target_state = torch.tensor(TARGET_STATE_ONE_HOT, dtype=torch.float32).reshape(1, -1)

        if self.compile:
            self.model_theta = torch.compile(self.model_theta)
            self.model_theta_e = torch.compile(self.model_theta_e)

        self.K = 1
        
        # 损失函数
        self.criterion = nn.MSELoss()
        
        # 保存超参数
        self.save_hyperparameters(config)

    def transfer_batch_to_tensor(self, batch):
        """
        批量将batch中的数据转移到tensor并移动到正确的设备上
        参数:
            batch: 输入的batch数据
        返回:
            处理后的batch字典,包含tensor格式的数据
        """
        batch_dict = {}
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                batch_dict[key] = value.to(self.device)
            else:
                batch_dict[key] = torch.tensor(value, device=self.device)
        return batch_dict
        
    def forward(self, x):
        return self.model_theta(x)
        
    def model_step(self, batch):
        # 从batch中获取状态和邻居
        batch_dict = self.transfer_batch_to_tensor(batch)
        states = batch_dict['state']
        neighbor_states = batch_dict['neighbors']

        B, N, D = neighbor_states.shape

        states = F.one_hot(states.long(), num_classes=6).float().view(B, -1)
        neighbor_states = F.one_hot(neighbor_states.long(), num_classes=6).float().view(B*N, -1)
        
        # 分块预测以避免显存不足
        num_chunks = (B * N + self.chunk_size - 1) // self.chunk_size
        chunked_neighbors = torch.chunk(neighbor_states, num_chunks, dim=0)
        
        with torch.no_grad():
            neighbor_costs = []
            for chunk in chunked_neighbors:
                mask = row_allclose_mask(chunk, self.target_state.to(chunk.device))
                cost = self.model_theta_e(chunk)
                cost[mask] = 0.0
                neighbor_costs.append(cost)
            
            # 聚合结果
            neighbor_costs = torch.cat(neighbor_costs, dim=0)
            neighbor_costs = neighbor_costs.view(B, N)
        
        # 计算min[J_theta_e(A(x_i, a)) + 1]
        min_neighbor_cost = neighbor_costs.abs().min(dim=1)[0] + 1
        
        # 使用model_theta预测当前状态的cost
        current_cost = self.model_theta(states)
        
        # 总是计算损失
        loss = self.criterion(current_cost.squeeze(), min_neighbor_cost)
        return loss, current_cost
        
    def training_step(self, batch, batch_idx):
        # 调用model_step获取损失
        loss, _ = self.model_step(batch)
        
        # 记录指标
        self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
        
    def on_validation_epoch_end(self):
        # 获取验证损失
        val_loss = self.trainer.callback_metrics.get('val_loss')
        
        if val_loss is not None and val_loss < self.convergence_threshold:
            self.log('converged', True)
            
            # 保存模型参数到专门的收敛模型目录
            import os
            os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"converged_model_K_{self.K}.pth")
            torch.save(self.model_theta.state_dict(), checkpoint_path)
            print(f'模型已保存到 {checkpoint_path}')

            # 如果收敛,更新model_theta_e
            self.model_theta_e.load_state_dict(self.model_theta.state_dict())
            
            # 原文中没有找到上一轮训练的模型下一轮是否要继承参数,这里选择完全继承上一轮的参数,因为从头训练开销太大
            # self.model_theta = DNN(self.input_dim, num_residual_blocks=4)
            # if self.compile:
            #     self.model_theta = torch.compile(self.model_theta)
            
            # 停止训练
            self.trainer.should_stop = True

    def on_train_end(self):
        # 检查训练是否正常结束(非early stopping)
        # 只有当训练不是因为converged而停止时,才执行保存操作
        if not self.trainer.callback_metrics.get('converged', False):
            # 获取最后一个epoch的验证损失
            val_loss = self.trainer.callback_metrics.get('val_loss')
            
            # 保存模型参数到专门的收敛模型目录
            import os
            os.makedirs(self.converged_checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(self.converged_checkpoint_dir, f"final_model_K_{self.K}.pth")
            torch.save(self.model_theta.state_dict(), checkpoint_path)
            print(f'训练结束,模型已保存到 {checkpoint_path}')

            # 更新model_theta_e
            self.model_theta_e.load_state_dict(self.model_theta.state_dict())
        
    def validation_step(self, batch, batch_idx):
        # 计算验证损失
        loss, current_cost = self.model_step(batch)
        self.log('val_loss', loss, on_epoch=True, prog_bar=True)
        self.log('val_cost', current_cost.mean(), on_epoch=True)
        return loss
        
    def configure_optimizers(self):
        optimizer = optim.AdamW(
            self.model_theta.parameters(),
            lr=self.learning_rate,
            weight_decay=self.weight_decay
        )
        
        return {'optimizer': optimizer}

    def load_state_dict_theta_e(self, checkpoint_path):
        state_dict = torch.load(checkpoint_path)
        self.model_theta_e.load_state_dict(state_dict)
        self.model_theta_e.zero_output = False