Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import numpy as np | |
| from torch.utils.data import DataLoader, Dataset, random_split | |
| from pytorch_lightning import LightningDataModule | |
| from model.Cube import Cube, TARGET_STATE | |
| class RubikDataset(Dataset): | |
| def __init__(self, config, num_samples, is_train=True): | |
| super().__init__() | |
| self.config = config | |
| self.num_samples = num_samples | |
| self.is_train = is_train | |
| self.cube = Cube() | |
| self.K = config.K # 最大打乱次数 | |
| self.all_actions = list(self.cube.moves.keys()) | |
| def __len__(self): | |
| return self.num_samples | |
| def get_neighbors(self, state): | |
| """ | |
| 获取给定状态的所有邻居状态 | |
| 参数: | |
| state: 当前魔方状态,np.array | |
| 返回: | |
| 所有邻居状态的列表 | |
| """ | |
| return self.cube.get_neibor_state(state) | |
| def __getitem__(self, idx): | |
| # 随机选择打乱次数 i ∈ [1, K],其中50%概率为K,50%概率从[1, K-1]中均匀选择 | |
| if np.random.random() < 0.5 and self.is_train: # 训练时提高K次打乱的概率,加速收敛 | |
| i = self.K | |
| else: | |
| i = np.random.randint(1, self.K+1) | |
| # 从初始状态开始,随机应用 i 次动作 | |
| state = TARGET_STATE.copy() | |
| # 采样i个随机动作: | |
| actions = np.random.choice(self.all_actions, size=i, replace=True) | |
| for action in actions: | |
| state = self.cube.apply_action(state, action) | |
| # 获取所有邻居状态 | |
| neighbor_states = self.get_neighbors(state.copy()) | |
| # 返回包装成dict的数据 | |
| return { | |
| 'state': state, # 54 | |
| 'steps': i, | |
| 'neighbors': neighbor_states # 12, 54 | |
| } | |
| class RubikDataModule(LightningDataModule): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.batch_size = config.batch_size | |
| self.num_workers = config.num_workers | |
| self.num_train_samples = config.num_train_samples | |
| self.num_val_samples = config.num_val_samples | |
| def prepare_data(self): | |
| # 不需要下载数据,数据集是自动生成的 | |
| pass | |
| def setup(self, stage=None): | |
| # 创建训练、验证数据集 | |
| self.train_dataset = RubikDataset( | |
| self.config, self.num_train_samples, is_train=True | |
| ) | |
| self.val_dataset = RubikDataset( | |
| self.config, self.num_val_samples, is_train=False | |
| ) | |
| def train_dataloader(self): | |
| return DataLoader( | |
| self.train_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=True, | |
| num_workers=self.num_workers, | |
| worker_init_fn=self._worker_init_fn | |
| ) | |
| def val_dataloader(self): | |
| return DataLoader( | |
| self.val_dataset, | |
| batch_size=self.batch_size, | |
| shuffle=False, | |
| num_workers=self.num_workers, | |
| worker_init_fn=self._worker_init_fn | |
| ) | |
| def _worker_init_fn(self, worker_id): | |
| # 获取 worker 的初始种子(会随 epoch 变化) | |
| worker_seed = (self.config.seed + worker_id + torch.initial_seed()) % 2**32 | |
| # 设置 numpy、torch、python random 的种子 | |
| np.random.seed(worker_seed) | |
| torch.manual_seed(worker_seed) | |