| import random, torch |
| from mmengine.registry import DATASETS |
| from torch.utils.data import Dataset |
| import numpy as np |
| import joblib |
|
|
| @DATASETS.register_module() |
| class MotionDataset(Dataset): |
| def __init__(self, split_file='', window_size=-1, unit_length=4, mean_path='', std_path='', |
| min_motion_length=60, max_motion_length=300): |
| super().__init__() |
| self.min_motion_length = min_motion_length |
| self.max_motion_length = max_motion_length |
| self.window_size = window_size |
| self.unit_length = unit_length |
|
|
| self.motions = joblib.load(split_file) |
| self.mean = torch.from_numpy(np.load(mean_path)) |
| self.std = torch.from_numpy(np.load(std_path)) |
|
|
| def __len__(self): |
| return len(self.motions) |
| |
| def __getitem__(self, index): |
| motion = torch.from_numpy(self.motions[index]) |
| motion = (motion - self.mean) / self.std |
| |
| T = motion.shape[0] |
| if self.window_size > 0: |
| if T < self.window_size: |
| motion = torch.cat([motion, motion[-1:].repeat(self.window_size - T, 1)], dim=0) |
| T = motion.shape[0] |
| idx = 0 if T == self.window_size else random.randint(0, T - self.window_size) |
| motion = motion[idx:idx+self.window_size] |
| motion_length = self.window_size |
| else: |
| motion_length = (T // self.unit_length) * self.unit_length |
| if motion_length > self.max_motion_length: |
| motion_length = self.max_motion_length |
| idx = random.randint(0, T - motion_length) |
| motion = motion[idx:idx + motion_length] |
| |
| return dict( |
| motion=motion, |
| motion_length=motion_length, |
| mean=self.mean, |
| std=self.std, |
| |
| ) |
|
|
| if __name__ == '__main__': |
| from tqdm import tqdm |
| import os.path as osp |
| data_root = '/mnt/shenzhen2cephfs/capybarali/codes/humanoid/data/demo_data' |
| for split in ['train', 'val']: |
| split_file = osp.join(data_root, f'{split}.txt') |
| all_data = [] |
| with open(split_file, 'r') as f: |
| lines = f.readlines() |
| for line in tqdm(lines): |
| data = dict(np.load(osp.join(data_root, 'motion', line.strip()))) |
| all_data.append(data) |
| joblib.dump(all_data, osp.join(data_root, f'{split}.pkl')) |
|
|