File size: 3,455 Bytes
b570cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning import LightningDataModule
from model.Cube import Cube, TARGET_STATE

class RubikDataset(Dataset):
    def __init__(self, config, num_samples, is_train=True):
        super().__init__()
        self.config = config
        self.num_samples = num_samples
        self.is_train = is_train
        self.cube = Cube()
        self.K = config.K  # 最大打乱次数
        self.all_actions = list(self.cube.moves.keys())
        
    def __len__(self):
        return self.num_samples
        
    def get_neighbors(self, state):
        """
        获取给定状态的所有邻居状态
        参数:
            state: 当前魔方状态,np.array
        返回:
            所有邻居状态的列表
        """
        return self.cube.get_neibor_state(state)
        
    def __getitem__(self, idx):
        # 随机选择打乱次数 i ∈ [1, K],其中50%概率为K,50%概率从[1, K-1]中均匀选择
        if np.random.random() < 0.5 and self.is_train: # 训练时提高K次打乱的概率,加速收敛
            i = self.K
        else:
            i = np.random.randint(1, self.K+1)
        
        # 从初始状态开始,随机应用 i 次动作
        state = TARGET_STATE.copy()
        # 采样i个随机动作:
        actions = np.random.choice(self.all_actions, size=i, replace=True)

        for action in actions:
            state = self.cube.apply_action(state, action)
        
        # 获取所有邻居状态
        neighbor_states = self.get_neighbors(state.copy())
        
        # 返回包装成dict的数据
        return {
            'state': state, # 54
            'steps': i, 
            'neighbors': neighbor_states # 12, 54
        }

class RubikDataModule(LightningDataModule):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.batch_size = config.batch_size
        self.num_workers = config.num_workers
        self.num_train_samples = config.num_train_samples
        self.num_val_samples = config.num_val_samples
        
    def prepare_data(self):
        # 不需要下载数据,数据集是自动生成的
        pass
        
    def setup(self, stage=None):
        # 创建训练、验证数据集
        self.train_dataset = RubikDataset(
            self.config, self.num_train_samples, is_train=True
        )
        self.val_dataset = RubikDataset(
            self.config, self.num_val_samples, is_train=False
        )
        
    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, 
            batch_size=self.batch_size, 
            shuffle=True, 
            num_workers=self.num_workers,
            worker_init_fn=self._worker_init_fn
        )
        
    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, 
            batch_size=self.batch_size, 
            shuffle=False, 
            num_workers=self.num_workers,
            worker_init_fn=self._worker_init_fn
        )
        
    def _worker_init_fn(self, worker_id):
        # 获取 worker 的初始种子(会随 epoch 变化)
        worker_seed = (self.config.seed + worker_id + torch.initial_seed()) % 2**32

        # 设置 numpy、torch、python random 的种子
        np.random.seed(worker_seed)
        torch.manual_seed(worker_seed)