NMR / src /datasets /motion_dataset.py
Xxx999's picture
upload
45950ff
import random, torch
from mmengine.registry import DATASETS
from torch.utils.data import Dataset
import numpy as np
import joblib
@DATASETS.register_module()
class MotionDataset(Dataset):
def __init__(self, split_file='', window_size=-1, unit_length=4, mean_path='', std_path='',
min_motion_length=60, max_motion_length=300):
super().__init__()
self.min_motion_length = min_motion_length
self.max_motion_length = max_motion_length
self.window_size = window_size
self.unit_length = unit_length
self.motions = joblib.load(split_file)
self.mean = torch.from_numpy(np.load(mean_path))
self.std = torch.from_numpy(np.load(std_path))
def __len__(self):
return len(self.motions)
def __getitem__(self, index):
motion = torch.from_numpy(self.motions[index])
motion = (motion - self.mean) / self.std
T = motion.shape[0]
if self.window_size > 0:
if T < self.window_size:
motion = torch.cat([motion, motion[-1:].repeat(self.window_size - T, 1)], dim=0)
T = motion.shape[0]
idx = 0 if T == self.window_size else random.randint(0, T - self.window_size)
motion = motion[idx:idx+self.window_size]
motion_length = self.window_size
else:
motion_length = (T // self.unit_length) * self.unit_length
if motion_length > self.max_motion_length:
motion_length = self.max_motion_length
idx = random.randint(0, T - motion_length)
motion = motion[idx:idx + motion_length]
return dict(
motion=motion,
motion_length=motion_length,
mean=self.mean,
std=self.std,
# caption=self.motions[index]['caption'].item()
)
if __name__ == '__main__':
from tqdm import tqdm
import os.path as osp
data_root = '/mnt/shenzhen2cephfs/capybarali/codes/humanoid/data/demo_data'
for split in ['train', 'val']:
split_file = osp.join(data_root, f'{split}.txt')
all_data = []
with open(split_file, 'r') as f:
lines = f.readlines()
for line in tqdm(lines):
data = dict(np.load(osp.join(data_root, 'motion', line.strip())))
all_data.append(data)
joblib.dump(all_data, osp.join(data_root, f'{split}.pkl'))