mdm / data_loaders /get_data.py
hassanjbara's picture
update model
5007d4b
import numpy as np
from torch.utils.data import DataLoader, Subset
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate, t2m_prefix_collate
def get_dataset_class(name):
if name == "amass":
from .amass import AMASS
return AMASS
elif name == "uestc":
from .a2m.uestc import UESTC
return UESTC
elif name == "humanact12":
from .a2m.humanact12poses import HumanAct12Poses
return HumanAct12Poses
elif name == "humanml":
from data_loaders.humanml.data.dataset import HumanML3D
return HumanML3D
elif name == "humanml_with_images":
from data_loaders.humanml.data.dataset import HumanML3DWithImages
return HumanML3DWithImages
elif name == "kit":
from data_loaders.humanml.data.dataset import KIT
return KIT
else:
raise ValueError(f"Unsupported dataset name [{name}]")
def get_collate_fn(name, hml_mode="train", pred_len=0, batch_size=1):
if hml_mode == "gt":
from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
return t2m_eval_collate
if name in ["humanml_with_images", "humanml", "kit"]:
if pred_len > 0:
return lambda x: t2m_prefix_collate(x, pred_len=pred_len)
return lambda x: t2m_collate(x, batch_size)
else:
return all_collate
def get_dataset(
name,
num_frames,
split="train",
hml_mode="train",
abs_path=".",
fixed_len=0,
device=None,
autoregressive=False,
use_precomputed_embeddings=False,
embeddings_dir=None,
prompt_drop_rate=0.0,
num_cond_frames=196,
fast_mode=False,
):
DATA = get_dataset_class(name)
if name in ["humanml_with_images", "humanml", "kit"]:
dataset = DATA(
split=split,
num_frames=num_frames,
mode=hml_mode,
abs_path=abs_path,
fixed_len=fixed_len,
device=device,
autoregressive=autoregressive,
use_precomputed_embeddings=use_precomputed_embeddings,
embeddings_dir=embeddings_dir,
prompt_drop_rate=prompt_drop_rate,
num_cond_frames=num_cond_frames,
fast_mode=fast_mode,
)
else:
dataset = DATA(split=split, num_frames=num_frames)
return dataset
def get_dataset_loader(
name,
batch_size,
num_frames,
split="train",
hml_mode="train",
fixed_len=0,
pred_len=0,
device=None,
autoregressive=False,
num_samples=None,
train_sample_indices=None,
use_precomputed_embeddings=False,
embeddings_dir=None,
prompt_drop_rate=0.0,
num_cond_frames=196,
fast_mode=False,
):
dataset = get_dataset(
name,
num_frames,
split=split,
hml_mode=hml_mode,
fixed_len=fixed_len,
device=device,
autoregressive=autoregressive,
use_precomputed_embeddings=use_precomputed_embeddings,
embeddings_dir=embeddings_dir,
prompt_drop_rate=prompt_drop_rate,
num_cond_frames=num_cond_frames,
fast_mode=fast_mode,
)
if train_sample_indices is not None:
dataset = Subset(dataset, train_sample_indices)
print(
f"Using only the samples with indices {train_sample_indices} from the dataset."
)
elif num_samples is not None:
assert num_samples > 0, "num_samples must be greater than 0"
assert num_samples <= len(dataset), (
f"num_samples {num_samples} is greater than the dataset size {len(dataset)}"
)
# Choose the first num_samples samples
indices = np.arange(num_samples)
dataset = Subset(dataset, indices)
print(f"Using a subset of {num_samples} samples from the dataset.")
collate = get_collate_fn(name, hml_mode, pred_len, batch_size)
loader = DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=collate,
pin_memory=True,
prefetch_factor=4,
persistent_workers=True,
)
return loader