Deepcube / dataset /dataloader.py
hanxiaofeng
first commit
b570cf2
import os
import torch
import numpy as np
from torch.utils.data import DataLoader, Dataset, random_split
from pytorch_lightning import LightningDataModule
from model.Cube import Cube, TARGET_STATE
class RubikDataset(Dataset):
def __init__(self, config, num_samples, is_train=True):
super().__init__()
self.config = config
self.num_samples = num_samples
self.is_train = is_train
self.cube = Cube()
self.K = config.K # 最大打乱次数
self.all_actions = list(self.cube.moves.keys())
def __len__(self):
return self.num_samples
def get_neighbors(self, state):
"""
获取给定状态的所有邻居状态
参数:
state: 当前魔方状态,np.array
返回:
所有邻居状态的列表
"""
return self.cube.get_neibor_state(state)
def __getitem__(self, idx):
# 随机选择打乱次数 i ∈ [1, K],其中50%概率为K,50%概率从[1, K-1]中均匀选择
if np.random.random() < 0.5 and self.is_train: # 训练时提高K次打乱的概率,加速收敛
i = self.K
else:
i = np.random.randint(1, self.K+1)
# 从初始状态开始,随机应用 i 次动作
state = TARGET_STATE.copy()
# 采样i个随机动作:
actions = np.random.choice(self.all_actions, size=i, replace=True)
for action in actions:
state = self.cube.apply_action(state, action)
# 获取所有邻居状态
neighbor_states = self.get_neighbors(state.copy())
# 返回包装成dict的数据
return {
'state': state, # 54
'steps': i,
'neighbors': neighbor_states # 12, 54
}
class RubikDataModule(LightningDataModule):
def __init__(self, config):
super().__init__()
self.config = config
self.batch_size = config.batch_size
self.num_workers = config.num_workers
self.num_train_samples = config.num_train_samples
self.num_val_samples = config.num_val_samples
def prepare_data(self):
# 不需要下载数据,数据集是自动生成的
pass
def setup(self, stage=None):
# 创建训练、验证数据集
self.train_dataset = RubikDataset(
self.config, self.num_train_samples, is_train=True
)
self.val_dataset = RubikDataset(
self.config, self.num_val_samples, is_train=False
)
def train_dataloader(self):
return DataLoader(
self.train_dataset,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
worker_init_fn=self._worker_init_fn
)
def val_dataloader(self):
return DataLoader(
self.val_dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
worker_init_fn=self._worker_init_fn
)
def _worker_init_fn(self, worker_id):
# 获取 worker 的初始种子(会随 epoch 变化)
worker_seed = (self.config.seed + worker_id + torch.initial_seed()) % 2**32
# 设置 numpy、torch、python random 的种子
np.random.seed(worker_seed)
torch.manual_seed(worker_seed)