Spaces:
Sleeping
Sleeping
File size: 4,193 Bytes
5007d4b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 | import numpy as np
from torch.utils.data import DataLoader, Subset
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate, t2m_prefix_collate
def get_dataset_class(name):
if name == "amass":
from .amass import AMASS
return AMASS
elif name == "uestc":
from .a2m.uestc import UESTC
return UESTC
elif name == "humanact12":
from .a2m.humanact12poses import HumanAct12Poses
return HumanAct12Poses
elif name == "humanml":
from data_loaders.humanml.data.dataset import HumanML3D
return HumanML3D
elif name == "humanml_with_images":
from data_loaders.humanml.data.dataset import HumanML3DWithImages
return HumanML3DWithImages
elif name == "kit":
from data_loaders.humanml.data.dataset import KIT
return KIT
else:
raise ValueError(f"Unsupported dataset name [{name}]")
def get_collate_fn(name, hml_mode="train", pred_len=0, batch_size=1):
if hml_mode == "gt":
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
return t2m_eval_collate
if name in ["humanml_with_images", "humanml", "kit"]:
if pred_len > 0:
return lambda x: t2m_prefix_collate(x, pred_len=pred_len)
return lambda x: t2m_collate(x, batch_size)
else:
return all_collate
def get_dataset(
name,
num_frames,
split="train",
hml_mode="train",
abs_path=".",
fixed_len=0,
device=None,
autoregressive=False,
use_precomputed_embeddings=False,
embeddings_dir=None,
prompt_drop_rate=0.0,
num_cond_frames=196,
fast_mode=False,
):
DATA = get_dataset_class(name)
if name in ["humanml_with_images", "humanml", "kit"]:
dataset = DATA(
split=split,
num_frames=num_frames,
mode=hml_mode,
abs_path=abs_path,
fixed_len=fixed_len,
device=device,
autoregressive=autoregressive,
use_precomputed_embeddings=use_precomputed_embeddings,
embeddings_dir=embeddings_dir,
prompt_drop_rate=prompt_drop_rate,
num_cond_frames=num_cond_frames,
fast_mode=fast_mode,
)
else:
dataset = DATA(split=split, num_frames=num_frames)
return dataset
def get_dataset_loader(
name,
batch_size,
num_frames,
split="train",
hml_mode="train",
fixed_len=0,
pred_len=0,
device=None,
autoregressive=False,
num_samples=None,
train_sample_indices=None,
use_precomputed_embeddings=False,
embeddings_dir=None,
prompt_drop_rate=0.0,
num_cond_frames=196,
fast_mode=False,
):
dataset = get_dataset(
name,
num_frames,
split=split,
hml_mode=hml_mode,
fixed_len=fixed_len,
device=device,
autoregressive=autoregressive,
use_precomputed_embeddings=use_precomputed_embeddings,
embeddings_dir=embeddings_dir,
prompt_drop_rate=prompt_drop_rate,
num_cond_frames=num_cond_frames,
fast_mode=fast_mode,
)
if train_sample_indices is not None:
dataset = Subset(dataset, train_sample_indices)
print(
f"Using only the samples with indices {train_sample_indices} from the dataset."
)
elif num_samples is not None:
assert num_samples > 0, "num_samples must be greater than 0"
assert num_samples <= len(dataset), (
f"num_samples {num_samples} is greater than the dataset size {len(dataset)}"
)
# Choose the first num_samples samples
indices = np.arange(num_samples)
dataset = Subset(dataset, indices)
print(f"Using a subset of {num_samples} samples from the dataset.")
collate = get_collate_fn(name, hml_mode, pred_len, batch_size)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=collate,
pin_memory=True,
prefetch_factor=4,
persistent_workers=True,
)
return loader
|