import random, torch from mmengine.registry import DATASETS from torch.utils.data import Dataset import numpy as np import joblib @DATASETS.register_module() class MotionDataset(Dataset): def __init__(self, split_file='', window_size=-1, unit_length=4, mean_path='', std_path='', min_motion_length=60, max_motion_length=300): super().__init__() self.min_motion_length = min_motion_length self.max_motion_length = max_motion_length self.window_size = window_size self.unit_length = unit_length self.motions = joblib.load(split_file) self.mean = torch.from_numpy(np.load(mean_path)) self.std = torch.from_numpy(np.load(std_path)) def __len__(self): return len(self.motions) def __getitem__(self, index): motion = torch.from_numpy(self.motions[index]) motion = (motion - self.mean) / self.std T = motion.shape[0] if self.window_size > 0: if T < self.window_size: motion = torch.cat([motion, motion[-1:].repeat(self.window_size - T, 1)], dim=0) T = motion.shape[0] idx = 0 if T == self.window_size else random.randint(0, T - self.window_size) motion = motion[idx:idx+self.window_size] motion_length = self.window_size else: motion_length = (T // self.unit_length) * self.unit_length if motion_length > self.max_motion_length: motion_length = self.max_motion_length idx = random.randint(0, T - motion_length) motion = motion[idx:idx + motion_length] return dict( motion=motion, motion_length=motion_length, mean=self.mean, std=self.std, # caption=self.motions[index]['caption'].item() ) if __name__ == '__main__': from tqdm import tqdm import os.path as osp data_root = '/mnt/shenzhen2cephfs/capybarali/codes/humanoid/data/demo_data' for split in ['train', 'val']: split_file = osp.join(data_root, f'{split}.txt') all_data = [] with open(split_file, 'r') as f: lines = f.readlines() for line in tqdm(lines): data = dict(np.load(osp.join(data_root, 'motion', line.strip()))) all_data.append(data) joblib.dump(all_data, osp.join(data_root, f'{split}.pkl'))