import numpy as np from torch.utils.data import DataLoader, Subset from data_loaders.tensors import collate as all_collate from data_loaders.tensors import t2m_collate, t2m_prefix_collate def get_dataset_class(name): if name == "amass": from .amass import AMASS return AMASS elif name == "uestc": from .a2m.uestc import UESTC return UESTC elif name == "humanact12": from .a2m.humanact12poses import HumanAct12Poses return HumanAct12Poses elif name == "humanml": from data_loaders.humanml.data.dataset import HumanML3D return HumanML3D elif name == "humanml_with_images": from data_loaders.humanml.data.dataset import HumanML3DWithImages return HumanML3DWithImages elif name == "kit": from data_loaders.humanml.data.dataset import KIT return KIT else: raise ValueError(f"Unsupported dataset name [{name}]") def get_collate_fn(name, hml_mode="train", pred_len=0, batch_size=1): if hml_mode == "gt": from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate return t2m_eval_collate if name in ["humanml_with_images", "humanml", "kit"]: if pred_len > 0: return lambda x: t2m_prefix_collate(x, pred_len=pred_len) return lambda x: t2m_collate(x, batch_size) else: return all_collate def get_dataset( name, num_frames, split="train", hml_mode="train", abs_path=".", fixed_len=0, device=None, autoregressive=False, use_precomputed_embeddings=False, embeddings_dir=None, prompt_drop_rate=0.0, num_cond_frames=196, fast_mode=False, ): DATA = get_dataset_class(name) if name in ["humanml_with_images", "humanml", "kit"]: dataset = DATA( split=split, num_frames=num_frames, mode=hml_mode, abs_path=abs_path, fixed_len=fixed_len, device=device, autoregressive=autoregressive, use_precomputed_embeddings=use_precomputed_embeddings, embeddings_dir=embeddings_dir, prompt_drop_rate=prompt_drop_rate, num_cond_frames=num_cond_frames, fast_mode=fast_mode, ) else: dataset = DATA(split=split, num_frames=num_frames) return dataset def get_dataset_loader( name, batch_size, num_frames, split="train", hml_mode="train", fixed_len=0, pred_len=0, device=None, autoregressive=False, num_samples=None, train_sample_indices=None, use_precomputed_embeddings=False, embeddings_dir=None, prompt_drop_rate=0.0, num_cond_frames=196, fast_mode=False, ): dataset = get_dataset( name, num_frames, split=split, hml_mode=hml_mode, fixed_len=fixed_len, device=device, autoregressive=autoregressive, use_precomputed_embeddings=use_precomputed_embeddings, embeddings_dir=embeddings_dir, prompt_drop_rate=prompt_drop_rate, num_cond_frames=num_cond_frames, fast_mode=fast_mode, ) if train_sample_indices is not None: dataset = Subset(dataset, train_sample_indices) print( f"Using only the samples with indices {train_sample_indices} from the dataset." ) elif num_samples is not None: assert num_samples > 0, "num_samples must be greater than 0" assert num_samples <= len(dataset), ( f"num_samples {num_samples} is greater than the dataset size {len(dataset)}" ) # Choose the first num_samples samples indices = np.arange(num_samples) dataset = Subset(dataset, indices) print(f"Using a subset of {num_samples} samples from the dataset.") collate = get_collate_fn(name, hml_mode, pred_len, batch_size) loader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=8, drop_last=True, collate_fn=collate, pin_memory=True, prefetch_factor=4, persistent_workers=True, ) return loader