| import random |
| import codecs as cs |
| import numpy as np |
| from torch.utils import data |
| from rich.progress import track |
| from os.path import join as pjoin |
| from .dataset_m import MotionDataset |
| from .dataset_t2m import Text2MotionDataset |
|
|
|
|
| class MotionDatasetVQ(Text2MotionDataset): |
| def __init__( |
| self, |
| data_root, |
| split, |
| mean, |
| std, |
| max_motion_length, |
| min_motion_length, |
| win_size, |
| unit_length=4, |
| fps=20, |
| tmpFile=True, |
| tiny=False, |
| debug=False, |
| **kwargs, |
| ): |
| super().__init__(data_root, split, mean, std, max_motion_length, |
| min_motion_length, unit_length, fps, tmpFile, tiny, |
| debug, **kwargs) |
|
|
| |
| self.window_size = win_size |
| name_list = list(self.name_list) |
| for name in self.name_list: |
| motion = self.data_dict[name]["motion"] |
| if motion.shape[0] < self.window_size: |
| name_list.remove(name) |
| self.data_dict.pop(name) |
| self.name_list = name_list |
|
|
| def __len__(self): |
| return len(self.name_list) |
|
|
| def __getitem__(self, item): |
| idx = self.pointer + item |
| data = self.data_dict[self.name_list[idx]] |
| motion, length = data["motion"], data["length"] |
|
|
| idx = random.randint(0, motion.shape[0] - self.window_size) |
| motion = motion[idx:idx + self.window_size] |
| motion = (motion - self.mean) / self.std |
|
|
| return None, motion, length, None, None, None, None, |
|
|