Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from torch.utils.data import DataLoader, Subset | |
| from data_loaders.tensors import collate as all_collate | |
| from data_loaders.tensors import t2m_collate, t2m_prefix_collate | |
| def get_dataset_class(name): | |
| if name == "amass": | |
| from .amass import AMASS | |
| return AMASS | |
| elif name == "uestc": | |
| from .a2m.uestc import UESTC | |
| return UESTC | |
| elif name == "humanact12": | |
| from .a2m.humanact12poses import HumanAct12Poses | |
| return HumanAct12Poses | |
| elif name == "humanml": | |
| from data_loaders.humanml.data.dataset import HumanML3D | |
| return HumanML3D | |
| elif name == "humanml_with_images": | |
| from data_loaders.humanml.data.dataset import HumanML3DWithImages | |
| return HumanML3DWithImages | |
| elif name == "kit": | |
| from data_loaders.humanml.data.dataset import KIT | |
| return KIT | |
| else: | |
| raise ValueError(f"Unsupported dataset name [{name}]") | |
| def get_collate_fn(name, hml_mode="train", pred_len=0, batch_size=1): | |
| if hml_mode == "gt": | |
| from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate | |
| return t2m_eval_collate | |
| if name in ["humanml_with_images", "humanml", "kit"]: | |
| if pred_len > 0: | |
| return lambda x: t2m_prefix_collate(x, pred_len=pred_len) | |
| return lambda x: t2m_collate(x, batch_size) | |
| else: | |
| return all_collate | |
| def get_dataset( | |
| name, | |
| num_frames, | |
| split="train", | |
| hml_mode="train", | |
| abs_path=".", | |
| fixed_len=0, | |
| device=None, | |
| autoregressive=False, | |
| use_precomputed_embeddings=False, | |
| embeddings_dir=None, | |
| prompt_drop_rate=0.0, | |
| num_cond_frames=196, | |
| fast_mode=False, | |
| ): | |
| DATA = get_dataset_class(name) | |
| if name in ["humanml_with_images", "humanml", "kit"]: | |
| dataset = DATA( | |
| split=split, | |
| num_frames=num_frames, | |
| mode=hml_mode, | |
| abs_path=abs_path, | |
| fixed_len=fixed_len, | |
| device=device, | |
| autoregressive=autoregressive, | |
| use_precomputed_embeddings=use_precomputed_embeddings, | |
| embeddings_dir=embeddings_dir, | |
| prompt_drop_rate=prompt_drop_rate, | |
| num_cond_frames=num_cond_frames, | |
| fast_mode=fast_mode, | |
| ) | |
| else: | |
| dataset = DATA(split=split, num_frames=num_frames) | |
| return dataset | |
| def get_dataset_loader( | |
| name, | |
| batch_size, | |
| num_frames, | |
| split="train", | |
| hml_mode="train", | |
| fixed_len=0, | |
| pred_len=0, | |
| device=None, | |
| autoregressive=False, | |
| num_samples=None, | |
| train_sample_indices=None, | |
| use_precomputed_embeddings=False, | |
| embeddings_dir=None, | |
| prompt_drop_rate=0.0, | |
| num_cond_frames=196, | |
| fast_mode=False, | |
| ): | |
| dataset = get_dataset( | |
| name, | |
| num_frames, | |
| split=split, | |
| hml_mode=hml_mode, | |
| fixed_len=fixed_len, | |
| device=device, | |
| autoregressive=autoregressive, | |
| use_precomputed_embeddings=use_precomputed_embeddings, | |
| embeddings_dir=embeddings_dir, | |
| prompt_drop_rate=prompt_drop_rate, | |
| num_cond_frames=num_cond_frames, | |
| fast_mode=fast_mode, | |
| ) | |
| if train_sample_indices is not None: | |
| dataset = Subset(dataset, train_sample_indices) | |
| print( | |
| f"Using only the samples with indices {train_sample_indices} from the dataset." | |
| ) | |
| elif num_samples is not None: | |
| assert num_samples > 0, "num_samples must be greater than 0" | |
| assert num_samples <= len(dataset), ( | |
| f"num_samples {num_samples} is greater than the dataset size {len(dataset)}" | |
| ) | |
| # Choose the first num_samples samples | |
| indices = np.arange(num_samples) | |
| dataset = Subset(dataset, indices) | |
| print(f"Using a subset of {num_samples} samples from the dataset.") | |
| collate = get_collate_fn(name, hml_mode, pred_len, batch_size) | |
| loader = DataLoader( | |
| dataset, | |
| batch_size=batch_size, | |
| shuffle=True, | |
| num_workers=8, | |
| drop_last=True, | |
| collate_fn=collate, | |
| pin_memory=True, | |
| prefetch_factor=4, | |
| persistent_workers=True, | |
| ) | |
| return loader | |