File size: 4,193 Bytes
5007d4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import numpy as np
from torch.utils.data import DataLoader, Subset
from data_loaders.tensors import collate as all_collate
from data_loaders.tensors import t2m_collate, t2m_prefix_collate


def get_dataset_class(name):
    if name == "amass":
        from .amass import AMASS

        return AMASS
    elif name == "uestc":
        from .a2m.uestc import UESTC

        return UESTC
    elif name == "humanact12":
        from .a2m.humanact12poses import HumanAct12Poses

        return HumanAct12Poses
    elif name == "humanml":
        from data_loaders.humanml.data.dataset import HumanML3D

        return HumanML3D
    elif name == "humanml_with_images":
        from data_loaders.humanml.data.dataset import HumanML3DWithImages

        return HumanML3DWithImages
    elif name == "kit":
        from data_loaders.humanml.data.dataset import KIT

        return KIT
    else:
        raise ValueError(f"Unsupported dataset name [{name}]")


def get_collate_fn(name, hml_mode="train", pred_len=0, batch_size=1):
    if hml_mode == "gt":
        from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate

        return t2m_eval_collate
    if name in ["humanml_with_images", "humanml", "kit"]:
        if pred_len > 0:
            return lambda x: t2m_prefix_collate(x, pred_len=pred_len)
        return lambda x: t2m_collate(x, batch_size)
    else:
        return all_collate


def get_dataset(
    name,
    num_frames,
    split="train",
    hml_mode="train",
    abs_path=".",
    fixed_len=0,
    device=None,
    autoregressive=False,
    use_precomputed_embeddings=False,
    embeddings_dir=None,
    prompt_drop_rate=0.0,
    num_cond_frames=196,
    fast_mode=False,
):
    DATA = get_dataset_class(name)
    if name in ["humanml_with_images", "humanml", "kit"]:
        dataset = DATA(
            split=split,
            num_frames=num_frames,
            mode=hml_mode,
            abs_path=abs_path,
            fixed_len=fixed_len,
            device=device,
            autoregressive=autoregressive,
            use_precomputed_embeddings=use_precomputed_embeddings,
            embeddings_dir=embeddings_dir,
            prompt_drop_rate=prompt_drop_rate,
            num_cond_frames=num_cond_frames,
            fast_mode=fast_mode,
        )
    else:
        dataset = DATA(split=split, num_frames=num_frames)
    return dataset


def get_dataset_loader(
    name,
    batch_size,
    num_frames,
    split="train",
    hml_mode="train",
    fixed_len=0,
    pred_len=0,
    device=None,
    autoregressive=False,
    num_samples=None,
    train_sample_indices=None,
    use_precomputed_embeddings=False,
    embeddings_dir=None,
    prompt_drop_rate=0.0,
    num_cond_frames=196,
    fast_mode=False,
):
    dataset = get_dataset(
        name,
        num_frames,
        split=split,
        hml_mode=hml_mode,
        fixed_len=fixed_len,
        device=device,
        autoregressive=autoregressive,
        use_precomputed_embeddings=use_precomputed_embeddings,
        embeddings_dir=embeddings_dir,
        prompt_drop_rate=prompt_drop_rate,
        num_cond_frames=num_cond_frames,
        fast_mode=fast_mode,
    )

    if train_sample_indices is not None:
        dataset = Subset(dataset, train_sample_indices)
        print(
            f"Using only the samples with indices {train_sample_indices} from the dataset."
        )
    elif num_samples is not None:
        assert num_samples > 0, "num_samples must be greater than 0"
        assert num_samples <= len(dataset), (
            f"num_samples {num_samples} is greater than the dataset size {len(dataset)}"
        )

        # Choose the first num_samples samples
        indices = np.arange(num_samples)
        dataset = Subset(dataset, indices)
        print(f"Using a subset of {num_samples} samples from the dataset.")

    collate = get_collate_fn(name, hml_mode, pred_len, batch_size)

    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        drop_last=True,
        collate_fn=collate,
        pin_memory=True,
        prefetch_factor=4,
        persistent_workers=True,
    )

    return loader