rsax commited on
Commit
8104cbc
·
verified ·
1 Parent(s): 02c7eaf

Upload 8 files

Browse files
dataset/dataset_TM_eval.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+ import utils.paramUtil as paramUtil
10
+ from torch.utils.data._utils.collate import default_collate
11
+
12
+
13
+ def collate_fn(batch):
14
+ batch.sort(key=lambda x: x[3], reverse=True)
15
+ return default_collate(batch)
16
+
17
+
18
+ '''For use of training text-2-motion generative model'''
19
+ class Text2MotionDataset(data.Dataset):
20
+ def __init__(self, dataset_name, is_test, w_vectorizer, feat_bias = 5, max_text_len = 20, unit_length = 4):
21
+
22
+ self.max_length = 20
23
+ self.pointer = 0
24
+ self.dataset_name = dataset_name
25
+ self.is_test = is_test
26
+ self.max_text_len = max_text_len
27
+ self.unit_length = unit_length
28
+ self.w_vectorizer = w_vectorizer
29
+ if dataset_name == 't2m':
30
+ self.data_root = './dataset/Sample1'
31
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
32
+ self.text_dir = pjoin(self.data_root, 'texts')
33
+ self.joints_num = 22
34
+ radius = 4
35
+ fps = 20
36
+ self.max_motion_length = 196
37
+ dim_pose = 263
38
+ kinematic_chain = paramUtil.t2m_kinematic_chain
39
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
40
+ elif dataset_name == 'kit':
41
+ self.data_root = './dataset/KIT-ML'
42
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
43
+ self.text_dir = pjoin(self.data_root, 'texts')
44
+ self.joints_num = 21
45
+ radius = 240 * 8
46
+ fps = 12.5
47
+ dim_pose = 251
48
+ self.max_motion_length = 196
49
+ kinematic_chain = paramUtil.kit_kinematic_chain
50
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
51
+
52
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
53
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
54
+
55
+ if is_test:
56
+ split_file = pjoin(self.data_root, 'test.txt')
57
+ else:
58
+ split_file = pjoin(self.data_root, 'val.txt')
59
+
60
+ min_motion_len = 40 if self.dataset_name =='t2m' else 24
61
+ # min_motion_len = 64
62
+
63
+ joints_num = self.joints_num
64
+
65
+ data_dict = {}
66
+ id_list = []
67
+ with cs.open(split_file, 'r') as f:
68
+ for line in f.readlines():
69
+ id_list.append(line.strip())
70
+
71
+ new_name_list = []
72
+ length_list = []
73
+ for name in tqdm(id_list):
74
+ try:
75
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
76
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
77
+ continue
78
+ text_data = []
79
+ flag = False
80
+ with cs.open(pjoin(self.text_dir, name + '.txt')) as f:
81
+ for line in f.readlines():
82
+ text_dict = {}
83
+ line_split = line.strip().split('#')
84
+ caption = line_split[0]
85
+ tokens = line_split[1].split(' ')
86
+ f_tag = float(line_split[2])
87
+ to_tag = float(line_split[3])
88
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
89
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
90
+
91
+ text_dict['caption'] = caption
92
+ text_dict['tokens'] = tokens
93
+ if f_tag == 0.0 and to_tag == 0.0:
94
+ flag = True
95
+ text_data.append(text_dict)
96
+ else:
97
+ try:
98
+ n_motion = motion[int(f_tag*fps) : int(to_tag*fps)]
99
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
100
+ continue
101
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
102
+ while new_name in data_dict:
103
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
104
+ data_dict[new_name] = {'motion': n_motion,
105
+ 'length': len(n_motion),
106
+ 'text':[text_dict]}
107
+ new_name_list.append(new_name)
108
+ length_list.append(len(n_motion))
109
+ except:
110
+ print(line_split)
111
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
112
+ # break
113
+
114
+ if flag:
115
+ data_dict[name] = {'motion': motion,
116
+ 'length': len(motion),
117
+ 'text': text_data}
118
+ new_name_list.append(name)
119
+ length_list.append(len(motion))
120
+ except Exception as e:
121
+ # print(e)
122
+ pass
123
+
124
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
125
+ self.mean = mean
126
+ self.std = std
127
+ self.length_arr = np.array(length_list)
128
+ self.data_dict = data_dict
129
+ self.name_list = name_list
130
+ self.reset_max_len(self.max_length)
131
+
132
+ def reset_max_len(self, length):
133
+ assert length <= self.max_motion_length
134
+ self.pointer = np.searchsorted(self.length_arr, length)
135
+ print("Pointer Pointing at %d"%self.pointer)
136
+ self.max_length = length
137
+
138
+ def inv_transform(self, data):
139
+ return data * self.std + self.mean
140
+
141
+ def forward_transform(self, data):
142
+ return (data - self.mean) / self.std
143
+
144
+ def __len__(self):
145
+ return len(self.data_dict) - self.pointer
146
+
147
+ def __getitem__(self, item):
148
+ idx = self.pointer + item
149
+ name = self.name_list[idx]
150
+ data = self.data_dict[name]
151
+ # data = self.data_dict[self.name_list[idx]]
152
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
153
+ # Randomly select a caption
154
+ text_data = random.choice(text_list)
155
+ caption, tokens = text_data['caption'], text_data['tokens']
156
+
157
+ if len(tokens) < self.max_text_len:
158
+ # pad with "unk"
159
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
160
+ sent_len = len(tokens)
161
+ tokens = tokens + ['unk/OTHER'] * (self.max_text_len + 2 - sent_len)
162
+ else:
163
+ # crop
164
+ tokens = tokens[:self.max_text_len]
165
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
166
+ sent_len = len(tokens)
167
+ pos_one_hots = []
168
+ word_embeddings = []
169
+ for token in tokens:
170
+ word_emb, pos_oh = self.w_vectorizer[token]
171
+ pos_one_hots.append(pos_oh[None, :])
172
+ word_embeddings.append(word_emb[None, :])
173
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
174
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
175
+
176
+ if self.unit_length < 10:
177
+ coin2 = np.random.choice(['single', 'single', 'double'])
178
+ else:
179
+ coin2 = 'single'
180
+
181
+ if coin2 == 'double':
182
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
183
+ elif coin2 == 'single':
184
+ m_length = (m_length // self.unit_length) * self.unit_length
185
+ idx = random.randint(0, len(motion) - m_length)
186
+ motion = motion[idx:idx+m_length]
187
+
188
+ "Z Normalization"
189
+ motion = (motion - self.mean) / self.std
190
+
191
+ if m_length < self.max_motion_length:
192
+ motion = np.concatenate([motion,
193
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
194
+ ], axis=0)
195
+
196
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens), name
197
+
198
+
199
+
200
+
201
+ def DATALoader(dataset_name, is_test,
202
+ batch_size, w_vectorizer,
203
+ num_workers = 8, unit_length = 4) :
204
+
205
+ val_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, is_test, w_vectorizer, unit_length=unit_length),
206
+ batch_size,
207
+ shuffle = True,
208
+ num_workers=num_workers,
209
+ collate_fn=collate_fn,
210
+ drop_last = True)
211
+ return val_loader
212
+
213
+
214
+ def cycle(iterable):
215
+ while True:
216
+ for x in iterable:
217
+ yield x
dataset/dataset_TM_train.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+ import utils.paramUtil as paramUtil
9
+ from torch.utils.data._utils.collate import default_collate
10
+
11
+
12
+ def collate_fn(batch):
13
+ batch.sort(key=lambda x: x[3], reverse=True)
14
+ return default_collate(batch)
15
+
16
+
17
+ '''For use of training text-2-motion generative model'''
18
+ class Text2MotionDataset(data.Dataset):
19
+ def __init__(self, dataset_name, feat_bias=5, unit_length=4, codebook_size=1024, tokenizer_name=None):
20
+ self.max_length = 64
21
+ self.pointer = 0
22
+ self.dataset_name = dataset_name
23
+ self.unit_length = unit_length
24
+ self.mot_end_idx = codebook_size
25
+ self.mot_pad_idx = codebook_size + 1
26
+
27
+ print(f"Loading dataset: {dataset_name}")
28
+
29
+ if dataset_name == 't2m':
30
+ self.data_root = './dataset/Sample1'
31
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
32
+ self.text_dir = pjoin(self.data_root, 'texts')
33
+ self.joints_num = 22
34
+ radius = 4
35
+ fps = 20
36
+ self.max_motion_length = 26 if unit_length == 8 else 51
37
+ dim_pose = 263
38
+ kinematic_chain = paramUtil.t2m_kinematic_chain
39
+ elif dataset_name == 'kit':
40
+ self.data_root = './dataset/KIT-ML'
41
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
42
+ self.text_dir = pjoin(self.data_root, 'texts')
43
+ self.joints_num = 21
44
+ radius = 240 * 8
45
+ fps = 12.5
46
+ dim_pose = 251
47
+ self.max_motion_length = 26 if unit_length == 8 else 51
48
+ kinematic_chain = paramUtil.kit_kinematic_chain
49
+
50
+ split_file = pjoin(self.data_root, 'train.txt')
51
+
52
+ id_list = []
53
+ with cs.open(split_file, 'r') as f:
54
+ for line in f.readlines():
55
+ id_list.append(line.strip())
56
+
57
+ new_name_list = []
58
+ data_dict = {}
59
+ for name in tqdm(id_list):
60
+ try:
61
+ m_token_list = np.load(pjoin(self.data_root, tokenizer_name, f'{name}.npy'))
62
+
63
+ with cs.open(pjoin(self.text_dir, f'{name}.txt')) as f:
64
+ text_data = []
65
+ flag = False
66
+ lines = f.readlines()
67
+
68
+ for line in lines:
69
+ try:
70
+ text_dict = {}
71
+ line_split = line.strip().split('#')
72
+ caption = line_split[0]
73
+ t_tokens = line_split[1].split(' ')
74
+ f_tag = float(line_split[2])
75
+ to_tag = float(line_split[3])
76
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
77
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
78
+
79
+ text_dict['caption'] = caption
80
+ text_dict['tokens'] = t_tokens
81
+ if f_tag == 0.0 and to_tag == 0.0:
82
+ flag = True
83
+ text_data.append(text_dict)
84
+ else:
85
+ m_token_list_new = [tokens[int(f_tag * fps / unit_length): int(to_tag * fps / unit_length)] for tokens in m_token_list if int(f_tag * fps / unit_length) < int(to_tag * fps / unit_length)]
86
+
87
+ if len(m_token_list_new) == 0:
88
+ continue
89
+ new_name = f'{name}_{f_tag}_{to_tag}'
90
+
91
+ data_dict[new_name] = {'m_token_list': m_token_list_new,
92
+ 'text': [text_dict]}
93
+ new_name_list.append(new_name)
94
+ except:
95
+ pass
96
+
97
+ if flag:
98
+ data_dict[name] = {'m_token_list': m_token_list,
99
+ 'text': text_data}
100
+ new_name_list.append(name)
101
+ except:
102
+ pass
103
+
104
+ self.data_dict = data_dict
105
+ self.name_list = new_name_list
106
+
107
+ print(f"Dataset loaded. Number of samples: {len(self.data_dict)}")
108
+
109
+ def __len__(self):
110
+ return len(self.data_dict)
111
+
112
+ def __getitem__(self, item):
113
+ data = self.data_dict[self.name_list[item]]
114
+ m_token_list, text_list = data['m_token_list'], data['text']
115
+ m_tokens = random.choice(m_token_list)
116
+
117
+ text_data = random.choice(text_list)
118
+ caption = text_data['caption']
119
+
120
+ coin = np.random.choice([False, False, True])
121
+ if coin:
122
+ coin2 = np.random.choice([True, False])
123
+ if coin2:
124
+ m_tokens = m_tokens[:-1]
125
+ else:
126
+ m_tokens = m_tokens[1:]
127
+ m_tokens_len = m_tokens.shape[0]
128
+
129
+ if m_tokens_len + 1 < self.max_motion_length:
130
+ m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx, np.ones((self.max_motion_length - 1 - m_tokens_len), dtype=int) * self.mot_pad_idx], axis=0)
131
+ else:
132
+ m_tokens = np.concatenate([m_tokens, np.ones((1), dtype=int) * self.mot_end_idx], axis=0)
133
+
134
+ return caption, m_tokens.reshape(-1), m_tokens_len
135
+
136
+
137
+ def DATALoader(dataset_name,
138
+ batch_size, codebook_size, tokenizer_name, unit_length=4,
139
+ num_workers = 8) :
140
+
141
+ train_loader = torch.utils.data.DataLoader(Text2MotionDataset(dataset_name, codebook_size = codebook_size, tokenizer_name = tokenizer_name, unit_length=unit_length),
142
+ batch_size,
143
+ shuffle=True,
144
+ num_workers=num_workers,
145
+ #collate_fn=collate_fn,
146
+ drop_last = True)
147
+
148
+
149
+ return train_loader
150
+
151
+
152
+ def cycle(iterable):
153
+ while True:
154
+ for x in iterable:
155
+ yield x
156
+
157
+
dataset/dataset_VQ.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+
10
+
11
+ class VQMotionDataset(data.Dataset):
12
+ def __init__(self, dataset_name, window_size = 64, unit_length = 4):
13
+ self.window_size = window_size
14
+ self.unit_length = unit_length
15
+ self.dataset_name = dataset_name
16
+
17
+ if dataset_name == 't2m':
18
+ self.data_root = './dataset/Sample1'
19
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
20
+ self.text_dir = pjoin(self.data_root, 'texts')
21
+ self.joints_num = 22
22
+ self.max_motion_length = 196
23
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
24
+
25
+ elif dataset_name == 'kit':
26
+ self.data_root = './dataset/KIT-ML'
27
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
28
+ self.text_dir = pjoin(self.data_root, 'texts')
29
+ self.joints_num = 21
30
+
31
+ self.max_motion_length = 196
32
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
33
+
34
+ joints_num = self.joints_num
35
+
36
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
37
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
38
+
39
+ split_file = pjoin(self.data_root, 'train.txt')
40
+
41
+ self.data = []
42
+ self.lengths = []
43
+ id_list = []
44
+ with cs.open(split_file, 'r') as f:
45
+ for line in f.readlines():
46
+ id_list.append(line.strip())
47
+
48
+ for name in tqdm(id_list):
49
+ try:
50
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
51
+ if motion.shape[0] < self.window_size:
52
+ continue
53
+ self.lengths.append(motion.shape[0] - self.window_size)
54
+ self.data.append(motion)
55
+ except:
56
+ # Some motion may not exist in KIT dataset
57
+ pass
58
+
59
+
60
+ self.mean = mean
61
+ self.std = std
62
+ print("Total number of motions {}".format(len(self.data)))
63
+
64
+ def inv_transform(self, data):
65
+ return data * self.std + self.mean
66
+
67
+ def compute_sampling_prob(self) :
68
+
69
+ prob = np.array(self.lengths, dtype=np.float32)
70
+ prob /= np.sum(prob)
71
+ return prob
72
+
73
+ def __len__(self):
74
+ return len(self.data)
75
+
76
+ def __getitem__(self, item):
77
+ motion = self.data[item]
78
+
79
+ idx = random.randint(0, len(motion) - self.window_size)
80
+
81
+ motion = motion[idx:idx+self.window_size]
82
+ "Z Normalization"
83
+ motion = (motion - self.mean) / self.std
84
+
85
+ return motion
86
+
87
+ def DATALoader(dataset_name,
88
+ batch_size,
89
+ num_workers = 8,
90
+ window_size = 64,
91
+ unit_length = 4):
92
+
93
+ trainSet = VQMotionDataset(dataset_name, window_size=window_size, unit_length=unit_length)
94
+ prob = trainSet.compute_sampling_prob()
95
+ sampler = torch.utils.data.WeightedRandomSampler(prob, num_samples = len(trainSet) * 1000, replacement=True)
96
+ train_loader = torch.utils.data.DataLoader(trainSet,
97
+ batch_size,
98
+ shuffle=True,
99
+ #sampler=sampler,
100
+ num_workers=num_workers,
101
+ #collate_fn=collate_fn,
102
+ drop_last = True)
103
+
104
+ return train_loader
105
+
106
+ def cycle(iterable):
107
+ while True:
108
+ for x in iterable:
109
+ yield x
dataset/dataset_tokenize.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm import tqdm
8
+
9
+
10
+
11
+ class VQMotionDataset(data.Dataset):
12
+ def __init__(self, dataset_name, feat_bias = 5, window_size = 64, unit_length = 8):
13
+ self.window_size = window_size
14
+ self.unit_length = unit_length
15
+ self.feat_bias = feat_bias
16
+
17
+ self.dataset_name = dataset_name
18
+ min_motion_len = 40 if dataset_name =='t2m' else 24
19
+
20
+ if dataset_name == 't2m':
21
+ self.data_root = './dataset/Sample1'
22
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
23
+ self.text_dir = pjoin(self.data_root, 'texts')
24
+ self.joints_num = 22
25
+ radius = 4
26
+ fps = 20
27
+ self.max_motion_length = 196
28
+ dim_pose = 263
29
+ self.meta_dir = 'checkpoints/t2m/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
30
+ #kinematic_chain = paramUtil.t2m_kinematic_chain
31
+ elif dataset_name == 'kit':
32
+ self.data_root = './dataset/KIT-ML'
33
+ self.motion_dir = pjoin(self.data_root, 'new_joint_vecs')
34
+ self.text_dir = pjoin(self.data_root, 'texts')
35
+ self.joints_num = 21
36
+ radius = 240 * 8
37
+ fps = 12.5
38
+ dim_pose = 251
39
+ self.max_motion_length = 196
40
+ self.meta_dir = 'checkpoints/kit/VQVAEV3_CB1024_CMT_H1024_NRES3/meta'
41
+ #kinematic_chain = paramUtil.kit_kinematic_chain
42
+
43
+ joints_num = self.joints_num
44
+
45
+ mean = np.load(pjoin(self.meta_dir, 'mean.npy'))
46
+ std = np.load(pjoin(self.meta_dir, 'std.npy'))
47
+
48
+ split_file = pjoin(self.data_root, 'train.txt')
49
+
50
+ data_dict = {}
51
+ id_list = []
52
+ with cs.open(split_file, 'r') as f:
53
+ for line in f.readlines():
54
+ id_list.append(line.strip())
55
+
56
+ new_name_list = []
57
+ length_list = []
58
+ for name in tqdm(id_list):
59
+ try:
60
+ motion = np.load(pjoin(self.motion_dir, name + '.npy'))
61
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
62
+ continue
63
+
64
+ data_dict[name] = {'motion': motion,
65
+ 'length': len(motion),
66
+ 'name': name}
67
+ new_name_list.append(name)
68
+ length_list.append(len(motion))
69
+ except:
70
+ # Some motion may not exist in KIT dataset
71
+ pass
72
+
73
+
74
+ self.mean = mean
75
+ self.std = std
76
+ self.length_arr = np.array(length_list)
77
+ self.data_dict = data_dict
78
+ self.name_list = new_name_list
79
+
80
+ def inv_transform(self, data):
81
+ return data * self.std + self.mean
82
+
83
+ def __len__(self):
84
+ return len(self.data_dict)
85
+
86
+ def __getitem__(self, item):
87
+ name = self.name_list[item]
88
+ data = self.data_dict[name]
89
+ motion, m_length = data['motion'], data['length']
90
+
91
+ m_length = (m_length // self.unit_length) * self.unit_length
92
+
93
+ idx = random.randint(0, len(motion) - m_length)
94
+ motion = motion[idx:idx+m_length]
95
+
96
+ "Z Normalization"
97
+ motion = (motion - self.mean) / self.std
98
+
99
+ return motion, name
100
+
101
+ def DATALoader(dataset_name,
102
+ batch_size = 1,
103
+ num_workers = 8, unit_length = 4) :
104
+
105
+ train_loader = torch.utils.data.DataLoader(VQMotionDataset(dataset_name, unit_length=unit_length),
106
+ batch_size,
107
+ shuffle=True,
108
+ num_workers=num_workers,
109
+ #collate_fn=collate_fn,
110
+ drop_last = True)
111
+
112
+ return train_loader
113
+
114
+ def cycle(iterable):
115
+ while True:
116
+ for x in iterable:
117
+ yield x
dataset/prepare/download_extractor.sh ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ rm -rf checkpoints
2
+ mkdir checkpoints
3
+ cd checkpoints
4
+ echo -e "Downloading extractors"
5
+ gdown --fuzzy https://drive.google.com/file/d/1o7RTDQcToJjTm9_mNWTyzvZvjTWpZfug/view
6
+ gdown --fuzzy https://drive.google.com/file/d/1KNU8CsMAnxFrwopKBBkC8jEULGLPBHQp/view
7
+
8
+
9
+ unzip t2m.zip
10
+ unzip kit.zip
11
+
12
+ echo -e "Cleaning\n"
13
+ rm t2m.zip
14
+ rm kit.zip
15
+ echo -e "Downloading done!"
dataset/prepare/download_glove.sh ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ echo -e "Downloading glove (in use by the evaluators)"
2
+ gdown --fuzzy https://drive.google.com/file/d/1bCeS6Sh_mLVTebxIgiUHgdPrroW06mb6/view?usp=sharing
3
+ rm -rf glove
4
+
5
+ unzip glove.zip
6
+ echo -e "Cleaning\n"
7
+ rm glove.zip
8
+
9
+ echo -e "Downloading done!"
dataset/prepare/download_model.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ mkdir -p pretrained
3
+ cd pretrained/
4
+
5
+ echo -e "The pretrained model files will be stored in the 'pretrained' folder\n"
6
+ gdown 1LaOvwypF-jM2Axnq5dc-Iuvv3w_G-WDE
7
+
8
+ unzip VQTrans_pretrained.zip
9
+ echo -e "Cleaning\n"
10
+ rm VQTrans_pretrained.zip
11
+
12
+ echo -e "Downloading done!"
dataset/prepare/download_smpl.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ mkdir -p body_models
3
+ cd body_models/
4
+
5
+ echo -e "The smpl files will be stored in the 'body_models/smpl/' folder\n"
6
+ gdown 1INYlGA76ak_cKGzvpOV2Pe6RkYTlXTW2
7
+ rm -rf smpl
8
+
9
+ unzip smpl.zip
10
+ echo -e "Cleaning\n"
11
+ rm smpl.zip
12
+
13
+ echo -e "Downloading done!"