| from typing import Optional |
| from os.path import join as pjoin |
|
|
| import numpy as np |
|
|
| from omegaconf import DictConfig |
|
|
| from .data import DataModule |
| from .base import BaseDataModule |
| from .utils import mld_collate, mld_collate_motion_only |
| from .humanml.utils.word_vectorizer import WordVectorizer |
|
|
|
|
| def get_mean_std(phase: str, cfg: DictConfig, dataset_name: str) -> tuple[np.ndarray, np.ndarray]: |
| name = "t2m" if dataset_name == "humanml3d" else dataset_name |
| assert name in ["t2m", "kit"] |
| if phase in ["val"]: |
| if name == 't2m': |
| data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD01", "meta") |
| elif name == 'kit': |
| data_root = pjoin(cfg.model.t2m_path, name, "Comp_v6_KLD005", "meta") |
| else: |
| raise ValueError("Only support t2m and kit") |
| mean = np.load(pjoin(data_root, "mean.npy")) |
| std = np.load(pjoin(data_root, "std.npy")) |
| else: |
| data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
| mean = np.load(pjoin(data_root, "Mean.npy")) |
| std = np.load(pjoin(data_root, "Std.npy")) |
|
|
| return mean, std |
|
|
|
|
| def get_WordVectorizer(cfg: DictConfig, dataset_name: str) -> Optional[WordVectorizer]: |
| if dataset_name.lower() in ["humanml3d", "kit"]: |
| return WordVectorizer(cfg.DATASET.WORD_VERTILIZER_PATH, "our_vab") |
| else: |
| raise ValueError("Only support WordVectorizer for HumanML3D and KIT") |
|
|
|
|
| dataset_module_map = {"humanml3d": DataModule, "kit": DataModule} |
| motion_subdir = {"humanml3d": "new_joint_vecs", "kit": "new_joint_vecs"} |
|
|
|
|
| def get_dataset(cfg: DictConfig, motion_only: bool = False) -> BaseDataModule: |
| dataset_name = cfg.DATASET.NAME |
| if dataset_name.lower() in ["humanml3d", "kit"]: |
| data_root = eval(f"cfg.DATASET.{dataset_name.upper()}.ROOT") |
| mean, std = get_mean_std('train', cfg, dataset_name) |
| mean_eval, std_eval = get_mean_std("val", cfg, dataset_name) |
| wordVectorizer = None if motion_only else get_WordVectorizer(cfg, dataset_name) |
| collate_fn = mld_collate_motion_only if motion_only else mld_collate |
| dataset = dataset_module_map[dataset_name.lower()]( |
| name=dataset_name.lower(), |
| cfg=cfg, |
| motion_only=motion_only, |
| collate_fn=collate_fn, |
| mean=mean, |
| std=std, |
| mean_eval=mean_eval, |
| std_eval=std_eval, |
| w_vectorizer=wordVectorizer, |
| text_dir=pjoin(data_root, "texts"), |
| motion_dir=pjoin(data_root, motion_subdir[dataset_name]), |
| max_motion_length=cfg.DATASET.SAMPLER.MAX_LEN, |
| min_motion_length=cfg.DATASET.SAMPLER.MIN_LEN, |
| max_text_len=cfg.DATASET.SAMPLER.MAX_TEXT_LEN, |
| unit_length=eval(f"cfg.DATASET.{dataset_name.upper()}.UNIT_LEN"), |
| fps=eval(f"cfg.DATASET.{dataset_name.upper()}.FRAME_RATE"), |
| padding_to_max=cfg.DATASET.PADDING_TO_MAX, |
| window_size=cfg.DATASET.WINDOW_SIZE, |
| control_args=eval(f"cfg.DATASET.{dataset_name.upper()}.CONTROL_ARGS")) |
|
|
| cfg.DATASET.NFEATS = dataset.nfeats |
| cfg.DATASET.NJOINTS = dataset.njoints |
| return dataset |
|
|
| elif dataset_name.lower() in ["humanact12", 'uestc', "amass"]: |
| raise NotImplementedError |
|
|