| import numpy as np |
| import torch |
| from os.path import join as pjoin |
| from .humanml.utils.word_vectorizer import WordVectorizer |
| from .humanml.scripts.motion_process import (process_file, recover_from_ric) |
| from . import BASEDataModule |
| from .humanml import Text2MotionDatasetEval, Text2MotionDataset, Text2MotionDatasetCB, MotionDataset, MotionDatasetVQ, Text2MotionDatasetToken, Text2MotionDatasetM2T |
| from .utils import humanml3d_collate |
|
|
|
|
| class HumanML3DDataModule(BASEDataModule): |
| def __init__(self, cfg, **kwargs): |
|
|
| super().__init__(collate_fn=humanml3d_collate) |
| self.cfg = cfg |
| self.save_hyperparameters(logger=False) |
| |
| |
| cfg.DATASET.JOINT_TYPE = 'humanml3d' |
| self.name = "humanml3d" |
| self.njoints = 22 |
| |
| |
| data_root = cfg.DATASET.HUMANML3D.ROOT |
| self.hparams.data_root = data_root |
| self.hparams.text_dir = pjoin(data_root, "texts") |
| self.hparams.motion_dir = pjoin(data_root, 'new_joint_vecs') |
| |
| |
| self.hparams.mean = np.load(pjoin('assets/meta', "mean.npy")) |
| self.hparams.std = np.load(pjoin('assets/meta', "std.npy")) |
| |
| |
| self.hparams.mean_eval = np.load(pjoin('assets/meta', "mean_eval.npy")) |
| self.hparams.std_eval = np.load(pjoin('assets/meta', "std_eval.npy")) |
| |
| |
| self.hparams.max_motion_length = cfg.DATASET.HUMANML3D.MAX_MOTION_LEN |
| self.hparams.min_motion_length = cfg.DATASET.HUMANML3D.MIN_MOTION_LEN |
| self.hparams.max_text_len = cfg.DATASET.HUMANML3D.MAX_TEXT_LEN |
| self.hparams.unit_length = cfg.DATASET.HUMANML3D.UNIT_LEN |
|
|
| |
| self.hparams.debug = cfg.DEBUG |
| self.hparams.stage = cfg.TRAIN.STAGE |
|
|
| |
| self.DatasetEval = Text2MotionDatasetEval |
|
|
| if cfg.TRAIN.STAGE == "vae": |
| if cfg.model.params.motion_vae.target.split('.')[-1].lower() == "vqvae": |
| self.hparams.win_size = 64 |
| self.Dataset = MotionDatasetVQ |
| else: |
| self.Dataset = MotionDataset |
| elif 'lm' in cfg.TRAIN.STAGE: |
| self.hparams.code_path = cfg.DATASET.CODE_PATH |
| self.hparams.task_path = cfg.DATASET.TASK_PATH |
| self.hparams.std_text = cfg.DATASET.HUMANML3D.STD_TEXT |
| self.Dataset = Text2MotionDatasetCB |
| elif cfg.TRAIN.STAGE == "token": |
| self.Dataset = Text2MotionDatasetToken |
| self.DatasetEval = Text2MotionDatasetToken |
| elif cfg.TRAIN.STAGE == "m2t": |
| self.Dataset = Text2MotionDatasetM2T |
| self.DatasetEval = Text2MotionDatasetM2T |
| else: |
| self.Dataset = Text2MotionDataset |
|
|
| |
| self.nfeats = 263 |
| cfg.DATASET.NFEATS = self.nfeats |
| |
|
|
| def feats2joints(self, features): |
| mean = torch.tensor(self.hparams.mean).to(features) |
| std = torch.tensor(self.hparams.std).to(features) |
| features = features * std + mean |
| return recover_from_ric(features, self.njoints) |
|
|
| def joints2feats(self, features): |
| features = process_file(features, self.njoints)[0] |
| return features |
|
|
| def normalize(self, features): |
| mean = torch.tensor(self.hparams.mean).to(features) |
| std = torch.tensor(self.hparams.std).to(features) |
| features = (features - mean) / std |
| return features |
|
|
| def denormalize(self, features): |
| mean = torch.tensor(self.hparams.mean).to(features) |
| std = torch.tensor(self.hparams.std).to(features) |
| features = features * std + mean |
| return features |
|
|
| def renorm4t2m(self, features): |
| |
| ori_mean = torch.tensor(self.hparams.mean).to(features) |
| ori_std = torch.tensor(self.hparams.std).to(features) |
| eval_mean = torch.tensor(self.hparams.mean_eval).to(features) |
| eval_std = torch.tensor(self.hparams.std_eval).to(features) |
| features = features * ori_std + ori_mean |
| features = (features - eval_mean) / eval_std |
| return features |
|
|
| def mm_mode(self, mm_on=True): |
| if mm_on: |
| self.is_mm = True |
| self.name_list = self.test_dataset.name_list |
| self.mm_list = np.random.choice(self.name_list, |
| self.cfg.METRIC.MM_NUM_SAMPLES, |
| replace=False) |
| self.test_dataset.name_list = self.mm_list |
| else: |
| self.is_mm = False |
| self.test_dataset.name_list = self.name_list |
|
|