megalado commited on
Commit
f87d582
·
1 Parent(s): e2304d4

Add local model code; tidy requirements

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. motion_diffusion_model/data_loaders/a2m/dataset.py +255 -0
  2. motion_diffusion_model/data_loaders/a2m/humanact12poses.py +57 -0
  3. motion_diffusion_model/data_loaders/a2m/uestc.py +226 -0
  4. motion_diffusion_model/data_loaders/get_data.py +59 -0
  5. motion_diffusion_model/data_loaders/humanml/README.md +1 -0
  6. motion_diffusion_model/data_loaders/humanml/common/quaternion.py +425 -0
  7. motion_diffusion_model/data_loaders/humanml/common/skeleton.py +202 -0
  8. motion_diffusion_model/data_loaders/humanml/data/__init__.py +0 -0
  9. motion_diffusion_model/data_loaders/humanml/data/dataset.py +823 -0
  10. motion_diffusion_model/data_loaders/humanml/motion_loaders/__init__.py +0 -0
  11. motion_diffusion_model/data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py +285 -0
  12. motion_diffusion_model/data_loaders/humanml/motion_loaders/dataset_motion_loader.py +27 -0
  13. motion_diffusion_model/data_loaders/humanml/motion_loaders/model_motion_loaders.py +91 -0
  14. motion_diffusion_model/data_loaders/humanml/networks/__init__.py +0 -0
  15. motion_diffusion_model/data_loaders/humanml/networks/evaluator_wrapper.py +187 -0
  16. motion_diffusion_model/data_loaders/humanml/networks/modules.py +438 -0
  17. motion_diffusion_model/data_loaders/humanml/networks/trainers.py +1089 -0
  18. motion_diffusion_model/data_loaders/humanml/scripts/motion_process.py +669 -0
  19. motion_diffusion_model/data_loaders/humanml/utils/get_opt.py +81 -0
  20. motion_diffusion_model/data_loaders/humanml/utils/metrics.py +146 -0
  21. motion_diffusion_model/data_loaders/humanml/utils/paramUtil.py +63 -0
  22. motion_diffusion_model/data_loaders/humanml/utils/plot_script.py +148 -0
  23. motion_diffusion_model/data_loaders/humanml/utils/utils.py +167 -0
  24. motion_diffusion_model/data_loaders/humanml/utils/word_vectorizer.py +80 -0
  25. motion_diffusion_model/data_loaders/humanml_utils.py +60 -0
  26. motion_diffusion_model/data_loaders/tensors.py +94 -0
  27. motion_diffusion_model/diffusion/fp16_util.py +236 -0
  28. motion_diffusion_model/diffusion/gaussian_diffusion.py +1615 -0
  29. motion_diffusion_model/diffusion/logger.py +495 -0
  30. motion_diffusion_model/diffusion/losses.py +77 -0
  31. motion_diffusion_model/diffusion/nn.py +197 -0
  32. motion_diffusion_model/diffusion/resample.py +154 -0
  33. motion_diffusion_model/diffusion/respace.py +134 -0
  34. motion_diffusion_model/model/BERT/BERT_encoder.py +32 -0
  35. motion_diffusion_model/model/cfg_sampler.py +33 -0
  36. motion_diffusion_model/model/mdm.py +480 -0
  37. motion_diffusion_model/model/rotation2xyz.py +92 -0
  38. motion_diffusion_model/model/smpl.py +97 -0
  39. motion_diffusion_model/sample/edit.py +212 -0
  40. motion_diffusion_model/sample/generate.py +318 -0
  41. motion_diffusion_model/sample/predict.py +167 -0
  42. motion_diffusion_model/utils/PYTORCH3D_LICENSE +30 -0
  43. motion_diffusion_model/utils/config.py +17 -0
  44. motion_diffusion_model/utils/dist_util.py +77 -0
  45. motion_diffusion_model/utils/fixseed.py +18 -0
  46. motion_diffusion_model/utils/loss_util.py +46 -0
  47. motion_diffusion_model/utils/misc.py +74 -0
  48. motion_diffusion_model/utils/model_util.py +132 -0
  49. motion_diffusion_model/utils/parser_util.py +320 -0
  50. motion_diffusion_model/utils/rotation_conversions.py +552 -0
motion_diffusion_model/data_loaders/a2m/dataset.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import torch
5
+ # from utils.action_label_to_idx import action_label_to_idx
6
+ from data_loaders.tensors import collate
7
+ from utils.misc import to_torch
8
+ import utils.rotation_conversions as geometry
9
+
10
+ class Dataset(torch.utils.data.Dataset):
11
+ def __init__(self, num_frames=1, sampling="conseq", sampling_step=1, split="train",
12
+ pose_rep="rot6d", translation=True, glob=True, max_len=-1, min_len=-1, num_seq_max=-1, **kwargs):
13
+ self.num_frames = num_frames
14
+ self.sampling = sampling
15
+ self.sampling_step = sampling_step
16
+ self.split = split
17
+ self.pose_rep = pose_rep
18
+ self.translation = translation
19
+ self.glob = glob
20
+ self.max_len = max_len
21
+ self.min_len = min_len
22
+ self.num_seq_max = num_seq_max
23
+
24
+ self.align_pose_frontview = kwargs.get('align_pose_frontview', False)
25
+ self.use_action_cat_as_text_labels = kwargs.get('use_action_cat_as_text_labels', False)
26
+ self.only_60_classes = kwargs.get('only_60_classes', False)
27
+ self.leave_out_15_classes = kwargs.get('leave_out_15_classes', False)
28
+ self.use_only_15_classes = kwargs.get('use_only_15_classes', False)
29
+
30
+ if self.split not in ["train", "val", "test"]:
31
+ raise ValueError(f"{self.split} is not a valid split")
32
+
33
+ super().__init__()
34
+
35
+ # to remove shuffling
36
+ self._original_train = None
37
+ self._original_test = None
38
+
39
+ def action_to_label(self, action):
40
+ return self._action_to_label[action]
41
+
42
+ def label_to_action(self, label):
43
+ import numbers
44
+ if isinstance(label, numbers.Integral):
45
+ return self._label_to_action[label]
46
+ else: # if it is one hot vector
47
+ label = np.argmax(label)
48
+ return self._label_to_action[label]
49
+
50
+ def get_pose_data(self, data_index, frame_ix):
51
+ pose = self._load(data_index, frame_ix)
52
+ label = self.get_label(data_index)
53
+ return pose, label
54
+
55
+ def get_label(self, ind):
56
+ action = self.get_action(ind)
57
+ return self.action_to_label(action)
58
+
59
+ def get_action(self, ind):
60
+ return self._actions[ind]
61
+
62
+ def action_to_action_name(self, action):
63
+ return self._action_classes[action]
64
+
65
+ def action_name_to_action(self, action_name):
66
+ # self._action_classes is either a list or a dictionary. If it's a dictionary, we 1st convert it to a list
67
+ all_action_names = self._action_classes
68
+ if isinstance(all_action_names, dict):
69
+ all_action_names = list(all_action_names.values())
70
+ assert list(self._action_classes.keys()) == list(range(len(all_action_names))) # the keys should be ordered from 0 to num_actions
71
+
72
+ sorter = np.argsort(all_action_names)
73
+ actions = sorter[np.searchsorted(all_action_names, action_name, sorter=sorter)]
74
+ return actions
75
+
76
+ def __getitem__(self, index):
77
+ if self.split == 'train':
78
+ data_index = self._train[index]
79
+ else:
80
+ data_index = self._test[index]
81
+
82
+ # inp, target = self._get_item_data_index(data_index)
83
+ # return inp, target
84
+ return self._get_item_data_index(data_index)
85
+
86
+ def _load(self, ind, frame_ix):
87
+ pose_rep = self.pose_rep
88
+ if pose_rep == "xyz" or self.translation:
89
+ if getattr(self, "_load_joints3D", None) is not None:
90
+ # Locate the root joint of initial pose at origin
91
+ joints3D = self._load_joints3D(ind, frame_ix)
92
+ joints3D = joints3D - joints3D[0, 0, :]
93
+ ret = to_torch(joints3D)
94
+ if self.translation:
95
+ ret_tr = ret[:, 0, :]
96
+ else:
97
+ if pose_rep == "xyz":
98
+ raise ValueError("This representation is not possible.")
99
+ if getattr(self, "_load_translation") is None:
100
+ raise ValueError("Can't extract translations.")
101
+ ret_tr = self._load_translation(ind, frame_ix)
102
+ ret_tr = to_torch(ret_tr - ret_tr[0])
103
+
104
+ if pose_rep != "xyz":
105
+ if getattr(self, "_load_rotvec", None) is None:
106
+ raise ValueError("This representation is not possible.")
107
+ else:
108
+ pose = self._load_rotvec(ind, frame_ix)
109
+ if not self.glob:
110
+ pose = pose[:, 1:, :]
111
+ pose = to_torch(pose)
112
+ if self.align_pose_frontview:
113
+ first_frame_root_pose_matrix = geometry.axis_angle_to_matrix(pose[0][0])
114
+ all_root_poses_matrix = geometry.axis_angle_to_matrix(pose[:, 0, :])
115
+ aligned_root_poses_matrix = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1),
116
+ all_root_poses_matrix)
117
+ pose[:, 0, :] = geometry.matrix_to_axis_angle(aligned_root_poses_matrix)
118
+
119
+ if self.translation:
120
+ ret_tr = torch.matmul(torch.transpose(first_frame_root_pose_matrix, 0, 1).float(),
121
+ torch.transpose(ret_tr, 0, 1))
122
+ ret_tr = torch.transpose(ret_tr, 0, 1)
123
+
124
+ if pose_rep == "rotvec":
125
+ ret = pose
126
+ elif pose_rep == "rotmat":
127
+ ret = geometry.axis_angle_to_matrix(pose).view(*pose.shape[:2], 9)
128
+ elif pose_rep == "rotquat":
129
+ ret = geometry.axis_angle_to_quaternion(pose)
130
+ elif pose_rep == "rot6d":
131
+ ret = geometry.matrix_to_rotation_6d(geometry.axis_angle_to_matrix(pose))
132
+ if pose_rep != "xyz" and self.translation:
133
+ padded_tr = torch.zeros((ret.shape[0], ret.shape[2]), dtype=ret.dtype)
134
+ padded_tr[:, :3] = ret_tr
135
+ ret = torch.cat((ret, padded_tr[:, None]), 1)
136
+ ret = ret.permute(1, 2, 0).contiguous()
137
+ return ret.float()
138
+
139
+ def _get_item_data_index(self, data_index):
140
+ nframes = self._num_frames_in_video[data_index]
141
+
142
+ if self.num_frames == -1 and (self.max_len == -1 or nframes <= self.max_len):
143
+ frame_ix = np.arange(nframes)
144
+ else:
145
+ if self.num_frames == -2:
146
+ if self.min_len <= 0:
147
+ raise ValueError("You should put a min_len > 0 for num_frames == -2 mode")
148
+ if self.max_len != -1:
149
+ max_frame = min(nframes, self.max_len)
150
+ else:
151
+ max_frame = nframes
152
+
153
+ num_frames = random.randint(self.min_len, max(max_frame, self.min_len))
154
+ else:
155
+ num_frames = self.num_frames if self.num_frames != -1 else self.max_len
156
+
157
+ if num_frames > nframes:
158
+ fair = False # True
159
+ if fair:
160
+ # distills redundancy everywhere
161
+ choices = np.random.choice(range(nframes),
162
+ num_frames,
163
+ replace=True)
164
+ frame_ix = sorted(choices)
165
+ else:
166
+ # adding the last frame until done
167
+ ntoadd = max(0, num_frames - nframes)
168
+ lastframe = nframes - 1
169
+ padding = lastframe * np.ones(ntoadd, dtype=int)
170
+ frame_ix = np.concatenate((np.arange(0, nframes),
171
+ padding))
172
+
173
+ elif self.sampling in ["conseq", "random_conseq"]:
174
+ step_max = (nframes - 1) // (num_frames - 1)
175
+ if self.sampling == "conseq":
176
+ if self.sampling_step == -1 or self.sampling_step * (num_frames - 1) >= nframes:
177
+ step = step_max
178
+ else:
179
+ step = self.sampling_step
180
+ elif self.sampling == "random_conseq":
181
+ step = random.randint(1, step_max)
182
+
183
+ lastone = step * (num_frames - 1)
184
+ shift_max = nframes - lastone - 1
185
+ shift = random.randint(0, max(0, shift_max - 1))
186
+ frame_ix = shift + np.arange(0, lastone + 1, step)
187
+
188
+ elif self.sampling == "random":
189
+ choices = np.random.choice(range(nframes),
190
+ num_frames,
191
+ replace=False)
192
+ frame_ix = sorted(choices)
193
+
194
+ else:
195
+ raise ValueError("Sampling not recognized.")
196
+
197
+ inp, action = self.get_pose_data(data_index, frame_ix)
198
+
199
+
200
+ output = {'inp': inp, 'action': action}
201
+
202
+ if hasattr(self, '_actions') and hasattr(self, '_action_classes'):
203
+ output['action_text'] = self.action_to_action_name(self.get_action(data_index))
204
+
205
+ return output
206
+
207
+
208
+ def get_mean_length_label(self, label):
209
+ if self.num_frames != -1:
210
+ return self.num_frames
211
+
212
+ if self.split == 'train':
213
+ index = self._train
214
+ else:
215
+ index = self._test
216
+
217
+ action = self.label_to_action(label)
218
+ choices = np.argwhere(self._actions[index] == action).squeeze(1)
219
+ lengths = self._num_frames_in_video[np.array(index)[choices]]
220
+
221
+ if self.max_len == -1:
222
+ return np.mean(lengths)
223
+ else:
224
+ # make the lengths less than max_len
225
+ lengths[lengths > self.max_len] = self.max_len
226
+ return np.mean(lengths)
227
+
228
+ def __len__(self):
229
+ num_seq_max = getattr(self, "num_seq_max", -1)
230
+ if num_seq_max == -1:
231
+ from math import inf
232
+ num_seq_max = inf
233
+
234
+ if self.split == 'train':
235
+ return min(len(self._train), num_seq_max)
236
+ else:
237
+ return min(len(self._test), num_seq_max)
238
+
239
+ def shuffle(self):
240
+ if self.split == 'train':
241
+ random.shuffle(self._train)
242
+ else:
243
+ random.shuffle(self._test)
244
+
245
+ def reset_shuffle(self):
246
+ if self.split == 'train':
247
+ if self._original_train is None:
248
+ self._original_train = self._train
249
+ else:
250
+ self._train = self._original_train
251
+ else:
252
+ if self._original_test is None:
253
+ self._original_test = self._test
254
+ else:
255
+ self._test = self._original_test
motion_diffusion_model/data_loaders/a2m/humanact12poses.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle as pkl
2
+ import numpy as np
3
+ import os
4
+ from .dataset import Dataset
5
+
6
+
7
+ class HumanAct12Poses(Dataset):
8
+ dataname = "humanact12"
9
+
10
+ def __init__(self, datapath="dataset/HumanAct12Poses", split="train", **kargs):
11
+ self.datapath = datapath
12
+
13
+ super().__init__(**kargs)
14
+
15
+ pkldatafilepath = os.path.join(datapath, "humanact12poses.pkl")
16
+ data = pkl.load(open(pkldatafilepath, "rb"))
17
+
18
+ self._pose = [x for x in data["poses"]]
19
+ self._num_frames_in_video = [p.shape[0] for p in self._pose]
20
+ self._joints = [x for x in data["joints3D"]]
21
+
22
+ self._actions = [x for x in data["y"]]
23
+
24
+ total_num_actions = 12
25
+ self.num_actions = total_num_actions
26
+
27
+ self._train = list(range(len(self._pose)))
28
+
29
+ keep_actions = np.arange(0, total_num_actions)
30
+
31
+ self._action_to_label = {x: i for i, x in enumerate(keep_actions)}
32
+ self._label_to_action = {i: x for i, x in enumerate(keep_actions)}
33
+
34
+ self._action_classes = humanact12_coarse_action_enumerator
35
+
36
+ def _load_joints3D(self, ind, frame_ix):
37
+ return self._joints[ind][frame_ix]
38
+
39
+ def _load_rotvec(self, ind, frame_ix):
40
+ pose = self._pose[ind][frame_ix].reshape(-1, 24, 3)
41
+ return pose
42
+
43
+
44
+ humanact12_coarse_action_enumerator = {
45
+ 0: "warm_up",
46
+ 1: "walk",
47
+ 2: "run",
48
+ 3: "jump",
49
+ 4: "drink",
50
+ 5: "lift_dumbbell",
51
+ 6: "sit",
52
+ 7: "eat",
53
+ 8: "turn steering wheel",
54
+ 9: "phone",
55
+ 10: "boxing",
56
+ 11: "throw",
57
+ }
motion_diffusion_model/data_loaders/a2m/uestc.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import numpy as np
4
+ import pickle as pkl
5
+ import utils.rotation_conversions as geometry
6
+ import torch
7
+
8
+ from .dataset import Dataset
9
+ # from torch.utils.data import Dataset
10
+
11
+ action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
12
+
13
+
14
+ def get_z(cam_s, cam_pos, joints, img_size, flength):
15
+ """
16
+ Solves for the depth offset of the model to approx. orth with persp camera.
17
+ """
18
+ # Translate the model itself: Solve the best z that maps to orth_proj points
19
+ joints_orth_target = (cam_s * (joints[:, :2] + cam_pos) + 1) * 0.5 * img_size
20
+ height3d = np.linalg.norm(np.max(joints[:, :2], axis=0) - np.min(joints[:, :2], axis=0))
21
+ height2d = np.linalg.norm(np.max(joints_orth_target, axis=0) - np.min(joints_orth_target, axis=0))
22
+ tz = np.array(flength * (height3d / height2d))
23
+ return float(tz)
24
+
25
+
26
+ def get_trans_from_vibe(vibe, index, use_z=True):
27
+ alltrans = []
28
+ for t in range(vibe["joints3d"][index].shape[0]):
29
+ # Convert crop cam to orig cam
30
+ # No need! Because `convert_crop_cam_to_orig_img` from demoutils of vibe
31
+ # does this already for us :)
32
+ # Its format is: [sx, sy, tx, ty]
33
+ cam_orig = vibe["orig_cam"][index][t]
34
+ x = cam_orig[2]
35
+ y = cam_orig[3]
36
+ if use_z:
37
+ z = get_z(cam_s=cam_orig[0], # TODO: There are two scales instead of 1.
38
+ cam_pos=cam_orig[2:4],
39
+ joints=vibe['joints3d'][index][t],
40
+ img_size=540,
41
+ flength=500)
42
+ # z = 500 / (0.5 * 480 * cam_orig[0])
43
+ else:
44
+ z = 0
45
+ trans = [x, y, z]
46
+ alltrans.append(trans)
47
+ alltrans = np.array(alltrans)
48
+ return alltrans - alltrans[0]
49
+
50
+
51
+ class UESTC(Dataset):
52
+ dataname = "uestc"
53
+
54
+ def __init__(self, datapath="dataset/uestc", method_name="vibe", view="all", **kargs):
55
+
56
+ self.datapath = datapath
57
+ self.method_name = method_name
58
+ self.view = view
59
+ super().__init__(**kargs)
60
+
61
+ # Load pre-computed #frames data
62
+ with open(os.path.join(datapath, 'info', 'num_frames_min.txt'), 'r') as f:
63
+ num_frames_video = np.asarray([int(s) for s in f.read().splitlines()])
64
+
65
+ # Out of 118 subjects -> 51 training, 67 in test
66
+ all_subjects = np.arange(1, 119)
67
+ self._tr_subjects = [
68
+ 1, 2, 6, 12, 13, 16, 21, 24, 28, 29, 30, 31, 33, 35, 39, 41, 42, 45, 47, 50,
69
+ 52, 54, 55, 57, 59, 61, 63, 64, 67, 69, 70, 71, 73, 77, 81, 84, 86, 87, 88,
70
+ 90, 91, 93, 96, 99, 102, 103, 104, 107, 108, 112, 113]
71
+ self._test_subjects = [s for s in all_subjects if s not in self._tr_subjects]
72
+
73
+ # Load names of 25600 videos
74
+ with open(os.path.join(datapath, 'info', 'names.txt'), 'r') as f:
75
+ videos = f.read().splitlines()
76
+
77
+ self._videos = videos
78
+
79
+ if self.method_name == "vibe":
80
+ vibe_data_path = os.path.join(datapath, "vibe_cache_refined.pkl")
81
+ vibe_data = pkl.load(open(vibe_data_path, "rb"))
82
+
83
+ self._pose = vibe_data["pose"]
84
+ num_frames_method = [p.shape[0] for p in self._pose]
85
+ globpath = os.path.join(datapath, "globtrans_usez.pkl")
86
+
87
+ if os.path.exists(globpath):
88
+ self._globtrans = pkl.load(open(globpath, "rb"))
89
+ else:
90
+ self._globtrans = []
91
+ for index in tqdm(range(len(self._pose))):
92
+ self._globtrans.append(get_trans_from_vibe(vibe_data, index, use_z=True))
93
+ pkl.dump(self._globtrans, open("globtrans_usez.pkl", "wb"))
94
+ self._joints = vibe_data["joints3d"]
95
+ self._jointsIx = action2motion_joints
96
+ else:
97
+ raise ValueError("This method name is not recognized.")
98
+
99
+ num_frames_video = np.minimum(num_frames_video, num_frames_method)
100
+ num_frames_video = num_frames_video.astype(int)
101
+ self._num_frames_in_video = [x for x in num_frames_video]
102
+
103
+ N = len(videos)
104
+ self._actions = np.zeros(N, dtype=int)
105
+ for ind in range(N):
106
+ self._actions[ind] = self.parse_action(videos[ind])
107
+
108
+ self._actions = [x for x in self._actions]
109
+
110
+ total_num_actions = 40
111
+ self.num_actions = total_num_actions
112
+ keep_actions = np.arange(0, total_num_actions)
113
+
114
+ self._action_to_label = {x: i for i, x in enumerate(keep_actions)}
115
+ self._label_to_action = {i: x for i, x in enumerate(keep_actions)}
116
+ self.num_classes = len(keep_actions)
117
+
118
+ self._train = []
119
+ self._test = []
120
+
121
+ self.info_actions = []
122
+
123
+ def get_rotation(view):
124
+ theta = - view * np.pi/4
125
+ axis = torch.tensor([0, 1, 0], dtype=torch.float)
126
+ axisangle = theta*axis
127
+ matrix = geometry.axis_angle_to_matrix(axisangle)
128
+ return matrix
129
+
130
+ # 0 is identity if needed
131
+ rotations = {key: get_rotation(key) for key in [0, 1, 2, 3, 4, 5, 6, 7]}
132
+
133
+ for index, video in enumerate(tqdm(videos, desc='Preparing UESTC data..')):
134
+ act, view, subject, side = self._get_action_view_subject_side(video)
135
+ self.info_actions.append({"action": act,
136
+ "view": view,
137
+ "subject": subject,
138
+ "side": side})
139
+ if self.view == "frontview":
140
+ if side != 1:
141
+ continue
142
+ # rotate to front view
143
+ if side != 1:
144
+ # don't take the view 8 in side 2
145
+ if view == 8:
146
+ continue
147
+ rotation = rotations[view]
148
+ global_matrix = geometry.axis_angle_to_matrix(torch.from_numpy(self._pose[index][:, :3]))
149
+ # rotate the global pose
150
+ self._pose[index][:, :3] = geometry.matrix_to_axis_angle(rotation @ global_matrix).numpy()
151
+ # rotate the joints
152
+ self._joints[index] = self._joints[index] @ rotation.T.numpy()
153
+ self._globtrans[index] = (self._globtrans[index] @ rotation.T.numpy())
154
+
155
+ # add the global translation to the joints
156
+ self._joints[index] = self._joints[index] + self._globtrans[index][:, None]
157
+
158
+ if subject in self._tr_subjects:
159
+ self._train.append(index)
160
+ elif subject in self._test_subjects:
161
+ self._test.append(index)
162
+ else:
163
+ raise ValueError("This subject doesn't belong to any set.")
164
+
165
+ # if index > 200:
166
+ # break
167
+
168
+ # Select only sequences which have a minimum number of frames
169
+ if self.num_frames > 0:
170
+ threshold = self.num_frames*3/4
171
+ else:
172
+ threshold = 0
173
+
174
+ method_extracted_ix = np.where(num_frames_video >= threshold)[0].tolist()
175
+ self._train = list(set(self._train) & set(method_extracted_ix))
176
+ # keep the test set without modification
177
+ self._test = list(set(self._test))
178
+
179
+ action_classes_file = os.path.join(datapath, "info/action_classes.txt")
180
+ with open(action_classes_file, 'r') as f:
181
+ self._action_classes = np.array(f.read().splitlines())
182
+
183
+ # with open(processd_path, 'wb') as file:
184
+ # pkl.dump(xxx, file)
185
+
186
+ def _load_joints3D(self, ind, frame_ix):
187
+ if len(self._joints[ind]) == 0:
188
+ raise ValueError(
189
+ f"Cannot load index {ind} in _load_joints3D function.")
190
+ if self._jointsIx is not None:
191
+ joints3D = self._joints[ind][frame_ix][:, self._jointsIx]
192
+ else:
193
+ joints3D = self._joints[ind][frame_ix]
194
+
195
+ return joints3D
196
+
197
+ def _load_rotvec(self, ind, frame_ix):
198
+ # 72 dim smpl
199
+ pose = self._pose[ind][frame_ix, :].reshape(-1, 24, 3)
200
+ return pose
201
+
202
+ def _get_action_view_subject_side(self, videopath):
203
+ # TODO: Can be moved to tools.py
204
+ spl = videopath.split('_')
205
+ action = int(spl[0][1:])
206
+ view = int(spl[1][1:])
207
+ subject = int(spl[2][1:])
208
+ side = int(spl[3][1:])
209
+ return action, view, subject, side
210
+
211
+ def _get_videopath(self, action, view, subject, side):
212
+ # Unused function
213
+ return 'a{:d}_d{:d}_p{:03d}_c{:d}_color.avi'.format(
214
+ action, view, subject, side)
215
+
216
+ def parse_action(self, path, return_int=True):
217
+ # Override parent method
218
+ info, _, _, _ = self._get_action_view_subject_side(path)
219
+ if return_int:
220
+ return int(info)
221
+ else:
222
+ return info
223
+
224
+
225
+ if __name__ == "__main__":
226
+ dataset = UESTC()
motion_diffusion_model/data_loaders/get_data.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from data_loaders.tensors import collate as all_collate
3
+ from data_loaders.tensors import t2m_collate, t2m_prefix_collate
4
+
5
+ def get_dataset_class(name):
6
+ if name == "amass":
7
+ from .amass import AMASS
8
+ return AMASS
9
+ elif name == "uestc":
10
+ from .a2m.uestc import UESTC
11
+ return UESTC
12
+ elif name == "humanact12":
13
+ from .a2m.humanact12poses import HumanAct12Poses
14
+ return HumanAct12Poses
15
+ elif name == "humanml":
16
+ from data_loaders.humanml.data.dataset import HumanML3D
17
+ return HumanML3D
18
+ elif name == "kit":
19
+ from data_loaders.humanml.data.dataset import KIT
20
+ return KIT
21
+ else:
22
+ raise ValueError(f'Unsupported dataset name [{name}]')
23
+
24
+ def get_collate_fn(name, hml_mode='train', pred_len=0, batch_size=1):
25
+ if hml_mode == 'gt':
26
+ from data_loaders.humanml.data.dataset import collate_fn as t2m_eval_collate
27
+ return t2m_eval_collate
28
+ if name in ["humanml", "kit"]:
29
+ if pred_len > 0:
30
+ return lambda x: t2m_prefix_collate(x, pred_len=pred_len)
31
+ return lambda x: t2m_collate(x, batch_size)
32
+ else:
33
+ return all_collate
34
+
35
+
36
+ def get_dataset(name, num_frames, split='train', hml_mode='train', abs_path='.', fixed_len=0,
37
+ device=None, autoregressive=False, cache_path=None):
38
+ DATA = get_dataset_class(name)
39
+ if name in ["humanml", "kit"]:
40
+ dataset = DATA(split=split, num_frames=num_frames, mode=hml_mode, abs_path=abs_path, fixed_len=fixed_len,
41
+ device=device, autoregressive=autoregressive)
42
+ else:
43
+ dataset = DATA(split=split, num_frames=num_frames)
44
+ return dataset
45
+
46
+
47
+ def get_dataset_loader(name, batch_size, num_frames, split='train', hml_mode='train', fixed_len=0, pred_len=0,
48
+ device=None, autoregressive=False):
49
+ dataset = get_dataset(name, num_frames, split=split, hml_mode=hml_mode, fixed_len=fixed_len,
50
+ device=device, autoregressive=autoregressive)
51
+
52
+ collate = get_collate_fn(name, hml_mode, pred_len, batch_size)
53
+
54
+ loader = DataLoader(
55
+ dataset, batch_size=batch_size, shuffle=True,
56
+ num_workers=8, drop_last=True, collate_fn=collate
57
+ )
58
+
59
+ return loader
motion_diffusion_model/data_loaders/humanml/README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ This code is based on https://github.com/EricGuo5513/text-to-motion.git
motion_diffusion_model/data_loaders/humanml/common/quaternion.py ADDED
@@ -0,0 +1,425 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(float).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ q[..., -1] += 1e-4 # Guy - for safty, avoid zero devision
31
+ return q / torch.norm(q, dim=-1, keepdim=True)
32
+
33
+
34
+ def qmul(q, r):
35
+ """
36
+ Multiply quaternion(s) q with quaternion(s) r.
37
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
38
+ Returns q*r as a tensor of shape (*, 4).
39
+ """
40
+ assert q.shape[-1] == 4
41
+ assert r.shape[-1] == 4
42
+
43
+ original_shape = q.shape
44
+
45
+ # Compute outer product
46
+ # terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
47
+ terms = torch.bmm(r.reshape(-1, 4, 1), q.reshape(-1, 1, 4))
48
+
49
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
50
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
51
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
52
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
53
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
54
+
55
+
56
+ def qrot(q, v):
57
+ """
58
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
59
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
60
+ where * denotes any number of dimensions.
61
+ Returns a tensor of shape (*, 3).
62
+ """
63
+ assert q.shape[-1] == 4
64
+ assert v.shape[-1] == 3
65
+ assert q.shape[:-1] == v.shape[:-1]
66
+
67
+ original_shape = list(v.shape)
68
+ # print(q.shape)
69
+ q = q.contiguous().view(-1, 4)
70
+ v = v.contiguous().view(-1, 3)
71
+
72
+ qvec = q[:, 1:].to(v.device)
73
+ uv = torch.cross(qvec, v, dim=1)
74
+ uuv = torch.cross(qvec, uv, dim=1)
75
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
76
+
77
+
78
+ def qeuler(q, order, epsilon=0, deg=True):
79
+ """
80
+ Convert quaternion(s) q to Euler angles.
81
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
82
+ Returns a tensor of shape (*, 3).
83
+ """
84
+ assert q.shape[-1] == 4
85
+
86
+ original_shape = list(q.shape)
87
+ original_shape[-1] = 3
88
+ q = q.view(-1, 4)
89
+
90
+ q0 = q[:, 0]
91
+ q1 = q[:, 1]
92
+ q2 = q[:, 2]
93
+ q3 = q[:, 3]
94
+
95
+ if order == 'xyz':
96
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
97
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
98
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
99
+ elif order == 'yzx':
100
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
101
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
102
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
103
+ elif order == 'zxy':
104
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
105
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
106
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ elif order == 'xzy':
108
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
109
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
110
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
111
+ elif order == 'yxz':
112
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
113
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
114
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
115
+ elif order == 'zyx':
116
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
117
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
118
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
119
+ else:
120
+ raise
121
+
122
+ if deg:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
124
+ else:
125
+ return torch.stack((x, y, z), dim=1).view(original_shape)
126
+
127
+
128
+ # Numpy-backed implementations
129
+
130
+ def qmul_np(q, r):
131
+ q = torch.from_numpy(q).contiguous().float()
132
+ r = torch.from_numpy(r).contiguous().float()
133
+ return qmul(q, r).numpy()
134
+
135
+
136
+ def qrot_np(q, v):
137
+ q = torch.from_numpy(q).contiguous().float()
138
+ v = torch.from_numpy(v).contiguous().float()
139
+ return qrot(q, v).numpy()
140
+
141
+
142
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
143
+ if use_gpu:
144
+ q = torch.from_numpy(q).cuda().float()
145
+ return qeuler(q, order, epsilon).cpu().numpy()
146
+ else:
147
+ q = torch.from_numpy(q).contiguous().float()
148
+ return qeuler(q, order, epsilon).numpy()
149
+
150
+
151
+ def qfix(q):
152
+ """
153
+ Enforce quaternion continuity across the time dimension by selecting
154
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
155
+ between two consecutive frames.
156
+
157
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
158
+ Returns a tensor of the same shape.
159
+ """
160
+ assert len(q.shape) == 3
161
+ assert q.shape[-1] == 4
162
+
163
+ result = q.copy()
164
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
165
+ mask = dot_products < 0
166
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
167
+ result[1:][mask] *= -1
168
+ return result
169
+
170
+
171
+ def euler2quat(e, order, deg=True):
172
+ """
173
+ Convert Euler angles to quaternions.
174
+ """
175
+ assert e.shape[-1] == 3
176
+
177
+ original_shape = list(e.shape)
178
+ original_shape[-1] = 4
179
+
180
+ e = e.view(-1, 3)
181
+
182
+ ## if euler angles in degrees
183
+ if deg:
184
+ e = e * np.pi / 180.
185
+
186
+ x = e[:, 0]
187
+ y = e[:, 1]
188
+ z = e[:, 2]
189
+
190
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
191
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
192
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
193
+
194
+ result = None
195
+ for coord in order:
196
+ if coord == 'x':
197
+ r = rx
198
+ elif coord == 'y':
199
+ r = ry
200
+ elif coord == 'z':
201
+ r = rz
202
+ else:
203
+ raise
204
+ if result is None:
205
+ result = r
206
+ else:
207
+ result = qmul(result, r)
208
+
209
+ # Reverse antipodal representation to have a non-negative "w"
210
+ if order in ['xyz', 'yzx', 'zxy']:
211
+ result *= -1
212
+
213
+ return result.view(original_shape)
214
+
215
+
216
+ def expmap_to_quaternion(e):
217
+ """
218
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
219
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
220
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
221
+ Returns a tensor of shape (*, 4).
222
+ """
223
+ assert e.shape[-1] == 3
224
+
225
+ original_shape = list(e.shape)
226
+ original_shape[-1] = 4
227
+ e = e.reshape(-1, 3)
228
+
229
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
230
+ w = np.cos(0.5 * theta).reshape(-1, 1)
231
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
232
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
233
+
234
+
235
+ def euler_to_quaternion(e, order):
236
+ """
237
+ Convert Euler angles to quaternions.
238
+ """
239
+ assert e.shape[-1] == 3
240
+
241
+ original_shape = list(e.shape)
242
+ original_shape[-1] = 4
243
+
244
+ e = e.reshape(-1, 3)
245
+
246
+ x = e[:, 0]
247
+ y = e[:, 1]
248
+ z = e[:, 2]
249
+
250
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
251
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
252
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
253
+
254
+ result = None
255
+ for coord in order:
256
+ if coord == 'x':
257
+ r = rx
258
+ elif coord == 'y':
259
+ r = ry
260
+ elif coord == 'z':
261
+ r = rz
262
+ else:
263
+ raise
264
+ if result is None:
265
+ result = r
266
+ else:
267
+ result = qmul_np(result, r)
268
+
269
+ # Reverse antipodal representation to have a non-negative "w"
270
+ if order in ['xyz', 'yzx', 'zxy']:
271
+ result *= -1
272
+
273
+ return result.reshape(original_shape)
274
+
275
+
276
+ def quaternion_to_matrix(quaternions):
277
+ """
278
+ Convert rotations given as quaternions to rotation matrices.
279
+ Args:
280
+ quaternions: quaternions with real part first,
281
+ as tensor of shape (..., 4).
282
+ Returns:
283
+ Rotation matrices as tensor of shape (..., 3, 3).
284
+ """
285
+ r, i, j, k = torch.unbind(quaternions, -1)
286
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
287
+
288
+ o = torch.stack(
289
+ (
290
+ 1 - two_s * (j * j + k * k),
291
+ two_s * (i * j - k * r),
292
+ two_s * (i * k + j * r),
293
+ two_s * (i * j + k * r),
294
+ 1 - two_s * (i * i + k * k),
295
+ two_s * (j * k - i * r),
296
+ two_s * (i * k - j * r),
297
+ two_s * (j * k + i * r),
298
+ 1 - two_s * (i * i + j * j),
299
+ ),
300
+ -1,
301
+ )
302
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
303
+
304
+
305
+ def quaternion_to_matrix_np(quaternions):
306
+ q = torch.from_numpy(quaternions).contiguous().float()
307
+ return quaternion_to_matrix(q).numpy()
308
+
309
+
310
+ def quaternion_to_cont6d_np(quaternions):
311
+ rotation_mat = quaternion_to_matrix_np(quaternions)
312
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
313
+ return cont_6d
314
+
315
+
316
+ def quaternion_to_cont6d(quaternions):
317
+ rotation_mat = quaternion_to_matrix(quaternions)
318
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
319
+ return cont_6d
320
+
321
+
322
+ def cont6d_to_matrix(cont6d):
323
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
324
+ x_raw = cont6d[..., 0:3]
325
+ y_raw = cont6d[..., 3:6]
326
+
327
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
328
+ z = torch.cross(x, y_raw, dim=-1)
329
+ z = z / torch.norm(z, dim=-1, keepdim=True)
330
+
331
+ y = torch.cross(z, x, dim=-1)
332
+
333
+ x = x[..., None]
334
+ y = y[..., None]
335
+ z = z[..., None]
336
+
337
+ mat = torch.cat([x, y, z], dim=-1)
338
+ return mat
339
+
340
+
341
+ def cont6d_to_matrix_np(cont6d):
342
+ q = torch.from_numpy(cont6d).contiguous().float()
343
+ return cont6d_to_matrix(q).numpy()
344
+
345
+
346
+ def qpow(q0, t, dtype=torch.float):
347
+ ''' q0 : tensor of quaternions
348
+ t: tensor of powers
349
+ '''
350
+ q0 = qnormalize(q0)
351
+ theta0 = torch.acos(q0[..., 0])
352
+
353
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
354
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
355
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
356
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
357
+
358
+ if isinstance(t, torch.Tensor):
359
+ q = torch.zeros(t.shape + q0.shape)
360
+ theta = t.view(-1, 1) * theta0.view(1, -1)
361
+ else: ## if t is a number
362
+ q = torch.zeros(q0.shape)
363
+ theta = t * theta0
364
+
365
+ q[..., 0] = torch.cos(theta)
366
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
367
+
368
+ return q.to(dtype)
369
+
370
+
371
+ def qslerp(q0, q1, t):
372
+ '''
373
+ q0: starting quaternion
374
+ q1: ending quaternion
375
+ t: array of points along the way
376
+
377
+ Returns:
378
+ Tensor of Slerps: t.shape + q0.shape
379
+ '''
380
+
381
+ q0 = qnormalize(q0)
382
+ q1 = qnormalize(q1)
383
+ q_ = qpow(qmul(q1, qinv(q0)), t)
384
+
385
+ return qmul(q_,
386
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
387
+
388
+
389
+ def qbetween(v0, v1):
390
+ '''
391
+ find the quaternion used to rotate v0 to v1
392
+ '''
393
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
394
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
395
+
396
+ v = torch.cross(v0, v1)
397
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
398
+ keepdim=True)
399
+ return qnormalize(torch.cat([w, v], dim=-1))
400
+
401
+
402
+ def qbetween_np(v0, v1):
403
+ '''
404
+ find the quaternion used to rotate v0 to v1
405
+ '''
406
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
407
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
408
+
409
+ v0 = torch.from_numpy(v0).float()
410
+ v1 = torch.from_numpy(v1).float()
411
+ return qbetween(v0, v1).numpy()
412
+
413
+
414
+ def lerp(p0, p1, t):
415
+ if not isinstance(t, torch.Tensor):
416
+ t = torch.Tensor([t])
417
+
418
+ new_shape = t.shape + p0.shape
419
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
420
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
421
+ p0 = p0.view(new_view_p).expand(new_shape)
422
+ p1 = p1.view(new_view_p).expand(new_shape)
423
+ t = t.view(new_view_t).expand(new_shape)
424
+
425
+ return p0 + t * (p1 - p0)
motion_diffusion_model/data_loaders/humanml/common/skeleton.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_loaders.humanml.common.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False, fix_bug=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ if fix_bug:
59
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_idx
60
+ else:
61
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
62
+ across1 = joints[:, r_hip] - joints[:, l_hip]
63
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
64
+ across = across1 + across2
65
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
66
+ # print(across1.shape, across2.shape)
67
+
68
+ # forward (batch_size, 3)
69
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
70
+ if smooth_forward:
71
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
72
+ # forward (batch_size, 3)
73
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
74
+
75
+ '''Get Root Rotation'''
76
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
77
+ root_quat = qbetween_np(forward, target) # angle from root to Z+ (= how much to rotate root such that it faces Z+)
78
+
79
+ '''Inverse Kinematics'''
80
+ # quat_params (batch_size, joints_num, 4)
81
+ # print(joints.shape[:-1])
82
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
83
+ # print(quat_params.shape)
84
+ # root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]]) # this is a bug: the rotation of next joint in chain is computed wrt the root joint, which is now 0, but the next joint was not moved so it is like a huge rotation
85
+ quat_params[:, 0] = root_quat
86
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
87
+ for chain in self._kinematic_tree:
88
+ R = root_quat
89
+ for j in range(len(chain) - 1):
90
+ # (batch, 3)
91
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0) # rest-pose bone direction for joint j in the chain
92
+ # print(u.shape)
93
+ # (batch, 3)
94
+ v = joints[:, chain[j+1]] - joints[:, chain[j]] # data bone direction for joint j+1 in the chain
95
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
96
+ # print(u.shape, v.shape)
97
+ rot_u_v = qbetween_np(u, v) # angle betweem rest-pose bone and data bone (bone is j to j+1)
98
+
99
+ R_loc = qmul_np(qinv_np(R), rot_u_v) # bring angle to be local coordinate system, i.e., relative to the parent bone
100
+
101
+ quat_params[:,chain[j + 1], :] = R_loc
102
+ R = qmul_np(R, R_loc)
103
+
104
+ return quat_params
105
+
106
+ # Be sure root joint is at the beginning of kinematic chains
107
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
108
+ # quat_params (batch_size, joints_num, 4)
109
+ # joints (batch_size, joints_num, 3)
110
+ # root_pos (batch_size, 3)
111
+ if skel_joints is not None:
112
+ offsets = self.get_offsets_joints_batch(skel_joints)
113
+ if len(self._offset.shape) == 2:
114
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
115
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
116
+ joints[:, 0] = root_pos
117
+ for chain in self._kinematic_tree:
118
+ if do_root_R:
119
+ R = quat_params[:, 0]
120
+ else:
121
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
122
+ for i in range(1, len(chain)):
123
+ R = qmul(R, quat_params[:, chain[i]])
124
+ offset_vec = offsets[:, chain[i]]
125
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
126
+ return joints
127
+
128
+ # Be sure root joint is at the beginning of kinematic chains
129
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
130
+ # quat_params (batch_size, joints_num, 4)
131
+ # joints (batch_size, joints_num, 3)
132
+ # root_pos (batch_size, 3)
133
+ if skel_joints is not None:
134
+ skel_joints = torch.from_numpy(skel_joints)
135
+ offsets = self.get_offsets_joints_batch(skel_joints)
136
+ if len(self._offset.shape) == 2:
137
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
138
+ offsets = offsets.numpy()
139
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
140
+ joints[:, 0] = root_pos
141
+ for chain in self._kinematic_tree:
142
+ if do_root_R:
143
+ R = quat_params[:, 0]
144
+ else:
145
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
146
+ for i in range(1, len(chain)):
147
+ R = qmul_np(R, quat_params[:, chain[i]])
148
+ offset_vec = offsets[:, chain[i]]
149
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
150
+ return joints
151
+
152
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
153
+ # cont6d_params (batch_size, joints_num, 6)
154
+ # joints (batch_size, joints_num, 3)
155
+ # root_pos (batch_size, 3)
156
+ if skel_joints is not None:
157
+ skel_joints = torch.from_numpy(skel_joints)
158
+ offsets = self.get_offsets_joints_batch(skel_joints)
159
+ if len(self._offset.shape) == 2:
160
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
161
+ offsets = offsets.numpy()
162
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
163
+ joints[:, 0] = root_pos
164
+ for chain in self._kinematic_tree:
165
+ if do_root_R:
166
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
167
+ else:
168
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
169
+ for i in range(1, len(chain)):
170
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
171
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
172
+ # print(matR.shape, offset_vec.shape)
173
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
174
+ return joints
175
+
176
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
177
+ # cont6d_params (batch_size, joints_num, 6)
178
+ # joints (batch_size, joints_num, 3)
179
+ # root_pos (batch_size, 3)
180
+ if skel_joints is not None:
181
+ # skel_joints = torch.from_numpy(skel_joints)
182
+ offsets = self.get_offsets_joints_batch(skel_joints)
183
+ if len(self._offset.shape) == 2:
184
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
185
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
186
+ joints[..., 0, :] = root_pos
187
+ for chain in self._kinematic_tree:
188
+ if do_root_R:
189
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
190
+ else:
191
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
192
+ for i in range(1, len(chain)):
193
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
194
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
195
+ # print(matR.shape, offset_vec.shape)
196
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
197
+ return joints
198
+
199
+
200
+
201
+
202
+
motion_diffusion_model/data_loaders/humanml/data/__init__.py ADDED
File without changes
motion_diffusion_model/data_loaders/humanml/data/dataset.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ import os
5
+ from os.path import join as pjoin
6
+ import random
7
+ import codecs as cs
8
+ from tqdm import tqdm
9
+ import spacy
10
+
11
+ from torch.utils.data._utils.collate import default_collate
12
+ from data_loaders.humanml.utils.word_vectorizer import WordVectorizer
13
+ from data_loaders.humanml.utils.get_opt import get_opt
14
+
15
+ # import spacy
16
+
17
+ def collate_fn(batch):
18
+ batch.sort(key=lambda x: x[3], reverse=True)
19
+ return default_collate(batch)
20
+
21
+
22
+ '''For use of training text-2-motion generative model'''
23
+ class Text2MotionDataset(data.Dataset):
24
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
25
+ self.opt = opt
26
+ self.w_vectorizer = w_vectorizer
27
+ self.max_length = 20
28
+ self.pointer = 0
29
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
30
+
31
+ joints_num = opt.joints_num
32
+
33
+ data_dict = {}
34
+ id_list = []
35
+ with cs.open(split_file, 'r') as f:
36
+ for line in f.readlines():
37
+ id_list.append(line.strip())
38
+
39
+ new_name_list = []
40
+ length_list = []
41
+ for name in tqdm(id_list):
42
+ try:
43
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
44
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
45
+ continue
46
+ text_data = []
47
+ flag = False
48
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
49
+ for line in f.readlines():
50
+ text_dict = {}
51
+ line_split = line.strip().split('#')
52
+ caption = line_split[0]
53
+ tokens = line_split[1].split(' ')
54
+ f_tag = float(line_split[2])
55
+ to_tag = float(line_split[3])
56
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
57
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
58
+
59
+ text_dict['caption'] = caption
60
+ text_dict['tokens'] = tokens
61
+ if f_tag == 0.0 and to_tag == 0.0:
62
+ flag = True
63
+ text_data.append(text_dict)
64
+ else:
65
+ try:
66
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
67
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
68
+ continue
69
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
70
+ while new_name in data_dict:
71
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
72
+ data_dict[new_name] = {'motion': n_motion,
73
+ 'length': len(n_motion),
74
+ 'text':[text_dict]}
75
+ new_name_list.append(new_name)
76
+ length_list.append(len(n_motion))
77
+ except:
78
+ print(line_split)
79
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
80
+ # break
81
+
82
+ if flag:
83
+ data_dict[name] = {'motion': motion,
84
+ 'length': len(motion),
85
+ 'text':text_data}
86
+ new_name_list.append(name)
87
+ length_list.append(len(motion))
88
+ except:
89
+ # Some motion may not exist in KIT dataset
90
+ pass
91
+
92
+
93
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
94
+
95
+ if opt.is_train:
96
+ # root_rot_velocity (B, seq_len, 1)
97
+ std[0:1] = std[0:1] / opt.feat_bias
98
+ # root_linear_velocity (B, seq_len, 2)
99
+ std[1:3] = std[1:3] / opt.feat_bias
100
+ # root_y (B, seq_len, 1)
101
+ std[3:4] = std[3:4] / opt.feat_bias
102
+ # ric_data (B, seq_len, (joint_num - 1)*3)
103
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
104
+ # rot_data (B, seq_len, (joint_num - 1)*6)
105
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
106
+ joints_num - 1) * 9] / 1.0
107
+ # local_velocity (B, seq_len, joint_num*3)
108
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
109
+ 4 + (joints_num - 1) * 9: 4 + (
110
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
111
+ # foot contact (B, seq_len, 4)
112
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
113
+ 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
114
+
115
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
116
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
117
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
118
+
119
+ self.mean = mean
120
+ self.std = std
121
+ self.length_arr = np.array(length_list)
122
+ self.data_dict = data_dict
123
+ self.name_list = name_list
124
+ self.reset_max_len(self.max_length)
125
+
126
+ def reset_max_len(self, length):
127
+ assert length <= self.opt.max_motion_length
128
+ self.pointer = np.searchsorted(self.length_arr, length)
129
+ print("Pointer Pointing at %d"%self.pointer)
130
+ self.max_length = length
131
+
132
+ def inv_transform(self, data):
133
+ return data * self.std + self.mean
134
+
135
+ def __len__(self):
136
+ return len(self.data_dict) - self.pointer
137
+
138
+ def __getitem__(self, item):
139
+ idx = self.pointer + item
140
+ data = self.data_dict[self.name_list[idx]]
141
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
142
+ # Randomly select a caption
143
+ text_data = random.choice(text_list)
144
+ caption, tokens = text_data['caption'], text_data['tokens']
145
+
146
+ if len(tokens) < self.opt.max_text_len:
147
+ # pad with "unk"
148
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
149
+ sent_len = len(tokens)
150
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
151
+ else:
152
+ # crop
153
+ tokens = tokens[:self.opt.max_text_len]
154
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
155
+ sent_len = len(tokens)
156
+ pos_one_hots = []
157
+ word_embeddings = []
158
+ for token in tokens:
159
+ word_emb, pos_oh = self.w_vectorizer[token]
160
+ pos_one_hots.append(pos_oh[None, :])
161
+ word_embeddings.append(word_emb[None, :])
162
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
163
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
164
+
165
+ len_gap = (m_length - self.max_length) // self.opt.unit_length
166
+
167
+ if self.opt.is_train:
168
+ if m_length != self.max_length:
169
+ # print("Motion original length:%d_%d"%(m_length, len(motion)))
170
+ if self.opt.unit_length < 10:
171
+ coin2 = np.random.choice(['single', 'single', 'double'])
172
+ else:
173
+ coin2 = 'single'
174
+ if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
175
+ m_length = self.max_length
176
+ idx = random.randint(0, m_length - self.max_length)
177
+ motion = motion[idx:idx+self.max_length]
178
+ else:
179
+ if coin2 == 'single':
180
+ n_m_length = self.max_length + self.opt.unit_length * len_gap
181
+ else:
182
+ n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
183
+ idx = random.randint(0, m_length - n_m_length)
184
+ motion = motion[idx:idx + self.max_length]
185
+ m_length = n_m_length
186
+ # print(len_gap, idx, coin2)
187
+ else:
188
+ if self.opt.unit_length < 10:
189
+ coin2 = np.random.choice(['single', 'single', 'double'])
190
+ else:
191
+ coin2 = 'single'
192
+
193
+ if coin2 == 'double':
194
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
195
+ elif coin2 == 'single':
196
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
197
+ idx = random.randint(0, len(motion) - m_length)
198
+ motion = motion[idx:idx+m_length]
199
+
200
+ "Z Normalization"
201
+ motion = (motion - self.mean) / self.std
202
+
203
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length
204
+
205
+
206
+ '''For use of training text motion matching model, and evaluations'''
207
+ class Text2MotionDatasetV2(data.Dataset):
208
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
209
+ self.opt = opt
210
+ self.w_vectorizer = w_vectorizer
211
+ self.max_length = 20
212
+ if self.opt.fixed_len > 0:
213
+ self.max_length = self.opt.fixed_len
214
+ self.pointer = 0
215
+ self.max_motion_length = opt.max_motion_length
216
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
217
+
218
+ data_dict = {}
219
+ id_list = []
220
+ with cs.open(split_file, 'r') as f:
221
+ for line in f.readlines():
222
+ id_list.append(line.strip())
223
+ # id_list = id_list[:200]
224
+
225
+ new_name_list = []
226
+ length_list = []
227
+
228
+ _split = os.path.basename(split_file).replace('.txt', '')
229
+ _name =''
230
+ # cache_path = os.path.join(opt.meta_dir, self.opt.dataset_name + '_' + _split + _name + '.npy')
231
+ cache_path = os.path.join(opt.cache_dir, 'dataset', self.opt.dataset_name + '_' + _split + _name + '.npy')
232
+ if opt.use_cache and os.path.exists(cache_path):
233
+ print(f'Loading motions from cache file [{cache_path}]...')
234
+ _cache = np.load(cache_path, allow_pickle=True)[None][0]
235
+ name_list, length_list, data_dict = _cache['name_list'], _cache['length_list'], _cache['data_dict']
236
+ # name_list = name_list[:15]; length_list = length_list[:15]
237
+ # data_dict = {key: data_dict[key] for key in name_list}
238
+ else:
239
+ for name in tqdm(id_list):
240
+ try:
241
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
242
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
243
+ continue
244
+ text_data = []
245
+ flag = False
246
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
247
+ for line in f.readlines():
248
+ text_dict = {}
249
+ line_split = line.strip().split('#')
250
+ caption = line_split[0]
251
+ tokens = line_split[1].split(' ')
252
+ f_tag = float(line_split[2])
253
+ to_tag = float(line_split[3])
254
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
255
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
256
+
257
+ text_dict['caption'] = caption
258
+ text_dict['tokens'] = tokens
259
+ if f_tag == 0.0 and to_tag == 0.0:
260
+ flag = True
261
+ text_data.append(text_dict)
262
+ else:
263
+ try:
264
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
265
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
266
+ continue
267
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
268
+ while new_name in data_dict:
269
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
270
+ data_dict[new_name] = {'motion': n_motion,
271
+ 'length': len(n_motion),
272
+ 'text':[text_dict]}
273
+ new_name_list.append(new_name)
274
+ length_list.append(len(n_motion))
275
+ except:
276
+ print(line_split)
277
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
278
+ # break
279
+
280
+ if flag:
281
+ data_dict[name] = {'motion': motion,
282
+ 'length': len(motion),
283
+ 'text': text_data}
284
+ new_name_list.append(name)
285
+ length_list.append(len(motion))
286
+ except:
287
+ pass
288
+
289
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
290
+ print(f'Saving motions to cache file [{cache_path}]...')
291
+ np.save(cache_path, {
292
+ 'name_list': name_list,
293
+ 'length_list': length_list,
294
+ 'data_dict': data_dict})
295
+
296
+ self.mean = mean
297
+ self.std = std
298
+ self.length_arr = np.array(length_list)
299
+ self.data_dict = data_dict
300
+ self.name_list = name_list
301
+ self.reset_max_len(self.max_length)
302
+
303
+ def reset_max_len(self, length):
304
+ assert length <= self.max_motion_length
305
+ self.pointer = np.searchsorted(self.length_arr, length)
306
+ print("Pointer Pointing at %d"%self.pointer)
307
+ self.max_length = length
308
+
309
+ def inv_transform(self, data):
310
+ return data * self.std + self.mean
311
+
312
+ def __len__(self):
313
+ return len(self.data_dict) - self.pointer
314
+
315
+ def __getitem__(self, item):
316
+ idx = self.pointer + item
317
+ key = self.name_list[idx]
318
+ data = self.data_dict[key]
319
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
320
+ # Randomly select a caption
321
+ text_data = random.choice(text_list)
322
+ caption, tokens = text_data['caption'], text_data['tokens']
323
+
324
+ if len(tokens) < self.opt.max_text_len:
325
+ # pad with "unk"
326
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
327
+ sent_len = len(tokens)
328
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
329
+ else:
330
+ # crop
331
+ tokens = tokens[:self.opt.max_text_len]
332
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
333
+ sent_len = len(tokens)
334
+ pos_one_hots = []
335
+ word_embeddings = []
336
+ for token in tokens:
337
+ word_emb, pos_oh = self.w_vectorizer[token]
338
+ pos_one_hots.append(pos_oh[None, :])
339
+ word_embeddings.append(word_emb[None, :])
340
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
341
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
342
+
343
+ # Crop the motions in to times of 4, and introduce small variations
344
+ if self.opt.unit_length < 10:
345
+ coin2 = np.random.choice(['single', 'single', 'double'])
346
+ else:
347
+ coin2 = 'single'
348
+
349
+ if coin2 == 'double':
350
+ m_length = (m_length // self.opt.unit_length - 1) * self.opt.unit_length
351
+ elif coin2 == 'single':
352
+ m_length = (m_length // self.opt.unit_length) * self.opt.unit_length
353
+
354
+ original_length = None
355
+ if self.opt.fixed_len > 0:
356
+ # Crop fixed_len
357
+ original_length = m_length
358
+ m_length = self.opt.fixed_len
359
+
360
+ idx = random.randint(0, len(motion) - m_length)
361
+ if self.opt.disable_offset_aug:
362
+ idx = random.randint(0, self.opt.unit_length)
363
+ motion = motion[idx:idx+m_length]
364
+
365
+ "Z Normalization"
366
+ motion = (motion - self.mean) / self.std
367
+
368
+ if m_length < self.max_motion_length:
369
+ motion = np.concatenate([motion,
370
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
371
+ ], axis=0)
372
+ # print(word_embeddings.shape, motion.shape)
373
+ # print(tokens)
374
+
375
+ length = (original_length, m_length) if self.opt.fixed_len > 0 else m_length
376
+
377
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, length, '_'.join(tokens)
378
+
379
+
380
+ '''For use of training baseline'''
381
+ class Text2MotionDatasetBaseline(data.Dataset):
382
+ def __init__(self, opt, mean, std, split_file, w_vectorizer):
383
+ self.opt = opt
384
+ self.w_vectorizer = w_vectorizer
385
+ self.max_length = 20
386
+ self.pointer = 0
387
+ self.max_motion_length = opt.max_motion_length
388
+ min_motion_len = 40 if self.opt.dataset_name =='t2m' else 24
389
+
390
+ data_dict = {}
391
+ id_list = []
392
+ with cs.open(split_file, 'r') as f:
393
+ for line in f.readlines():
394
+ id_list.append(line.strip())
395
+ # id_list = id_list[:200]
396
+
397
+ new_name_list = []
398
+ length_list = []
399
+ for name in tqdm(id_list):
400
+ try:
401
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
402
+ if (len(motion)) < min_motion_len or (len(motion) >= 200):
403
+ continue
404
+ text_data = []
405
+ flag = False
406
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
407
+ for line in f.readlines():
408
+ text_dict = {}
409
+ line_split = line.strip().split('#')
410
+ caption = line_split[0]
411
+ tokens = line_split[1].split(' ')
412
+ f_tag = float(line_split[2])
413
+ to_tag = float(line_split[3])
414
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
415
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
416
+
417
+ text_dict['caption'] = caption
418
+ text_dict['tokens'] = tokens
419
+ if f_tag == 0.0 and to_tag == 0.0:
420
+ flag = True
421
+ text_data.append(text_dict)
422
+ else:
423
+ try:
424
+ n_motion = motion[int(f_tag*20) : int(to_tag*20)]
425
+ if (len(n_motion)) < min_motion_len or (len(n_motion) >= 200):
426
+ continue
427
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
428
+ while new_name in data_dict:
429
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
430
+ data_dict[new_name] = {'motion': n_motion,
431
+ 'length': len(n_motion),
432
+ 'text':[text_dict]}
433
+ new_name_list.append(new_name)
434
+ length_list.append(len(n_motion))
435
+ except:
436
+ print(line_split)
437
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
438
+ # break
439
+
440
+ if flag:
441
+ data_dict[name] = {'motion': motion,
442
+ 'length': len(motion),
443
+ 'text': text_data}
444
+ new_name_list.append(name)
445
+ length_list.append(len(motion))
446
+ except:
447
+ pass
448
+
449
+ name_list, length_list = zip(*sorted(zip(new_name_list, length_list), key=lambda x: x[1]))
450
+
451
+ self.mean = mean
452
+ self.std = std
453
+ self.length_arr = np.array(length_list)
454
+ self.data_dict = data_dict
455
+ self.name_list = name_list
456
+ self.reset_max_len(self.max_length)
457
+
458
+ def reset_max_len(self, length):
459
+ assert length <= self.max_motion_length
460
+ self.pointer = np.searchsorted(self.length_arr, length)
461
+ print("Pointer Pointing at %d"%self.pointer)
462
+ self.max_length = length
463
+
464
+ def inv_transform(self, data):
465
+ return data * self.std + self.mean
466
+
467
+ def __len__(self):
468
+ return len(self.data_dict) - self.pointer
469
+
470
+ def __getitem__(self, item):
471
+ idx = self.pointer + item
472
+ data = self.data_dict[self.name_list[idx]]
473
+ motion, m_length, text_list = data['motion'], data['length'], data['text']
474
+ # Randomly select a caption
475
+ text_data = random.choice(text_list)
476
+ caption, tokens = text_data['caption'], text_data['tokens']
477
+
478
+ if len(tokens) < self.opt.max_text_len:
479
+ # pad with "unk"
480
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
481
+ sent_len = len(tokens)
482
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
483
+ else:
484
+ # crop
485
+ tokens = tokens[:self.opt.max_text_len]
486
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
487
+ sent_len = len(tokens)
488
+ pos_one_hots = []
489
+ word_embeddings = []
490
+ for token in tokens:
491
+ word_emb, pos_oh = self.w_vectorizer[token]
492
+ pos_one_hots.append(pos_oh[None, :])
493
+ word_embeddings.append(word_emb[None, :])
494
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
495
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
496
+
497
+ len_gap = (m_length - self.max_length) // self.opt.unit_length
498
+
499
+ if m_length != self.max_length:
500
+ # print("Motion original length:%d_%d"%(m_length, len(motion)))
501
+ if self.opt.unit_length < 10:
502
+ coin2 = np.random.choice(['single', 'single', 'double'])
503
+ else:
504
+ coin2 = 'single'
505
+ if len_gap == 0 or (len_gap == 1 and coin2 == 'double'):
506
+ m_length = self.max_length
507
+ s_idx = random.randint(0, m_length - self.max_length)
508
+ else:
509
+ if coin2 == 'single':
510
+ n_m_length = self.max_length + self.opt.unit_length * len_gap
511
+ else:
512
+ n_m_length = self.max_length + self.opt.unit_length * (len_gap - 1)
513
+ s_idx = random.randint(0, m_length - n_m_length)
514
+ m_length = n_m_length
515
+ else:
516
+ s_idx = 0
517
+
518
+ src_motion = motion[s_idx: s_idx + m_length]
519
+ tgt_motion = motion[s_idx: s_idx + self.max_length]
520
+
521
+ "Z Normalization"
522
+ src_motion = (src_motion - self.mean) / self.std
523
+ tgt_motion = (tgt_motion - self.mean) / self.std
524
+
525
+ if m_length < self.max_motion_length:
526
+ src_motion = np.concatenate([src_motion,
527
+ np.zeros((self.max_motion_length - m_length, motion.shape[1]))
528
+ ], axis=0)
529
+ # print(m_length, src_motion.shape, tgt_motion.shape)
530
+ # print(word_embeddings.shape, motion.shape)
531
+ # print(tokens)
532
+ return word_embeddings, caption, sent_len, src_motion, tgt_motion, m_length
533
+
534
+
535
+ class MotionDatasetV2(data.Dataset):
536
+ def __init__(self, opt, mean, std, split_file):
537
+ self.opt = opt
538
+ joints_num = opt.joints_num
539
+
540
+ self.data = []
541
+ self.lengths = []
542
+ id_list = []
543
+ with cs.open(split_file, 'r') as f:
544
+ for line in f.readlines():
545
+ id_list.append(line.strip())
546
+
547
+ for name in tqdm(id_list):
548
+ try:
549
+ motion = np.load(pjoin(opt.motion_dir, name + '.npy'))
550
+ if motion.shape[0] < opt.window_size:
551
+ continue
552
+ self.lengths.append(motion.shape[0] - opt.window_size)
553
+ self.data.append(motion)
554
+ except:
555
+ # Some motion may not exist in KIT dataset
556
+ pass
557
+
558
+ self.cumsum = np.cumsum([0] + self.lengths)
559
+
560
+ if opt.is_train:
561
+ # root_rot_velocity (B, seq_len, 1)
562
+ std[0:1] = std[0:1] / opt.feat_bias
563
+ # root_linear_velocity (B, seq_len, 2)
564
+ std[1:3] = std[1:3] / opt.feat_bias
565
+ # root_y (B, seq_len, 1)
566
+ std[3:4] = std[3:4] / opt.feat_bias
567
+ # ric_data (B, seq_len, (joint_num - 1)*3)
568
+ std[4: 4 + (joints_num - 1) * 3] = std[4: 4 + (joints_num - 1) * 3] / 1.0
569
+ # rot_data (B, seq_len, (joint_num - 1)*6)
570
+ std[4 + (joints_num - 1) * 3: 4 + (joints_num - 1) * 9] = std[4 + (joints_num - 1) * 3: 4 + (
571
+ joints_num - 1) * 9] / 1.0
572
+ # local_velocity (B, seq_len, joint_num*3)
573
+ std[4 + (joints_num - 1) * 9: 4 + (joints_num - 1) * 9 + joints_num * 3] = std[
574
+ 4 + (joints_num - 1) * 9: 4 + (
575
+ joints_num - 1) * 9 + joints_num * 3] / 1.0
576
+ # foot contact (B, seq_len, 4)
577
+ std[4 + (joints_num - 1) * 9 + joints_num * 3:] = std[
578
+ 4 + (joints_num - 1) * 9 + joints_num * 3:] / opt.feat_bias
579
+
580
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
581
+ np.save(pjoin(opt.meta_dir, 'mean.npy'), mean)
582
+ np.save(pjoin(opt.meta_dir, 'std.npy'), std)
583
+
584
+ self.mean = mean
585
+ self.std = std
586
+ print("Total number of motions {}, snippets {}".format(len(self.data), self.cumsum[-1]))
587
+
588
+ def inv_transform(self, data):
589
+ return data * self.std + self.mean
590
+
591
+ def __len__(self):
592
+ return self.cumsum[-1]
593
+
594
+ def __getitem__(self, item):
595
+ if item != 0:
596
+ motion_id = np.searchsorted(self.cumsum, item) - 1
597
+ idx = item - self.cumsum[motion_id] - 1
598
+ else:
599
+ motion_id = 0
600
+ idx = 0
601
+ motion = self.data[motion_id][idx:idx+self.opt.window_size]
602
+ "Z Normalization"
603
+ motion = (motion - self.mean) / self.std
604
+
605
+ return motion
606
+
607
+
608
+ class RawTextDataset(data.Dataset):
609
+ def __init__(self, opt, mean, std, text_file, w_vectorizer):
610
+ self.mean = mean
611
+ self.std = std
612
+ self.opt = opt
613
+ self.data_dict = []
614
+ self.nlp = spacy.load('en_core_web_sm')
615
+
616
+ with cs.open(text_file) as f:
617
+ for line in f.readlines():
618
+ word_list, pos_list = self.process_text(line.strip())
619
+ tokens = ['%s/%s'%(word_list[i], pos_list[i]) for i in range(len(word_list))]
620
+ self.data_dict.append({'caption':line.strip(), "tokens":tokens})
621
+
622
+ self.w_vectorizer = w_vectorizer
623
+ print("Total number of descriptions {}".format(len(self.data_dict)))
624
+
625
+
626
+ def process_text(self, sentence):
627
+ sentence = sentence.replace('-', '')
628
+ doc = self.nlp(sentence)
629
+ word_list = []
630
+ pos_list = []
631
+ for token in doc:
632
+ word = token.text
633
+ if not word.isalpha():
634
+ continue
635
+ if (token.pos_ == 'NOUN' or token.pos_ == 'VERB') and (word != 'left'):
636
+ word_list.append(token.lemma_)
637
+ else:
638
+ word_list.append(word)
639
+ pos_list.append(token.pos_)
640
+ return word_list, pos_list
641
+
642
+ def inv_transform(self, data):
643
+ return data * self.std + self.mean
644
+
645
+ def __len__(self):
646
+ return len(self.data_dict)
647
+
648
+ def __getitem__(self, item):
649
+ data = self.data_dict[item]
650
+ caption, tokens = data['caption'], data['tokens']
651
+
652
+ if len(tokens) < self.opt.max_text_len:
653
+ # pad with "unk"
654
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
655
+ sent_len = len(tokens)
656
+ tokens = tokens + ['unk/OTHER'] * (self.opt.max_text_len + 2 - sent_len)
657
+ else:
658
+ # crop
659
+ tokens = tokens[:self.opt.max_text_len]
660
+ tokens = ['sos/OTHER'] + tokens + ['eos/OTHER']
661
+ sent_len = len(tokens)
662
+ pos_one_hots = []
663
+ word_embeddings = []
664
+ for token in tokens:
665
+ word_emb, pos_oh = self.w_vectorizer[token]
666
+ pos_one_hots.append(pos_oh[None, :])
667
+ word_embeddings.append(word_emb[None, :])
668
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
669
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
670
+
671
+ return word_embeddings, pos_one_hots, caption, sent_len
672
+
673
+ class TextOnlyDataset(data.Dataset):
674
+ def __init__(self, opt, mean, std, split_file):
675
+ self.mean = mean
676
+ self.std = std
677
+ self.opt = opt
678
+ self.data_dict = []
679
+ self.max_length = 20
680
+ self.pointer = 0
681
+ self.fixed_length = 120
682
+
683
+
684
+ data_dict = {}
685
+ id_list = []
686
+ with cs.open(split_file, 'r') as f:
687
+ for line in f.readlines():
688
+ id_list.append(line.strip())
689
+ # id_list = id_list[:200]
690
+
691
+ new_name_list = []
692
+ length_list = []
693
+ for name in tqdm(id_list):
694
+ try:
695
+ text_data = []
696
+ flag = False
697
+ with cs.open(pjoin(opt.text_dir, name + '.txt')) as f:
698
+ for line in f.readlines():
699
+ text_dict = {}
700
+ line_split = line.strip().split('#')
701
+ caption = line_split[0]
702
+ tokens = line_split[1].split(' ')
703
+ f_tag = float(line_split[2])
704
+ to_tag = float(line_split[3])
705
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
706
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
707
+
708
+ text_dict['caption'] = caption
709
+ text_dict['tokens'] = tokens
710
+ if f_tag == 0.0 and to_tag == 0.0:
711
+ flag = True
712
+ text_data.append(text_dict)
713
+ else:
714
+ try:
715
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
716
+ while new_name in data_dict:
717
+ new_name = random.choice('ABCDEFGHIJKLMNOPQRSTUVW') + '_' + name
718
+ data_dict[new_name] = {'text':[text_dict]}
719
+ new_name_list.append(new_name)
720
+ except:
721
+ print(line_split)
722
+ print(line_split[2], line_split[3], f_tag, to_tag, name)
723
+ # break
724
+
725
+ if flag:
726
+ data_dict[name] = {'text': text_data}
727
+ new_name_list.append(name)
728
+ except:
729
+ pass
730
+
731
+ self.length_arr = np.array(length_list)
732
+ self.data_dict = data_dict
733
+ self.name_list = new_name_list
734
+
735
+ def inv_transform(self, data):
736
+ return data * self.std + self.mean
737
+
738
+ def __len__(self):
739
+ return len(self.data_dict)
740
+
741
+ def __getitem__(self, item):
742
+ idx = self.pointer + item
743
+ data = self.data_dict[self.name_list[idx]]
744
+ text_list = data['text']
745
+
746
+ # Randomly select a caption
747
+ text_data = random.choice(text_list)
748
+ caption, tokens = text_data['caption'], text_data['tokens']
749
+ return None, None, caption, None, np.array([0]), self.fixed_length, None
750
+ # fixed_length can be set from outside before sampling
751
+
752
+ # A wrapper class for t2m original dataset for MDM purposes
753
+ class HumanML3D(data.Dataset):
754
+ def __init__(self, mode, datapath='./dataset/humanml_opt.txt', split="train", **kwargs):
755
+ self.mode = mode
756
+
757
+ self.dataset_name = 't2m'
758
+ self.dataname = 't2m'
759
+
760
+ # Configurations of T2M dataset and KIT dataset is almost the same
761
+ abs_base_path = kwargs.get('abs_path', '.')
762
+ dataset_opt_path = pjoin(abs_base_path, datapath)
763
+ device = kwargs.get('device', None)
764
+ opt = get_opt(dataset_opt_path, device)
765
+ # opt.meta_dir = pjoin(abs_base_path, opt.meta_dir)
766
+ opt.cache_dir = kwargs.get('cache_path', '.')
767
+ opt.motion_dir = pjoin(abs_base_path, opt.motion_dir)
768
+ opt.text_dir = pjoin(abs_base_path, opt.text_dir)
769
+ opt.model_dir = pjoin(abs_base_path, opt.model_dir)
770
+ opt.checkpoints_dir = pjoin(abs_base_path, opt.checkpoints_dir)
771
+ opt.data_root = pjoin(abs_base_path, opt.data_root)
772
+ opt.save_root = pjoin(abs_base_path, opt.save_root)
773
+ opt.meta_dir = pjoin(abs_base_path, './dataset')
774
+ opt.use_cache = kwargs.get('use_cache', True)
775
+ opt.fixed_len = kwargs.get('fixed_len', 0)
776
+ if opt.fixed_len > 0:
777
+ opt.max_motion_length = opt.fixed_len
778
+ is_autoregressive = kwargs.get('autoregressive', False)
779
+ opt.disable_offset_aug = is_autoregressive and (opt.fixed_len > 0) and (mode == 'eval') # for autoregressive evaluation, use the start of the motion and not something from the middle
780
+ self.opt = opt
781
+ print('Loading dataset %s ...' % opt.dataset_name)
782
+
783
+ if mode == 'gt':
784
+ # used by T2M models (including evaluators)
785
+ self.mean = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
786
+ self.std = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
787
+ elif mode in ['train', 'eval', 'text_only']:
788
+ # used by our models
789
+ self.mean = np.load(pjoin(opt.data_root, 'Mean.npy'))
790
+ self.std = np.load(pjoin(opt.data_root, 'Std.npy'))
791
+
792
+ if mode == 'eval':
793
+ # used by T2M models (including evaluators)
794
+ # this is to translate their norms to ours
795
+ self.mean_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_mean.npy'))
796
+ self.std_for_eval = np.load(pjoin(opt.meta_dir, f'{opt.dataset_name}_std.npy'))
797
+
798
+ self.split_file = pjoin(opt.data_root, f'{split}.txt')
799
+ if mode == 'text_only':
800
+ self.t2m_dataset = TextOnlyDataset(self.opt, self.mean, self.std, self.split_file)
801
+ else:
802
+ self.w_vectorizer = WordVectorizer(pjoin(opt.cache_dir, 'glove'), 'our_vab')
803
+ self.t2m_dataset = Text2MotionDatasetV2(self.opt, self.mean, self.std, self.split_file, self.w_vectorizer)
804
+ self.num_actions = 1 # dummy placeholder
805
+
806
+ self.mean_gpu = torch.tensor(self.mean).to(device)[None, :, None, None]
807
+ self.std_gpu = torch.tensor(self.std).to(device)[None, :, None, None]
808
+
809
+ assert len(self.t2m_dataset) > 1, 'You loaded an empty dataset, ' \
810
+ 'it is probably because your data dir has only texts and no motions.\n' \
811
+ 'To train and evaluate MDM you should get the FULL data as described ' \
812
+ 'in the README file.'
813
+
814
+ def __getitem__(self, item):
815
+ return self.t2m_dataset.__getitem__(item)
816
+
817
+ def __len__(self):
818
+ return self.t2m_dataset.__len__()
819
+
820
+ # A wrapper class for t2m original dataset for MDM purposes
821
+ class KIT(HumanML3D):
822
+ def __init__(self, mode, datapath='./dataset/kit_opt.txt', split="train", **kwargs):
823
+ super(KIT, self).__init__(mode, datapath, split, **kwargs)
motion_diffusion_model/data_loaders/humanml/motion_loaders/__init__.py ADDED
File without changes
motion_diffusion_model/data_loaders/humanml/motion_loaders/comp_v6_model_dataset.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from data_loaders.humanml.networks.modules import *
3
+ from data_loaders.humanml.networks.trainers import CompTrainerV6
4
+ from torch.utils.data import Dataset, DataLoader
5
+ from os.path import join as pjoin
6
+ from tqdm import tqdm
7
+ from utils import dist_util
8
+ from utils.sampler_util import AutoRegressiveSampler
9
+
10
+
11
+ def build_models(opt):
12
+ if opt.text_enc_mod == 'bigru':
13
+ text_encoder = TextEncoderBiGRU(word_size=opt.dim_word,
14
+ pos_size=opt.dim_pos_ohot,
15
+ hidden_size=opt.dim_text_hidden,
16
+ device=opt.device)
17
+ text_size = opt.dim_text_hidden * 2
18
+ else:
19
+ raise Exception("Text Encoder Mode not Recognized!!!")
20
+
21
+ seq_prior = TextDecoder(text_size=text_size,
22
+ input_size=opt.dim_att_vec + opt.dim_movement_latent,
23
+ output_size=opt.dim_z,
24
+ hidden_size=opt.dim_pri_hidden,
25
+ n_layers=opt.n_layers_pri)
26
+
27
+
28
+ seq_decoder = TextVAEDecoder(text_size=text_size,
29
+ input_size=opt.dim_att_vec + opt.dim_z + opt.dim_movement_latent,
30
+ output_size=opt.dim_movement_latent,
31
+ hidden_size=opt.dim_dec_hidden,
32
+ n_layers=opt.n_layers_dec)
33
+
34
+ att_layer = AttLayer(query_dim=opt.dim_pos_hidden,
35
+ key_dim=text_size,
36
+ value_dim=opt.dim_att_vec)
37
+
38
+ movement_enc = MovementConvEncoder(opt.dim_pose - 4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
39
+ movement_dec = MovementConvDecoder(opt.dim_movement_latent, opt.dim_movement_dec_hidden, opt.dim_pose)
40
+
41
+ len_estimator = MotionLenEstimatorBiGRU(opt.dim_word, opt.dim_pos_ohot, 512, opt.num_classes)
42
+
43
+ # latent_dis = LatentDis(input_size=opt.dim_z * 2)
44
+ checkpoints = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'length_est_bigru', 'model', 'latest.tar'), map_location=opt.device)
45
+ len_estimator.load_state_dict(checkpoints['estimator'])
46
+ len_estimator.to(opt.device)
47
+ len_estimator.eval()
48
+
49
+ # return text_encoder, text_decoder, att_layer, vae_pri, vae_dec, vae_pos, motion_dis, movement_dis, latent_dis
50
+ return text_encoder, seq_prior, seq_decoder, att_layer, movement_enc, movement_dec, len_estimator
51
+
52
+ class CompV6GeneratedDataset(Dataset):
53
+
54
+ def __init__(self, opt, dataset, w_vectorizer, mm_num_samples, mm_num_repeats):
55
+ assert mm_num_samples < len(dataset)
56
+ print(opt.model_dir)
57
+
58
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
59
+ text_enc, seq_pri, seq_dec, att_layer, mov_enc, mov_dec, len_estimator = build_models(opt)
60
+ trainer = CompTrainerV6(opt, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=mov_enc)
61
+ epoch, it, sub_ep, schedule_len = trainer.load(pjoin(opt.model_dir, opt.which_epoch + '.tar'))
62
+ generated_motion = []
63
+ mm_generated_motions = []
64
+ mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
65
+ mm_idxs = np.sort(mm_idxs)
66
+ min_mov_length = 10 if opt.dataset_name == 't2m' else 6
67
+ # print(mm_idxs)
68
+
69
+ print('Loading model: Epoch %03d Schedule_len %03d' % (epoch, schedule_len))
70
+ trainer.eval_mode()
71
+ trainer.to(opt.device)
72
+ with torch.no_grad():
73
+ for i, data in tqdm(enumerate(dataloader)):
74
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
75
+ tokens = tokens[0].split('_')
76
+ word_emb = word_emb.detach().to(opt.device).float()
77
+ pos_ohot = pos_ohot.detach().to(opt.device).float()
78
+
79
+ pred_dis = len_estimator(word_emb, pos_ohot, cap_lens)
80
+ pred_dis = nn.Softmax(-1)(pred_dis).squeeze()
81
+
82
+ mm_num_now = len(mm_generated_motions)
83
+ is_mm = True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False
84
+
85
+ repeat_times = mm_num_repeats if is_mm else 1
86
+ mm_motions = []
87
+ for t in range(repeat_times):
88
+ mov_length = torch.multinomial(pred_dis, 1, replacement=True)
89
+ if mov_length < min_mov_length:
90
+ mov_length = torch.multinomial(pred_dis, 1, replacement=True)
91
+ if mov_length < min_mov_length:
92
+ mov_length = torch.multinomial(pred_dis, 1, replacement=True)
93
+
94
+ m_lens = mov_length * opt.unit_length
95
+ pred_motions, _, _ = trainer.generate(word_emb, pos_ohot, cap_lens, m_lens,
96
+ m_lens[0]//opt.unit_length, opt.dim_pose)
97
+ if t == 0:
98
+ # print(m_lens)
99
+ # print(text_data)
100
+ sub_dict = {'motion': pred_motions[0].cpu().numpy(),
101
+ 'length': m_lens[0].item(),
102
+ 'cap_len': cap_lens[0].item(),
103
+ 'caption': caption[0],
104
+ 'tokens': tokens}
105
+ generated_motion.append(sub_dict)
106
+
107
+ if is_mm:
108
+ mm_motions.append({
109
+ 'motion': pred_motions[0].cpu().numpy(),
110
+ 'length': m_lens[0].item()
111
+ })
112
+ if is_mm:
113
+ mm_generated_motions.append({'caption': caption[0],
114
+ 'tokens': tokens,
115
+ 'cap_len': cap_lens[0].item(),
116
+ 'mm_motions': mm_motions})
117
+
118
+ self.generated_motion = generated_motion
119
+ self.mm_generated_motion = mm_generated_motions
120
+ self.opt = opt
121
+ self.w_vectorizer = w_vectorizer
122
+
123
+
124
+ def __len__(self):
125
+ return len(self.generated_motion)
126
+
127
+
128
+ def __getitem__(self, item):
129
+ data = self.generated_motion[item]
130
+ motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
131
+ sent_len = data['cap_len']
132
+
133
+ pos_one_hots = []
134
+ word_embeddings = []
135
+ for token in tokens:
136
+ word_emb, pos_oh = self.w_vectorizer[token]
137
+ pos_one_hots.append(pos_oh[None, :])
138
+ word_embeddings.append(word_emb[None, :])
139
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
140
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
141
+
142
+ if m_length < self.opt.max_motion_length:
143
+ motion = np.concatenate([motion,
144
+ np.zeros((self.opt.max_motion_length - m_length, motion.shape[1]))
145
+ ], axis=0)
146
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
147
+
148
+ class CompMDMGeneratedDataset(Dataset):
149
+
150
+ def __init__(self, args, model, diffusion, dataloader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale=1.):
151
+ self.args = args
152
+ self.dataloader = dataloader
153
+ self.dataset = dataloader.dataset
154
+ self.model = model
155
+ assert mm_num_samples < len(dataloader.dataset)
156
+ use_ddim = False # FIXME - hardcoded
157
+ clip_denoised = False # FIXME - hardcoded
158
+ self.max_motion_length = max_motion_length
159
+ sample_fn = (
160
+ diffusion.p_sample_loop if not use_ddim else diffusion.ddim_sample_loop
161
+ )
162
+ if self.args.autoregressive:
163
+ sample_cls = AutoRegressiveSampler(args, sample_fn)
164
+ sample_fn = sample_cls.sample
165
+
166
+
167
+ real_num_batches = len(dataloader)
168
+ if num_samples_limit is not None:
169
+ real_num_batches = min(num_samples_limit // dataloader.batch_size + 1, real_num_batches)
170
+ print('real_num_batches', real_num_batches)
171
+
172
+ generated_motion = []
173
+ mm_generated_motions = []
174
+ if mm_num_samples > 0:
175
+ mm_idxs = np.random.choice(real_num_batches, mm_num_samples // dataloader.batch_size +1, replace=False)
176
+ mm_idxs = np.sort(mm_idxs)
177
+ else:
178
+ mm_idxs = []
179
+ print('mm_idxs', mm_idxs)
180
+
181
+ model.eval()
182
+
183
+
184
+ with torch.no_grad():
185
+ for i, (motion, model_kwargs) in tqdm(enumerate(dataloader)):
186
+
187
+ if num_samples_limit is not None and len(generated_motion) >= num_samples_limit:
188
+ break
189
+
190
+ model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}
191
+ motion = motion.to(dist_util.dev())
192
+
193
+ tokens = [t.split('_') for t in model_kwargs['y']['tokens']]
194
+
195
+ # add CFG scale to batch
196
+ if scale != 1.:
197
+ model_kwargs['y']['scale'] = torch.ones(motion.shape[0],
198
+ device=dist_util.dev()) * scale
199
+
200
+ mm_num_now = len(mm_generated_motions) // dataloader.batch_size
201
+ is_mm = i in mm_idxs
202
+ repeat_times = mm_num_repeats if is_mm else 1
203
+ mm_motions = []
204
+ for t in range(repeat_times):
205
+
206
+ sample = sample_fn(
207
+ model,
208
+ motion.shape,
209
+ clip_denoised=clip_denoised,
210
+ model_kwargs=model_kwargs,
211
+ skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
212
+ init_image=None,
213
+ progress=False,
214
+ dump_steps=None,
215
+ noise=None,
216
+ const_noise=False,
217
+ # when experimenting guidance_scale we want to nutrileze the effect of noise on generation
218
+ )
219
+
220
+ if 'prefix' in model_kwargs['y'].keys():
221
+ model_kwargs['y']['lengths'] = model_kwargs['y']['orig_lengths']
222
+
223
+ if t == 0:
224
+ sub_dicts = [{
225
+ 'motion': sample[bs_i].squeeze().permute(1, 0).cpu().numpy(),
226
+ 'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
227
+ 'caption': model_kwargs['y']['text'][bs_i],
228
+ 'tokens': tokens[bs_i],
229
+ # Fixed cap_len calculation, changed from len(tokens[bs_i])
230
+ # Lead to improved R-precision and Multimodal Dist.
231
+ # issue: https://github.com/GuyTevet/motion-diffusion-model/issues/182
232
+ 'cap_len': tokens[bs_i].index('eos/OTHER') + 1,
233
+ } for bs_i in range(dataloader.batch_size)]
234
+ generated_motion += sub_dicts
235
+
236
+ if is_mm:
237
+ for bs_i in range(dataloader.batch_size):
238
+ mm_motion = sample[bs_i].squeeze().permute(1, 0).cpu().numpy()
239
+ if self.dataset.mode == 'eval':
240
+ mm_motion = self.dataset.t2m_dataset.inv_transform(mm_motion)
241
+ mm_motion = (mm_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval # according to T2M norms
242
+
243
+ mm_motions.append({'motion': mm_motion,
244
+ 'length': model_kwargs['y']['lengths'][bs_i].cpu().numpy(),
245
+ })
246
+ if is_mm:
247
+ mm_generated_motions += [{
248
+ 'caption': model_kwargs['y']['text'][bs_i],
249
+ 'tokens': tokens[bs_i],
250
+ 'cap_len': len(tokens[bs_i]),
251
+ 'mm_motions': mm_motions[bs_i::dataloader.batch_size], # collect all 10 repeats from the (32*10) generated motions
252
+ } for bs_i in range(dataloader.batch_size)]
253
+
254
+
255
+ self.generated_motion = generated_motion
256
+ self.mm_generated_motion = mm_generated_motions
257
+ self.w_vectorizer = dataloader.dataset.w_vectorizer
258
+
259
+
260
+ def __len__(self):
261
+ return len(self.generated_motion)
262
+
263
+
264
+ def __getitem__(self, item):
265
+ data = self.generated_motion[item]
266
+ motion, m_length, caption, tokens = data['motion'], data['length'], data['caption'], data['tokens']
267
+ sent_len = data['cap_len']
268
+
269
+ if self.dataset.mode == 'eval':
270
+ normed_motion = motion
271
+ denormed_motion = self.dataset.t2m_dataset.inv_transform(normed_motion)
272
+ renormed_motion = (denormed_motion - self.dataset.mean_for_eval) / self.dataset.std_for_eval # according to T2M norms
273
+ motion = renormed_motion
274
+ # This step is needed because T2M evaluators expect their norm convention
275
+
276
+ pos_one_hots = []
277
+ word_embeddings = []
278
+ for token in tokens:
279
+ word_emb, pos_oh = self.w_vectorizer[token]
280
+ pos_one_hots.append(pos_oh[None, :])
281
+ word_embeddings.append(word_emb[None, :])
282
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
283
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
284
+
285
+ return word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, '_'.join(tokens)
motion_diffusion_model/data_loaders/humanml/motion_loaders/dataset_motion_loader.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from t2m.data.dataset import Text2MotionDatasetV2, collate_fn
2
+ from t2m.utils.word_vectorizer import WordVectorizer
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ from torch.utils.data import DataLoader
6
+ from t2m.utils.get_opt import get_opt
7
+
8
+ def get_dataset_motion_loader(opt_path, batch_size, device):
9
+ opt = get_opt(opt_path, device)
10
+
11
+ # Configurations of T2M dataset and KIT dataset is almost the same
12
+ if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
13
+ print('Loading dataset %s ...' % opt.dataset_name)
14
+
15
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
16
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
17
+
18
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
19
+ split_file = pjoin(opt.data_root, 'test.txt')
20
+ dataset = Text2MotionDatasetV2(opt, mean, std, split_file, w_vectorizer)
21
+ dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=4, drop_last=True,
22
+ collate_fn=collate_fn, shuffle=True)
23
+ else:
24
+ raise KeyError('Dataset not Recognized !!')
25
+
26
+ print('Ground Truth Dataset Loading Completed!!!')
27
+ return dataloader, dataset
motion_diffusion_model/data_loaders/humanml/motion_loaders/model_motion_loaders.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader, Dataset
2
+ from data_loaders.humanml.utils.get_opt import get_opt
3
+ from data_loaders.humanml.motion_loaders.comp_v6_model_dataset import CompMDMGeneratedDataset
4
+ from data_loaders.humanml.utils.word_vectorizer import WordVectorizer
5
+ import numpy as np
6
+ from torch.utils.data._utils.collate import default_collate
7
+
8
+
9
+ def collate_fn(batch):
10
+ batch.sort(key=lambda x: x[3], reverse=True)
11
+ return default_collate(batch)
12
+
13
+
14
+ class MMGeneratedDataset(Dataset):
15
+ def __init__(self, opt, motion_dataset, w_vectorizer):
16
+ self.opt = opt
17
+ self.dataset = motion_dataset.mm_generated_motion
18
+ self.w_vectorizer = w_vectorizer
19
+
20
+ def __len__(self):
21
+ return len(self.dataset)
22
+
23
+ def __getitem__(self, item):
24
+ data = self.dataset[item]
25
+ mm_motions = data['mm_motions']
26
+ m_lens = []
27
+ motions = []
28
+ for mm_motion in mm_motions:
29
+ m_lens.append(mm_motion['length'])
30
+ motion = mm_motion['motion']
31
+ # We don't need the following logic because our sample func generates the full tensor anyway:
32
+ # if len(motion) < self.opt.max_motion_length:
33
+ # motion = np.concatenate([motion,
34
+ # np.zeros((self.opt.max_motion_length - len(motion), motion.shape[1]))
35
+ # ], axis=0)
36
+ motion = motion[None, :]
37
+ motions.append(motion)
38
+ m_lens = np.array(m_lens, dtype=np.int)
39
+ motions = np.concatenate(motions, axis=0)
40
+ sort_indx = np.argsort(m_lens)[::-1].copy()
41
+ # print(m_lens)
42
+ # print(sort_indx)
43
+ # print(m_lens[sort_indx])
44
+ m_lens = m_lens[sort_indx]
45
+ motions = motions[sort_indx]
46
+ return motions, m_lens
47
+
48
+
49
+
50
+ def get_motion_loader(opt_path, batch_size, ground_truth_dataset, mm_num_samples, mm_num_repeats, device):
51
+ opt = get_opt(opt_path, device)
52
+
53
+ # Currently the configurations of two datasets are almost the same
54
+ if opt.dataset_name == 't2m' or opt.dataset_name == 'kit':
55
+ w_vectorizer = WordVectorizer('./glove', 'our_vab')
56
+ else:
57
+ raise KeyError('Dataset not recognized!!')
58
+ print('Generating %s ...' % opt.name)
59
+
60
+ if 'v6' in opt.name:
61
+ dataset = CompV6GeneratedDataset(opt, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats)
62
+ else:
63
+ raise KeyError('Dataset not recognized!!')
64
+
65
+ mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
66
+
67
+ motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4)
68
+ mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
69
+
70
+ print('Generated Dataset Loading Completed!!!')
71
+
72
+ return motion_loader, mm_motion_loader
73
+
74
+ # our loader
75
+ def get_mdm_loader(args, model, diffusion, batch_size, ground_truth_loader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale):
76
+ opt = {
77
+ 'name': 'test', # FIXME
78
+ }
79
+ print('Generating %s ...' % opt['name'])
80
+ # dataset = CompMDMGeneratedDataset(opt, ground_truth_dataset, ground_truth_dataset.w_vectorizer, mm_num_samples, mm_num_repeats)
81
+ dataset = CompMDMGeneratedDataset(args, model, diffusion, ground_truth_loader, mm_num_samples, mm_num_repeats, max_motion_length, num_samples_limit, scale)
82
+
83
+ mm_dataset = MMGeneratedDataset(opt, dataset, ground_truth_loader.dataset.w_vectorizer)
84
+
85
+ # NOTE: bs must not be changed! this will cause a bug in R precision calc!
86
+ motion_loader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4)
87
+ mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
88
+
89
+ print('Generated Dataset Loading Completed!!!')
90
+
91
+ return motion_loader, mm_motion_loader
motion_diffusion_model/data_loaders/humanml/networks/__init__.py ADDED
File without changes
motion_diffusion_model/data_loaders/humanml/networks/evaluator_wrapper.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from data_loaders.humanml.networks.modules import *
2
+ from data_loaders.humanml.utils.word_vectorizer import POS_enumerator
3
+ from os.path import join as pjoin
4
+
5
+ def build_models(opt):
6
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
7
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
8
+ pos_size=opt.dim_pos_ohot,
9
+ hidden_size=opt.dim_text_hidden,
10
+ output_size=opt.dim_coemb_hidden,
11
+ device=opt.device)
12
+
13
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
14
+ hidden_size=opt.dim_motion_hidden,
15
+ output_size=opt.dim_coemb_hidden,
16
+ device=opt.device)
17
+
18
+ checkpoint = torch.load(pjoin(opt.checkpoints_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
19
+ map_location=opt.device)
20
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
21
+ text_enc.load_state_dict(checkpoint['text_encoder'])
22
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
23
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
24
+ return text_enc, motion_enc, movement_enc
25
+
26
+
27
+ class EvaluatorModelWrapper(object):
28
+
29
+ def __init__(self, opt):
30
+
31
+ if opt.dataset_name == 't2m':
32
+ opt.dim_pose = 263
33
+ elif opt.dataset_name == 'kit':
34
+ opt.dim_pose = 251
35
+ else:
36
+ raise KeyError('Dataset not Recognized!!!')
37
+
38
+ opt.dim_word = 300
39
+ opt.max_motion_length = 196
40
+ opt.dim_pos_ohot = len(POS_enumerator)
41
+ opt.dim_motion_hidden = 1024
42
+ opt.max_text_len = 20
43
+ opt.dim_text_hidden = 512
44
+ opt.dim_coemb_hidden = 512
45
+
46
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
47
+ self.opt = opt
48
+ self.device = opt.device
49
+
50
+ self.text_encoder.to(opt.device)
51
+ self.motion_encoder.to(opt.device)
52
+ self.movement_encoder.to(opt.device)
53
+
54
+ self.text_encoder.eval()
55
+ self.motion_encoder.eval()
56
+ self.movement_encoder.eval()
57
+
58
+ # Please note that the results does not following the order of inputs
59
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
60
+ with torch.no_grad():
61
+ word_embs = word_embs.detach().to(self.device).float()
62
+ pos_ohot = pos_ohot.detach().to(self.device).float()
63
+ motions = motions.detach().to(self.device).float()
64
+
65
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
66
+ motions = motions[align_idx]
67
+ m_lens = m_lens[align_idx]
68
+
69
+ '''Movement Encoding'''
70
+ movements = self.movement_encoder(motions[..., :-4]).detach()
71
+ m_lens = m_lens // self.opt.unit_length
72
+ motion_embedding = self.motion_encoder(movements, m_lens)
73
+
74
+ '''Text Encoding'''
75
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
76
+ text_embedding = text_embedding[align_idx]
77
+ return text_embedding, motion_embedding
78
+
79
+ # Please note that the results does not following the order of inputs
80
+ def get_motion_embeddings(self, motions, m_lens):
81
+ with torch.no_grad():
82
+ motions = motions.detach().to(self.device).float()
83
+
84
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
85
+ motions = motions[align_idx]
86
+ m_lens = m_lens[align_idx]
87
+
88
+ '''Movement Encoding'''
89
+ movements = self.movement_encoder(motions[..., :-4]).detach()
90
+ m_lens = m_lens // self.opt.unit_length
91
+ motion_embedding = self.motion_encoder(movements, m_lens)
92
+ return motion_embedding
93
+
94
+ # our version
95
+ def build_evaluators(opt):
96
+ movement_enc = MovementConvEncoder(opt['dim_pose']-4, opt['dim_movement_enc_hidden'], opt['dim_movement_latent'])
97
+ text_enc = TextEncoderBiGRUCo(word_size=opt['dim_word'],
98
+ pos_size=opt['dim_pos_ohot'],
99
+ hidden_size=opt['dim_text_hidden'],
100
+ output_size=opt['dim_coemb_hidden'],
101
+ device=opt['device'])
102
+
103
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt['dim_movement_latent'],
104
+ hidden_size=opt['dim_motion_hidden'],
105
+ output_size=opt['dim_coemb_hidden'],
106
+ device=opt['device'])
107
+
108
+ ckpt_dir = opt['dataset_name']
109
+ if opt['dataset_name'] == 'humanml':
110
+ ckpt_dir = 't2m'
111
+
112
+ checkpoint = torch.load(pjoin(opt['checkpoints_dir'], ckpt_dir, 'text_mot_match', 'model', 'finest.tar'),
113
+ map_location=opt['device'])
114
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
115
+ text_enc.load_state_dict(checkpoint['text_encoder'])
116
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
117
+ print('Loading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
118
+ return text_enc, motion_enc, movement_enc
119
+
120
+ # our wrapper
121
+ class EvaluatorMDMWrapper(object):
122
+
123
+ def __init__(self, dataset_name, device):
124
+ opt = {
125
+ 'dataset_name': dataset_name,
126
+ 'device': device,
127
+ 'dim_word': 300,
128
+ 'max_motion_length': 196,
129
+ 'dim_pos_ohot': len(POS_enumerator),
130
+ 'dim_motion_hidden': 1024,
131
+ 'max_text_len': 20,
132
+ 'dim_text_hidden': 512,
133
+ 'dim_coemb_hidden': 512,
134
+ 'dim_pose': 263 if dataset_name == 'humanml' else 251,
135
+ 'dim_movement_enc_hidden': 512,
136
+ 'dim_movement_latent': 512,
137
+ 'checkpoints_dir': '.',
138
+ 'unit_length': 4,
139
+ }
140
+
141
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_evaluators(opt)
142
+ self.opt = opt
143
+ self.device = opt['device']
144
+
145
+ self.text_encoder.to(opt['device'])
146
+ self.motion_encoder.to(opt['device'])
147
+ self.movement_encoder.to(opt['device'])
148
+
149
+ self.text_encoder.eval()
150
+ self.motion_encoder.eval()
151
+ self.movement_encoder.eval()
152
+
153
+ # Please note that the results does not following the order of inputs
154
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
155
+ with torch.no_grad():
156
+ word_embs = word_embs.detach().to(self.device).float()
157
+ pos_ohot = pos_ohot.detach().to(self.device).float()
158
+ motions = motions.detach().to(self.device).float()
159
+
160
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
161
+ motions = motions[align_idx]
162
+ m_lens = m_lens[align_idx]
163
+
164
+ '''Movement Encoding'''
165
+ movements = self.movement_encoder(motions[..., :-4]).detach()
166
+ m_lens = m_lens // self.opt['unit_length']
167
+ motion_embedding = self.motion_encoder(movements, m_lens)
168
+
169
+ '''Text Encoding'''
170
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
171
+ text_embedding = text_embedding[align_idx]
172
+ return text_embedding, motion_embedding
173
+
174
+ # Please note that the results does not following the order of inputs
175
+ def get_motion_embeddings(self, motions, m_lens):
176
+ with torch.no_grad():
177
+ motions = motions.detach().to(self.device).float()
178
+
179
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
180
+ motions = motions[align_idx]
181
+ m_lens = m_lens[align_idx]
182
+
183
+ '''Movement Encoding'''
184
+ movements = self.movement_encoder(motions[..., :-4]).detach()
185
+ m_lens = m_lens // self.opt['unit_length']
186
+ motion_embedding = self.motion_encoder(movements, m_lens)
187
+ return motion_embedding
motion_diffusion_model/data_loaders/humanml/networks/modules.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import time
5
+ import math
6
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
7
+ # from networks.layers import *
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class ContrastiveLoss(torch.nn.Module):
12
+ """
13
+ Contrastive loss function.
14
+ Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
15
+ """
16
+ def __init__(self, margin=3.0):
17
+ super(ContrastiveLoss, self).__init__()
18
+ self.margin = margin
19
+
20
+ def forward(self, output1, output2, label):
21
+ euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
22
+ loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
23
+ (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
24
+ return loss_contrastive
25
+
26
+
27
+ def init_weight(m):
28
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
29
+ nn.init.xavier_normal_(m.weight)
30
+ # m.bias.data.fill_(0.01)
31
+ if m.bias is not None:
32
+ nn.init.constant_(m.bias, 0)
33
+
34
+
35
+ def reparameterize(mu, logvar):
36
+ s_var = logvar.mul(0.5).exp_()
37
+ eps = s_var.data.new(s_var.size()).normal_()
38
+ return eps.mul(s_var).add_(mu)
39
+
40
+
41
+ # batch_size, dimension and position
42
+ # output: (batch_size, dim)
43
+ def positional_encoding(batch_size, dim, pos):
44
+ assert batch_size == pos.shape[0]
45
+ positions_enc = np.array([
46
+ [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
47
+ for j in range(batch_size)
48
+ ], dtype=np.float32)
49
+ positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
50
+ positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
51
+ return torch.from_numpy(positions_enc).float()
52
+
53
+
54
+ def get_padding_mask(batch_size, seq_len, cap_lens):
55
+ cap_lens = cap_lens.data.tolist()
56
+ mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
57
+ for i, cap_len in enumerate(cap_lens):
58
+ mask_2d[i, :, :cap_len] = 0
59
+ return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
60
+
61
+
62
+ class PositionalEncoding(nn.Module):
63
+
64
+ def __init__(self, d_model, max_len=300):
65
+ super(PositionalEncoding, self).__init__()
66
+
67
+ pe = torch.zeros(max_len, d_model)
68
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
69
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
70
+ pe[:, 0::2] = torch.sin(position * div_term)
71
+ pe[:, 1::2] = torch.cos(position * div_term)
72
+ # pe = pe.unsqueeze(0).transpose(0, 1)
73
+ self.register_buffer('pe', pe)
74
+
75
+ def forward(self, pos):
76
+ return self.pe[pos]
77
+
78
+
79
+ class MovementConvEncoder(nn.Module):
80
+ def __init__(self, input_size, hidden_size, output_size):
81
+ super(MovementConvEncoder, self).__init__()
82
+ self.main = nn.Sequential(
83
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
84
+ nn.Dropout(0.2, inplace=True),
85
+ nn.LeakyReLU(0.2, inplace=True),
86
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
87
+ nn.Dropout(0.2, inplace=True),
88
+ nn.LeakyReLU(0.2, inplace=True),
89
+ )
90
+ self.out_net = nn.Linear(output_size, output_size)
91
+ self.main.apply(init_weight)
92
+ self.out_net.apply(init_weight)
93
+
94
+ def forward(self, inputs):
95
+ inputs = inputs.permute(0, 2, 1)
96
+ outputs = self.main(inputs).permute(0, 2, 1)
97
+ # print(outputs.shape)
98
+ return self.out_net(outputs)
99
+
100
+
101
+ class MovementConvDecoder(nn.Module):
102
+ def __init__(self, input_size, hidden_size, output_size):
103
+ super(MovementConvDecoder, self).__init__()
104
+ self.main = nn.Sequential(
105
+ nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
106
+ # nn.Dropout(0.2, inplace=True),
107
+ nn.LeakyReLU(0.2, inplace=True),
108
+ nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
109
+ # nn.Dropout(0.2, inplace=True),
110
+ nn.LeakyReLU(0.2, inplace=True),
111
+ )
112
+ self.out_net = nn.Linear(output_size, output_size)
113
+
114
+ self.main.apply(init_weight)
115
+ self.out_net.apply(init_weight)
116
+
117
+ def forward(self, inputs):
118
+ inputs = inputs.permute(0, 2, 1)
119
+ outputs = self.main(inputs).permute(0, 2, 1)
120
+ return self.out_net(outputs)
121
+
122
+
123
+ class TextVAEDecoder(nn.Module):
124
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
125
+ super(TextVAEDecoder, self).__init__()
126
+ self.input_size = input_size
127
+ self.output_size = output_size
128
+ self.hidden_size = hidden_size
129
+ self.n_layers = n_layers
130
+ self.emb = nn.Sequential(
131
+ nn.Linear(input_size, hidden_size),
132
+ nn.LayerNorm(hidden_size),
133
+ nn.LeakyReLU(0.2, inplace=True))
134
+
135
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
136
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
137
+ self.positional_encoder = PositionalEncoding(hidden_size)
138
+
139
+
140
+ self.output = nn.Sequential(
141
+ nn.Linear(hidden_size, hidden_size),
142
+ nn.LayerNorm(hidden_size),
143
+ nn.LeakyReLU(0.2, inplace=True),
144
+ nn.Linear(hidden_size, output_size)
145
+ )
146
+
147
+ #
148
+ # self.output = nn.Sequential(
149
+ # nn.Linear(hidden_size, hidden_size),
150
+ # nn.LayerNorm(hidden_size),
151
+ # nn.LeakyReLU(0.2, inplace=True),
152
+ # nn.Linear(hidden_size, output_size-4)
153
+ # )
154
+
155
+ # self.contact_net = nn.Sequential(
156
+ # nn.Linear(output_size-4, 64),
157
+ # nn.LayerNorm(64),
158
+ # nn.LeakyReLU(0.2, inplace=True),
159
+ # nn.Linear(64, 4)
160
+ # )
161
+
162
+ self.output.apply(init_weight)
163
+ self.emb.apply(init_weight)
164
+ self.z2init.apply(init_weight)
165
+ # self.contact_net.apply(init_weight)
166
+
167
+ def get_init_hidden(self, latent):
168
+ hidden = self.z2init(latent)
169
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
170
+ return list(hidden)
171
+
172
+ def forward(self, inputs, last_pred, hidden, p):
173
+ h_in = self.emb(inputs)
174
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
175
+ h_in = h_in + pos_enc
176
+ for i in range(self.n_layers):
177
+ # print(h_in.shape)
178
+ hidden[i] = self.gru[i](h_in, hidden[i])
179
+ h_in = hidden[i]
180
+ pose_pred = self.output(h_in)
181
+ # pose_pred = self.output(h_in) + last_pred.detach()
182
+ # contact = self.contact_net(pose_pred)
183
+ # return torch.cat([pose_pred, contact], dim=-1), hidden
184
+ return pose_pred, hidden
185
+
186
+
187
+ class TextDecoder(nn.Module):
188
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
189
+ super(TextDecoder, self).__init__()
190
+ self.input_size = input_size
191
+ self.output_size = output_size
192
+ self.hidden_size = hidden_size
193
+ self.n_layers = n_layers
194
+ self.emb = nn.Sequential(
195
+ nn.Linear(input_size, hidden_size),
196
+ nn.LayerNorm(hidden_size),
197
+ nn.LeakyReLU(0.2, inplace=True))
198
+
199
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
200
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
201
+ self.positional_encoder = PositionalEncoding(hidden_size)
202
+
203
+ self.mu_net = nn.Linear(hidden_size, output_size)
204
+ self.logvar_net = nn.Linear(hidden_size, output_size)
205
+
206
+ self.emb.apply(init_weight)
207
+ self.z2init.apply(init_weight)
208
+ self.mu_net.apply(init_weight)
209
+ self.logvar_net.apply(init_weight)
210
+
211
+ def get_init_hidden(self, latent):
212
+
213
+ hidden = self.z2init(latent)
214
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
215
+
216
+ return list(hidden)
217
+
218
+ def forward(self, inputs, hidden, p):
219
+ # print(inputs.shape)
220
+ x_in = self.emb(inputs)
221
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
222
+ x_in = x_in + pos_enc
223
+
224
+ for i in range(self.n_layers):
225
+ hidden[i] = self.gru[i](x_in, hidden[i])
226
+ h_in = hidden[i]
227
+ mu = self.mu_net(h_in)
228
+ logvar = self.logvar_net(h_in)
229
+ z = reparameterize(mu, logvar)
230
+ return z, mu, logvar, hidden
231
+
232
+ class AttLayer(nn.Module):
233
+ def __init__(self, query_dim, key_dim, value_dim):
234
+ super(AttLayer, self).__init__()
235
+ self.W_q = nn.Linear(query_dim, value_dim)
236
+ self.W_k = nn.Linear(key_dim, value_dim, bias=False)
237
+ self.W_v = nn.Linear(key_dim, value_dim)
238
+
239
+ self.softmax = nn.Softmax(dim=1)
240
+ self.dim = value_dim
241
+
242
+ self.W_q.apply(init_weight)
243
+ self.W_k.apply(init_weight)
244
+ self.W_v.apply(init_weight)
245
+
246
+ def forward(self, query, key_mat):
247
+ '''
248
+ query (batch, query_dim)
249
+ key (batch, seq_len, key_dim)
250
+ '''
251
+ # print(query.shape)
252
+ query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1)
253
+ val_set = self.W_v(key_mat) # (batch, seq_len, value_dim)
254
+ key_set = self.W_k(key_mat) # (batch, seq_len, value_dim)
255
+
256
+ weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim)
257
+
258
+ co_weights = self.softmax(weights) # (batch, seq_len, 1)
259
+ values = val_set * co_weights # (batch, seq_len, value_dim)
260
+ pred = values.sum(dim=1) # (batch, value_dim)
261
+ return pred, co_weights
262
+
263
+ def short_cut(self, querys, keys):
264
+ return self.W_q(querys), self.W_k(keys)
265
+
266
+
267
+ class TextEncoderBiGRU(nn.Module):
268
+ def __init__(self, word_size, pos_size, hidden_size, device):
269
+ super(TextEncoderBiGRU, self).__init__()
270
+ self.device = device
271
+
272
+ self.pos_emb = nn.Linear(pos_size, word_size)
273
+ self.input_emb = nn.Linear(word_size, hidden_size)
274
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
275
+ # self.linear2 = nn.Linear(hidden_size, output_size)
276
+
277
+ self.input_emb.apply(init_weight)
278
+ self.pos_emb.apply(init_weight)
279
+ # self.linear2.apply(init_weight)
280
+ # self.batch_size = batch_size
281
+ self.hidden_size = hidden_size
282
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
283
+
284
+ # input(batch_size, seq_len, dim)
285
+ def forward(self, word_embs, pos_onehot, cap_lens):
286
+ num_samples = word_embs.shape[0]
287
+
288
+ pos_embs = self.pos_emb(pos_onehot)
289
+ inputs = word_embs + pos_embs
290
+ input_embs = self.input_emb(inputs)
291
+ hidden = self.hidden.repeat(1, num_samples, 1)
292
+
293
+ cap_lens = cap_lens.data.tolist()
294
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
295
+
296
+ gru_seq, gru_last = self.gru(emb, hidden)
297
+
298
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
299
+ gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0]
300
+ forward_seq = gru_seq[..., :self.hidden_size]
301
+ backward_seq = gru_seq[..., self.hidden_size:].clone()
302
+
303
+ # Concate the forward and backward word embeddings
304
+ for i, length in enumerate(cap_lens):
305
+ backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1])
306
+ gru_seq = torch.cat([forward_seq, backward_seq], dim=-1)
307
+
308
+ return gru_seq, gru_last
309
+
310
+
311
+ class TextEncoderBiGRUCo(nn.Module):
312
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
313
+ super(TextEncoderBiGRUCo, self).__init__()
314
+ self.device = device
315
+
316
+ self.pos_emb = nn.Linear(pos_size, word_size)
317
+ self.input_emb = nn.Linear(word_size, hidden_size)
318
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
319
+ self.output_net = nn.Sequential(
320
+ nn.Linear(hidden_size * 2, hidden_size),
321
+ nn.LayerNorm(hidden_size),
322
+ nn.LeakyReLU(0.2, inplace=True),
323
+ nn.Linear(hidden_size, output_size)
324
+ )
325
+
326
+ self.input_emb.apply(init_weight)
327
+ self.pos_emb.apply(init_weight)
328
+ self.output_net.apply(init_weight)
329
+ # self.linear2.apply(init_weight)
330
+ # self.batch_size = batch_size
331
+ self.hidden_size = hidden_size
332
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
333
+
334
+ # input(batch_size, seq_len, dim)
335
+ def forward(self, word_embs, pos_onehot, cap_lens):
336
+ num_samples = word_embs.shape[0]
337
+
338
+ pos_embs = self.pos_emb(pos_onehot)
339
+ inputs = word_embs + pos_embs
340
+ input_embs = self.input_emb(inputs)
341
+ hidden = self.hidden.repeat(1, num_samples, 1)
342
+
343
+ cap_lens = cap_lens.data.tolist()
344
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
345
+
346
+ gru_seq, gru_last = self.gru(emb, hidden)
347
+
348
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
349
+
350
+ return self.output_net(gru_last)
351
+
352
+
353
+ class MotionEncoderBiGRUCo(nn.Module):
354
+ def __init__(self, input_size, hidden_size, output_size, device):
355
+ super(MotionEncoderBiGRUCo, self).__init__()
356
+ self.device = device
357
+
358
+ self.input_emb = nn.Linear(input_size, hidden_size)
359
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
360
+ self.output_net = nn.Sequential(
361
+ nn.Linear(hidden_size*2, hidden_size),
362
+ nn.LayerNorm(hidden_size),
363
+ nn.LeakyReLU(0.2, inplace=True),
364
+ nn.Linear(hidden_size, output_size)
365
+ )
366
+
367
+ self.input_emb.apply(init_weight)
368
+ self.output_net.apply(init_weight)
369
+ self.hidden_size = hidden_size
370
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
371
+
372
+ # input(batch_size, seq_len, dim)
373
+ def forward(self, inputs, m_lens):
374
+ num_samples = inputs.shape[0]
375
+
376
+ input_embs = self.input_emb(inputs)
377
+ hidden = self.hidden.repeat(1, num_samples, 1)
378
+
379
+ cap_lens = m_lens.data.tolist()
380
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
381
+
382
+ gru_seq, gru_last = self.gru(emb, hidden)
383
+
384
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
385
+
386
+ return self.output_net(gru_last)
387
+
388
+
389
+ class MotionLenEstimatorBiGRU(nn.Module):
390
+ def __init__(self, word_size, pos_size, hidden_size, output_size):
391
+ super(MotionLenEstimatorBiGRU, self).__init__()
392
+
393
+ self.pos_emb = nn.Linear(pos_size, word_size)
394
+ self.input_emb = nn.Linear(word_size, hidden_size)
395
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
396
+ nd = 512
397
+ self.output = nn.Sequential(
398
+ nn.Linear(hidden_size*2, nd),
399
+ nn.LayerNorm(nd),
400
+ nn.LeakyReLU(0.2, inplace=True),
401
+
402
+ nn.Linear(nd, nd // 2),
403
+ nn.LayerNorm(nd // 2),
404
+ nn.LeakyReLU(0.2, inplace=True),
405
+
406
+ nn.Linear(nd // 2, nd // 4),
407
+ nn.LayerNorm(nd // 4),
408
+ nn.LeakyReLU(0.2, inplace=True),
409
+
410
+ nn.Linear(nd // 4, output_size)
411
+ )
412
+ # self.linear2 = nn.Linear(hidden_size, output_size)
413
+
414
+ self.input_emb.apply(init_weight)
415
+ self.pos_emb.apply(init_weight)
416
+ self.output.apply(init_weight)
417
+ # self.linear2.apply(init_weight)
418
+ # self.batch_size = batch_size
419
+ self.hidden_size = hidden_size
420
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
421
+
422
+ # input(batch_size, seq_len, dim)
423
+ def forward(self, word_embs, pos_onehot, cap_lens):
424
+ num_samples = word_embs.shape[0]
425
+
426
+ pos_embs = self.pos_emb(pos_onehot)
427
+ inputs = word_embs + pos_embs
428
+ input_embs = self.input_emb(inputs)
429
+ hidden = self.hidden.repeat(1, num_samples, 1)
430
+
431
+ cap_lens = cap_lens.data.tolist()
432
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
433
+
434
+ gru_seq, gru_last = self.gru(emb, hidden)
435
+
436
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
437
+
438
+ return self.output(gru_last)
motion_diffusion_model/data_loaders/humanml/networks/trainers.py ADDED
@@ -0,0 +1,1089 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import random
4
+ from data_loaders.humanml.networks.modules import *
5
+ from torch.utils.data import DataLoader
6
+ import torch.optim as optim
7
+ from torch.nn.utils import clip_grad_norm_
8
+ # import tensorflow as tf
9
+ from collections import OrderedDict
10
+ from data_loaders.humanml.utils.utils import *
11
+ from os.path import join as pjoin
12
+ from data_loaders.humanml.data.dataset import collate_fn
13
+ import codecs as cs
14
+
15
+
16
+ class Logger(object):
17
+ def __init__(self, log_dir):
18
+ self.writer = tf.summary.create_file_writer(log_dir)
19
+
20
+ def scalar_summary(self, tag, value, step):
21
+ with self.writer.as_default():
22
+ tf.summary.scalar(tag, value, step=step)
23
+ self.writer.flush()
24
+
25
+ class DecompTrainerV3(object):
26
+ def __init__(self, args, movement_enc, movement_dec):
27
+ self.opt = args
28
+ self.movement_enc = movement_enc
29
+ self.movement_dec = movement_dec
30
+ self.device = args.device
31
+
32
+ if args.is_train:
33
+ self.logger = Logger(args.log_dir)
34
+ self.sml1_criterion = torch.nn.SmoothL1Loss()
35
+ self.l1_criterion = torch.nn.L1Loss()
36
+ self.mse_criterion = torch.nn.MSELoss()
37
+
38
+
39
+ @staticmethod
40
+ def zero_grad(opt_list):
41
+ for opt in opt_list:
42
+ opt.zero_grad()
43
+
44
+ @staticmethod
45
+ def clip_norm(network_list):
46
+ for network in network_list:
47
+ clip_grad_norm_(network.parameters(), 0.5)
48
+
49
+ @staticmethod
50
+ def step(opt_list):
51
+ for opt in opt_list:
52
+ opt.step()
53
+
54
+ def forward(self, batch_data):
55
+ motions = batch_data
56
+ self.motions = motions.detach().to(self.device).float()
57
+ self.latents = self.movement_enc(self.motions[..., :-4])
58
+ self.recon_motions = self.movement_dec(self.latents)
59
+
60
+ def backward(self):
61
+ self.loss_rec = self.l1_criterion(self.recon_motions, self.motions)
62
+ # self.sml1_criterion(self.recon_motions[:, 1:] - self.recon_motions[:, :-1],
63
+ # self.motions[:, 1:] - self.recon_motions[:, :-1])
64
+ self.loss_sparsity = torch.mean(torch.abs(self.latents))
65
+ self.loss_smooth = self.l1_criterion(self.latents[:, 1:], self.latents[:, :-1])
66
+ self.loss = self.loss_rec + self.loss_sparsity * self.opt.lambda_sparsity +\
67
+ self.loss_smooth*self.opt.lambda_smooth
68
+
69
+ def update(self):
70
+ # time0 = time.time()
71
+ self.zero_grad([self.opt_movement_enc, self.opt_movement_dec])
72
+ # time1 = time.time()
73
+ # print('\t Zero_grad Time: %.5f s' % (time1 - time0))
74
+ self.backward()
75
+ # time2 = time.time()
76
+ # print('\t Backward Time: %.5f s' % (time2 - time1))
77
+ self.loss.backward()
78
+ # time3 = time.time()
79
+ # print('\t Loss backward Time: %.5f s' % (time3 - time2))
80
+ # self.clip_norm([self.movement_enc, self.movement_dec])
81
+ # time4 = time.time()
82
+ # print('\t Clip_norm Time: %.5f s' % (time4 - time3))
83
+ self.step([self.opt_movement_enc, self.opt_movement_dec])
84
+ # time5 = time.time()
85
+ # print('\t Step Time: %.5f s' % (time5 - time4))
86
+
87
+ loss_logs = OrderedDict({})
88
+ loss_logs['loss'] = self.loss_rec.item()
89
+ loss_logs['loss_rec'] = self.loss_rec.item()
90
+ loss_logs['loss_sparsity'] = self.loss_sparsity.item()
91
+ loss_logs['loss_smooth'] = self.loss_smooth.item()
92
+ return loss_logs
93
+
94
+ def save(self, file_name, ep, total_it):
95
+ state = {
96
+ 'movement_enc': self.movement_enc.state_dict(),
97
+ 'movement_dec': self.movement_dec.state_dict(),
98
+
99
+ 'opt_movement_enc': self.opt_movement_enc.state_dict(),
100
+ 'opt_movement_dec': self.opt_movement_dec.state_dict(),
101
+
102
+ 'ep': ep,
103
+ 'total_it': total_it,
104
+ }
105
+ torch.save(state, file_name)
106
+ return
107
+
108
+ def resume(self, model_dir):
109
+ checkpoint = torch.load(model_dir, map_location=self.device)
110
+
111
+ self.movement_dec.load_state_dict(checkpoint['movement_dec'])
112
+ self.movement_enc.load_state_dict(checkpoint['movement_enc'])
113
+
114
+ self.opt_movement_enc.load_state_dict(checkpoint['opt_movement_enc'])
115
+ self.opt_movement_dec.load_state_dict(checkpoint['opt_movement_dec'])
116
+
117
+ return checkpoint['ep'], checkpoint['total_it']
118
+
119
+ def train(self, train_dataloader, val_dataloader, plot_eval):
120
+ self.movement_enc.to(self.device)
121
+ self.movement_dec.to(self.device)
122
+
123
+ self.opt_movement_enc = optim.Adam(self.movement_enc.parameters(), lr=self.opt.lr)
124
+ self.opt_movement_dec = optim.Adam(self.movement_dec.parameters(), lr=self.opt.lr)
125
+
126
+ epoch = 0
127
+ it = 0
128
+
129
+ if self.opt.is_continue:
130
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
131
+ epoch, it = self.resume(model_dir)
132
+
133
+ start_time = time.time()
134
+ total_iters = self.opt.max_epoch * len(train_dataloader)
135
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader)))
136
+ val_loss = 0
137
+ logs = OrderedDict()
138
+ while epoch < self.opt.max_epoch:
139
+ # time0 = time.time()
140
+ for i, batch_data in enumerate(train_dataloader):
141
+ self.movement_dec.train()
142
+ self.movement_enc.train()
143
+
144
+ # time1 = time.time()
145
+ # print('DataLoader Time: %.5f s'%(time1-time0) )
146
+ self.forward(batch_data)
147
+ # time2 = time.time()
148
+ # print('Forward Time: %.5f s'%(time2-time1))
149
+ log_dict = self.update()
150
+ # time3 = time.time()
151
+ # print('Update Time: %.5f s' % (time3 - time2))
152
+ # time0 = time3
153
+ for k, v in log_dict.items():
154
+ if k not in logs:
155
+ logs[k] = v
156
+ else:
157
+ logs[k] += v
158
+
159
+ it += 1
160
+ if it % self.opt.log_every == 0:
161
+ mean_loss = OrderedDict({'val_loss': val_loss})
162
+ self.logger.scalar_summary('val_loss', val_loss, it)
163
+
164
+ for tag, value in logs.items():
165
+ self.logger.scalar_summary(tag, value / self.opt.log_every, it)
166
+ mean_loss[tag] = value / self.opt.log_every
167
+ logs = OrderedDict()
168
+ print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i)
169
+
170
+ if it % self.opt.save_latest == 0:
171
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
172
+
173
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
174
+
175
+ epoch += 1
176
+ if epoch % self.opt.save_every_e == 0:
177
+ self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, total_it=it)
178
+
179
+ print('Validation time:')
180
+
181
+ val_loss = 0
182
+ val_rec_loss = 0
183
+ val_sparcity_loss = 0
184
+ val_smooth_loss = 0
185
+ with torch.no_grad():
186
+ for i, batch_data in enumerate(val_dataloader):
187
+ self.forward(batch_data)
188
+ self.backward()
189
+ val_rec_loss += self.loss_rec.item()
190
+ val_smooth_loss += self.loss.item()
191
+ val_sparcity_loss += self.loss_sparsity.item()
192
+ val_smooth_loss += self.loss_smooth.item()
193
+ val_loss += self.loss.item()
194
+
195
+ val_loss = val_loss / (len(val_dataloader) + 1)
196
+ val_rec_loss = val_rec_loss / (len(val_dataloader) + 1)
197
+ val_sparcity_loss = val_sparcity_loss / (len(val_dataloader) + 1)
198
+ val_smooth_loss = val_smooth_loss / (len(val_dataloader) + 1)
199
+ print('Validation Loss: %.5f Reconstruction Loss: %.5f '
200
+ 'Sparsity Loss: %.5f Smooth Loss: %.5f' % (val_loss, val_rec_loss, val_sparcity_loss, \
201
+ val_smooth_loss))
202
+
203
+ if epoch % self.opt.eval_every_e == 0:
204
+ data = torch.cat([self.recon_motions[:4], self.motions[:4]], dim=0).detach().cpu().numpy()
205
+ save_dir = pjoin(self.opt.eval_dir, 'E%04d' % (epoch))
206
+ os.makedirs(save_dir, exist_ok=True)
207
+ plot_eval(data, save_dir)
208
+
209
+
210
+ # VAE Sequence Decoder/Prior/Posterior latent by latent
211
+ class CompTrainerV6(object):
212
+
213
+ def __init__(self, args, text_enc, seq_pri, seq_dec, att_layer, mov_dec, mov_enc=None, seq_post=None):
214
+ self.opt = args
215
+ self.text_enc = text_enc
216
+ self.seq_pri = seq_pri
217
+ self.att_layer = att_layer
218
+ self.device = args.device
219
+ self.seq_dec = seq_dec
220
+ self.mov_dec = mov_dec
221
+ self.mov_enc = mov_enc
222
+
223
+ if args.is_train:
224
+ self.seq_post = seq_post
225
+ # self.motion_dis
226
+ self.logger = Logger(args.log_dir)
227
+ self.l1_criterion = torch.nn.SmoothL1Loss()
228
+ self.gan_criterion = torch.nn.BCEWithLogitsLoss()
229
+ self.mse_criterion = torch.nn.MSELoss()
230
+
231
+ @staticmethod
232
+ def reparametrize(mu, logvar):
233
+ s_var = logvar.mul(0.5).exp_()
234
+ eps = s_var.data.new(s_var.size()).normal_()
235
+ return eps.mul(s_var).add_(mu)
236
+
237
+ @staticmethod
238
+ def ones_like(tensor, val=1.):
239
+ return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)
240
+
241
+ @staticmethod
242
+ def zeros_like(tensor, val=0.):
243
+ return torch.FloatTensor(tensor.size()).fill_(val).to(tensor.device).requires_grad_(False)
244
+
245
+ @staticmethod
246
+ def zero_grad(opt_list):
247
+ for opt in opt_list:
248
+ opt.zero_grad()
249
+
250
+ @staticmethod
251
+ def clip_norm(network_list):
252
+ for network in network_list:
253
+ clip_grad_norm_(network.parameters(), 0.5)
254
+
255
+ @staticmethod
256
+ def step(opt_list):
257
+ for opt in opt_list:
258
+ opt.step()
259
+
260
+ @staticmethod
261
+ def kl_criterion(mu1, logvar1, mu2, logvar2):
262
+ # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
263
+ # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
264
+ sigma1 = logvar1.mul(0.5).exp()
265
+ sigma2 = logvar2.mul(0.5).exp()
266
+ kld = torch.log(sigma2 / sigma1) + (torch.exp(logvar1) + (mu1 - mu2) ** 2) / (
267
+ 2 * torch.exp(logvar2)) - 1 / 2
268
+ return kld.sum() / mu1.shape[0]
269
+
270
+ @staticmethod
271
+ def kl_criterion_unit(mu, logvar):
272
+ # KL( N(mu1, sigma2_1) || N(mu_2, sigma2_2))
273
+ # loss = log(sigma2/sigma1) + (sigma1^2 + (mu1 - mu2)^2)/(2*sigma2^2) - 1/2
274
+ kld = ((torch.exp(logvar) + mu ** 2) - logvar - 1) / 2
275
+ return kld.sum() / mu.shape[0]
276
+
277
+ def forward(self, batch_data, tf_ratio, mov_len, eval_mode=False):
278
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens = batch_data
279
+ word_emb = word_emb.detach().to(self.device).float()
280
+ pos_ohot = pos_ohot.detach().to(self.device).float()
281
+ motions = motions.detach().to(self.device).float()
282
+ self.cap_lens = cap_lens
283
+ self.caption = caption
284
+
285
+ # print(motions.shape)
286
+ # (batch_size, motion_len, pose_dim)
287
+ self.motions = motions
288
+
289
+ '''Movement Encoding'''
290
+ self.movements = self.mov_enc(self.motions[..., :-4]).detach()
291
+ # Initially input a mean vector
292
+ mov_in = self.mov_enc(
293
+ torch.zeros((self.motions.shape[0], self.opt.unit_length, self.motions.shape[-1] - 4), device=self.device)
294
+ ).squeeze(1).detach()
295
+ assert self.movements.shape[1] == mov_len
296
+
297
+ teacher_force = True if random.random() < tf_ratio else False
298
+
299
+ '''Text Encoding'''
300
+ # time0 = time.time()
301
+ # text_input = torch.cat([word_emb, pos_ohot], dim=-1)
302
+ word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens)
303
+ # print(word_hids.shape, hidden.shape)
304
+
305
+ if self.opt.text_enc_mod == 'bigru':
306
+ hidden_pos = self.seq_post.get_init_hidden(hidden)
307
+ hidden_pri = self.seq_pri.get_init_hidden(hidden)
308
+ hidden_dec = self.seq_dec.get_init_hidden(hidden)
309
+ elif self.opt.text_enc_mod == 'transformer':
310
+ hidden_pos = self.seq_post.get_init_hidden(hidden.detach())
311
+ hidden_pri = self.seq_pri.get_init_hidden(hidden.detach())
312
+ hidden_dec = self.seq_dec.get_init_hidden(hidden)
313
+
314
+ mus_pri = []
315
+ logvars_pri = []
316
+ mus_post = []
317
+ logvars_post = []
318
+ fake_mov_batch = []
319
+
320
+ query_input = []
321
+
322
+ # time1 = time.time()
323
+ # print("\t Text Encoder Cost:%5f" % (time1 - time0))
324
+ # print(self.movements.shape)
325
+
326
+ for i in range(mov_len):
327
+ # print("\t Sequence Measure")
328
+ # print(mov_in.shape)
329
+ mov_tgt = self.movements[:, i]
330
+ '''Local Attention Vector'''
331
+ att_vec, _ = self.att_layer(hidden_dec[-1], word_hids)
332
+ query_input.append(hidden_dec[-1])
333
+
334
+ tta = m_lens // self.opt.unit_length - i
335
+
336
+ if self.opt.text_enc_mod == 'bigru':
337
+ pos_in = torch.cat([mov_in, mov_tgt, att_vec], dim=-1)
338
+ pri_in = torch.cat([mov_in, att_vec], dim=-1)
339
+
340
+ elif self.opt.text_enc_mod == 'transformer':
341
+ pos_in = torch.cat([mov_in, mov_tgt, att_vec.detach()], dim=-1)
342
+ pri_in = torch.cat([mov_in, att_vec.detach()], dim=-1)
343
+
344
+ '''Posterior'''
345
+ z_pos, mu_pos, logvar_pos, hidden_pos = self.seq_post(pos_in, hidden_pos, tta)
346
+
347
+ '''Prior'''
348
+ z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta)
349
+
350
+ '''Decoder'''
351
+ if eval_mode:
352
+ dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1)
353
+ else:
354
+ dec_in = torch.cat([mov_in, att_vec, z_pos], dim=-1)
355
+ fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta)
356
+
357
+ # print(fake_mov.shape)
358
+
359
+ mus_post.append(mu_pos)
360
+ logvars_post.append(logvar_pos)
361
+ mus_pri.append(mu_pri)
362
+ logvars_pri.append(logvar_pri)
363
+ fake_mov_batch.append(fake_mov.unsqueeze(1))
364
+
365
+ if teacher_force:
366
+ mov_in = self.movements[:, i].detach()
367
+ else:
368
+ mov_in = fake_mov.detach()
369
+
370
+
371
+ self.fake_movements = torch.cat(fake_mov_batch, dim=1)
372
+
373
+ # print(self.fake_movements.shape)
374
+
375
+ self.fake_motions = self.mov_dec(self.fake_movements)
376
+
377
+ self.mus_post = torch.cat(mus_post, dim=0)
378
+ self.mus_pri = torch.cat(mus_pri, dim=0)
379
+ self.logvars_post = torch.cat(logvars_post, dim=0)
380
+ self.logvars_pri = torch.cat(logvars_pri, dim=0)
381
+
382
+ def generate(self, word_emb, pos_ohot, cap_lens, m_lens, mov_len, dim_pose):
383
+ word_emb = word_emb.detach().to(self.device).float()
384
+ pos_ohot = pos_ohot.detach().to(self.device).float()
385
+ self.cap_lens = cap_lens
386
+
387
+ # print(motions.shape)
388
+ # (batch_size, motion_len, pose_dim)
389
+
390
+ '''Movement Encoding'''
391
+ # Initially input a mean vector
392
+ mov_in = self.mov_enc(
393
+ torch.zeros((word_emb.shape[0], self.opt.unit_length, dim_pose - 4), device=self.device)
394
+ ).squeeze(1).detach()
395
+
396
+ '''Text Encoding'''
397
+ # time0 = time.time()
398
+ # text_input = torch.cat([word_emb, pos_ohot], dim=-1)
399
+ word_hids, hidden = self.text_enc(word_emb, pos_ohot, cap_lens)
400
+ # print(word_hids.shape, hidden.shape)
401
+
402
+ hidden_pri = self.seq_pri.get_init_hidden(hidden)
403
+ hidden_dec = self.seq_dec.get_init_hidden(hidden)
404
+
405
+ mus_pri = []
406
+ logvars_pri = []
407
+ fake_mov_batch = []
408
+ att_wgt = []
409
+
410
+ # time1 = time.time()
411
+ # print("\t Text Encoder Cost:%5f" % (time1 - time0))
412
+ # print(self.movements.shape)
413
+
414
+ for i in range(mov_len):
415
+ # print("\t Sequence Measure")
416
+ # print(mov_in.shape)
417
+ '''Local Attention Vector'''
418
+ att_vec, co_weights = self.att_layer(hidden_dec[-1], word_hids)
419
+
420
+ tta = m_lens // self.opt.unit_length - i
421
+ # tta = m_lens - i
422
+
423
+ '''Prior'''
424
+ pri_in = torch.cat([mov_in, att_vec], dim=-1)
425
+ z_pri, mu_pri, logvar_pri, hidden_pri = self.seq_pri(pri_in, hidden_pri, tta)
426
+
427
+ '''Decoder'''
428
+ dec_in = torch.cat([mov_in, att_vec, z_pri], dim=-1)
429
+
430
+ fake_mov, hidden_dec = self.seq_dec(dec_in, mov_in, hidden_dec, tta)
431
+
432
+ # print(fake_mov.shape)
433
+ mus_pri.append(mu_pri)
434
+ logvars_pri.append(logvar_pri)
435
+ fake_mov_batch.append(fake_mov.unsqueeze(1))
436
+ att_wgt.append(co_weights)
437
+
438
+ mov_in = fake_mov.detach()
439
+
440
+ fake_movements = torch.cat(fake_mov_batch, dim=1)
441
+ att_wgts = torch.cat(att_wgt, dim=-1)
442
+
443
+ # print(self.fake_movements.shape)
444
+
445
+ fake_motions = self.mov_dec(fake_movements)
446
+
447
+ mus_pri = torch.cat(mus_pri, dim=0)
448
+ logvars_pri = torch.cat(logvars_pri, dim=0)
449
+
450
+ return fake_motions, mus_pri, att_wgts
451
+
452
+ def backward_G(self):
453
+ self.loss_mot_rec = self.l1_criterion(self.fake_motions, self.motions)
454
+ self.loss_mov_rec = self.l1_criterion(self.fake_movements, self.movements)
455
+
456
+ self.loss_kld = self.kl_criterion(self.mus_post, self.logvars_post, self.mus_pri, self.logvars_pri)
457
+
458
+ self.loss_gen = self.loss_mot_rec * self.opt.lambda_rec_mov + self.loss_mov_rec * self.opt.lambda_rec_mot + \
459
+ self.loss_kld * self.opt.lambda_kld
460
+ loss_logs = OrderedDict({})
461
+ loss_logs['loss_gen'] = self.loss_gen.item()
462
+ loss_logs['loss_mot_rec'] = self.loss_mot_rec.item()
463
+ loss_logs['loss_mov_rec'] = self.loss_mov_rec.item()
464
+ loss_logs['loss_kld'] = self.loss_kld.item()
465
+
466
+ return loss_logs
467
+ # self.loss_gen = self.loss_rec_mov
468
+
469
+ # self.loss_gen = self.loss_rec_mov * self.opt.lambda_rec_mov + self.loss_rec_mot + \
470
+ # self.loss_kld * self.opt.lambda_kld + \
471
+ # self.loss_mtgan_G * self.opt.lambda_gan_mt + self.loss_mvgan_G * self.opt.lambda_gan_mv
472
+
473
+
474
+ def update(self):
475
+
476
+ self.zero_grad([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post,
477
+ self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec])
478
+ # time2_0 = time.time()
479
+ # print("\t\t Zero Grad:%5f" % (time2_0 - time1))
480
+ loss_logs = self.backward_G()
481
+
482
+ # time2_1 = time.time()
483
+ # print("\t\t Backward_G :%5f" % (time2_1 - time2_0))
484
+ self.loss_gen.backward()
485
+
486
+ # time2_2 = time.time()
487
+ # print("\t\t Backward :%5f" % (time2_2 - time2_1))
488
+ self.clip_norm([self.text_enc, self.seq_dec, self.seq_post, self.seq_pri,
489
+ self.att_layer, self.mov_dec])
490
+
491
+ # time2_3 = time.time()
492
+ # print("\t\t Clip Norm :%5f" % (time2_3 - time2_2))
493
+ self.step([self.opt_text_enc, self.opt_seq_dec, self.opt_seq_post,
494
+ self.opt_seq_pri, self.opt_att_layer, self.opt_mov_dec])
495
+
496
+ # time2_4 = time.time()
497
+ # print("\t\t Step :%5f" % (time2_4 - time2_3))
498
+
499
+ # time2 = time.time()
500
+ # print("\t Update Generator Cost:%5f" % (time2 - time1))
501
+
502
+ # self.zero_grad([self.opt_att_layer])
503
+ # self.backward_Att()
504
+ # self.loss_lgan_G_.backward()
505
+ # self.clip_norm([self.att_layer])
506
+ # self.step([self.opt_att_layer])
507
+ # # time3 = time.time()
508
+ # # print("\t Update Att Cost:%5f" % (time3 - time2))
509
+
510
+ # self.loss_gen += self.loss_lgan_G_
511
+
512
+ return loss_logs
513
+
514
+ def to(self, device):
515
+ if self.opt.is_train:
516
+ self.gan_criterion.to(device)
517
+ self.mse_criterion.to(device)
518
+ self.l1_criterion.to(device)
519
+ self.seq_post.to(device)
520
+ self.mov_enc.to(device)
521
+ self.text_enc.to(device)
522
+ self.mov_dec.to(device)
523
+ self.seq_pri.to(device)
524
+ self.att_layer.to(device)
525
+ self.seq_dec.to(device)
526
+
527
+ def train_mode(self):
528
+ if self.opt.is_train:
529
+ self.seq_post.train()
530
+ self.mov_enc.eval()
531
+ # self.motion_dis.train()
532
+ # self.movement_dis.train()
533
+ self.mov_dec.train()
534
+ self.text_enc.train()
535
+ self.seq_pri.train()
536
+ self.att_layer.train()
537
+ self.seq_dec.train()
538
+
539
+
540
+ def eval_mode(self):
541
+ if self.opt.is_train:
542
+ self.seq_post.eval()
543
+ self.mov_enc.eval()
544
+ # self.motion_dis.train()
545
+ # self.movement_dis.train()
546
+ self.mov_dec.eval()
547
+ self.text_enc.eval()
548
+ self.seq_pri.eval()
549
+ self.att_layer.eval()
550
+ self.seq_dec.eval()
551
+
552
+
553
+ def save(self, file_name, ep, total_it, sub_ep, sl_len):
554
+ state = {
555
+ # 'latent_dis': self.latent_dis.state_dict(),
556
+ # 'motion_dis': self.motion_dis.state_dict(),
557
+ 'text_enc': self.text_enc.state_dict(),
558
+ 'seq_post': self.seq_post.state_dict(),
559
+ 'att_layer': self.att_layer.state_dict(),
560
+ 'seq_dec': self.seq_dec.state_dict(),
561
+ 'seq_pri': self.seq_pri.state_dict(),
562
+ 'mov_enc': self.mov_enc.state_dict(),
563
+ 'mov_dec': self.mov_dec.state_dict(),
564
+
565
+ # 'opt_motion_dis': self.opt_motion_dis.state_dict(),
566
+ 'opt_mov_dec': self.opt_mov_dec.state_dict(),
567
+ 'opt_text_enc': self.opt_text_enc.state_dict(),
568
+ 'opt_seq_pri': self.opt_seq_pri.state_dict(),
569
+ 'opt_att_layer': self.opt_att_layer.state_dict(),
570
+ 'opt_seq_post': self.opt_seq_post.state_dict(),
571
+ 'opt_seq_dec': self.opt_seq_dec.state_dict(),
572
+ # 'opt_movement_dis': self.opt_movement_dis.state_dict(),
573
+
574
+ 'ep': ep,
575
+ 'total_it': total_it,
576
+ 'sub_ep': sub_ep,
577
+ 'sl_len': sl_len
578
+ }
579
+ torch.save(state, file_name)
580
+ return
581
+
582
+ def load(self, model_dir):
583
+ checkpoint = torch.load(model_dir, map_location=self.device)
584
+ if self.opt.is_train:
585
+ self.seq_post.load_state_dict(checkpoint['seq_post'])
586
+ # self.opt_latent_dis.load_state_dict(checkpoint['opt_latent_dis'])
587
+
588
+ self.opt_text_enc.load_state_dict(checkpoint['opt_text_enc'])
589
+ self.opt_seq_post.load_state_dict(checkpoint['opt_seq_post'])
590
+ self.opt_att_layer.load_state_dict(checkpoint['opt_att_layer'])
591
+ self.opt_seq_pri.load_state_dict(checkpoint['opt_seq_pri'])
592
+ self.opt_seq_dec.load_state_dict(checkpoint['opt_seq_dec'])
593
+ self.opt_mov_dec.load_state_dict(checkpoint['opt_mov_dec'])
594
+
595
+ self.text_enc.load_state_dict(checkpoint['text_enc'])
596
+ self.mov_dec.load_state_dict(checkpoint['mov_dec'])
597
+ self.seq_pri.load_state_dict(checkpoint['seq_pri'])
598
+ self.att_layer.load_state_dict(checkpoint['att_layer'])
599
+ self.seq_dec.load_state_dict(checkpoint['seq_dec'])
600
+ self.mov_enc.load_state_dict(checkpoint['mov_enc'])
601
+
602
+ return checkpoint['ep'], checkpoint['total_it'], checkpoint['sub_ep'], checkpoint['sl_len']
603
+
604
+ def train(self, train_dataset, val_dataset, plot_eval):
605
+ self.to(self.device)
606
+
607
+ self.opt_text_enc = optim.Adam(self.text_enc.parameters(), lr=self.opt.lr)
608
+ self.opt_seq_post = optim.Adam(self.seq_post.parameters(), lr=self.opt.lr)
609
+ self.opt_seq_pri = optim.Adam(self.seq_pri.parameters(), lr=self.opt.lr)
610
+ self.opt_att_layer = optim.Adam(self.att_layer.parameters(), lr=self.opt.lr)
611
+ self.opt_seq_dec = optim.Adam(self.seq_dec.parameters(), lr=self.opt.lr)
612
+
613
+ self.opt_mov_dec = optim.Adam(self.mov_dec.parameters(), lr=self.opt.lr*0.1)
614
+
615
+ epoch = 0
616
+ it = 0
617
+ if self.opt.dataset_name == 't2m':
618
+ schedule_len = 10
619
+ elif self.opt.dataset_name == 'kit':
620
+ schedule_len = 6
621
+ sub_ep = 0
622
+
623
+ if self.opt.is_continue:
624
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
625
+ epoch, it, sub_ep, schedule_len = self.load(model_dir)
626
+
627
+ invalid = True
628
+ start_time = time.time()
629
+ val_loss = 0
630
+ is_continue_and_first = self.opt.is_continue
631
+ while invalid:
632
+ train_dataset.reset_max_len(schedule_len * self.opt.unit_length)
633
+ val_dataset.reset_max_len(schedule_len * self.opt.unit_length)
634
+
635
+ train_loader = DataLoader(train_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4,
636
+ shuffle=True, collate_fn=collate_fn, pin_memory=True)
637
+ val_loader = DataLoader(val_dataset, batch_size=self.opt.batch_size, drop_last=True, num_workers=4,
638
+ shuffle=True, collate_fn=collate_fn, pin_memory=True)
639
+ print("Max_Length:%03d Training Split:%05d Validation Split:%04d" % (schedule_len, len(train_loader), len(val_loader)))
640
+
641
+ min_val_loss = np.inf
642
+ stop_cnt = 0
643
+ logs = OrderedDict()
644
+ for sub_epoch in range(sub_ep, self.opt.max_sub_epoch):
645
+ self.train_mode()
646
+
647
+ if is_continue_and_first:
648
+ sub_ep = 0
649
+ is_continue_and_first = False
650
+
651
+ tf_ratio = self.opt.tf_ratio
652
+
653
+ time1 = time.time()
654
+ for i, batch_data in enumerate(train_loader):
655
+ time2 = time.time()
656
+ self.forward(batch_data, tf_ratio, schedule_len)
657
+ time3 = time.time()
658
+ log_dict = self.update()
659
+ for k, v in log_dict.items():
660
+ if k not in logs:
661
+ logs[k] = v
662
+ else:
663
+ logs[k] += v
664
+ time4 = time.time()
665
+
666
+
667
+ it += 1
668
+ if it % self.opt.log_every == 0:
669
+ mean_loss = OrderedDict({'val_loss': val_loss})
670
+ self.logger.scalar_summary('val_loss', val_loss, it)
671
+ self.logger.scalar_summary('scheduled_length', schedule_len, it)
672
+
673
+ for tag, value in logs.items():
674
+ self.logger.scalar_summary(tag, value/self.opt.log_every, it)
675
+ mean_loss[tag] = value / self.opt.log_every
676
+ logs = OrderedDict()
677
+ print_current_loss(start_time, it, mean_loss, epoch, sub_epoch=sub_epoch, inner_iter=i,
678
+ tf_ratio=tf_ratio, sl_steps=schedule_len)
679
+
680
+ if it % self.opt.save_latest == 0:
681
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len)
682
+
683
+ time5 = time.time()
684
+ # print("Data Loader Time: %5f s" % ((time2 - time1)))
685
+ # print("Forward Time: %5f s" % ((time3 - time2)))
686
+ # print("Update Time: %5f s" % ((time4 - time3)))
687
+ # print('Per Iteration: %5f s' % ((time5 - time1)))
688
+ time1 = time5
689
+
690
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it, sub_epoch, schedule_len)
691
+
692
+ epoch += 1
693
+ if epoch % self.opt.save_every_e == 0:
694
+ self.save(pjoin(self.opt.model_dir, 'E%03d_SE%02d_SL%02d.tar'%(epoch, sub_epoch, schedule_len)),
695
+ epoch, total_it=it, sub_ep=sub_epoch, sl_len=schedule_len)
696
+
697
+ print('Validation time:')
698
+
699
+ loss_mot_rec = 0
700
+ loss_mov_rec = 0
701
+ loss_kld = 0
702
+ val_loss = 0
703
+ with torch.no_grad():
704
+ for i, batch_data in enumerate(val_loader):
705
+ self.forward(batch_data, 0, schedule_len)
706
+ self.backward_G()
707
+ loss_mot_rec += self.loss_mot_rec.item()
708
+ loss_mov_rec += self.loss_mov_rec.item()
709
+ loss_kld += self.loss_kld.item()
710
+ val_loss += self.loss_gen.item()
711
+
712
+ loss_mot_rec /= len(val_loader) + 1
713
+ loss_mov_rec /= len(val_loader) + 1
714
+ loss_kld /= len(val_loader) + 1
715
+ val_loss /= len(val_loader) + 1
716
+ print('Validation Loss: %.5f Movement Recon Loss: %.5f Motion Recon Loss: %.5f KLD Loss: %.5f:' %
717
+ (val_loss, loss_mov_rec, loss_mot_rec, loss_kld))
718
+
719
+ if epoch % self.opt.eval_every_e == 0:
720
+ reco_data = self.fake_motions[:4]
721
+ with torch.no_grad():
722
+ self.forward(batch_data, 0, schedule_len, eval_mode=True)
723
+ fake_data = self.fake_motions[:4]
724
+ gt_data = self.motions[:4]
725
+ data = torch.cat([fake_data, reco_data, gt_data], dim=0).cpu().numpy()
726
+ captions = self.caption[:4] * 3
727
+ save_dir = pjoin(self.opt.eval_dir, 'E%03d_SE%02d_SL%02d'%(epoch, sub_epoch, schedule_len))
728
+ os.makedirs(save_dir, exist_ok=True)
729
+ plot_eval(data, save_dir, captions)
730
+
731
+ # if cl_ratio == 1:
732
+ if val_loss < min_val_loss:
733
+ min_val_loss = val_loss
734
+ stop_cnt = 0
735
+ elif stop_cnt < self.opt.early_stop_count:
736
+ stop_cnt += 1
737
+ elif stop_cnt >= self.opt.early_stop_count:
738
+ break
739
+ if val_loss - min_val_loss >= 0.1:
740
+ break
741
+
742
+ schedule_len += 1
743
+
744
+ if schedule_len > 49:
745
+ invalid = False
746
+
747
+
748
+ class LengthEstTrainer(object):
749
+
750
+ def __init__(self, args, estimator):
751
+ self.opt = args
752
+ self.estimator = estimator
753
+ self.device = args.device
754
+
755
+ if args.is_train:
756
+ # self.motion_dis
757
+ self.logger = Logger(args.log_dir)
758
+ self.mul_cls_criterion = torch.nn.CrossEntropyLoss()
759
+
760
+ def resume(self, model_dir):
761
+ checkpoints = torch.load(model_dir, map_location=self.device)
762
+ self.estimator.load_state_dict(checkpoints['estimator'])
763
+ self.opt_estimator.load_state_dict(checkpoints['opt_estimator'])
764
+ return checkpoints['epoch'], checkpoints['iter']
765
+
766
+ def save(self, model_dir, epoch, niter):
767
+ state = {
768
+ 'estimator': self.estimator.state_dict(),
769
+ 'opt_estimator': self.opt_estimator.state_dict(),
770
+ 'epoch': epoch,
771
+ 'niter': niter,
772
+ }
773
+ torch.save(state, model_dir)
774
+
775
+ @staticmethod
776
+ def zero_grad(opt_list):
777
+ for opt in opt_list:
778
+ opt.zero_grad()
779
+
780
+ @staticmethod
781
+ def clip_norm(network_list):
782
+ for network in network_list:
783
+ clip_grad_norm_(network.parameters(), 0.5)
784
+
785
+ @staticmethod
786
+ def step(opt_list):
787
+ for opt in opt_list:
788
+ opt.step()
789
+
790
+ def train(self, train_dataloader, val_dataloader):
791
+ self.estimator.to(self.device)
792
+
793
+ self.opt_estimator = optim.Adam(self.estimator.parameters(), lr=self.opt.lr)
794
+
795
+ epoch = 0
796
+ it = 0
797
+
798
+ if self.opt.is_continue:
799
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
800
+ epoch, it = self.resume(model_dir)
801
+
802
+ start_time = time.time()
803
+ total_iters = self.opt.max_epoch * len(train_dataloader)
804
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader)))
805
+ val_loss = 0
806
+ min_val_loss = np.inf
807
+ logs = OrderedDict({'loss': 0})
808
+ while epoch < self.opt.max_epoch:
809
+ # time0 = time.time()
810
+ for i, batch_data in enumerate(train_dataloader):
811
+ self.estimator.train()
812
+
813
+ word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data
814
+ word_emb = word_emb.detach().to(self.device).float()
815
+ pos_ohot = pos_ohot.detach().to(self.device).float()
816
+
817
+ pred_dis = self.estimator(word_emb, pos_ohot, cap_lens)
818
+
819
+ self.zero_grad([self.opt_estimator])
820
+
821
+ gt_labels = m_lens // self.opt.unit_length
822
+ gt_labels = gt_labels.long().to(self.device)
823
+ # print(gt_labels)
824
+ # print(pred_dis)
825
+ loss = self.mul_cls_criterion(pred_dis, gt_labels)
826
+
827
+ loss.backward()
828
+
829
+ self.clip_norm([self.estimator])
830
+ self.step([self.opt_estimator])
831
+
832
+ logs['loss'] += loss.item()
833
+
834
+ it += 1
835
+ if it % self.opt.log_every == 0:
836
+ mean_loss = OrderedDict({'val_loss': val_loss})
837
+ self.logger.scalar_summary('val_loss', val_loss, it)
838
+
839
+ for tag, value in logs.items():
840
+ self.logger.scalar_summary(tag, value / self.opt.log_every, it)
841
+ mean_loss[tag] = value / self.opt.log_every
842
+ logs = OrderedDict({'loss': 0})
843
+ print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i)
844
+
845
+ if it % self.opt.save_latest == 0:
846
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
847
+
848
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
849
+
850
+ epoch += 1
851
+ if epoch % self.opt.save_every_e == 0:
852
+ self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it)
853
+
854
+ print('Validation time:')
855
+
856
+ val_loss = 0
857
+ with torch.no_grad():
858
+ for i, batch_data in enumerate(val_dataloader):
859
+ word_emb, pos_ohot, _, cap_lens, _, m_lens = batch_data
860
+ word_emb = word_emb.detach().to(self.device).float()
861
+ pos_ohot = pos_ohot.detach().to(self.device).float()
862
+
863
+ pred_dis = self.estimator(word_emb, pos_ohot, cap_lens)
864
+
865
+ gt_labels = m_lens // self.opt.unit_length
866
+ gt_labels = gt_labels.long().to(self.device)
867
+ loss = self.mul_cls_criterion(pred_dis, gt_labels)
868
+
869
+ val_loss += loss.item()
870
+
871
+ val_loss = val_loss / (len(val_dataloader) + 1)
872
+ print('Validation Loss: %.5f' % (val_loss))
873
+
874
+ if val_loss < min_val_loss:
875
+ self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
876
+ min_val_loss = val_loss
877
+
878
+
879
+ class TextMotionMatchTrainer(object):
880
+
881
+ def __init__(self, args, text_encoder, motion_encoder, movement_encoder):
882
+ self.opt = args
883
+ self.text_encoder = text_encoder
884
+ self.motion_encoder = motion_encoder
885
+ self.movement_encoder = movement_encoder
886
+ self.device = args.device
887
+
888
+ if args.is_train:
889
+ # self.motion_dis
890
+ self.logger = Logger(args.log_dir)
891
+ self.contrastive_loss = ContrastiveLoss(self.opt.negative_margin)
892
+
893
+ def resume(self, model_dir):
894
+ checkpoints = torch.load(model_dir, map_location=self.device)
895
+ self.text_encoder.load_state_dict(checkpoints['text_encoder'])
896
+ self.motion_encoder.load_state_dict(checkpoints['motion_encoder'])
897
+ self.movement_encoder.load_state_dict(checkpoints['movement_encoder'])
898
+
899
+ self.opt_text_encoder.load_state_dict(checkpoints['opt_text_encoder'])
900
+ self.opt_motion_encoder.load_state_dict(checkpoints['opt_motion_encoder'])
901
+ return checkpoints['epoch'], checkpoints['iter']
902
+
903
+ def save(self, model_dir, epoch, niter):
904
+ state = {
905
+ 'text_encoder': self.text_encoder.state_dict(),
906
+ 'motion_encoder': self.motion_encoder.state_dict(),
907
+ 'movement_encoder': self.movement_encoder.state_dict(),
908
+
909
+ 'opt_text_encoder': self.opt_text_encoder.state_dict(),
910
+ 'opt_motion_encoder': self.opt_motion_encoder.state_dict(),
911
+ 'epoch': epoch,
912
+ 'iter': niter,
913
+ }
914
+ torch.save(state, model_dir)
915
+
916
+ @staticmethod
917
+ def zero_grad(opt_list):
918
+ for opt in opt_list:
919
+ opt.zero_grad()
920
+
921
+ @staticmethod
922
+ def clip_norm(network_list):
923
+ for network in network_list:
924
+ clip_grad_norm_(network.parameters(), 0.5)
925
+
926
+ @staticmethod
927
+ def step(opt_list):
928
+ for opt in opt_list:
929
+ opt.step()
930
+
931
+ def to(self, device):
932
+ self.text_encoder.to(device)
933
+ self.motion_encoder.to(device)
934
+ self.movement_encoder.to(device)
935
+
936
+ def train_mode(self):
937
+ self.text_encoder.train()
938
+ self.motion_encoder.train()
939
+ self.movement_encoder.eval()
940
+
941
+ def forward(self, batch_data):
942
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, _ = batch_data
943
+ word_emb = word_emb.detach().to(self.device).float()
944
+ pos_ohot = pos_ohot.detach().to(self.device).float()
945
+ motions = motions.detach().to(self.device).float()
946
+
947
+ # Sort the length of motions in descending order, (length of text has been sorted)
948
+ self.align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
949
+ # print(self.align_idx)
950
+ # print(m_lens[self.align_idx])
951
+ motions = motions[self.align_idx]
952
+ m_lens = m_lens[self.align_idx]
953
+
954
+ '''Movement Encoding'''
955
+ movements = self.movement_encoder(motions[..., :-4]).detach()
956
+ m_lens = m_lens // self.opt.unit_length
957
+ self.motion_embedding = self.motion_encoder(movements, m_lens)
958
+
959
+ '''Text Encoding'''
960
+ # time0 = time.time()
961
+ # text_input = torch.cat([word_emb, pos_ohot], dim=-1)
962
+ self.text_embedding = self.text_encoder(word_emb, pos_ohot, cap_lens)
963
+ self.text_embedding = self.text_embedding.clone()[self.align_idx]
964
+
965
+
966
+ def backward(self):
967
+
968
+ batch_size = self.text_embedding.shape[0]
969
+ '''Positive pairs'''
970
+ pos_labels = torch.zeros(batch_size).to(self.text_embedding.device)
971
+ self.loss_pos = self.contrastive_loss(self.text_embedding, self.motion_embedding, pos_labels)
972
+
973
+ '''Negative Pairs, shifting index'''
974
+ neg_labels = torch.ones(batch_size).to(self.text_embedding.device)
975
+ shift = np.random.randint(0, batch_size-1)
976
+ new_idx = np.arange(shift, batch_size + shift) % batch_size
977
+ self.mis_motion_embedding = self.motion_embedding.clone()[new_idx]
978
+ self.loss_neg = self.contrastive_loss(self.text_embedding, self.mis_motion_embedding, neg_labels)
979
+ self.loss = self.loss_pos + self.loss_neg
980
+
981
+ loss_logs = OrderedDict({})
982
+ loss_logs['loss'] = self.loss.item()
983
+ loss_logs['loss_pos'] = self.loss_pos.item()
984
+ loss_logs['loss_neg'] = self.loss_neg.item()
985
+ return loss_logs
986
+
987
+
988
+ def update(self):
989
+
990
+ self.zero_grad([self.opt_motion_encoder, self.opt_text_encoder])
991
+ loss_logs = self.backward()
992
+ self.loss.backward()
993
+ self.clip_norm([self.text_encoder, self.motion_encoder])
994
+ self.step([self.opt_text_encoder, self.opt_motion_encoder])
995
+
996
+ return loss_logs
997
+
998
+
999
+ def train(self, train_dataloader, val_dataloader):
1000
+ self.to(self.device)
1001
+
1002
+ self.opt_motion_encoder = optim.Adam(self.motion_encoder.parameters(), lr=self.opt.lr)
1003
+ self.opt_text_encoder = optim.Adam(self.text_encoder.parameters(), lr=self.opt.lr)
1004
+
1005
+ epoch = 0
1006
+ it = 0
1007
+
1008
+ if self.opt.is_continue:
1009
+ model_dir = pjoin(self.opt.model_dir, 'latest.tar')
1010
+ epoch, it = self.resume(model_dir)
1011
+
1012
+ start_time = time.time()
1013
+ total_iters = self.opt.max_epoch * len(train_dataloader)
1014
+ print('Iters Per Epoch, Training: %04d, Validation: %03d' % (len(train_dataloader), len(val_dataloader)))
1015
+ val_loss = 0
1016
+ logs = OrderedDict()
1017
+
1018
+ min_val_loss = np.inf
1019
+ while epoch < self.opt.max_epoch:
1020
+ # time0 = time.time()
1021
+ for i, batch_data in enumerate(train_dataloader):
1022
+ self.train_mode()
1023
+
1024
+ self.forward(batch_data)
1025
+ # time3 = time.time()
1026
+ log_dict = self.update()
1027
+ for k, v in log_dict.items():
1028
+ if k not in logs:
1029
+ logs[k] = v
1030
+ else:
1031
+ logs[k] += v
1032
+
1033
+
1034
+ it += 1
1035
+ if it % self.opt.log_every == 0:
1036
+ mean_loss = OrderedDict({'val_loss': val_loss})
1037
+ self.logger.scalar_summary('val_loss', val_loss, it)
1038
+
1039
+ for tag, value in logs.items():
1040
+ self.logger.scalar_summary(tag, value / self.opt.log_every, it)
1041
+ mean_loss[tag] = value / self.opt.log_every
1042
+ logs = OrderedDict()
1043
+ print_current_loss_decomp(start_time, it, total_iters, mean_loss, epoch, i)
1044
+
1045
+ if it % self.opt.save_latest == 0:
1046
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
1047
+
1048
+ self.save(pjoin(self.opt.model_dir, 'latest.tar'), epoch, it)
1049
+
1050
+ epoch += 1
1051
+ if epoch % self.opt.save_every_e == 0:
1052
+ self.save(pjoin(self.opt.model_dir, 'E%04d.tar' % (epoch)), epoch, it)
1053
+
1054
+ print('Validation time:')
1055
+
1056
+ loss_pos_pair = 0
1057
+ loss_neg_pair = 0
1058
+ val_loss = 0
1059
+ with torch.no_grad():
1060
+ for i, batch_data in enumerate(val_dataloader):
1061
+ self.forward(batch_data)
1062
+ self.backward()
1063
+ loss_pos_pair += self.loss_pos.item()
1064
+ loss_neg_pair += self.loss_neg.item()
1065
+ val_loss += self.loss.item()
1066
+
1067
+ loss_pos_pair /= len(val_dataloader) + 1
1068
+ loss_neg_pair /= len(val_dataloader) + 1
1069
+ val_loss /= len(val_dataloader) + 1
1070
+ print('Validation Loss: %.5f Positive Loss: %.5f Negative Loss: %.5f' %
1071
+ (val_loss, loss_pos_pair, loss_neg_pair))
1072
+
1073
+ if val_loss < min_val_loss:
1074
+ self.save(pjoin(self.opt.model_dir, 'finest.tar'), epoch, it)
1075
+ min_val_loss = val_loss
1076
+
1077
+ if epoch % self.opt.eval_every_e == 0:
1078
+ pos_dist = F.pairwise_distance(self.text_embedding, self.motion_embedding)
1079
+ neg_dist = F.pairwise_distance(self.text_embedding, self.mis_motion_embedding)
1080
+
1081
+ pos_str = ' '.join(['%.3f' % (pos_dist[i]) for i in range(pos_dist.shape[0])])
1082
+ neg_str = ' '.join(['%.3f' % (neg_dist[i]) for i in range(neg_dist.shape[0])])
1083
+
1084
+ save_path = pjoin(self.opt.eval_dir, 'E%03d.txt' % (epoch))
1085
+ with cs.open(save_path, 'w') as f:
1086
+ f.write('Positive Pairs Distance\n')
1087
+ f.write(pos_str + '\n')
1088
+ f.write('Negative Pairs Distance\n')
1089
+ f.write(neg_str + '\n')
motion_diffusion_model/data_loaders/humanml/scripts/motion_process.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from os.path import join as pjoin
2
+
3
+ from data_loaders.humanml.common.skeleton import Skeleton
4
+ import numpy as np
5
+ import os
6
+ from data_loaders.humanml.common.quaternion import *
7
+ from data_loaders.humanml.utils.paramUtil import *
8
+
9
+ import torch
10
+ from tqdm import tqdm
11
+ from data_loaders.humanml_utils import HML_JOINT_NAMES, HML_EE_JOINT_NAMES
12
+
13
+ import random
14
+ from copy import copy, deepcopy
15
+
16
+ # positions (batch, joint_num, 3)
17
+ def uniform_skeleton(positions, target_offset):
18
+ src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
19
+ src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
20
+ src_offset = src_offset.numpy()
21
+ tgt_offset = target_offset.numpy()
22
+ # print(src_offset)
23
+ # print(tgt_offset)
24
+ '''Calculate Scale Ratio as the ratio of legs'''
25
+ src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
26
+ tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
27
+
28
+ scale_rt = tgt_leg_len / src_leg_len
29
+ # print(scale_rt)
30
+ src_root_pos = positions[:, 0]
31
+ tgt_root_pos = src_root_pos * scale_rt
32
+
33
+ '''Inverse Kinematics'''
34
+ quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
35
+ # print(quat_params.shape)
36
+
37
+ '''Forward Kinematics'''
38
+ src_skel.set_offset(target_offset)
39
+ new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
40
+ return new_joints
41
+
42
+
43
+ def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l):
44
+ global_positions = positions.copy()
45
+ """ Get Foot Contacts """
46
+
47
+ def foot_detect(positions, thres):
48
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
49
+
50
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
51
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
52
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
53
+ # feet_l_h = positions[:-1,fid_l,1]
54
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
55
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
56
+
57
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
58
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
59
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
60
+ # feet_r_h = positions[:-1,fid_r,1]
61
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
62
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
63
+ return feet_l, feet_r
64
+
65
+ #
66
+ feet_l, feet_r = foot_detect(positions, feet_thre)
67
+ # feet_l, feet_r = foot_detect(positions, 0.002)
68
+
69
+ '''Quaternion and Cartesian representation'''
70
+ r_rot = None
71
+
72
+ def get_rifke(positions):
73
+ '''Local pose'''
74
+ positions[..., 0] -= positions[:, 0:1, 0]
75
+ positions[..., 2] -= positions[:, 0:1, 2]
76
+ '''All pose face Z+'''
77
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
78
+ return positions
79
+
80
+ def get_quaternion(positions):
81
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
82
+ # (seq_len, joints_num, 4)
83
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
84
+
85
+ '''Fix Quaternion Discontinuity'''
86
+ quat_params = qfix(quat_params)
87
+ # (seq_len, 4)
88
+ r_rot = quat_params[:, 0].copy()
89
+ # print(r_rot[0])
90
+ '''Root Linear Velocity'''
91
+ # (seq_len - 1, 3)
92
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
93
+ # print(r_rot.shape, velocity.shape)
94
+ velocity = qrot_np(r_rot[1:], velocity)
95
+ '''Root Angular Velocity'''
96
+ # (seq_len - 1, 4)
97
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
98
+ quat_params[1:, 0] = r_velocity
99
+ # (seq_len, joints_num, 4)
100
+ return quat_params, r_velocity, velocity, r_rot
101
+
102
+ def get_cont6d_params(positions):
103
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
104
+ # (seq_len, joints_num, 4)
105
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
106
+
107
+ '''Quaternion to continuous 6D'''
108
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
109
+ # (seq_len, 4)
110
+ r_rot = quat_params[:, 0].copy()
111
+ # print(r_rot[0])
112
+ '''Root Linear Velocity'''
113
+ # (seq_len - 1, 3)
114
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
115
+ # print(r_rot.shape, velocity.shape)
116
+ velocity = qrot_np(r_rot[1:], velocity)
117
+ '''Root Angular Velocity'''
118
+ # (seq_len - 1, 4)
119
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
120
+ # (seq_len, joints_num, 4)
121
+ return cont_6d_params, r_velocity, velocity, r_rot
122
+
123
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
124
+ positions = get_rifke(positions)
125
+
126
+ # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
127
+ # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
128
+
129
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
130
+ # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
131
+ # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
132
+ # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
133
+ # plt.xlabel('x')
134
+ # plt.ylabel('z')
135
+ # plt.axis('equal')
136
+ # plt.show()
137
+
138
+ '''Root height'''
139
+ root_y = positions[:, 0, 1:2]
140
+
141
+ '''Root rotation and linear velocity'''
142
+ # (seq_len-1, 1) rotation velocity along y-axis
143
+ # (seq_len-1, 2) linear velovity on xz plane
144
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
145
+ l_velocity = velocity[:, [0, 2]]
146
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
147
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
148
+
149
+ '''Get Joint Rotation Representation'''
150
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
151
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
152
+
153
+ '''Get Joint Rotation Invariant Position Represention'''
154
+ # (seq_len, (joints_num-1)*3) local joint position
155
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
156
+
157
+ '''Get Joint Velocity Representation'''
158
+ # (seq_len-1, joints_num*3)
159
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
160
+ global_positions[1:] - global_positions[:-1])
161
+ local_vel = local_vel.reshape(len(local_vel), -1)
162
+
163
+ data = root_data
164
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
165
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
166
+ # print(dataset.shape, local_vel.shape)
167
+ data = np.concatenate([data, local_vel], axis=-1)
168
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
169
+
170
+ return data
171
+
172
+
173
+ def process_file(positions, feet_thre):
174
+ # (seq_len, joints_num, 3)
175
+ # '''Down Sample'''
176
+ # positions = positions[::ds_num]
177
+
178
+ '''Uniform Skeleton'''
179
+ positions = uniform_skeleton(positions, tgt_offsets)
180
+
181
+ '''Put on Floor'''
182
+ floor_height = positions.min(axis=0).min(axis=0)[1]
183
+ positions[:, :, 1] -= floor_height
184
+ # print(floor_height)
185
+
186
+ # plot_3d_motion("./positions_1.mp4", kinematic_chain, positions, 'title', fps=20)
187
+
188
+ '''XZ at origin'''
189
+ root_pos_init = positions[0]
190
+ root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
191
+ positions = positions - root_pose_init_xz
192
+
193
+ # '''Move the first pose to origin '''
194
+ # root_pos_init = positions[0]
195
+ # positions = positions - root_pos_init[0]
196
+
197
+ '''All initially face Z+'''
198
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
199
+ across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
200
+ across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
201
+ across = across1 + across2
202
+ across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
203
+
204
+ # forward (3,), rotate around y-axis
205
+ forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
206
+ # forward (3,)
207
+ forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
208
+
209
+ # print(forward_init)
210
+
211
+ target = np.array([[0, 0, 1]])
212
+ root_quat_init = qbetween_np(forward_init, target)
213
+ root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
214
+
215
+ positions_b = positions.copy()
216
+
217
+ positions = qrot_np(root_quat_init, positions)
218
+
219
+ # plot_3d_motion("./positions_2.mp4", kinematic_chain, positions, 'title', fps=20)
220
+
221
+ '''New ground truth positions'''
222
+ global_positions = positions.copy()
223
+
224
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
225
+ # plt.plot(positions[:, 0, 0], positions[:, 0, 2], marker='o', color='r')
226
+ # plt.xlabel('x')
227
+ # plt.ylabel('z')
228
+ # plt.axis('equal')
229
+ # plt.show()
230
+
231
+ """ Get Foot Contacts """
232
+
233
+ def foot_detect(positions, thres):
234
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
235
+
236
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
237
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
238
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
239
+ # feet_l_h = positions[:-1,fid_l,1]
240
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
241
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
242
+
243
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
244
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
245
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
246
+ # feet_r_h = positions[:-1,fid_r,1]
247
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
248
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
249
+ return feet_l, feet_r
250
+ #
251
+ feet_l, feet_r = foot_detect(positions, feet_thre)
252
+ # feet_l, feet_r = foot_detect(positions, 0.002)
253
+
254
+ '''Quaternion and Cartesian representation'''
255
+ r_rot = None
256
+
257
+ def get_rifke(positions):
258
+ '''Local pose'''
259
+ positions[..., 0] -= positions[:, 0:1, 0]
260
+ positions[..., 2] -= positions[:, 0:1, 2]
261
+ '''All pose face Z+'''
262
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
263
+ return positions
264
+
265
+ def get_quaternion(positions):
266
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
267
+ # (seq_len, joints_num, 4)
268
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
269
+
270
+ '''Fix Quaternion Discontinuity'''
271
+ quat_params = qfix(quat_params)
272
+ # (seq_len, 4)
273
+ r_rot = quat_params[:, 0].copy()
274
+ # print(r_rot[0])
275
+ '''Root Linear Velocity'''
276
+ # (seq_len - 1, 3)
277
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
278
+ # print(r_rot.shape, velocity.shape)
279
+ velocity = qrot_np(r_rot[1:], velocity)
280
+ '''Root Angular Velocity'''
281
+ # (seq_len - 1, 4)
282
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
283
+ quat_params[1:, 0] = r_velocity
284
+ # (seq_len, joints_num, 4)
285
+ return quat_params, r_velocity, velocity, r_rot
286
+
287
+ def get_cont6d_params(positions):
288
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
289
+ # (seq_len, joints_num, 4)
290
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
291
+
292
+ '''Quaternion to continuous 6D'''
293
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
294
+ # (seq_len, 4)
295
+ r_rot = quat_params[:, 0].copy()
296
+ # print(r_rot[0])
297
+ '''Root Linear Velocity'''
298
+ # (seq_len - 1, 3)
299
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
300
+ # print(r_rot.shape, velocity.shape)
301
+ velocity = qrot_np(r_rot[1:], velocity)
302
+ '''Root Angular Velocity'''
303
+ # (seq_len - 1, 4)
304
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
305
+ # (seq_len, joints_num, 4)
306
+ return cont_6d_params, r_velocity, velocity, r_rot
307
+
308
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
309
+ positions = get_rifke(positions)
310
+
311
+ # trejec = np.cumsum(np.concatenate([np.array([[0, 0, 0]]), velocity], axis=0), axis=0)
312
+ # r_rotations, r_pos = recover_ric_glo_np(r_velocity, velocity[:, [0, 2]])
313
+
314
+ # plt.plot(positions_b[:, 0, 0], positions_b[:, 0, 2], marker='*')
315
+ # plt.plot(ground_positions[:, 0, 0], ground_positions[:, 0, 2], marker='o', color='r')
316
+ # plt.plot(trejec[:, 0], trejec[:, 2], marker='^', color='g')
317
+ # plt.plot(r_pos[:, 0], r_pos[:, 2], marker='s', color='y')
318
+ # plt.xlabel('x')
319
+ # plt.ylabel('z')
320
+ # plt.axis('equal')
321
+ # plt.show()
322
+
323
+ '''Root height'''
324
+ root_y = positions[:, 0, 1:2]
325
+
326
+ '''Root rotation and linear velocity'''
327
+ # (seq_len-1, 1) rotation velocity along y-axis
328
+ # (seq_len-1, 2) linear velovity on xz plane
329
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
330
+ l_velocity = velocity[:, [0, 2]]
331
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
332
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
333
+
334
+ '''Get Joint Rotation Representation'''
335
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
336
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
337
+
338
+ '''Get Joint Rotation Invariant Position Represention'''
339
+ # (seq_len, (joints_num-1)*3) local joint position
340
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
341
+
342
+ '''Get Joint Velocity Representation'''
343
+ # (seq_len-1, joints_num*3)
344
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
345
+ global_positions[1:] - global_positions[:-1])
346
+ local_vel = local_vel.reshape(len(local_vel), -1)
347
+
348
+ data = root_data
349
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
350
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
351
+ # print(dataset.shape, local_vel.shape)
352
+ data = np.concatenate([data, local_vel], axis=-1)
353
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
354
+
355
+ return data, global_positions, positions, l_velocity
356
+
357
+
358
+ # Recover global angle and positions for rotation dataset
359
+ # root_rot_velocity (B, seq_len, 1)
360
+ # root_linear_velocity (B, seq_len, 2)
361
+ # root_y (B, seq_len, 1)
362
+ # ric_data (B, seq_len, (joint_num - 1)*3)
363
+ # rot_data (B, seq_len, (joint_num - 1)*6)
364
+ # local_velocity (B, seq_len, joint_num*3)
365
+ # foot contact (B, seq_len, 4)
366
+ def recover_root_rot_pos(data):
367
+ rot_vel = data[..., 0]
368
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
369
+ '''Get Y-axis rotation from rotation velocity'''
370
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
371
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
372
+
373
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
374
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
375
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
376
+
377
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
378
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
379
+ '''Add Y-axis rotation to root position'''
380
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
381
+
382
+ r_pos = torch.cumsum(r_pos, dim=-2)
383
+
384
+ r_pos[..., 1] = data[..., 3]
385
+ return r_rot_quat, r_pos
386
+
387
+
388
+ def recover_root_rot_heading_ang(joints):
389
+
390
+ '''Get Forward Direction'''
391
+ face_joint_idx = [2, 1, 17, 16]
392
+ # l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
393
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_idx # Note the bugfix
394
+ across1 = joints[:, r_hip] - joints[:, l_hip]
395
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
396
+ across = across1 + across2
397
+ across = torch.nn.functional.normalize(across, dim=1)
398
+ # print(across1.shape, across2.shape)
399
+
400
+ # forward (batch_size, 3)
401
+ forward = torch.cross(torch.tensor([[[0], [1], [0]]], dtype=across.dtype, device=across.device), across, axis=1)
402
+ forward = torch.nn.functional.normalize(forward, dim=1)
403
+
404
+ return torch.atan2(forward[:, 0], forward[:, 2])[:, None]
405
+
406
+ def recover_from_rot(data, joints_num, skeleton):
407
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
408
+
409
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
410
+
411
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
412
+ end_indx = start_indx + (joints_num - 1) * 6
413
+ cont6d_params = data[..., start_indx:end_indx]
414
+ # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
415
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
416
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
417
+
418
+ positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
419
+
420
+ return positions
421
+
422
+ def recover_rot(data):
423
+ # dataset [bs, seqlen, 263/251] HumanML/KIT
424
+ joints_num = 22 if data.shape[-1] == 263 else 21
425
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
426
+ r_pos_pad = torch.cat([r_pos, torch.zeros_like(r_pos)], dim=-1).unsqueeze(-2)
427
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
428
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
429
+ end_indx = start_indx + (joints_num - 1) * 6
430
+ cont6d_params = data[..., start_indx:end_indx]
431
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
432
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
433
+ cont6d_params = torch.cat([cont6d_params, r_pos_pad], dim=-2)
434
+ return cont6d_params
435
+
436
+
437
+ def recover_from_ric(data, joints_num):
438
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
439
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
440
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
441
+
442
+ '''Add Y-axis rotation to local joints'''
443
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
444
+
445
+ '''Add root XZ to joints'''
446
+ positions[..., 0] += r_pos[..., 0:1]
447
+ positions[..., 2] += r_pos[..., 2:3]
448
+
449
+ '''Concate root and joints'''
450
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
451
+
452
+ return positions
453
+ '''
454
+ For Text2Motion Dataset
455
+ '''
456
+ '''
457
+ if __name__ == "__main__":
458
+ example_id = "000021"
459
+ # Lower legs
460
+ l_idx1, l_idx2 = 5, 8
461
+ # Right/Left foot
462
+ fid_r, fid_l = [8, 11], [7, 10]
463
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
464
+ face_joint_indx = [2, 1, 17, 16]
465
+ # l_hip, r_hip
466
+ r_hip, l_hip = 2, 1
467
+ joints_num = 22
468
+ # ds_num = 8
469
+ data_dir = '../dataset/pose_data_raw/joints/'
470
+ save_dir1 = '../dataset/pose_data_raw/new_joints/'
471
+ save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/'
472
+
473
+ n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
474
+ kinematic_chain = t2m_kinematic_chain
475
+
476
+ # Get offsets of target skeleton
477
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
478
+ example_data = example_data.reshape(len(example_data), -1, 3)
479
+ example_data = torch.from_numpy(example_data)
480
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
481
+ # (joints_num, 3)
482
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
483
+ # print(tgt_offsets)
484
+
485
+ source_list = os.listdir(data_dir)
486
+ frame_num = 0
487
+ for source_file in tqdm(source_list):
488
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
489
+ try:
490
+ dataset, ground_positions, positions, l_velocity = process_file(source_data, 0.002)
491
+ rec_ric_data = recover_from_ric(torch.from_numpy(dataset).unsqueeze(0).float(), joints_num)
492
+ np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy())
493
+ np.save(pjoin(save_dir2, source_file), dataset)
494
+ frame_num += dataset.shape[0]
495
+ except Exception as e:
496
+ print(source_file)
497
+ print(e)
498
+
499
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
500
+ (len(source_list), frame_num, frame_num / 20 / 60))
501
+ '''
502
+
503
+ if __name__ == "__main__":
504
+ example_id = "03950_gt"
505
+ # Lower legs
506
+ l_idx1, l_idx2 = 17, 18
507
+ # Right/Left foot
508
+ fid_r, fid_l = [14, 15], [19, 20]
509
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
510
+ face_joint_indx = [11, 16, 5, 8]
511
+ # l_hip, r_hip
512
+ r_hip, l_hip = 11, 16
513
+ joints_num = 21
514
+ # ds_num = 8
515
+ data_dir = '../dataset/kit_mocap_dataset/joints/'
516
+ save_dir1 = '../dataset/kit_mocap_dataset/new_joints/'
517
+ save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/'
518
+
519
+ n_raw_offsets = torch.from_numpy(kit_raw_offsets)
520
+ kinematic_chain = kit_kinematic_chain
521
+
522
+ '''Get offsets of target skeleton'''
523
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
524
+ example_data = example_data.reshape(len(example_data), -1, 3)
525
+ example_data = torch.from_numpy(example_data)
526
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
527
+ # (joints_num, 3)
528
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
529
+ # print(tgt_offsets)
530
+
531
+ source_list = os.listdir(data_dir)
532
+ frame_num = 0
533
+ '''Read source dataset'''
534
+ for source_file in tqdm(source_list):
535
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
536
+ try:
537
+ name = ''.join(source_file[:-7].split('_')) + '.npy'
538
+ data, ground_positions, positions, l_velocity = process_file(source_data, 0.05)
539
+ rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
540
+ if np.isnan(rec_ric_data.numpy()).any():
541
+ print(source_file)
542
+ continue
543
+ np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy())
544
+ np.save(pjoin(save_dir2, name), data)
545
+ frame_num += data.shape[0]
546
+ except Exception as e:
547
+ print(source_file)
548
+ print(e)
549
+
550
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
551
+ (len(source_list), frame_num, frame_num / 12.5 / 60))
552
+
553
+
554
+ def traj_global2vel(traj_positions, traj_rot):
555
+
556
+ # traj_positions [bs, 2 (x,z), seqlen]
557
+ # traj_positions [bs, 1 (z+, rad), seqlen]
558
+ # return first 3 hml enries [bs, 3, seqlen-1]
559
+
560
+ # skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
561
+ # # (seq_len, joints_num, 4)
562
+ # quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
563
+
564
+ bs, _, seqlen = traj_positions.shape
565
+ traj_positions = traj_positions.permute(0, 2, 1)
566
+ euler = torch.zeros([bs, 3, seqlen], dtype=traj_rot.dtype, device=traj_rot.device)
567
+ euler[:, 1:2] = traj_rot
568
+ euler = euler.permute(0, 2, 1).contiguous()
569
+ traj_rot_quat = euler2quat(euler, 'yxz', deg=False)
570
+
571
+ # '''Quaternion to continuous 6D'''
572
+ # cont_6d_params = quaternion_to_cont6d_np(quat_params)
573
+ # # (seq_len, 4)
574
+ r_rot = traj_rot_quat.clone()
575
+ # print(r_rot[0])
576
+ '''Root Linear Velocity'''
577
+ # (seq_len - 1, 3)
578
+ velocity = torch.zeros_like(euler[:, 1:, :])
579
+ velocity[:, :, [0,2]] = (traj_positions[:, 1:, :] - traj_positions[:, :-1, :]).clone()
580
+ # print(r_rot.shape, velocity.shape)
581
+ velocity = qrot(r_rot[:, 1:], velocity)
582
+ '''Root Angular Velocity'''
583
+ # (seq_len - 1, 4)
584
+ r_velocity = qmul(r_rot[:, 1:].contiguous(), qinv(r_rot[:, :-1]))
585
+ # (seq_len, joints_num, 4)
586
+
587
+ r_velocity = torch.arcsin(r_velocity[:, :, 2:3])
588
+ l_velocity = velocity[:, :, [0, 2]]
589
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
590
+ root_data = torch.cat([r_velocity, l_velocity], axis=-1).permute(0, 2, 1)[:, :, None]
591
+
592
+ return root_data
593
+
594
+ def get_target_location(motion, mean, std, lengths, joints_num, all_goal_joint_names, target_joint_names, is_heading):
595
+ assert (lengths == lengths[0]).all(), 'currently supporting only fixed length'
596
+ batch_size = motion.shape[0]
597
+ extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading'] # todo: fix hardcoded indexing that assumes traj and heading are last
598
+
599
+ # output tensor
600
+ target_loc = torch.zeros((batch_size, len(extended_goal_joint_names), 3, lengths[0]), dtype=motion.dtype, device=motion.device) # n_samples x (n_target_joints+1) x 3 x n_frames
601
+
602
+ # hml to abs loc (all joints, not only the requested ones)
603
+ joints_loc = hml_to_abs_loc(motion, mean, std, joints_num)
604
+ pelvis_loc = HML_JOINT_NAMES.index('pelvis')
605
+ joints_loc = torch.concat([joints_loc, joints_loc[:, pelvis_loc:pelvis_loc+1]], dim=1) # concatenate the pelvis location to be used for traj
606
+
607
+ # joint names to indices
608
+ HML_JOINT_NAMES_w_traj = HML_JOINT_NAMES + ['traj']
609
+ for sample_idx in range(batch_size):
610
+ req_joint_idx_in = [HML_JOINT_NAMES_w_traj.index(name) for name in target_joint_names[sample_idx]]
611
+ req_joint_idx_out = [extended_goal_joint_names.index(name) for name in target_joint_names[sample_idx]]
612
+
613
+ target_loc[sample_idx, req_joint_idx_out] = joints_loc[sample_idx, req_joint_idx_in] # assign joints loc to output tensor
614
+
615
+ target_loc[:, -2, 1] = 0 # zero the y axis for the trajectory
616
+
617
+ # last entry is the heading
618
+ heading = recover_root_rot_heading_ang(joints_loc)
619
+ target_loc[:, -1:, 0][is_heading] = heading[is_heading]
620
+
621
+ return target_loc[..., -1] # return last frame only
622
+
623
+
624
+ def hml_to_abs_loc(motion, mean, std, joints_num):
625
+ # hml to abs loc (all joints, not only the requested ones)
626
+ unnormed_motion = (motion * std + mean).permute(0, 2, 3, 1).float()
627
+ joints_loc = recover_from_ric(unnormed_motion, joints_num)
628
+ joints_loc = joints_loc.view(-1, *joints_loc.shape[2:]).permute(0, 2, 3, 1) # n_samples x n_joints x 3 x n_frames
629
+ return joints_loc
630
+
631
+
632
+ def sample_goal(batch_size, device, force_joints=None):
633
+ if force_joints is None:
634
+ choices = np.array(['None', 'traj', 'pelvis'] + HML_EE_JOINT_NAMES) # todo: fix hardcoded 'pelvis' ('traj' is ok because it's our convention)
635
+ none_prob = 0.5 # todo: maybe convert to an argument
636
+ probabilities = torch.ones(len(choices)) * (1-none_prob) / (len(choices) -1)
637
+ probabilities[0] = none_prob # None's probability
638
+ assert probabilities.sum() - 1 < 1e-6, 'probabilities should sum to 1'
639
+ max_goal_joints_per_sample = 2
640
+ # target_cond_idx = torch.randint(low=0, high=len(choices), size=(batch_size,max_goal_joints_per_sample))
641
+ target_cond_idx = torch.multinomial(probabilities, max_goal_joints_per_sample * batch_size, replacement=True).view(batch_size, max_goal_joints_per_sample)
642
+ names = choices[target_cond_idx]
643
+ names = np.array([np.unique(name) for name in names])
644
+ names = np.array([np.delete(name, np.argwhere(name=='None')) for name in names])
645
+ is_heading = torch.bernoulli(torch.ones(batch_size, device=device) * .5).to(bool)
646
+ else:
647
+ options = get_allowed_joint_options(force_joints)
648
+ names = [copy(random.choice(options)) for _ in range(batch_size)]
649
+ is_heading = torch.zeros(batch_size, device=device).to(bool)
650
+ for i, n in enumerate(names):
651
+ if 'heading' in n:
652
+ is_heading[i] = True
653
+ del n[n.index('heading')]
654
+ return names, is_heading
655
+
656
+ def get_allowed_joint_options(config_name):
657
+ if config_name == 'DIMP_FULL':
658
+ return [['pelvis', 'heading'], ['pelvis', 'head'], ['traj', 'heading'], ['right_wrist', 'heading'], ['left_wrist', 'heading'], ['right_foot', 'heading'], ['left_foot', 'heading']]
659
+ elif config_name == 'DIMP_FINAL':
660
+ return [['pelvis', 'heading'], ['traj', 'heading'], ['right_wrist', 'heading'], ['left_wrist', 'heading'], ['right_foot', 'heading'], ['left_foot', 'heading'], []]
661
+ elif config_name == 'DIMP_SLIM':
662
+ return [['pelvis', 'heading'], ['pelvis', 'head'], ['traj', 'heading'], ['left_wrist', 'heading'], ['left_foot', 'heading']]
663
+ elif config_name == 'DIMP_BENCH':
664
+ return [['pelvis', 'heading'], ['pelvis', 'head']]
665
+ elif config_name == 'PURE_T2M':
666
+ return [[]]
667
+ else:
668
+ return [config_name.split(',')]
669
+
motion_diffusion_model/data_loaders/humanml/utils/get_opt.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+ import re
4
+ from os.path import join as pjoin
5
+ from data_loaders.humanml.utils.word_vectorizer import POS_enumerator
6
+
7
+
8
+ def is_float(numStr):
9
+ flag = False
10
+ numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号
11
+ try:
12
+ reg = re.compile(r'^[-+]?[0-9]+\.[0-9]+$')
13
+ res = reg.match(str(numStr))
14
+ if res:
15
+ flag = True
16
+ except Exception as ex:
17
+ print("is_float() - error: " + str(ex))
18
+ return flag
19
+
20
+
21
+ def is_number(numStr):
22
+ flag = False
23
+ numStr = str(numStr).strip().lstrip('-').lstrip('+') # 去除正数(+)、负数(-)符号
24
+ if str(numStr).isdigit():
25
+ flag = True
26
+ return flag
27
+
28
+
29
+ def get_opt(opt_path, device):
30
+ opt = Namespace()
31
+ opt_dict = vars(opt)
32
+
33
+ skip = ('-------------- End ----------------',
34
+ '------------ Options -------------',
35
+ '\n')
36
+ print('Reading', opt_path)
37
+ with open(opt_path) as f:
38
+ for line in f:
39
+ if line.strip() not in skip:
40
+ # print(line.strip())
41
+ key, value = line.strip().split(': ')
42
+ if value in ('True', 'False'):
43
+ opt_dict[key] = bool(value)
44
+ elif is_float(value):
45
+ opt_dict[key] = float(value)
46
+ elif is_number(value):
47
+ opt_dict[key] = int(value)
48
+ else:
49
+ opt_dict[key] = str(value)
50
+
51
+ # print(opt)
52
+ opt_dict['which_epoch'] = 'latest'
53
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
54
+ opt.model_dir = pjoin(opt.save_root, 'model')
55
+ opt.meta_dir = pjoin(opt.save_root, 'meta')
56
+
57
+ if opt.dataset_name == 't2m':
58
+ opt.data_root = './dataset/HumanML3D'
59
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
60
+ opt.text_dir = pjoin(opt.data_root, 'texts')
61
+ opt.joints_num = 22
62
+ opt.dim_pose = 263
63
+ opt.max_motion_length = 196
64
+ elif opt.dataset_name == 'kit':
65
+ opt.data_root = './dataset/KIT-ML'
66
+ opt.motion_dir = pjoin(opt.data_root, 'new_joint_vecs')
67
+ opt.text_dir = pjoin(opt.data_root, 'texts')
68
+ opt.joints_num = 21
69
+ opt.dim_pose = 251
70
+ opt.max_motion_length = 196
71
+ else:
72
+ raise KeyError('Dataset not recognized')
73
+
74
+ opt.dim_word = 300
75
+ opt.num_classes = 200 // opt.unit_length
76
+ opt.dim_pos_ohot = len(POS_enumerator)
77
+ opt.is_train = False
78
+ opt.is_continue = False
79
+ opt.device = device
80
+
81
+ return opt
motion_diffusion_model/data_loaders/humanml/utils/metrics.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+
5
+ # (X - X_train)*(X - X_train) = -2X*X_train + X*X + X_train*X_train
6
+ def euclidean_distance_matrix(matrix1, matrix2):
7
+ """
8
+ Params:
9
+ -- matrix1: N1 x D
10
+ -- matrix2: N2 x D
11
+ Returns:
12
+ -- dist: N1 x N2
13
+ dist[i, j] == distance(matrix1[i], matrix2[j])
14
+ """
15
+ assert matrix1.shape[1] == matrix2.shape[1]
16
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
17
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
18
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
19
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
20
+ return dists
21
+
22
+ def calculate_top_k(mat, top_k):
23
+ size = mat.shape[0]
24
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
25
+ bool_mat = (mat == gt_mat)
26
+ correct_vec = False
27
+ top_k_list = []
28
+ for i in range(top_k):
29
+ # print(correct_vec, bool_mat[:, i])
30
+ correct_vec = (correct_vec | bool_mat[:, i])
31
+ # print(correct_vec)
32
+ top_k_list.append(correct_vec[:, None])
33
+ top_k_mat = np.concatenate(top_k_list, axis=1)
34
+ return top_k_mat
35
+
36
+
37
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
38
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
39
+ argmax = np.argsort(dist_mat, axis=1)
40
+ top_k_mat = calculate_top_k(argmax, top_k)
41
+ if sum_all:
42
+ return top_k_mat.sum(axis=0)
43
+ else:
44
+ return top_k_mat
45
+
46
+
47
+ def calculate_matching_score(embedding1, embedding2, sum_all=False):
48
+ assert len(embedding1.shape) == 2
49
+ assert embedding1.shape[0] == embedding2.shape[0]
50
+ assert embedding1.shape[1] == embedding2.shape[1]
51
+
52
+ dist = linalg.norm(embedding1 - embedding2, axis=1)
53
+ if sum_all:
54
+ return dist.sum(axis=0)
55
+ else:
56
+ return dist
57
+
58
+
59
+
60
+ def calculate_activation_statistics(activations):
61
+ """
62
+ Params:
63
+ -- activation: num_samples x dim_feat
64
+ Returns:
65
+ -- mu: dim_feat
66
+ -- sigma: dim_feat x dim_feat
67
+ """
68
+ mu = np.mean(activations, axis=0)
69
+ cov = np.cov(activations, rowvar=False)
70
+ return mu, cov
71
+
72
+
73
+ def calculate_diversity(activation, diversity_times):
74
+ assert len(activation.shape) == 2
75
+ assert activation.shape[0] > diversity_times
76
+ num_samples = activation.shape[0]
77
+
78
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
79
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
80
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
81
+ return dist.mean()
82
+
83
+
84
+ def calculate_multimodality(activation, multimodality_times):
85
+ assert len(activation.shape) == 3
86
+ assert activation.shape[1] > multimodality_times
87
+ num_per_sent = activation.shape[1]
88
+
89
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
90
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
91
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
92
+ return dist.mean()
93
+
94
+
95
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
96
+ """Numpy implementation of the Frechet Distance.
97
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
98
+ and X_2 ~ N(mu_2, C_2) is
99
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
100
+ Stable version by Dougal J. Sutherland.
101
+ Params:
102
+ -- mu1 : Numpy array containing the activations of a layer of the
103
+ inception net (like returned by the function 'get_predictions')
104
+ for generated samples.
105
+ -- mu2 : The sample mean over activations, precalculated on an
106
+ representative dataset set.
107
+ -- sigma1: The covariance matrix over activations for generated samples.
108
+ -- sigma2: The covariance matrix over activations, precalculated on an
109
+ representative dataset set.
110
+ Returns:
111
+ -- : The Frechet Distance.
112
+ """
113
+
114
+ mu1 = np.atleast_1d(mu1)
115
+ mu2 = np.atleast_1d(mu2)
116
+
117
+ sigma1 = np.atleast_2d(sigma1)
118
+ sigma2 = np.atleast_2d(sigma2)
119
+
120
+ assert mu1.shape == mu2.shape, \
121
+ 'Training and test mean vectors have different lengths'
122
+ assert sigma1.shape == sigma2.shape, \
123
+ 'Training and test covariances have different dimensions'
124
+
125
+ diff = mu1 - mu2
126
+
127
+ # Product might be almost singular
128
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
129
+ if not np.isfinite(covmean).all():
130
+ msg = ('fid calculation produces singular product; '
131
+ 'adding %s to diagonal of cov estimates') % eps
132
+ print(msg)
133
+ offset = np.eye(sigma1.shape[0]) * eps
134
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
135
+
136
+ # Numerical error might give slight imaginary component
137
+ if np.iscomplexobj(covmean):
138
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
139
+ m = np.max(np.abs(covmean.imag))
140
+ raise ValueError('Imaginary component {}'.format(m))
141
+ covmean = covmean.real
142
+
143
+ tr_covmean = np.trace(covmean)
144
+
145
+ return (diff.dot(diff) + np.trace(sigma1) +
146
+ np.trace(sigma2) - 2 * tr_covmean)
motion_diffusion_model/data_loaders/humanml/utils/paramUtil.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal struture
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0,0,0],
33
+ [1,0,0],
34
+ [-1,0,0],
35
+ [0,1,0],
36
+ [0,-1,0],
37
+ [0,-1,0],
38
+ [0,1,0],
39
+ [0,-1,0],
40
+ [0,-1,0],
41
+ [0,1,0],
42
+ [0,0,1],
43
+ [0,0,1],
44
+ [0,1,0],
45
+ [1,0,0],
46
+ [-1,0,0],
47
+ [0,0,1],
48
+ [0,-1,0],
49
+ [0,-1,0],
50
+ [0,-1,0],
51
+ [0,-1,0],
52
+ [0,-1,0],
53
+ [0,-1,0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58
+
59
+
60
+ kit_tgt_skel_id = '03950'
61
+
62
+ t2m_tgt_skel_id = '000021'
63
+
motion_diffusion_model/data_loaders/humanml/utils/plot_script.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from matplotlib.animation import FuncAnimation, FFMpegFileWriter
7
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8
+ import mpl_toolkits.mplot3d.axes3d as p3
9
+ # import cv2
10
+ from textwrap import wrap
11
+ from moviepy.editor import VideoClip
12
+ from moviepy.video.io.bindings import mplfig_to_npimage
13
+
14
+ def list_cut_average(ll, intervals):
15
+ if intervals == 1:
16
+ return ll
17
+
18
+ bins = math.ceil(len(ll) * 1.0 / intervals)
19
+ ll_new = []
20
+ for i in range(bins):
21
+ l_low = intervals * i
22
+ l_high = l_low + intervals
23
+ l_high = l_high if l_high < len(ll) else len(ll)
24
+ ll_new.append(np.mean(ll[l_low:l_high]))
25
+ return ll_new
26
+
27
+
28
+ def plot_3d_motion(save_path, kinematic_tree, joints, title, dataset, figsize=(3, 3), fps=120, radius=3,
29
+ vis_mode='default', gt_frames=[]):
30
+ matplotlib.use('Agg')
31
+
32
+ title_per_frame = type(title) == list
33
+ if title_per_frame:
34
+ assert len(title) == len(joints), 'Title length should match the number of frames'
35
+ title = ['\n'.join(wrap(s, 20)) for s in title]
36
+ else:
37
+ title = '\n'.join(wrap(title, 20))
38
+
39
+ def init():
40
+ ax.set_xlim3d([-radius / 2, radius / 2])
41
+ ax.set_ylim3d([0, radius])
42
+ ax.set_zlim3d([-radius / 3., radius * 2 / 3.])
43
+ # print(title)
44
+ # fig.suptitle(title, fontsize=10) # Using dynamic title instead
45
+ ax.grid(b=False)
46
+
47
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
48
+ ## Plot a plane XZ
49
+ verts = [
50
+ [minx, miny, minz],
51
+ [minx, miny, maxz],
52
+ [maxx, miny, maxz],
53
+ [maxx, miny, minz]
54
+ ]
55
+ xz_plane = Poly3DCollection([verts])
56
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
57
+ ax.add_collection3d(xz_plane)
58
+
59
+ # return ax
60
+
61
+ # (seq_len, joints_num, 3)
62
+ data = joints.copy().reshape(len(joints), -1, 3)
63
+
64
+ # preparation related to specific datasets
65
+ if dataset == 'kit':
66
+ data *= 0.003 # scale for visualization
67
+ elif dataset == 'humanml':
68
+ data *= 1.3 # scale for visualization
69
+ elif dataset in ['humanact12', 'uestc']:
70
+ data *= -1.5 # reverse axes, scale for visualization
71
+
72
+ fig = plt.figure(figsize=figsize)
73
+ plt.tight_layout()
74
+ ax = p3.Axes3D(fig)
75
+ init()
76
+ MINS = data.min(axis=0).min(axis=0)
77
+ MAXS = data.max(axis=0).max(axis=0)
78
+ colors_blue = ["#4D84AA", "#5B9965", "#61CEB9", "#34C1E2", "#80B79A"] # GT color
79
+ colors_orange = ["#DD5A37", "#D69E00", "#B75A39", "#FF6D00", "#DDB50E"] # Generation color
80
+ colors = colors_orange
81
+ if vis_mode == 'upper_body': # lower body taken fixed to input motion
82
+ colors[0] = colors_blue[0]
83
+ colors[1] = colors_blue[1]
84
+ elif vis_mode == 'gt':
85
+ colors = colors_blue
86
+
87
+ n_frames = data.shape[0]
88
+ # print(dataset.shape)
89
+
90
+ height_offset = MINS[1]
91
+ data[:, :, 1] -= height_offset
92
+ trajec = data[:, 0, [0, 2]] # memorize original x,z pelvis values
93
+
94
+ # locate x,z pelvis values of ** each frame ** at zero
95
+ data[..., 0] -= data[:, 0:1, 0]
96
+ data[..., 2] -= data[:, 0:1, 2]
97
+
98
+ # print(trajec.shape)
99
+
100
+ def update(index):
101
+ # sometimes index is equal to n_frames/fps due to floating point issues. in such case, we duplicate the last frame
102
+ index = min(n_frames-1, int(index*fps))
103
+ ax.clear()
104
+ ax.view_init(elev=120, azim=-90)
105
+ ax.dist = 7.5
106
+
107
+ # Dynamic title
108
+ if title_per_frame:
109
+ _title = title[index]
110
+ else:
111
+ _title = title
112
+ _title += f' [{index}]'
113
+ fig.suptitle(_title, fontsize=10)
114
+
115
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
116
+ MAXS[2] - trajec[index, 1])
117
+
118
+ used_colors = colors_blue if index in gt_frames else colors
119
+ for i, (chain, color) in enumerate(zip(kinematic_tree, used_colors)):
120
+ if i < 5:
121
+ linewidth = 4.0
122
+ else:
123
+ linewidth = 2.0
124
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
125
+ color=color)
126
+ # print(trajec[:index, 0].shape)
127
+
128
+ plt.axis('off')
129
+ ax.set_axis_off()
130
+ ax.set_xticklabels([])
131
+ ax.set_yticklabels([])
132
+ ax.set_zticklabels([])
133
+
134
+ # Hide grid lines
135
+ ax.grid(False)
136
+
137
+ # Hide axes ticks
138
+ ax.set_xticks([])
139
+ ax.set_yticks([])
140
+ ax.set_zticks([])
141
+
142
+
143
+ return mplfig_to_npimage(fig)
144
+
145
+ ani = VideoClip(update)
146
+
147
+ plt.close()
148
+ return ani
motion_diffusion_model/data_loaders/humanml/utils/utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ # import cv2
4
+ from PIL import Image
5
+ from data_loaders.humanml.utils import paramUtil
6
+ import math
7
+ import time
8
+ import matplotlib.pyplot as plt
9
+ from scipy.ndimage import gaussian_filter
10
+
11
+
12
+ def mkdir(path):
13
+ if not os.path.exists(path):
14
+ os.makedirs(path)
15
+
16
+ COLORS = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
17
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
18
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
19
+
20
+ MISSING_VALUE = -1
21
+
22
+ def save_image(image_numpy, image_path):
23
+ img_pil = Image.fromarray(image_numpy)
24
+ img_pil.save(image_path)
25
+
26
+
27
+ def save_logfile(log_loss, save_path):
28
+ with open(save_path, 'wt') as f:
29
+ for k, v in log_loss.items():
30
+ w_line = k
31
+ for digit in v:
32
+ w_line += ' %.3f' % digit
33
+ f.write(w_line + '\n')
34
+
35
+
36
+ def print_current_loss(start_time, niter_state, losses, epoch=None, sub_epoch=None,
37
+ inner_iter=None, tf_ratio=None, sl_steps=None):
38
+
39
+ def as_minutes(s):
40
+ m = math.floor(s / 60)
41
+ s -= m * 60
42
+ return '%dm %ds' % (m, s)
43
+
44
+ def time_since(since, percent):
45
+ now = time.time()
46
+ s = now - since
47
+ es = s / percent
48
+ rs = es - s
49
+ return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
50
+
51
+ if epoch is not None:
52
+ print('epoch: %3d niter: %6d sub_epoch: %2d inner_iter: %4d' % (epoch, niter_state, sub_epoch, inner_iter), end=" ")
53
+
54
+ # message = '%s niter: %d completed: %3d%%)' % (time_since(start_time, niter_state / total_niters),
55
+ # niter_state, niter_state / total_niters * 100)
56
+ now = time.time()
57
+ message = '%s'%(as_minutes(now - start_time))
58
+
59
+ for k, v in losses.items():
60
+ message += ' %s: %.4f ' % (k, v)
61
+ message += ' sl_length:%2d tf_ratio:%.2f'%(sl_steps, tf_ratio)
62
+ print(message)
63
+
64
+ def print_current_loss_decomp(start_time, niter_state, total_niters, losses, epoch=None, inner_iter=None):
65
+
66
+ def as_minutes(s):
67
+ m = math.floor(s / 60)
68
+ s -= m * 60
69
+ return '%dm %ds' % (m, s)
70
+
71
+ def time_since(since, percent):
72
+ now = time.time()
73
+ s = now - since
74
+ es = s / percent
75
+ rs = es - s
76
+ return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
77
+
78
+ print('epoch: %03d inner_iter: %5d' % (epoch, inner_iter), end=" ")
79
+ # now = time.time()
80
+ message = '%s niter: %07d completed: %3d%%)'%(time_since(start_time, niter_state / total_niters), niter_state, niter_state / total_niters * 100)
81
+ for k, v in losses.items():
82
+ message += ' %s: %.4f ' % (k, v)
83
+ print(message)
84
+
85
+
86
+ def compose_gif_img_list(img_list, fp_out, duration):
87
+ img, *imgs = [Image.fromarray(np.array(image)) for image in img_list]
88
+ img.save(fp=fp_out, format='GIF', append_images=imgs, optimize=False,
89
+ save_all=True, loop=0, duration=duration)
90
+
91
+
92
+ def save_images(visuals, image_path):
93
+ if not os.path.exists(image_path):
94
+ os.makedirs(image_path)
95
+
96
+ for i, (label, img_numpy) in enumerate(visuals.items()):
97
+ img_name = '%d_%s.jpg' % (i, label)
98
+ save_path = os.path.join(image_path, img_name)
99
+ save_image(img_numpy, save_path)
100
+
101
+
102
+ def save_images_test(visuals, image_path, from_name, to_name):
103
+ if not os.path.exists(image_path):
104
+ os.makedirs(image_path)
105
+
106
+ for i, (label, img_numpy) in enumerate(visuals.items()):
107
+ img_name = "%s_%s_%s" % (from_name, to_name, label)
108
+ save_path = os.path.join(image_path, img_name)
109
+ save_image(img_numpy, save_path)
110
+
111
+
112
+ def compose_and_save_img(img_list, save_dir, img_name, col=4, row=1, img_size=(256, 200)):
113
+ # print(col, row)
114
+ compose_img = compose_image(img_list, col, row, img_size)
115
+ if not os.path.exists(save_dir):
116
+ os.makedirs(save_dir)
117
+ img_path = os.path.join(save_dir, img_name)
118
+ # print(img_path)
119
+ compose_img.save(img_path)
120
+
121
+
122
+ def compose_image(img_list, col, row, img_size):
123
+ to_image = Image.new('RGB', (col * img_size[0], row * img_size[1]))
124
+ for y in range(0, row):
125
+ for x in range(0, col):
126
+ from_img = Image.fromarray(img_list[y * col + x])
127
+ # print((x * img_size[0], y*img_size[1],
128
+ # (x + 1) * img_size[0], (y + 1) * img_size[1]))
129
+ paste_area = (x * img_size[0], y*img_size[1],
130
+ (x + 1) * img_size[0], (y + 1) * img_size[1])
131
+ to_image.paste(from_img, paste_area)
132
+ # to_image[y*img_size[1]:(y + 1) * img_size[1], x * img_size[0] :(x + 1) * img_size[0]] = from_img
133
+ return to_image
134
+
135
+
136
+ def plot_loss_curve(losses, save_path, intervals=500):
137
+ plt.figure(figsize=(10, 5))
138
+ plt.title("Loss During Training")
139
+ for key in losses.keys():
140
+ plt.plot(list_cut_average(losses[key], intervals), label=key)
141
+ plt.xlabel("Iterations/" + str(intervals))
142
+ plt.ylabel("Loss")
143
+ plt.legend()
144
+ plt.savefig(save_path)
145
+ plt.show()
146
+
147
+
148
+ def list_cut_average(ll, intervals):
149
+ if intervals == 1:
150
+ return ll
151
+
152
+ bins = math.ceil(len(ll) * 1.0 / intervals)
153
+ ll_new = []
154
+ for i in range(bins):
155
+ l_low = intervals * i
156
+ l_high = l_low + intervals
157
+ l_high = l_high if l_high < len(ll) else len(ll)
158
+ ll_new.append(np.mean(ll[l_low:l_high]))
159
+ return ll_new
160
+
161
+
162
+ def motion_temporal_filter(motion, sigma=1):
163
+ motion = motion.reshape(motion.shape[0], -1)
164
+ for i in range(motion.shape[1]):
165
+ motion[:, i] = gaussian_filter(motion[:, i], sigma=sigma, mode="nearest")
166
+ return motion.reshape(motion.shape[0], -1, 3)
167
+
motion_diffusion_model/data_loaders/humanml/utils/word_vectorizer.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pickle
3
+ from os.path import join as pjoin
4
+
5
+ POS_enumerator = {
6
+ 'VERB': 0,
7
+ 'NOUN': 1,
8
+ 'DET': 2,
9
+ 'ADP': 3,
10
+ 'NUM': 4,
11
+ 'AUX': 5,
12
+ 'PRON': 6,
13
+ 'ADJ': 7,
14
+ 'ADV': 8,
15
+ 'Loc_VIP': 9,
16
+ 'Body_VIP': 10,
17
+ 'Obj_VIP': 11,
18
+ 'Act_VIP': 12,
19
+ 'Desc_VIP': 13,
20
+ 'OTHER': 14,
21
+ }
22
+
23
+ Loc_list = ('left', 'right', 'clockwise', 'counterclockwise', 'anticlockwise', 'forward', 'back', 'backward',
24
+ 'up', 'down', 'straight', 'curve')
25
+
26
+ Body_list = ('arm', 'chin', 'foot', 'feet', 'face', 'hand', 'mouth', 'leg', 'waist', 'eye', 'knee', 'shoulder', 'thigh')
27
+
28
+ Obj_List = ('stair', 'dumbbell', 'chair', 'window', 'floor', 'car', 'ball', 'handrail', 'baseball', 'basketball')
29
+
30
+ Act_list = ('walk', 'run', 'swing', 'pick', 'bring', 'kick', 'put', 'squat', 'throw', 'hop', 'dance', 'jump', 'turn',
31
+ 'stumble', 'dance', 'stop', 'sit', 'lift', 'lower', 'raise', 'wash', 'stand', 'kneel', 'stroll',
32
+ 'rub', 'bend', 'balance', 'flap', 'jog', 'shuffle', 'lean', 'rotate', 'spin', 'spread', 'climb')
33
+
34
+ Desc_list = ('slowly', 'carefully', 'fast', 'careful', 'slow', 'quickly', 'happy', 'angry', 'sad', 'happily',
35
+ 'angrily', 'sadly')
36
+
37
+ VIP_dict = {
38
+ 'Loc_VIP': Loc_list,
39
+ 'Body_VIP': Body_list,
40
+ 'Obj_VIP': Obj_List,
41
+ 'Act_VIP': Act_list,
42
+ 'Desc_VIP': Desc_list,
43
+ }
44
+
45
+
46
+ class WordVectorizer(object):
47
+ def __init__(self, meta_root, prefix):
48
+ vectors = np.load(pjoin(meta_root, '%s_data.npy'%prefix))
49
+ words = pickle.load(open(pjoin(meta_root, '%s_words.pkl'%prefix), 'rb'))
50
+ word2idx = pickle.load(open(pjoin(meta_root, '%s_idx.pkl'%prefix), 'rb'))
51
+ self.word2vec = {w: vectors[word2idx[w]] for w in words}
52
+
53
+ def _get_pos_ohot(self, pos):
54
+ pos_vec = np.zeros(len(POS_enumerator))
55
+ if pos in POS_enumerator:
56
+ pos_vec[POS_enumerator[pos]] = 1
57
+ else:
58
+ pos_vec[POS_enumerator['OTHER']] = 1
59
+ return pos_vec
60
+
61
+ def __len__(self):
62
+ return len(self.word2vec)
63
+
64
+ def __getitem__(self, item):
65
+ word, pos = item.split('/')
66
+ if word in self.word2vec:
67
+ word_vec = self.word2vec[word]
68
+ vip_pos = None
69
+ for key, values in VIP_dict.items():
70
+ if word in values:
71
+ vip_pos = key
72
+ break
73
+ if vip_pos is not None:
74
+ pos_vec = self._get_pos_ohot(vip_pos)
75
+ else:
76
+ pos_vec = self._get_pos_ohot(pos)
77
+ else:
78
+ word_vec = self.word2vec['unk']
79
+ pos_vec = self._get_pos_ohot('OTHER')
80
+ return word_vec, pos_vec
motion_diffusion_model/data_loaders/humanml_utils.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ HML_JOINT_NAMES = [
4
+ 'pelvis',
5
+ 'left_hip',
6
+ 'right_hip',
7
+ 'spine1',
8
+ 'left_knee',
9
+ 'right_knee',
10
+ 'spine2',
11
+ 'left_ankle',
12
+ 'right_ankle',
13
+ 'spine3',
14
+ 'left_foot',
15
+ 'right_foot',
16
+ 'neck',
17
+ 'left_collar',
18
+ 'right_collar',
19
+ 'head',
20
+ 'left_shoulder',
21
+ 'right_shoulder',
22
+ 'left_elbow',
23
+ 'right_elbow',
24
+ 'left_wrist',
25
+ 'right_wrist',
26
+ ]
27
+
28
+ NUM_HML_JOINTS = len(HML_JOINT_NAMES) # 22 SMPLH body joints
29
+
30
+ HML_EE_JOINT_NAMES = ['left_foot', 'right_foot', 'left_wrist', 'right_wrist', 'head']
31
+ HML_LOWER_BODY_JOINTS = [HML_JOINT_NAMES.index(name) for name in ['pelvis', 'left_hip', 'right_hip', 'left_knee', 'right_knee', 'left_ankle', 'right_ankle', 'left_foot', 'right_foot',]]
32
+ SMPL_UPPER_BODY_JOINTS = [i for i in range(len(HML_JOINT_NAMES)) if i not in HML_LOWER_BODY_JOINTS]
33
+
34
+
35
+ # Recover global angle and positions for rotation data
36
+ # root_rot_velocity (B, seq_len, 1)
37
+ # root_linear_velocity (B, seq_len, 2)
38
+ # root_y (B, seq_len, 1)
39
+ # ric_data (B, seq_len, (joint_num - 1)*3)
40
+ # rot_data (B, seq_len, (joint_num - 1)*6)
41
+ # local_velocity (B, seq_len, joint_num*3)
42
+ # foot contact (B, seq_len, 4)
43
+ HML_ROOT_BINARY = np.array([True] + [False] * (NUM_HML_JOINTS-1))
44
+ HML_ROOT_MASK = np.concatenate(([True]*(1+2+1),
45
+ HML_ROOT_BINARY[1:].repeat(3),
46
+ HML_ROOT_BINARY[1:].repeat(6),
47
+ HML_ROOT_BINARY.repeat(3),
48
+ [False] * 4))
49
+ HML_ROOT_HORIZONTAL_MASK = np.concatenate(([True]*(1+2) + [False],
50
+ np.zeros_like(HML_ROOT_BINARY[1:].repeat(3)),
51
+ np.zeros_like(HML_ROOT_BINARY[1:].repeat(6)),
52
+ np.zeros_like(HML_ROOT_BINARY.repeat(3)),
53
+ [False] * 4))
54
+ HML_LOWER_BODY_JOINTS_BINARY = np.array([i in HML_LOWER_BODY_JOINTS for i in range(NUM_HML_JOINTS)])
55
+ HML_LOWER_BODY_MASK = np.concatenate(([True]*(1+2+1),
56
+ HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(3),
57
+ HML_LOWER_BODY_JOINTS_BINARY[1:].repeat(6),
58
+ HML_LOWER_BODY_JOINTS_BINARY.repeat(3),
59
+ [True]*4))
60
+ HML_UPPER_BODY_MASK = ~HML_LOWER_BODY_MASK
motion_diffusion_model/data_loaders/tensors.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def lengths_to_mask(lengths, max_len):
4
+ # max_len = max(lengths)
5
+ mask = torch.arange(max_len, device=lengths.device).expand(len(lengths), max_len) < lengths.unsqueeze(1)
6
+ return mask
7
+
8
+
9
+ def collate_tensors(batch):
10
+ dims = batch[0].dim()
11
+ max_size = [max([b.size(i) for b in batch]) for i in range(dims)]
12
+ size = (len(batch),) + tuple(max_size)
13
+ canvas = batch[0].new_zeros(size=size)
14
+ for i, b in enumerate(batch):
15
+ sub_tensor = canvas[i]
16
+ for d in range(dims):
17
+ sub_tensor = sub_tensor.narrow(d, 0, b.size(d))
18
+ sub_tensor.add_(b)
19
+ return canvas
20
+
21
+
22
+ def collate(batch):
23
+ notnone_batches = [b for b in batch if b is not None]
24
+ databatch = [b['inp'] for b in notnone_batches]
25
+ if 'lengths' in notnone_batches[0]:
26
+ lenbatch = [b['lengths'] for b in notnone_batches]
27
+ else:
28
+ lenbatch = [len(b['inp'][0][0]) for b in notnone_batches]
29
+
30
+
31
+ databatchTensor = collate_tensors(databatch)
32
+ lenbatchTensor = torch.as_tensor(lenbatch)
33
+ maskbatchTensor = lengths_to_mask(lenbatchTensor, databatchTensor.shape[-1]).unsqueeze(1).unsqueeze(1) # unqueeze for broadcasting
34
+
35
+ motion = databatchTensor
36
+ cond = {'y': {'mask': maskbatchTensor, 'lengths': lenbatchTensor}}
37
+
38
+ if 'text' in notnone_batches[0]:
39
+ textbatch = [b['text'] for b in notnone_batches]
40
+ cond['y'].update({'text': textbatch})
41
+
42
+ if 'tokens' in notnone_batches[0]:
43
+ textbatch = [b['tokens'] for b in notnone_batches]
44
+ cond['y'].update({'tokens': textbatch})
45
+
46
+ if 'action' in notnone_batches[0]:
47
+ actionbatch = [b['action'] for b in notnone_batches]
48
+ cond['y'].update({'action': torch.as_tensor(actionbatch).unsqueeze(1)})
49
+
50
+ # collate action textual names
51
+ if 'action_text' in notnone_batches[0]:
52
+ action_text = [b['action_text']for b in notnone_batches]
53
+ cond['y'].update({'action_text': action_text})
54
+
55
+ if 'prefix' in notnone_batches[0]:
56
+ cond['y'].update({'prefix': collate_tensors([b['prefix'] for b in notnone_batches])})
57
+
58
+ if 'orig_lengths' in notnone_batches[0]:
59
+ cond['y'].update({'orig_lengths': torch.as_tensor([b['orig_lengths'] for b in notnone_batches])})
60
+
61
+ if 'key' in notnone_batches[0]:
62
+ cond['y'].update({'db_key': [b['key'] for b in notnone_batches]})
63
+
64
+ return motion, cond
65
+
66
+ # an adapter to our collate func
67
+ def t2m_collate(batch, target_batch_size):
68
+ repeat_factor = -(-target_batch_size // len(batch)) # Ceiling division
69
+ repeated_batch = batch * repeat_factor
70
+ full_batch = repeated_batch[:target_batch_size] # Truncate to the target batch size
71
+ # batch.sort(key=lambda x: x[3], reverse=True)
72
+ adapted_batch = [{
73
+ 'inp': torch.tensor(b[4].T).float().unsqueeze(1), # [seqlen, J] -> [J, 1, seqlen]
74
+ 'text': b[2], #b[0]['caption']
75
+ 'tokens': b[6],
76
+ 'lengths': b[5],
77
+ 'key': b[7] if len(b) > 7 else None,
78
+ } for b in full_batch]
79
+ return collate(adapted_batch)
80
+
81
+
82
+ def t2m_prefix_collate(batch, pred_len):
83
+ # batch.sort(key=lambda x: x[3], reverse=True)
84
+ adapted_batch = [{
85
+ 'inp': torch.tensor(b[4].T).float().unsqueeze(1)[..., -pred_len:], # [seqlen, J] -> [J, 1, seqlen]
86
+ 'prefix': torch.tensor(b[4].T).float().unsqueeze(1)[..., :-pred_len],
87
+ 'text': b[2], #b[0]['caption']
88
+ 'tokens': b[6],
89
+ 'lengths': pred_len, # b[5],
90
+ 'orig_lengths': b[5][0], # For evaluation
91
+ 'key': b[7] if len(b) > 7 else None,
92
+ } for b in batch]
93
+ return collate(adapted_batch)
94
+
motion_diffusion_model/diffusion/fp16_util.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers to train with 16-bit precision.
3
+ """
4
+
5
+ import numpy as np
6
+ import torch as th
7
+ import torch.nn as nn
8
+ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
9
+
10
+ from diffusion import logger
11
+
12
+ INITIAL_LOG_LOSS_SCALE = 20.0
13
+
14
+
15
+ def convert_module_to_f16(l):
16
+ """
17
+ Convert primitive modules to float16.
18
+ """
19
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
20
+ l.weight.data = l.weight.data.half()
21
+ if l.bias is not None:
22
+ l.bias.data = l.bias.data.half()
23
+
24
+
25
+ def convert_module_to_f32(l):
26
+ """
27
+ Convert primitive modules to float32, undoing convert_module_to_f16().
28
+ """
29
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
30
+ l.weight.data = l.weight.data.float()
31
+ if l.bias is not None:
32
+ l.bias.data = l.bias.data.float()
33
+
34
+
35
+ def make_master_params(param_groups_and_shapes):
36
+ """
37
+ Copy model parameters into a (differently-shaped) list of full-precision
38
+ parameters.
39
+ """
40
+ master_params = []
41
+ for param_group, shape in param_groups_and_shapes:
42
+ master_param = nn.Parameter(
43
+ _flatten_dense_tensors(
44
+ [param.detach().float() for (_, param) in param_group]
45
+ ).view(shape)
46
+ )
47
+ master_param.requires_grad = True
48
+ master_params.append(master_param)
49
+ return master_params
50
+
51
+
52
+ def model_grads_to_master_grads(param_groups_and_shapes, master_params):
53
+ """
54
+ Copy the gradients from the model parameters into the master parameters
55
+ from make_master_params().
56
+ """
57
+ for master_param, (param_group, shape) in zip(
58
+ master_params, param_groups_and_shapes
59
+ ):
60
+ master_param.grad = _flatten_dense_tensors(
61
+ [param_grad_or_zeros(param) for (_, param) in param_group]
62
+ ).view(shape)
63
+
64
+
65
+ def master_params_to_model_params(param_groups_and_shapes, master_params):
66
+ """
67
+ Copy the master parameter data back into the model parameters.
68
+ """
69
+ # Without copying to a list, if a generator is passed, this will
70
+ # silently not copy any parameters.
71
+ for master_param, (param_group, _) in zip(master_params, param_groups_and_shapes):
72
+ for (_, param), unflat_master_param in zip(
73
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
74
+ ):
75
+ param.detach().copy_(unflat_master_param)
76
+
77
+
78
+ def unflatten_master_params(param_group, master_param):
79
+ return _unflatten_dense_tensors(master_param, [param for (_, param) in param_group])
80
+
81
+
82
+ def get_param_groups_and_shapes(named_model_params):
83
+ named_model_params = list(named_model_params)
84
+ scalar_vector_named_params = (
85
+ [(n, p) for (n, p) in named_model_params if p.ndim <= 1],
86
+ (-1),
87
+ )
88
+ matrix_named_params = (
89
+ [(n, p) for (n, p) in named_model_params if p.ndim > 1],
90
+ (1, -1),
91
+ )
92
+ return [scalar_vector_named_params, matrix_named_params]
93
+
94
+
95
+ def master_params_to_state_dict(
96
+ model, param_groups_and_shapes, master_params, use_fp16
97
+ ):
98
+ if use_fp16:
99
+ state_dict = model.state_dict()
100
+ for master_param, (param_group, _) in zip(
101
+ master_params, param_groups_and_shapes
102
+ ):
103
+ for (name, _), unflat_master_param in zip(
104
+ param_group, unflatten_master_params(param_group, master_param.view(-1))
105
+ ):
106
+ assert name in state_dict
107
+ state_dict[name] = unflat_master_param
108
+ else:
109
+ state_dict = model.state_dict()
110
+ for i, (name, _value) in enumerate(model.named_parameters()):
111
+ assert name in state_dict
112
+ state_dict[name] = master_params[i]
113
+ return state_dict
114
+
115
+
116
+ def state_dict_to_master_params(model, state_dict, use_fp16):
117
+ if use_fp16:
118
+ named_model_params = [
119
+ (name, state_dict[name]) for name, _ in model.named_parameters()
120
+ ]
121
+ param_groups_and_shapes = get_param_groups_and_shapes(named_model_params)
122
+ master_params = make_master_params(param_groups_and_shapes)
123
+ else:
124
+ master_params = [state_dict[name] for name, _ in model.named_parameters()]
125
+ return master_params
126
+
127
+
128
+ def zero_master_grads(master_params):
129
+ for param in master_params:
130
+ param.grad = None
131
+
132
+
133
+ def zero_grad(model_params):
134
+ for param in model_params:
135
+ # Taken from https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.add_param_group
136
+ if param.grad is not None:
137
+ param.grad.detach_()
138
+ param.grad.zero_()
139
+
140
+
141
+ def param_grad_or_zeros(param):
142
+ if param.grad is not None:
143
+ return param.grad.data.detach()
144
+ else:
145
+ return th.zeros_like(param)
146
+
147
+
148
+ class MixedPrecisionTrainer:
149
+ def __init__(
150
+ self,
151
+ *,
152
+ model,
153
+ use_fp16=False,
154
+ fp16_scale_growth=1e-3,
155
+ initial_lg_loss_scale=INITIAL_LOG_LOSS_SCALE,
156
+ ):
157
+ self.model = model
158
+ self.use_fp16 = use_fp16
159
+ self.fp16_scale_growth = fp16_scale_growth
160
+
161
+ self.model_params = list(self.model.parameters())
162
+ self.master_params = self.model_params
163
+ self.param_groups_and_shapes = None
164
+ self.lg_loss_scale = initial_lg_loss_scale
165
+
166
+ if self.use_fp16:
167
+ self.param_groups_and_shapes = get_param_groups_and_shapes(
168
+ self.model.named_parameters()
169
+ )
170
+ self.master_params = make_master_params(self.param_groups_and_shapes)
171
+ self.model.convert_to_fp16()
172
+
173
+ def zero_grad(self):
174
+ zero_grad(self.model_params)
175
+
176
+ def backward(self, loss: th.Tensor):
177
+ if self.use_fp16:
178
+ loss_scale = 2 ** self.lg_loss_scale
179
+ (loss * loss_scale).backward()
180
+ else:
181
+ loss.backward()
182
+
183
+ def optimize(self, opt: th.optim.Optimizer):
184
+ if self.use_fp16:
185
+ return self._optimize_fp16(opt)
186
+ else:
187
+ return self._optimize_normal(opt)
188
+
189
+ def _optimize_fp16(self, opt: th.optim.Optimizer):
190
+ logger.logkv_mean("lg_loss_scale", self.lg_loss_scale)
191
+ model_grads_to_master_grads(self.param_groups_and_shapes, self.master_params)
192
+ grad_norm, param_norm = self._compute_norms(grad_scale=2 ** self.lg_loss_scale)
193
+ if check_overflow(grad_norm):
194
+ self.lg_loss_scale -= 1
195
+ logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")
196
+ zero_master_grads(self.master_params)
197
+ return False
198
+
199
+ logger.logkv_mean("grad_norm", grad_norm)
200
+ logger.logkv_mean("param_norm", param_norm)
201
+
202
+ self.master_params[0].grad.mul_(1.0 / (2 ** self.lg_loss_scale))
203
+ opt.step()
204
+ zero_master_grads(self.master_params)
205
+ master_params_to_model_params(self.param_groups_and_shapes, self.master_params)
206
+ self.lg_loss_scale += self.fp16_scale_growth
207
+ return True
208
+
209
+ def _optimize_normal(self, opt: th.optim.Optimizer):
210
+ grad_norm, param_norm = self._compute_norms()
211
+ logger.logkv_mean("grad_norm", grad_norm)
212
+ logger.logkv_mean("param_norm", param_norm)
213
+ opt.step()
214
+ return True
215
+
216
+ def _compute_norms(self, grad_scale=1.0):
217
+ grad_norm = 0.0
218
+ param_norm = 0.0
219
+ for p in self.master_params:
220
+ with th.no_grad():
221
+ param_norm += th.norm(p, p=2, dtype=th.float32).item() ** 2
222
+ if p.grad is not None:
223
+ grad_norm += th.norm(p.grad, p=2, dtype=th.float32).item() ** 2
224
+ return np.sqrt(grad_norm) / grad_scale, np.sqrt(param_norm)
225
+
226
+ def master_params_to_state_dict(self, master_params):
227
+ return master_params_to_state_dict(
228
+ self.model, self.param_groups_and_shapes, master_params, self.use_fp16
229
+ )
230
+
231
+ def state_dict_to_master_params(self, state_dict):
232
+ return state_dict_to_master_params(self.model, state_dict, self.use_fp16)
233
+
234
+
235
+ def check_overflow(value):
236
+ return (value == float("inf")) or (value == -float("inf")) or (value != value)
motion_diffusion_model/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,1615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ """
3
+ This code started out as a PyTorch port of Ho et al's diffusion models:
4
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py
5
+
6
+ Docstrings have been added, as well as DDIM sampling and a new collection of beta schedules.
7
+ """
8
+
9
+ import enum
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+ import torch as th
15
+ from copy import deepcopy
16
+ from diffusion.nn import mean_flat, sum_flat
17
+ from diffusion.losses import normal_kl, discretized_gaussian_log_likelihood
18
+ from data_loaders.humanml.scripts import motion_process
19
+ from utils.loss_util import masked_l2, masked_goal_l2
20
+ from data_loaders.humanml.scripts.motion_process import get_target_location
21
+
22
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, scale_betas=1.):
23
+ """
24
+ Get a pre-defined beta schedule for the given name.
25
+
26
+ The beta schedule library consists of beta schedules which remain similar
27
+ in the limit of num_diffusion_timesteps.
28
+ Beta schedules may be added, but should not be removed or changed once
29
+ they are committed to maintain backwards compatibility.
30
+ """
31
+ if schedule_name == "linear":
32
+ # Linear schedule from Ho et al, extended to work for any number of
33
+ # diffusion steps.
34
+ scale = scale_betas * 1000 / num_diffusion_timesteps
35
+ beta_start = scale * 0.0001
36
+ beta_end = scale * 0.02
37
+ return np.linspace(
38
+ beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64
39
+ )
40
+ elif schedule_name == "cosine":
41
+ return betas_for_alpha_bar(
42
+ num_diffusion_timesteps,
43
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
44
+ )
45
+ else:
46
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
47
+
48
+
49
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
50
+ """
51
+ Create a beta schedule that discretizes the given alpha_t_bar function,
52
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
53
+
54
+ :param num_diffusion_timesteps: the number of betas to produce.
55
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
56
+ produces the cumulative product of (1-beta) up to that
57
+ part of the diffusion process.
58
+ :param max_beta: the maximum beta to use; use values lower than 1 to
59
+ prevent singularities.
60
+ """
61
+ betas = []
62
+ for i in range(num_diffusion_timesteps):
63
+ t1 = i / num_diffusion_timesteps
64
+ t2 = (i + 1) / num_diffusion_timesteps
65
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
66
+ return np.array(betas)
67
+
68
+
69
+ class ModelMeanType(enum.Enum):
70
+ """
71
+ Which type of output the model predicts.
72
+ """
73
+
74
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
75
+ START_X = enum.auto() # the model predicts x_0
76
+ EPSILON = enum.auto() # the model predicts epsilon
77
+
78
+
79
+ class ModelVarType(enum.Enum):
80
+ """
81
+ What is used as the model's output variance.
82
+
83
+ The LEARNED_RANGE option has been added to allow the model to predict
84
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
85
+ """
86
+
87
+ LEARNED = enum.auto()
88
+ FIXED_SMALL = enum.auto()
89
+ FIXED_LARGE = enum.auto()
90
+ LEARNED_RANGE = enum.auto()
91
+
92
+
93
+ class LossType(enum.Enum):
94
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
95
+ RESCALED_MSE = (
96
+ enum.auto()
97
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
98
+ KL = enum.auto() # use the variational lower-bound
99
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
100
+
101
+ def is_vb(self):
102
+ return self == LossType.KL or self == LossType.RESCALED_KL
103
+
104
+
105
+ class GaussianDiffusion:
106
+ """
107
+ Utilities for training and sampling diffusion models.
108
+
109
+ Ported directly from here, and then adapted over time to further experimentation.
110
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
111
+
112
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
113
+ starting at T and going to 1.
114
+ :param model_mean_type: a ModelMeanType determining what the model outputs.
115
+ :param model_var_type: a ModelVarType determining how variance is output.
116
+ :param loss_type: a LossType determining the loss function to use.
117
+ :param rescale_timesteps: if True, pass floating point timesteps into the
118
+ model so that they are always scaled like in the
119
+ original paper (0 to 1000).
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ *,
125
+ betas,
126
+ model_mean_type,
127
+ model_var_type,
128
+ loss_type,
129
+ rescale_timesteps=False,
130
+ lambda_rcxyz=0.,
131
+ lambda_vel=0.,
132
+ lambda_pose=1.,
133
+ lambda_orient=1.,
134
+ lambda_loc=1.,
135
+ data_rep='rot6d',
136
+ lambda_root_vel=0.,
137
+ lambda_vel_rcxyz=0.,
138
+ lambda_fc=0.,
139
+ lambda_target_loc=0.,
140
+ **kargs,
141
+ ):
142
+ self.model_mean_type = model_mean_type
143
+ self.model_var_type = model_var_type
144
+ self.loss_type = loss_type
145
+ self.rescale_timesteps = rescale_timesteps
146
+ self.data_rep = data_rep
147
+
148
+ if data_rep != 'rot_vel' and lambda_pose != 1.:
149
+ raise ValueError('lambda_pose is relevant only when training on velocities!')
150
+ self.lambda_pose = lambda_pose
151
+ self.lambda_orient = lambda_orient
152
+ self.lambda_loc = lambda_loc
153
+
154
+ self.lambda_rcxyz = lambda_rcxyz
155
+ self.lambda_target_loc = lambda_target_loc
156
+ self.lambda_vel = lambda_vel
157
+ self.lambda_root_vel = lambda_root_vel
158
+ self.lambda_vel_rcxyz = lambda_vel_rcxyz
159
+ self.lambda_fc = lambda_fc
160
+
161
+ if self.lambda_rcxyz > 0. or self.lambda_vel > 0. or self.lambda_root_vel > 0. or \
162
+ self.lambda_vel_rcxyz > 0. or self.lambda_fc > 0. or self.lambda_target_loc > 0.:
163
+ assert self.loss_type == LossType.MSE, 'Geometric losses are supported by MSE loss type only!'
164
+
165
+ # Use float64 for accuracy.
166
+ betas = np.array(betas, dtype=np.float64)
167
+ self.betas = betas
168
+ assert len(betas.shape) == 1, "betas must be 1-D"
169
+ assert (betas > 0).all() and (betas <= 1).all()
170
+
171
+ self.num_timesteps = int(betas.shape[0])
172
+
173
+ alphas = 1.0 - betas
174
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
175
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
176
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
177
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
178
+
179
+ # calculations for diffusion q(x_t | x_{t-1}) and others
180
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
181
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
182
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
183
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
184
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
185
+
186
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
187
+ self.posterior_variance = (
188
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
189
+ )
190
+ # log calculation clipped because the posterior variance is 0 at the
191
+ # beginning of the diffusion chain.
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ )
195
+ self.posterior_mean_coef1 = (
196
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
197
+ )
198
+ self.posterior_mean_coef2 = (
199
+ (1.0 - self.alphas_cumprod_prev)
200
+ * np.sqrt(alphas)
201
+ / (1.0 - self.alphas_cumprod)
202
+ )
203
+
204
+ # self.l2_loss = lambda a, b: (a - b) ** 2 # th.nn.MSELoss(reduction='none') # must be None for handling mask later on.
205
+ self.masked_l2 = masked_l2
206
+
207
+
208
+
209
+ def q_mean_variance(self, x_start, t):
210
+ """
211
+ Get the distribution q(x_t | x_0).
212
+
213
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
214
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
215
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
216
+ """
217
+ mean = (
218
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
219
+ )
220
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
221
+ log_variance = _extract_into_tensor(
222
+ self.log_one_minus_alphas_cumprod, t, x_start.shape
223
+ )
224
+ return mean, variance, log_variance
225
+
226
+ def q_sample(self, x_start, t, noise=None):
227
+ """
228
+ Diffuse the dataset for a given number of diffusion steps.
229
+
230
+ In other words, sample from q(x_t | x_0).
231
+
232
+ :param x_start: the initial dataset batch.
233
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
234
+ :param noise: if specified, the split-out normal noise.
235
+ :return: A noisy version of x_start.
236
+ """
237
+ if noise is None:
238
+ noise = th.randn_like(x_start)
239
+ assert noise.shape == x_start.shape
240
+ return (
241
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
242
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
243
+ * noise
244
+ )
245
+
246
+ def q_posterior_mean_variance(self, x_start, x_t, t):
247
+ """
248
+ Compute the mean and variance of the diffusion posterior:
249
+
250
+ q(x_{t-1} | x_t, x_0)
251
+
252
+ """
253
+ assert x_start.shape == x_t.shape
254
+ posterior_mean = (
255
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
256
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
257
+ )
258
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
259
+ posterior_log_variance_clipped = _extract_into_tensor(
260
+ self.posterior_log_variance_clipped, t, x_t.shape
261
+ )
262
+ assert (
263
+ posterior_mean.shape[0]
264
+ == posterior_variance.shape[0]
265
+ == posterior_log_variance_clipped.shape[0]
266
+ == x_start.shape[0]
267
+ )
268
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
269
+
270
+ def p_mean_variance(
271
+ self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None
272
+ ):
273
+ """
274
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
275
+ the initial x, x_0.
276
+
277
+ :param model: the model, which takes a signal and a batch of timesteps
278
+ as input.
279
+ :param x: the [N x C x ...] tensor at time t.
280
+ :param t: a 1-D Tensor of timesteps.
281
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
282
+ :param denoised_fn: if not None, a function which applies to the
283
+ x_start prediction before it is used to sample. Applies before
284
+ clip_denoised.
285
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
286
+ pass to the model. This can be used for conditioning.
287
+ :return: a dict with the following keys:
288
+ - 'mean': the model mean output.
289
+ - 'variance': the model variance output.
290
+ - 'log_variance': the log of 'variance'.
291
+ - 'pred_xstart': the prediction for x_0.
292
+ """
293
+ if model_kwargs is None:
294
+ model_kwargs = {}
295
+
296
+ B, C = x.shape[:2]
297
+ assert t.shape == (B,)
298
+ model_output = model(x, self._scale_timesteps(t), **model_kwargs)
299
+
300
+ if 'inpainting_mask' in model_kwargs['y'].keys() and 'inpainted_motion' in model_kwargs['y'].keys():
301
+ inpainting_mask, inpainted_motion = model_kwargs['y']['inpainting_mask'], model_kwargs['y']['inpainted_motion']
302
+ assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for mow!'
303
+ assert model_output.shape == inpainting_mask.shape == inpainted_motion.shape
304
+ model_output = (model_output * ~inpainting_mask) + (inpainted_motion * inpainting_mask)
305
+ # print('model_output', model_output.shape, model_output)
306
+ # print('inpainting_mask', inpainting_mask.shape, inpainting_mask[0,0,0,:])
307
+ # print('inpainted_motion', inpainted_motion.shape, inpainted_motion)
308
+
309
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
310
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
311
+ model_output, model_var_values = th.split(model_output, C, dim=1)
312
+ if self.model_var_type == ModelVarType.LEARNED:
313
+ model_log_variance = model_var_values
314
+ model_variance = th.exp(model_log_variance)
315
+ else:
316
+ min_log = _extract_into_tensor(
317
+ self.posterior_log_variance_clipped, t, x.shape
318
+ )
319
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
320
+ # The model_var_values is [-1, 1] for [min_var, max_var].
321
+ frac = (model_var_values + 1) / 2
322
+ model_log_variance = frac * max_log + (1 - frac) * min_log
323
+ model_variance = th.exp(model_log_variance)
324
+ else:
325
+ model_variance, model_log_variance = {
326
+ # for fixedlarge, we set the initial (log-)variance like so
327
+ # to get a better decoder log likelihood.
328
+ ModelVarType.FIXED_LARGE: (
329
+ np.append(self.posterior_variance[1], self.betas[1:]),
330
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
331
+ ),
332
+ ModelVarType.FIXED_SMALL: (
333
+ self.posterior_variance,
334
+ self.posterior_log_variance_clipped,
335
+ ),
336
+ }[self.model_var_type]
337
+ # print('model_variance', model_variance)
338
+ # print('model_log_variance',model_log_variance)
339
+ # print('self.posterior_variance', self.posterior_variance)
340
+ # print('self.posterior_log_variance_clipped', self.posterior_log_variance_clipped)
341
+ # print('self.model_var_type', self.model_var_type)
342
+
343
+
344
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
345
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
346
+
347
+ def process_xstart(x):
348
+ if denoised_fn is not None:
349
+ x = denoised_fn(x)
350
+ if clip_denoised:
351
+ # print('clip_denoised', clip_denoised)
352
+ return x.clamp(-1, 1)
353
+ return x
354
+
355
+ if self.model_mean_type == ModelMeanType.PREVIOUS_X:
356
+ pred_xstart = process_xstart(
357
+ self._predict_xstart_from_xprev(x_t=x, t=t, xprev=model_output)
358
+ )
359
+ model_mean = model_output
360
+ elif self.model_mean_type in [ModelMeanType.START_X, ModelMeanType.EPSILON]: # THIS IS US!
361
+ if self.model_mean_type == ModelMeanType.START_X:
362
+ pred_xstart = process_xstart(model_output)
363
+ else:
364
+ pred_xstart = process_xstart(
365
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
366
+ )
367
+ model_mean, _, _ = self.q_posterior_mean_variance(
368
+ x_start=pred_xstart, x_t=x, t=t
369
+ )
370
+ else:
371
+ raise NotImplementedError(self.model_mean_type)
372
+
373
+ assert (
374
+ model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
375
+ )
376
+ return {
377
+ "mean": model_mean,
378
+ "variance": model_variance,
379
+ "log_variance": model_log_variance,
380
+ "pred_xstart": pred_xstart,
381
+ }
382
+
383
+ def _predict_xstart_from_eps(self, x_t, t, eps):
384
+ assert x_t.shape == eps.shape
385
+ return (
386
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
387
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
388
+ )
389
+
390
+ def _predict_xstart_from_xprev(self, x_t, t, xprev):
391
+ assert x_t.shape == xprev.shape
392
+ return ( # (xprev - coef2*x_t) / coef1
393
+ _extract_into_tensor(1.0 / self.posterior_mean_coef1, t, x_t.shape) * xprev
394
+ - _extract_into_tensor(
395
+ self.posterior_mean_coef2 / self.posterior_mean_coef1, t, x_t.shape
396
+ )
397
+ * x_t
398
+ )
399
+
400
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
401
+ return (
402
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
403
+ - pred_xstart
404
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
405
+
406
+ def _scale_timesteps(self, t):
407
+ if self.rescale_timesteps:
408
+ return t.float() * (1000.0 / self.num_timesteps)
409
+ return t
410
+
411
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
412
+ """
413
+ Compute the mean for the previous step, given a function cond_fn that
414
+ computes the gradient of a conditional log probability with respect to
415
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
416
+ condition on y.
417
+
418
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
419
+ """
420
+ gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
421
+ new_mean = (
422
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
423
+ )
424
+ return new_mean
425
+
426
+ def condition_mean_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
427
+ """
428
+ Compute the mean for the previous step, given a function cond_fn that
429
+ computes the gradient of a conditional log probability with respect to
430
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
431
+ condition on y.
432
+
433
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
434
+ """
435
+ gradient = cond_fn(x, t, p_mean_var, **model_kwargs)
436
+ new_mean = (
437
+ p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
438
+ )
439
+ return new_mean
440
+
441
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
442
+ """
443
+ Compute what the p_mean_variance output would have been, should the
444
+ model's score function be conditioned by cond_fn.
445
+
446
+ See condition_mean() for details on cond_fn.
447
+
448
+ Unlike condition_mean(), this instead uses the conditioning strategy
449
+ from Song et al (2020).
450
+ """
451
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
452
+
453
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
454
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
455
+ x, self._scale_timesteps(t), **model_kwargs
456
+ )
457
+
458
+ out = p_mean_var.copy()
459
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
460
+ out["mean"], _, _ = self.q_posterior_mean_variance(
461
+ x_start=out["pred_xstart"], x_t=x, t=t
462
+ )
463
+ return out
464
+
465
+ def condition_score_with_grad(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
466
+ """
467
+ Compute what the p_mean_variance output would have been, should the
468
+ model's score function be conditioned by cond_fn.
469
+
470
+ See condition_mean() for details on cond_fn.
471
+
472
+ Unlike condition_mean(), this instead uses the conditioning strategy
473
+ from Song et al (2020).
474
+ """
475
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
476
+
477
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
478
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(
479
+ x, t, p_mean_var, **model_kwargs
480
+ )
481
+
482
+ out = p_mean_var.copy()
483
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
484
+ out["mean"], _, _ = self.q_posterior_mean_variance(
485
+ x_start=out["pred_xstart"], x_t=x, t=t
486
+ )
487
+ return out
488
+
489
+ def p_sample(
490
+ self,
491
+ model,
492
+ x,
493
+ t,
494
+ clip_denoised=True,
495
+ denoised_fn=None,
496
+ cond_fn=None,
497
+ model_kwargs=None,
498
+ const_noise=False,
499
+ ):
500
+ """
501
+ Sample x_{t-1} from the model at the given timestep.
502
+
503
+ :param model: the model to sample from.
504
+ :param x: the current tensor at x_{t-1}.
505
+ :param t: the value of t, starting at 0 for the first diffusion step.
506
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
507
+ :param denoised_fn: if not None, a function which applies to the
508
+ x_start prediction before it is used to sample.
509
+ :param cond_fn: if not None, this is a gradient function that acts
510
+ similarly to the model.
511
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
512
+ pass to the model. This can be used for conditioning.
513
+ :return: a dict containing the following keys:
514
+ - 'sample': a random sample from the model.
515
+ - 'pred_xstart': a prediction of x_0.
516
+ """
517
+ out = self.p_mean_variance(
518
+ model,
519
+ x,
520
+ t,
521
+ clip_denoised=clip_denoised,
522
+ denoised_fn=denoised_fn,
523
+ model_kwargs=model_kwargs,
524
+ )
525
+ noise = th.randn_like(x)
526
+ # print('const_noise', const_noise)
527
+ if const_noise:
528
+ noise = noise[[0]].repeat(x.shape[0], 1, 1, 1)
529
+
530
+ nonzero_mask = (
531
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
532
+ ) # no noise when t == 0
533
+ if cond_fn is not None:
534
+ out["mean"] = self.condition_mean(
535
+ cond_fn, out, x, t, model_kwargs=model_kwargs
536
+ )
537
+ # print('mean', out["mean"].shape, out["mean"])
538
+ # print('log_variance', out["log_variance"].shape, out["log_variance"])
539
+ # print('nonzero_mask', nonzero_mask.shape, nonzero_mask)
540
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
541
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
542
+
543
+ def p_sample_with_grad(
544
+ self,
545
+ model,
546
+ x,
547
+ t,
548
+ clip_denoised=True,
549
+ denoised_fn=None,
550
+ cond_fn=None,
551
+ model_kwargs=None,
552
+ ):
553
+ """
554
+ Sample x_{t-1} from the model at the given timestep.
555
+
556
+ :param model: the model to sample from.
557
+ :param x: the current tensor at x_{t-1}.
558
+ :param t: the value of t, starting at 0 for the first diffusion step.
559
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
560
+ :param denoised_fn: if not None, a function which applies to the
561
+ x_start prediction before it is used to sample.
562
+ :param cond_fn: if not None, this is a gradient function that acts
563
+ similarly to the model.
564
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
565
+ pass to the model. This can be used for conditioning.
566
+ :return: a dict containing the following keys:
567
+ - 'sample': a random sample from the model.
568
+ - 'pred_xstart': a prediction of x_0.
569
+ """
570
+ with th.enable_grad():
571
+ x = x.detach().requires_grad_()
572
+ out = self.p_mean_variance(
573
+ model,
574
+ x,
575
+ t,
576
+ clip_denoised=clip_denoised,
577
+ denoised_fn=denoised_fn,
578
+ model_kwargs=model_kwargs,
579
+ )
580
+ noise = th.randn_like(x)
581
+ nonzero_mask = (
582
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
583
+ ) # no noise when t == 0
584
+ if cond_fn is not None:
585
+ out["mean"] = self.condition_mean_with_grad(
586
+ cond_fn, out, x, t, model_kwargs=model_kwargs
587
+ )
588
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
589
+ return {"sample": sample, "pred_xstart": out["pred_xstart"].detach()}
590
+
591
+ def p_sample_loop(
592
+ self,
593
+ model,
594
+ shape,
595
+ noise=None,
596
+ clip_denoised=True,
597
+ denoised_fn=None,
598
+ cond_fn=None,
599
+ model_kwargs=None,
600
+ device=None,
601
+ progress=False,
602
+ skip_timesteps=0,
603
+ init_image=None,
604
+ randomize_class=False,
605
+ cond_fn_with_grad=False,
606
+ dump_steps=None,
607
+ const_noise=False,
608
+ ):
609
+ """
610
+ Generate samples from the model.
611
+
612
+ :param model: the model module.
613
+ :param shape: the shape of the samples, (N, C, H, W).
614
+ :param noise: if specified, the noise from the encoder to sample.
615
+ Should be of the same shape as `shape`.
616
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
617
+ :param denoised_fn: if not None, a function which applies to the
618
+ x_start prediction before it is used to sample.
619
+ :param cond_fn: if not None, this is a gradient function that acts
620
+ similarly to the model.
621
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
622
+ pass to the model. This can be used for conditioning.
623
+ :param device: if specified, the device to create the samples on.
624
+ If not specified, use a model parameter's device.
625
+ :param progress: if True, show a tqdm progress bar.
626
+ :param const_noise: If True, will noise all samples with the same noise throughout sampling
627
+ :return: a non-differentiable batch of samples.
628
+ """
629
+ final = None
630
+ if dump_steps is not None:
631
+ dump = []
632
+
633
+ if 'text' in model_kwargs['y'].keys():
634
+ # encoding once instead of each iteration saves lots of time
635
+ model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text'])
636
+
637
+ for i, sample in enumerate(self.p_sample_loop_progressive(
638
+ model,
639
+ shape,
640
+ noise=noise,
641
+ clip_denoised=clip_denoised,
642
+ denoised_fn=denoised_fn,
643
+ cond_fn=cond_fn,
644
+ model_kwargs=model_kwargs,
645
+ device=device,
646
+ progress=progress,
647
+ skip_timesteps=skip_timesteps,
648
+ init_image=init_image,
649
+ randomize_class=randomize_class,
650
+ cond_fn_with_grad=cond_fn_with_grad,
651
+ const_noise=const_noise,
652
+ )):
653
+ if dump_steps is not None and i in dump_steps:
654
+ dump.append(deepcopy(sample["sample"]))
655
+ final = sample
656
+ if dump_steps is not None:
657
+ return dump
658
+ return final["sample"]
659
+
660
+ def p_sample_loop_progressive(
661
+ self,
662
+ model,
663
+ shape,
664
+ noise=None,
665
+ clip_denoised=True,
666
+ denoised_fn=None,
667
+ cond_fn=None,
668
+ model_kwargs=None,
669
+ device=None,
670
+ progress=False,
671
+ skip_timesteps=0,
672
+ init_image=None,
673
+ randomize_class=False,
674
+ cond_fn_with_grad=False,
675
+ const_noise=False,
676
+ ):
677
+ """
678
+ Generate samples from the model and yield intermediate samples from
679
+ each timestep of diffusion.
680
+
681
+ Arguments are the same as p_sample_loop().
682
+ Returns a generator over dicts, where each dict is the return value of
683
+ p_sample().
684
+ """
685
+ if device is None:
686
+ device = next(model.parameters()).device
687
+ assert isinstance(shape, (tuple, list))
688
+ if noise is not None:
689
+ img = noise
690
+ else:
691
+ img = th.randn(*shape, device=device)
692
+
693
+ if skip_timesteps and init_image is None:
694
+ init_image = th.zeros_like(img)
695
+
696
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
697
+
698
+ if init_image is not None:
699
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
700
+ img = self.q_sample(init_image, my_t, img)
701
+
702
+ if progress:
703
+ # Lazy import so that we don't depend on tqdm.
704
+ from tqdm.auto import tqdm
705
+
706
+ indices = tqdm(indices)
707
+
708
+ for i in indices:
709
+ t = th.tensor([i] * shape[0], device=device)
710
+ if randomize_class and 'y' in model_kwargs:
711
+ model_kwargs['y'] = th.randint(low=0, high=model.num_classes,
712
+ size=model_kwargs['y'].shape,
713
+ device=model_kwargs['y'].device)
714
+ with th.no_grad():
715
+ sample_fn = self.p_sample_with_grad if cond_fn_with_grad else self.p_sample
716
+ out = sample_fn(
717
+ model,
718
+ img,
719
+ t,
720
+ clip_denoised=clip_denoised,
721
+ denoised_fn=denoised_fn,
722
+ cond_fn=cond_fn,
723
+ model_kwargs=model_kwargs,
724
+ const_noise=const_noise,
725
+ )
726
+ yield out
727
+ img = out["sample"]
728
+
729
+ def ddim_sample(
730
+ self,
731
+ model,
732
+ x,
733
+ t,
734
+ clip_denoised=True,
735
+ denoised_fn=None,
736
+ cond_fn=None,
737
+ model_kwargs=None,
738
+ eta=0.0,
739
+ ):
740
+ """
741
+ Sample x_{t-1} from the model using DDIM.
742
+
743
+ Same usage as p_sample().
744
+ """
745
+ out_orig = self.p_mean_variance(
746
+ model,
747
+ x,
748
+ t,
749
+ clip_denoised=clip_denoised,
750
+ denoised_fn=denoised_fn,
751
+ model_kwargs=model_kwargs,
752
+ )
753
+ if cond_fn is not None:
754
+ out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
755
+ else:
756
+ out = out_orig
757
+
758
+ # Usually our model outputs epsilon, but we re-derive it
759
+ # in case we used x_start or x_prev prediction.
760
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
761
+
762
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
763
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
764
+ sigma = (
765
+ eta
766
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
767
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
768
+ )
769
+ # Equation 12.
770
+ noise = th.randn_like(x)
771
+ mean_pred = (
772
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
773
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
774
+ )
775
+ nonzero_mask = (
776
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
777
+ ) # no noise when t == 0
778
+ sample = mean_pred + nonzero_mask * sigma * noise
779
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"]}
780
+
781
+ def ddim_sample_with_grad(
782
+ self,
783
+ model,
784
+ x,
785
+ t,
786
+ clip_denoised=True,
787
+ denoised_fn=None,
788
+ cond_fn=None,
789
+ model_kwargs=None,
790
+ eta=0.0,
791
+ ):
792
+ """
793
+ Sample x_{t-1} from the model using DDIM.
794
+
795
+ Same usage as p_sample().
796
+ """
797
+ with th.enable_grad():
798
+ x = x.detach().requires_grad_()
799
+ out_orig = self.p_mean_variance(
800
+ model,
801
+ x,
802
+ t,
803
+ clip_denoised=clip_denoised,
804
+ denoised_fn=denoised_fn,
805
+ model_kwargs=model_kwargs,
806
+ )
807
+ if cond_fn is not None:
808
+ out = self.condition_score_with_grad(cond_fn, out_orig, x, t,
809
+ model_kwargs=model_kwargs)
810
+ else:
811
+ out = out_orig
812
+
813
+ out["pred_xstart"] = out["pred_xstart"].detach()
814
+
815
+ # Usually our model outputs epsilon, but we re-derive it
816
+ # in case we used x_start or x_prev prediction.
817
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
818
+
819
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
820
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
821
+ sigma = (
822
+ eta
823
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
824
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
825
+ )
826
+ # Equation 12.
827
+ noise = th.randn_like(x)
828
+ mean_pred = (
829
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
830
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
831
+ )
832
+ nonzero_mask = (
833
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
834
+ ) # no noise when t == 0
835
+ sample = mean_pred + nonzero_mask * sigma * noise
836
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"].detach()}
837
+
838
+ def ddim_reverse_sample(
839
+ self,
840
+ model,
841
+ x,
842
+ t,
843
+ clip_denoised=True,
844
+ denoised_fn=None,
845
+ model_kwargs=None,
846
+ eta=0.0,
847
+ ):
848
+ """
849
+ Sample x_{t+1} from the model using DDIM reverse ODE.
850
+ """
851
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
852
+ out = self.p_mean_variance(
853
+ model,
854
+ x,
855
+ t,
856
+ clip_denoised=clip_denoised,
857
+ denoised_fn=denoised_fn,
858
+ model_kwargs=model_kwargs,
859
+ )
860
+ # Usually our model outputs epsilon, but we re-derive it
861
+ # in case we used x_start or x_prev prediction.
862
+ eps = (
863
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
864
+ - out["pred_xstart"]
865
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
866
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
867
+
868
+ # Equation 12. reversed
869
+ mean_pred = (
870
+ out["pred_xstart"] * th.sqrt(alpha_bar_next)
871
+ + th.sqrt(1 - alpha_bar_next) * eps
872
+ )
873
+
874
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
875
+
876
+ def ddim_sample_loop(
877
+ self,
878
+ model,
879
+ shape,
880
+ noise=None,
881
+ clip_denoised=True,
882
+ denoised_fn=None,
883
+ cond_fn=None,
884
+ model_kwargs=None,
885
+ device=None,
886
+ progress=False,
887
+ eta=0.0,
888
+ skip_timesteps=0,
889
+ init_image=None,
890
+ randomize_class=False,
891
+ cond_fn_with_grad=False,
892
+ dump_steps=None,
893
+ const_noise=False,
894
+ ):
895
+ """
896
+ Generate samples from the model using DDIM.
897
+
898
+ Same usage as p_sample_loop().
899
+ """
900
+ if dump_steps is not None:
901
+ raise NotImplementedError()
902
+ if const_noise == True:
903
+ raise NotImplementedError()
904
+
905
+ final = None
906
+ for sample in self.ddim_sample_loop_progressive(
907
+ model,
908
+ shape,
909
+ noise=noise,
910
+ clip_denoised=clip_denoised,
911
+ denoised_fn=denoised_fn,
912
+ cond_fn=cond_fn,
913
+ model_kwargs=model_kwargs,
914
+ device=device,
915
+ progress=progress,
916
+ eta=eta,
917
+ skip_timesteps=skip_timesteps,
918
+ init_image=init_image,
919
+ randomize_class=randomize_class,
920
+ cond_fn_with_grad=cond_fn_with_grad,
921
+ ):
922
+ final = sample
923
+ return final["sample"]
924
+
925
+ def ddim_sample_loop_progressive(
926
+ self,
927
+ model,
928
+ shape,
929
+ noise=None,
930
+ clip_denoised=True,
931
+ denoised_fn=None,
932
+ cond_fn=None,
933
+ model_kwargs=None,
934
+ device=None,
935
+ progress=False,
936
+ eta=0.0,
937
+ skip_timesteps=0,
938
+ init_image=None,
939
+ randomize_class=False,
940
+ cond_fn_with_grad=False,
941
+ ):
942
+ """
943
+ Use DDIM to sample from the model and yield intermediate samples from
944
+ each timestep of DDIM.
945
+
946
+ Same usage as p_sample_loop_progressive().
947
+ """
948
+ if device is None:
949
+ device = next(model.parameters()).device
950
+ assert isinstance(shape, (tuple, list))
951
+ if noise is not None:
952
+ img = noise
953
+ else:
954
+ img = th.randn(*shape, device=device)
955
+
956
+ if skip_timesteps and init_image is None:
957
+ init_image = th.zeros_like(img)
958
+
959
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
960
+
961
+ if init_image is not None:
962
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
963
+ img = self.q_sample(init_image, my_t, img)
964
+
965
+ if progress:
966
+ # Lazy import so that we don't depend on tqdm.
967
+ from tqdm.auto import tqdm
968
+
969
+ indices = tqdm(indices)
970
+
971
+ for i in indices:
972
+ t = th.tensor([i] * shape[0], device=device)
973
+ if randomize_class and 'y' in model_kwargs:
974
+ model_kwargs['y'] = th.randint(low=0, high=model.num_classes,
975
+ size=model_kwargs['y'].shape,
976
+ device=model_kwargs['y'].device)
977
+ with th.no_grad():
978
+ sample_fn = self.ddim_sample_with_grad if cond_fn_with_grad else self.ddim_sample
979
+ out = sample_fn(
980
+ model,
981
+ img,
982
+ t,
983
+ clip_denoised=clip_denoised,
984
+ denoised_fn=denoised_fn,
985
+ cond_fn=cond_fn,
986
+ model_kwargs=model_kwargs,
987
+ eta=eta,
988
+ )
989
+ yield out
990
+ img = out["sample"]
991
+
992
+ def plms_sample(
993
+ self,
994
+ model,
995
+ x,
996
+ t,
997
+ clip_denoised=True,
998
+ denoised_fn=None,
999
+ cond_fn=None,
1000
+ model_kwargs=None,
1001
+ cond_fn_with_grad=False,
1002
+ order=2,
1003
+ old_out=None,
1004
+ ):
1005
+ """
1006
+ Sample x_{t-1} from the model using Pseudo Linear Multistep.
1007
+
1008
+ Same usage as p_sample().
1009
+ """
1010
+ if not int(order) or not 1 <= order <= 4:
1011
+ raise ValueError('order is invalid (should be int from 1-4).')
1012
+
1013
+ def get_model_output(x, t):
1014
+ with th.set_grad_enabled(cond_fn_with_grad and cond_fn is not None):
1015
+ x = x.detach().requires_grad_() if cond_fn_with_grad else x
1016
+ out_orig = self.p_mean_variance(
1017
+ model,
1018
+ x,
1019
+ t,
1020
+ clip_denoised=clip_denoised,
1021
+ denoised_fn=denoised_fn,
1022
+ model_kwargs=model_kwargs,
1023
+ )
1024
+ if cond_fn is not None:
1025
+ if cond_fn_with_grad:
1026
+ out = self.condition_score_with_grad(cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
1027
+ x = x.detach()
1028
+ else:
1029
+ out = self.condition_score(cond_fn, out_orig, x, t, model_kwargs=model_kwargs)
1030
+ else:
1031
+ out = out_orig
1032
+
1033
+ # Usually our model outputs epsilon, but we re-derive it
1034
+ # in case we used x_start or x_prev prediction.
1035
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
1036
+ return eps, out, out_orig
1037
+
1038
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
1039
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
1040
+ eps, out, out_orig = get_model_output(x, t)
1041
+
1042
+ if order > 1 and old_out is None:
1043
+ # Pseudo Improved Euler
1044
+ old_eps = [eps]
1045
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps
1046
+ eps_2, _, _ = get_model_output(mean_pred, t - 1)
1047
+ eps_prime = (eps + eps_2) / 2
1048
+ pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
1049
+ mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime
1050
+ else:
1051
+ # Pseudo Linear Multistep (Adams-Bashforth)
1052
+ old_eps = old_out["old_eps"]
1053
+ old_eps.append(eps)
1054
+ cur_order = min(order, len(old_eps))
1055
+ if cur_order == 1:
1056
+ eps_prime = old_eps[-1]
1057
+ elif cur_order == 2:
1058
+ eps_prime = (3 * old_eps[-1] - old_eps[-2]) / 2
1059
+ elif cur_order == 3:
1060
+ eps_prime = (23 * old_eps[-1] - 16 * old_eps[-2] + 5 * old_eps[-3]) / 12
1061
+ elif cur_order == 4:
1062
+ eps_prime = (55 * old_eps[-1] - 59 * old_eps[-2] + 37 * old_eps[-3] - 9 * old_eps[-4]) / 24
1063
+ else:
1064
+ raise RuntimeError('cur_order is invalid.')
1065
+ pred_prime = self._predict_xstart_from_eps(x, t, eps_prime)
1066
+ mean_pred = pred_prime * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev) * eps_prime
1067
+
1068
+ if len(old_eps) >= order:
1069
+ old_eps.pop(0)
1070
+
1071
+ nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
1072
+ sample = mean_pred * nonzero_mask + out["pred_xstart"] * (1 - nonzero_mask)
1073
+
1074
+ return {"sample": sample, "pred_xstart": out_orig["pred_xstart"], "old_eps": old_eps}
1075
+
1076
+ def plms_sample_loop(
1077
+ self,
1078
+ model,
1079
+ shape,
1080
+ noise=None,
1081
+ clip_denoised=True,
1082
+ denoised_fn=None,
1083
+ cond_fn=None,
1084
+ model_kwargs=None,
1085
+ device=None,
1086
+ progress=False,
1087
+ skip_timesteps=0,
1088
+ init_image=None,
1089
+ randomize_class=False,
1090
+ cond_fn_with_grad=False,
1091
+ order=2,
1092
+ ):
1093
+ """
1094
+ Generate samples from the model using Pseudo Linear Multistep.
1095
+
1096
+ Same usage as p_sample_loop().
1097
+ """
1098
+ final = None
1099
+ for sample in self.plms_sample_loop_progressive(
1100
+ model,
1101
+ shape,
1102
+ noise=noise,
1103
+ clip_denoised=clip_denoised,
1104
+ denoised_fn=denoised_fn,
1105
+ cond_fn=cond_fn,
1106
+ model_kwargs=model_kwargs,
1107
+ device=device,
1108
+ progress=progress,
1109
+ skip_timesteps=skip_timesteps,
1110
+ init_image=init_image,
1111
+ randomize_class=randomize_class,
1112
+ cond_fn_with_grad=cond_fn_with_grad,
1113
+ order=order,
1114
+ ):
1115
+ final = sample
1116
+ return final["sample"]
1117
+
1118
+ def plms_sample_loop_progressive(
1119
+ self,
1120
+ model,
1121
+ shape,
1122
+ noise=None,
1123
+ clip_denoised=True,
1124
+ denoised_fn=None,
1125
+ cond_fn=None,
1126
+ model_kwargs=None,
1127
+ device=None,
1128
+ progress=False,
1129
+ skip_timesteps=0,
1130
+ init_image=None,
1131
+ randomize_class=False,
1132
+ cond_fn_with_grad=False,
1133
+ order=2,
1134
+ ):
1135
+ """
1136
+ Use PLMS to sample from the model and yield intermediate samples from each
1137
+ timestep of PLMS.
1138
+
1139
+ Same usage as p_sample_loop_progressive().
1140
+ """
1141
+ if device is None:
1142
+ device = next(model.parameters()).device
1143
+ assert isinstance(shape, (tuple, list))
1144
+ if noise is not None:
1145
+ img = noise
1146
+ else:
1147
+ img = th.randn(*shape, device=device)
1148
+
1149
+ if skip_timesteps and init_image is None:
1150
+ init_image = th.zeros_like(img)
1151
+
1152
+ indices = list(range(self.num_timesteps - skip_timesteps))[::-1]
1153
+
1154
+ if init_image is not None:
1155
+ my_t = th.ones([shape[0]], device=device, dtype=th.long) * indices[0]
1156
+ img = self.q_sample(init_image, my_t, img)
1157
+
1158
+ if progress:
1159
+ # Lazy import so that we don't depend on tqdm.
1160
+ from tqdm.auto import tqdm
1161
+
1162
+ indices = tqdm(indices)
1163
+
1164
+ old_out = None
1165
+
1166
+ for i in indices:
1167
+ t = th.tensor([i] * shape[0], device=device)
1168
+ if randomize_class and 'y' in model_kwargs:
1169
+ model_kwargs['y'] = th.randint(low=0, high=model.num_classes,
1170
+ size=model_kwargs['y'].shape,
1171
+ device=model_kwargs['y'].device)
1172
+ with th.no_grad():
1173
+ out = self.plms_sample(
1174
+ model,
1175
+ img,
1176
+ t,
1177
+ clip_denoised=clip_denoised,
1178
+ denoised_fn=denoised_fn,
1179
+ cond_fn=cond_fn,
1180
+ model_kwargs=model_kwargs,
1181
+ cond_fn_with_grad=cond_fn_with_grad,
1182
+ order=order,
1183
+ old_out=old_out,
1184
+ )
1185
+ yield out
1186
+ old_out = out
1187
+ img = out["sample"]
1188
+
1189
+ def _vb_terms_bpd(
1190
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
1191
+ ):
1192
+ """
1193
+ Get a term for the variational lower-bound.
1194
+
1195
+ The resulting units are bits (rather than nats, as one might expect).
1196
+ This allows for comparison to other papers.
1197
+
1198
+ :return: a dict with the following keys:
1199
+ - 'output': a shape [N] tensor of NLLs or KLs.
1200
+ - 'pred_xstart': the x_0 predictions.
1201
+ """
1202
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
1203
+ x_start=x_start, x_t=x_t, t=t
1204
+ )
1205
+ out = self.p_mean_variance(
1206
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
1207
+ )
1208
+ kl = normal_kl(
1209
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
1210
+ )
1211
+ kl = mean_flat(kl) / np.log(2.0)
1212
+
1213
+ decoder_nll = -discretized_gaussian_log_likelihood(
1214
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
1215
+ )
1216
+ assert decoder_nll.shape == x_start.shape
1217
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
1218
+
1219
+ # At the first timestep return the decoder NLL,
1220
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
1221
+ output = th.where((t == 0), decoder_nll, kl)
1222
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
1223
+
1224
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, dataset=None):
1225
+ """
1226
+ Compute training losses for a single timestep.
1227
+
1228
+ :param model: the model to evaluate loss on.
1229
+ :param x_start: the [N x C x ...] tensor of inputs.
1230
+ :param t: a batch of timestep indices.
1231
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1232
+ pass to the model. This can be used for conditioning.
1233
+ :param noise: if specified, the specific Gaussian noise to try to remove.
1234
+ :return: a dict with the key "loss" containing a tensor of shape [N].
1235
+ Some mean or variance settings may also have other keys.
1236
+ """
1237
+
1238
+ # enc = model.model._modules['module']
1239
+ enc = model.model
1240
+ mask = model_kwargs['y']['mask']
1241
+ get_xyz = lambda sample: enc.rot2xyz(sample, mask=None, pose_rep=enc.pose_rep, translation=enc.translation,
1242
+ glob=enc.glob,
1243
+ # jointstype='vertices', # 3.4 iter/sec # USED ALSO IN MotionCLIP
1244
+ jointstype='smpl', # 3.4 iter/sec
1245
+ vertstrans=False)
1246
+
1247
+ if model_kwargs is None:
1248
+ model_kwargs = {}
1249
+ if noise is None:
1250
+ noise = th.randn_like(x_start)
1251
+ x_t = self.q_sample(x_start, t, noise=noise)
1252
+
1253
+ terms = {}
1254
+
1255
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
1256
+ terms["loss"] = self._vb_terms_bpd(
1257
+ model=model,
1258
+ x_start=x_start,
1259
+ x_t=x_t,
1260
+ t=t,
1261
+ clip_denoised=False,
1262
+ model_kwargs=model_kwargs,
1263
+ )["output"]
1264
+ if self.loss_type == LossType.RESCALED_KL:
1265
+ terms["loss"] *= self.num_timesteps
1266
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
1267
+ model_output = model(x_t, self._scale_timesteps(t), **model_kwargs)
1268
+
1269
+ if self.model_var_type in [
1270
+ ModelVarType.LEARNED,
1271
+ ModelVarType.LEARNED_RANGE,
1272
+ ]:
1273
+ B, C = x_t.shape[:2]
1274
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
1275
+ model_output, model_var_values = th.split(model_output, C, dim=1)
1276
+ # Learn the variance using the variational bound, but don't let
1277
+ # it affect our mean prediction.
1278
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
1279
+ terms["vb"] = self._vb_terms_bpd(
1280
+ model=lambda *args, r=frozen_out: r,
1281
+ x_start=x_start,
1282
+ x_t=x_t,
1283
+ t=t,
1284
+ clip_denoised=False,
1285
+ )["output"]
1286
+ if self.loss_type == LossType.RESCALED_MSE:
1287
+ # Divide by 1000 for equivalence with initial implementation.
1288
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
1289
+ terms["vb"] *= self.num_timesteps / 1000.0
1290
+
1291
+ target = {
1292
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
1293
+ x_start=x_start, x_t=x_t, t=t
1294
+ )[0],
1295
+ ModelMeanType.START_X: x_start,
1296
+ ModelMeanType.EPSILON: noise,
1297
+ }[self.model_mean_type]
1298
+ assert model_output.shape == target.shape == x_start.shape # [bs, njoints, nfeats, nframes]
1299
+
1300
+ terms["rot_mse"] = self.masked_l2(target, model_output, mask) # mean_flat(rot_mse)
1301
+
1302
+ target_xyz, model_output_xyz = None, None
1303
+
1304
+ if self.lambda_rcxyz > 0.:
1305
+ target_xyz = get_xyz(target) # [bs, nvertices(vertices)/njoints(smpl), 3, nframes]
1306
+ model_output_xyz = get_xyz(model_output) # [bs, nvertices, 3, nframes]
1307
+ terms["rcxyz_mse"] = self.masked_l2(target_xyz, model_output_xyz, mask) # mean_flat((target_xyz - model_output_xyz) ** 2)
1308
+
1309
+ if self.lambda_vel_rcxyz > 0.:
1310
+ if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']:
1311
+ target_xyz = get_xyz(target) if target_xyz is None else target_xyz
1312
+ model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz
1313
+ target_xyz_vel = (target_xyz[:, :, :, 1:] - target_xyz[:, :, :, :-1])
1314
+ model_output_xyz_vel = (model_output_xyz[:, :, :, 1:] - model_output_xyz[:, :, :, :-1])
1315
+ terms["vel_xyz_mse"] = self.masked_l2(target_xyz_vel, model_output_xyz_vel, mask[:, :, :, 1:])
1316
+
1317
+ if self.lambda_fc > 0.:
1318
+ torch.autograd.set_detect_anomaly(True)
1319
+ if self.data_rep == 'rot6d' and dataset.dataname in ['humanact12', 'uestc']:
1320
+ target_xyz = get_xyz(target) if target_xyz is None else target_xyz
1321
+ model_output_xyz = get_xyz(model_output) if model_output_xyz is None else model_output_xyz
1322
+ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11
1323
+ l_ankle_idx, r_ankle_idx, l_foot_idx, r_foot_idx = 7, 8, 10, 11
1324
+ relevant_joints = [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx]
1325
+ gt_joint_xyz = target_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames]
1326
+ gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames]
1327
+ fc_mask = torch.unsqueeze((gt_joint_vel <= 0.01), dim=2).repeat(1, 1, 3, 1)
1328
+ pred_joint_xyz = model_output_xyz[:, relevant_joints, :, :] # [BatchSize, 4, 3, Frames]
1329
+ pred_vel = pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1]
1330
+ pred_vel[~fc_mask] = 0
1331
+ terms["fc"] = self.masked_l2(pred_vel,
1332
+ torch.zeros(pred_vel.shape, device=pred_vel.device),
1333
+ mask[:, :, :, 1:])
1334
+ if self.lambda_vel > 0.:
1335
+ target_vel = (target[..., 1:] - target[..., :-1])
1336
+ model_output_vel = (model_output[..., 1:] - model_output[..., :-1])
1337
+ terms["vel_mse"] = self.masked_l2(target_vel[:, :-1, :, :], # Remove last joint, is the root location!
1338
+ model_output_vel[:, :-1, :, :],
1339
+ mask[:, :, :, 1:]) # mean_flat((target_vel - model_output_vel) ** 2)
1340
+
1341
+ if self.lambda_target_loc > 0.:
1342
+ assert self.model_mean_type == ModelMeanType.START_X, 'This feature supports only X_start pred for now!'
1343
+ ref_target = model_kwargs['y']['target_cond']
1344
+ pred_target = get_target_location(model_output, dataset.mean_gpu, dataset.std_gpu,
1345
+ model_kwargs['y']['lengths'], dataset.t2m_dataset.opt.joints_num, model.all_goal_joint_names,
1346
+ model_kwargs['y']['target_joint_names'], model_kwargs['y']['is_heading'])
1347
+ terms["target_loc"] = masked_goal_l2(pred_target, ref_target, model_kwargs['y'], model.all_goal_joint_names)
1348
+
1349
+
1350
+ terms["loss"] = terms["rot_mse"] + terms.get('vb', 0.) +\
1351
+ (self.lambda_vel * terms.get('vel_mse', 0.)) +\
1352
+ (self.lambda_rcxyz * terms.get('rcxyz_mse', 0.)) + \
1353
+ (self.lambda_target_loc * terms.get('target_loc', 0.)) + \
1354
+ (self.lambda_fc * terms.get('fc', 0.))
1355
+
1356
+ else:
1357
+ raise NotImplementedError(self.loss_type)
1358
+
1359
+ return terms
1360
+
1361
+ def fc_loss_rot_repr(self, gt_xyz, pred_xyz, mask):
1362
+ def to_np_cpu(x):
1363
+ return x.detach().cpu().numpy()
1364
+ """
1365
+ pose_xyz: SMPL batch tensor of shape: [BatchSize, 24, 3, Frames]
1366
+ """
1367
+ # 'L_Ankle', # 7, 'R_Ankle', # 8 , 'L_Foot', # 10, 'R_Foot', # 11
1368
+
1369
+ l_ankle_idx, r_ankle_idx = 7, 8
1370
+ l_foot_idx, r_foot_idx = 10, 11
1371
+ """ Contact calculated by 'Kfir Method' Commented code)"""
1372
+ # contact_signal = torch.zeros((pose_xyz.shape[0], pose_xyz.shape[3], 2), device=pose_xyz.device) # [BatchSize, Frames, 2]
1373
+ # left_xyz = 0.5 * (pose_xyz[:, l_ankle_idx, :, :] + pose_xyz[:, l_foot_idx, :, :]) # [BatchSize, 3, Frames]
1374
+ # right_xyz = 0.5 * (pose_xyz[:, r_ankle_idx, :, :] + pose_xyz[:, r_foot_idx, :, :])
1375
+ # left_z, right_z = left_xyz[:, 2, :], right_xyz[:, 2, :] # [BatchSize, Frames]
1376
+ # left_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1) # [BatchSize, Frames]
1377
+ # right_velocity = torch.linalg.norm(left_xyz[:, :, 2:] - left_xyz[:, :, :-2], axis=1)
1378
+ #
1379
+ # left_z_mask = left_z <= torch.mean(torch.sort(left_z)[0][:, :left_z.shape[1] // 5], axis=-1)
1380
+ # left_z_mask = torch.stack([left_z_mask, left_z_mask], dim=-1) # [BatchSize, Frames, 2]
1381
+ # left_z_mask[:, :, 1] = False # Blank right side
1382
+ # contact_signal[left_z_mask] = 0.4
1383
+ #
1384
+ # right_z_mask = right_z <= torch.mean(torch.sort(right_z)[0][:, :right_z.shape[1] // 5], axis=-1)
1385
+ # right_z_mask = torch.stack([right_z_mask, right_z_mask], dim=-1) # [BatchSize, Frames, 2]
1386
+ # right_z_mask[:, :, 0] = False # Blank left side
1387
+ # contact_signal[right_z_mask] = 0.4
1388
+ # contact_signal[left_z <= (torch.mean(torch.sort(left_z)[:left_z.shape[0] // 5]) + 20), 0] = 1
1389
+ # contact_signal[right_z <= (torch.mean(torch.sort(right_z)[:right_z.shape[0] // 5]) + 20), 1] = 1
1390
+
1391
+ # plt.plot(to_np_cpu(left_z[0]), label='left_z')
1392
+ # plt.plot(to_np_cpu(left_velocity[0]), label='left_velocity')
1393
+ # plt.plot(to_np_cpu(contact_signal[0, :, 0]), label='left_fc')
1394
+ # plt.grid()
1395
+ # plt.legend()
1396
+ # plt.show()
1397
+ # plt.plot(to_np_cpu(right_z[0]), label='right_z')
1398
+ # plt.plot(to_np_cpu(right_velocity[0]), label='right_velocity')
1399
+ # plt.plot(to_np_cpu(contact_signal[0, :, 1]), label='right_fc')
1400
+ # plt.grid()
1401
+ # plt.legend()
1402
+ # plt.show()
1403
+
1404
+ gt_joint_xyz = gt_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames]
1405
+ gt_joint_vel = torch.linalg.norm(gt_joint_xyz[:, :, :, 1:] - gt_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames]
1406
+ fc_mask = (gt_joint_vel <= 0.01)
1407
+ pred_joint_xyz = pred_xyz[:, [l_ankle_idx, l_foot_idx, r_ankle_idx, r_foot_idx], :, :] # [BatchSize, 4, 3, Frames]
1408
+ pred_joint_vel = torch.linalg.norm(pred_joint_xyz[:, :, :, 1:] - pred_joint_xyz[:, :, :, :-1], axis=2) # [BatchSize, 4, Frames]
1409
+ pred_joint_vel[~fc_mask] = 0 # Blank non-contact velocities frames. [BS,4,FRAMES]
1410
+ pred_joint_vel = torch.unsqueeze(pred_joint_vel, dim=2)
1411
+
1412
+ """DEBUG CODE"""
1413
+ # print(f'mask: {mask.shape}')
1414
+ # print(f'pred_joint_vel: {pred_joint_vel.shape}')
1415
+ # plt.title(f'Joint: {joint_idx}')
1416
+ # plt.plot(to_np_cpu(gt_joint_vel[0]), label='velocity')
1417
+ # plt.plot(to_np_cpu(fc_mask[0]), label='fc')
1418
+ # plt.grid()
1419
+ # plt.legend()
1420
+ # plt.show()
1421
+ return self.masked_l2(pred_joint_vel, torch.zeros(pred_joint_vel.shape, device=pred_joint_vel.device),
1422
+ mask[:, :, :, 1:])
1423
+ # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE!
1424
+ def foot_contact_loss_humanml3d(self, target, model_output):
1425
+ # root_rot_velocity (B, seq_len, 1)
1426
+ # root_linear_velocity (B, seq_len, 2)
1427
+ # root_y (B, seq_len, 1)
1428
+ # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ
1429
+ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D
1430
+ # local_velocity (B, seq_len, joint_num*3) , XYZ
1431
+ # foot contact (B, seq_len, 4) ,
1432
+
1433
+ target_fc = target[:, -4:, :, :]
1434
+ root_rot_velocity = target[:, :1, :, :]
1435
+ root_linear_velocity = target[:, 1:3, :, :]
1436
+ root_y = target[:, 3:4, :, :]
1437
+ ric_data = target[:, 4:67, :, :] # 4+(3*21)=67
1438
+ rot_data = target[:, 67:193, :, :] # 67+(6*21)=193
1439
+ local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259
1440
+ contact = target[:, 259:, :, :] # 193+(3*22)=259
1441
+ contact_mask_gt = contact > 0.5 # contact mask order for indexes are fid_l [7, 10], fid_r [8, 11]
1442
+ vel_lf_7 = local_velocity[:, 7 * 3:8 * 3, :, :]
1443
+ vel_rf_8 = local_velocity[:, 8 * 3:9 * 3, :, :]
1444
+ vel_lf_10 = local_velocity[:, 10 * 3:11 * 3, :, :]
1445
+ vel_rf_11 = local_velocity[:, 11 * 3:12 * 3, :, :]
1446
+
1447
+ calc_vel_lf_7 = ric_data[:, 6 * 3:7 * 3, :, 1:] - ric_data[:, 6 * 3:7 * 3, :, :-1]
1448
+ calc_vel_rf_8 = ric_data[:, 7 * 3:8 * 3, :, 1:] - ric_data[:, 7 * 3:8 * 3, :, :-1]
1449
+ calc_vel_lf_10 = ric_data[:, 9 * 3:10 * 3, :, 1:] - ric_data[:, 9 * 3:10 * 3, :, :-1]
1450
+ calc_vel_rf_11 = ric_data[:, 10 * 3:11 * 3, :, 1:] - ric_data[:, 10 * 3:11 * 3, :, :-1]
1451
+
1452
+ # vel_foots = torch.stack([vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11], dim=1)
1453
+ for chosen_vel_foot_calc, chosen_vel_foot, joint_idx, contact_mask_idx in zip(
1454
+ [calc_vel_lf_7, calc_vel_rf_8, calc_vel_lf_10, calc_vel_rf_11],
1455
+ [vel_lf_7, vel_lf_10, vel_rf_8, vel_rf_11],
1456
+ [7, 10, 8, 11],
1457
+ [0, 1, 2, 3]):
1458
+ tmp_mask_gt = contact_mask_gt[:, contact_mask_idx, :, :].cpu().detach().numpy().reshape(-1).astype(int)
1459
+ chosen_vel_norm = np.linalg.norm(chosen_vel_foot.cpu().detach().numpy().reshape((3, -1)), axis=0)
1460
+ chosen_vel_calc_norm = np.linalg.norm(chosen_vel_foot_calc.cpu().detach().numpy().reshape((3, -1)),
1461
+ axis=0)
1462
+
1463
+ print(tmp_mask_gt.shape)
1464
+ print(chosen_vel_foot.shape)
1465
+ print(chosen_vel_calc_norm.shape)
1466
+ import matplotlib.pyplot as plt
1467
+ plt.plot(tmp_mask_gt, label='FC mask')
1468
+ plt.plot(chosen_vel_norm, label='Vel. XYZ norm (from vector)')
1469
+ plt.plot(chosen_vel_calc_norm, label='Vel. XYZ norm (calculated diff XYZ)')
1470
+
1471
+ plt.title(f'FC idx {contact_mask_idx}, Joint Index {joint_idx}')
1472
+ plt.legend()
1473
+ plt.show()
1474
+ # print(vel_foots.shape)
1475
+ return 0
1476
+ # TODO - NOT USED YET, JUST COMMITING TO NOT DELETE THIS AND KEEP INITIAL IMPLEMENTATION, NOT DONE!
1477
+ def velocity_consistency_loss_humanml3d(self, target, model_output):
1478
+ # root_rot_velocity (B, seq_len, 1)
1479
+ # root_linear_velocity (B, seq_len, 2)
1480
+ # root_y (B, seq_len, 1)
1481
+ # ric_data (B, seq_len, (joint_num - 1)*3) , XYZ
1482
+ # rot_data (B, seq_len, (joint_num - 1)*6) , 6D
1483
+ # local_velocity (B, seq_len, joint_num*3) , XYZ
1484
+ # foot contact (B, seq_len, 4) ,
1485
+
1486
+ target_fc = target[:, -4:, :, :]
1487
+ root_rot_velocity = target[:, :1, :, :]
1488
+ root_linear_velocity = target[:, 1:3, :, :]
1489
+ root_y = target[:, 3:4, :, :]
1490
+ ric_data = target[:, 4:67, :, :] # 4+(3*21)=67
1491
+ rot_data = target[:, 67:193, :, :] # 67+(6*21)=193
1492
+ local_velocity = target[:, 193:259, :, :] # 193+(3*22)=259
1493
+ contact = target[:, 259:, :, :] # 193+(3*22)=259
1494
+
1495
+ calc_vel_from_xyz = ric_data[:, :, :, 1:] - ric_data[:, :, :, :-1]
1496
+ velocity_from_vector = local_velocity[:, 3:, :, 1:] # Slicing out root
1497
+ r_rot_quat, r_pos = motion_process.recover_root_rot_pos(target.permute(0, 2, 3, 1).type(th.FloatTensor))
1498
+ print(f'r_rot_quat: {r_rot_quat.shape}')
1499
+ print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape}')
1500
+ calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 2, 3, 1)
1501
+ calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21, 3)).type(th.FloatTensor)
1502
+ r_rot_quat_adapted = r_rot_quat[..., :-1, None, :].repeat((1,1,1,21,1)).to(calc_vel_from_xyz.device)
1503
+ print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}')
1504
+ print(f'r_rot_quat_adapted: {r_rot_quat_adapted.shape}, {r_rot_quat_adapted.device}')
1505
+
1506
+ calc_vel_from_xyz = motion_process.qrot(r_rot_quat_adapted, calc_vel_from_xyz)
1507
+ calc_vel_from_xyz = calc_vel_from_xyz.reshape((1, 1, -1, 21 * 3))
1508
+ calc_vel_from_xyz = calc_vel_from_xyz.permute(0, 3, 1, 2)
1509
+ print(f'calc_vel_from_xyz: {calc_vel_from_xyz.shape} , {calc_vel_from_xyz.device}')
1510
+
1511
+ import matplotlib.pyplot as plt
1512
+ for i in range(21):
1513
+ plt.plot(np.linalg.norm(calc_vel_from_xyz[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Calc Vel')
1514
+ plt.plot(np.linalg.norm(velocity_from_vector[:,i*3:(i+1)*3,:,:].cpu().detach().numpy().reshape((3, -1)), axis=0), label='Vector Vel')
1515
+ plt.title(f'Joint idx: {i}')
1516
+ plt.legend()
1517
+ plt.show()
1518
+ print(calc_vel_from_xyz.shape)
1519
+ print(velocity_from_vector.shape)
1520
+ diff = calc_vel_from_xyz-velocity_from_vector
1521
+ print(np.linalg.norm(diff.cpu().detach().numpy().reshape((63, -1)), axis=0))
1522
+
1523
+ return 0
1524
+
1525
+
1526
+ def _prior_bpd(self, x_start):
1527
+ """
1528
+ Get the prior KL term for the variational lower-bound, measured in
1529
+ bits-per-dim.
1530
+
1531
+ This term can't be optimized, as it only depends on the encoder.
1532
+
1533
+ :param x_start: the [N x C x ...] tensor of inputs.
1534
+ :return: a batch of [N] KL values (in bits), one per batch element.
1535
+ """
1536
+ batch_size = x_start.shape[0]
1537
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
1538
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
1539
+ kl_prior = normal_kl(
1540
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
1541
+ )
1542
+ return mean_flat(kl_prior) / np.log(2.0)
1543
+
1544
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
1545
+ """
1546
+ Compute the entire variational lower-bound, measured in bits-per-dim,
1547
+ as well as other related quantities.
1548
+
1549
+ :param model: the model to evaluate loss on.
1550
+ :param x_start: the [N x C x ...] tensor of inputs.
1551
+ :param clip_denoised: if True, clip denoised samples.
1552
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
1553
+ pass to the model. This can be used for conditioning.
1554
+
1555
+ :return: a dict containing the following keys:
1556
+ - total_bpd: the total variational lower-bound, per batch element.
1557
+ - prior_bpd: the prior term in the lower-bound.
1558
+ - vb: an [N x T] tensor of terms in the lower-bound.
1559
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
1560
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
1561
+ """
1562
+ device = x_start.device
1563
+ batch_size = x_start.shape[0]
1564
+
1565
+ vb = []
1566
+ xstart_mse = []
1567
+ mse = []
1568
+ for t in list(range(self.num_timesteps))[::-1]:
1569
+ t_batch = th.tensor([t] * batch_size, device=device)
1570
+ noise = th.randn_like(x_start)
1571
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
1572
+ # Calculate VLB term at the current timestep
1573
+ with th.no_grad():
1574
+ out = self._vb_terms_bpd(
1575
+ model,
1576
+ x_start=x_start,
1577
+ x_t=x_t,
1578
+ t=t_batch,
1579
+ clip_denoised=clip_denoised,
1580
+ model_kwargs=model_kwargs,
1581
+ )
1582
+ vb.append(out["output"])
1583
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
1584
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
1585
+ mse.append(mean_flat((eps - noise) ** 2))
1586
+
1587
+ vb = th.stack(vb, dim=1)
1588
+ xstart_mse = th.stack(xstart_mse, dim=1)
1589
+ mse = th.stack(mse, dim=1)
1590
+
1591
+ prior_bpd = self._prior_bpd(x_start)
1592
+ total_bpd = vb.sum(dim=1) + prior_bpd
1593
+ return {
1594
+ "total_bpd": total_bpd,
1595
+ "prior_bpd": prior_bpd,
1596
+ "vb": vb,
1597
+ "xstart_mse": xstart_mse,
1598
+ "mse": mse,
1599
+ }
1600
+
1601
+
1602
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
1603
+ """
1604
+ Extract values from a 1-D numpy array for a batch of indices.
1605
+
1606
+ :param arr: the 1-D numpy array.
1607
+ :param timesteps: a tensor of indices into the array to extract.
1608
+ :param broadcast_shape: a larger shape of K dimensions with the batch
1609
+ dimension equal to the length of timesteps.
1610
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
1611
+ """
1612
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
1613
+ while len(res.shape) < len(broadcast_shape):
1614
+ res = res[..., None]
1615
+ return res.expand(broadcast_shape)
motion_diffusion_model/diffusion/logger.py ADDED
@@ -0,0 +1,495 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Logger copied from OpenAI baselines to avoid extra RL-based dependencies:
3
+ https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/logger.py
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import shutil
9
+ import os.path as osp
10
+ import json
11
+ import time
12
+ import datetime
13
+ import tempfile
14
+ import warnings
15
+ from collections import defaultdict
16
+ from contextlib import contextmanager
17
+
18
+ DEBUG = 10
19
+ INFO = 20
20
+ WARN = 30
21
+ ERROR = 40
22
+
23
+ DISABLED = 50
24
+
25
+
26
+ class KVWriter(object):
27
+ def writekvs(self, kvs):
28
+ raise NotImplementedError
29
+
30
+
31
+ class SeqWriter(object):
32
+ def writeseq(self, seq):
33
+ raise NotImplementedError
34
+
35
+
36
+ class HumanOutputFormat(KVWriter, SeqWriter):
37
+ def __init__(self, filename_or_file):
38
+ if isinstance(filename_or_file, str):
39
+ self.file = open(filename_or_file, "wt")
40
+ self.own_file = True
41
+ else:
42
+ assert hasattr(filename_or_file, "read"), (
43
+ "expected file or str, got %s" % filename_or_file
44
+ )
45
+ self.file = filename_or_file
46
+ self.own_file = False
47
+
48
+ def writekvs(self, kvs):
49
+ # Create strings for printing
50
+ key2str = {}
51
+ for (key, val) in sorted(kvs.items()):
52
+ if hasattr(val, "__float__"):
53
+ valstr = "%-8.3g" % val
54
+ else:
55
+ valstr = str(val)
56
+ key2str[self._truncate(key)] = self._truncate(valstr)
57
+
58
+ # Find max widths
59
+ if len(key2str) == 0:
60
+ print("WARNING: tried to write empty key-value dict")
61
+ return
62
+ else:
63
+ keywidth = max(map(len, key2str.keys()))
64
+ valwidth = max(map(len, key2str.values()))
65
+
66
+ # Write out the data
67
+ dashes = "-" * (keywidth + valwidth + 7)
68
+ lines = [dashes]
69
+ for (key, val) in sorted(key2str.items(), key=lambda kv: kv[0].lower()):
70
+ lines.append(
71
+ "| %s%s | %s%s |"
72
+ % (key, " " * (keywidth - len(key)), val, " " * (valwidth - len(val)))
73
+ )
74
+ lines.append(dashes)
75
+ self.file.write("\n".join(lines) + "\n")
76
+
77
+ # Flush the output to the file
78
+ self.file.flush()
79
+
80
+ def _truncate(self, s):
81
+ maxlen = 30
82
+ return s[: maxlen - 3] + "..." if len(s) > maxlen else s
83
+
84
+ def writeseq(self, seq):
85
+ seq = list(seq)
86
+ for (i, elem) in enumerate(seq):
87
+ self.file.write(elem)
88
+ if i < len(seq) - 1: # add space unless this is the last one
89
+ self.file.write(" ")
90
+ self.file.write("\n")
91
+ self.file.flush()
92
+
93
+ def close(self):
94
+ if self.own_file:
95
+ self.file.close()
96
+
97
+
98
+ class JSONOutputFormat(KVWriter):
99
+ def __init__(self, filename):
100
+ self.file = open(filename, "wt")
101
+
102
+ def writekvs(self, kvs):
103
+ for k, v in sorted(kvs.items()):
104
+ if hasattr(v, "dtype"):
105
+ kvs[k] = float(v)
106
+ self.file.write(json.dumps(kvs) + "\n")
107
+ self.file.flush()
108
+
109
+ def close(self):
110
+ self.file.close()
111
+
112
+
113
+ class CSVOutputFormat(KVWriter):
114
+ def __init__(self, filename):
115
+ self.file = open(filename, "w+t")
116
+ self.keys = []
117
+ self.sep = ","
118
+
119
+ def writekvs(self, kvs):
120
+ # Add our current row to the history
121
+ extra_keys = list(kvs.keys() - self.keys)
122
+ extra_keys.sort()
123
+ if extra_keys:
124
+ self.keys.extend(extra_keys)
125
+ self.file.seek(0)
126
+ lines = self.file.readlines()
127
+ self.file.seek(0)
128
+ for (i, k) in enumerate(self.keys):
129
+ if i > 0:
130
+ self.file.write(",")
131
+ self.file.write(k)
132
+ self.file.write("\n")
133
+ for line in lines[1:]:
134
+ self.file.write(line[:-1])
135
+ self.file.write(self.sep * len(extra_keys))
136
+ self.file.write("\n")
137
+ for (i, k) in enumerate(self.keys):
138
+ if i > 0:
139
+ self.file.write(",")
140
+ v = kvs.get(k)
141
+ if v is not None:
142
+ self.file.write(str(v))
143
+ self.file.write("\n")
144
+ self.file.flush()
145
+
146
+ def close(self):
147
+ self.file.close()
148
+
149
+
150
+ class TensorBoardOutputFormat(KVWriter):
151
+ """
152
+ Dumps key/value pairs into TensorBoard's numeric format.
153
+ """
154
+
155
+ def __init__(self, dir):
156
+ os.makedirs(dir, exist_ok=True)
157
+ self.dir = dir
158
+ self.step = 1
159
+ prefix = "events"
160
+ path = osp.join(osp.abspath(dir), prefix)
161
+ import tensorflow as tf
162
+ from tensorflow.python import pywrap_tensorflow
163
+ from tensorflow.core.util import event_pb2
164
+ from tensorflow.python.util import compat
165
+
166
+ self.tf = tf
167
+ self.event_pb2 = event_pb2
168
+ self.pywrap_tensorflow = pywrap_tensorflow
169
+ self.writer = pywrap_tensorflow.EventsWriter(compat.as_bytes(path))
170
+
171
+ def writekvs(self, kvs):
172
+ def summary_val(k, v):
173
+ kwargs = {"tag": k, "simple_value": float(v)}
174
+ return self.tf.Summary.Value(**kwargs)
175
+
176
+ summary = self.tf.Summary(value=[summary_val(k, v) for k, v in kvs.items()])
177
+ event = self.event_pb2.Event(wall_time=time.time(), summary=summary)
178
+ event.step = (
179
+ self.step
180
+ ) # is there any reason why you'd want to specify the step?
181
+ self.writer.WriteEvent(event)
182
+ self.writer.Flush()
183
+ self.step += 1
184
+
185
+ def close(self):
186
+ if self.writer:
187
+ self.writer.Close()
188
+ self.writer = None
189
+
190
+
191
+ def make_output_format(format, ev_dir, log_suffix=""):
192
+ os.makedirs(ev_dir, exist_ok=True)
193
+ if format == "stdout":
194
+ return HumanOutputFormat(sys.stdout)
195
+ elif format == "log":
196
+ return HumanOutputFormat(osp.join(ev_dir, "log%s.txt" % log_suffix))
197
+ elif format == "json":
198
+ return JSONOutputFormat(osp.join(ev_dir, "progress%s.json" % log_suffix))
199
+ elif format == "csv":
200
+ return CSVOutputFormat(osp.join(ev_dir, "progress%s.csv" % log_suffix))
201
+ elif format == "tensorboard":
202
+ return TensorBoardOutputFormat(osp.join(ev_dir, "tb%s" % log_suffix))
203
+ else:
204
+ raise ValueError("Unknown format specified: %s" % (format,))
205
+
206
+
207
+ # ================================================================
208
+ # API
209
+ # ================================================================
210
+
211
+
212
+ def logkv(key, val):
213
+ """
214
+ Log a value of some diagnostic
215
+ Call this once for each diagnostic quantity, each iteration
216
+ If called many times, last value will be used.
217
+ """
218
+ get_current().logkv(key, val)
219
+
220
+
221
+ def logkv_mean(key, val):
222
+ """
223
+ The same as logkv(), but if called many times, values averaged.
224
+ """
225
+ get_current().logkv_mean(key, val)
226
+
227
+
228
+ def logkvs(d):
229
+ """
230
+ Log a dictionary of key-value pairs
231
+ """
232
+ for (k, v) in d.items():
233
+ logkv(k, v)
234
+
235
+
236
+ def dumpkvs():
237
+ """
238
+ Write all of the diagnostics from the current iteration
239
+ """
240
+ return get_current().dumpkvs()
241
+
242
+
243
+ def getkvs():
244
+ return get_current().name2val
245
+
246
+
247
+ def log(*args, level=INFO):
248
+ """
249
+ Write the sequence of args, with no separators, to the console and output files (if you've configured an output file).
250
+ """
251
+ get_current().log(*args, level=level)
252
+
253
+
254
+ def debug(*args):
255
+ log(*args, level=DEBUG)
256
+
257
+
258
+ def info(*args):
259
+ log(*args, level=INFO)
260
+
261
+
262
+ def warn(*args):
263
+ log(*args, level=WARN)
264
+
265
+
266
+ def error(*args):
267
+ log(*args, level=ERROR)
268
+
269
+
270
+ def set_level(level):
271
+ """
272
+ Set logging threshold on current logger.
273
+ """
274
+ get_current().set_level(level)
275
+
276
+
277
+ def set_comm(comm):
278
+ get_current().set_comm(comm)
279
+
280
+
281
+ def get_dir():
282
+ """
283
+ Get directory that log files are being written to.
284
+ will be None if there is no output directory (i.e., if you didn't call start)
285
+ """
286
+ return get_current().get_dir()
287
+
288
+
289
+ record_tabular = logkv
290
+ dump_tabular = dumpkvs
291
+
292
+
293
+ @contextmanager
294
+ def profile_kv(scopename):
295
+ logkey = "wait_" + scopename
296
+ tstart = time.time()
297
+ try:
298
+ yield
299
+ finally:
300
+ get_current().name2val[logkey] += time.time() - tstart
301
+
302
+
303
+ def profile(n):
304
+ """
305
+ Usage:
306
+ @profile("my_func")
307
+ def my_func(): code
308
+ """
309
+
310
+ def decorator_with_name(func):
311
+ def func_wrapper(*args, **kwargs):
312
+ with profile_kv(n):
313
+ return func(*args, **kwargs)
314
+
315
+ return func_wrapper
316
+
317
+ return decorator_with_name
318
+
319
+
320
+ # ================================================================
321
+ # Backend
322
+ # ================================================================
323
+
324
+
325
+ def get_current():
326
+ if Logger.CURRENT is None:
327
+ _configure_default_logger()
328
+
329
+ return Logger.CURRENT
330
+
331
+
332
+ class Logger(object):
333
+ DEFAULT = None # A logger with no output files. (See right below class definition)
334
+ # So that you can still log to the terminal without setting up any output files
335
+ CURRENT = None # Current logger being used by the free functions above
336
+
337
+ def __init__(self, dir, output_formats, comm=None):
338
+ self.name2val = defaultdict(float) # values this iteration
339
+ self.name2cnt = defaultdict(int)
340
+ self.level = INFO
341
+ self.dir = dir
342
+ self.output_formats = output_formats
343
+ self.comm = comm
344
+
345
+ # Logging API, forwarded
346
+ # ----------------------------------------
347
+ def logkv(self, key, val):
348
+ self.name2val[key] = val
349
+
350
+ def logkv_mean(self, key, val):
351
+ oldval, cnt = self.name2val[key], self.name2cnt[key]
352
+ self.name2val[key] = oldval * cnt / (cnt + 1) + val / (cnt + 1)
353
+ self.name2cnt[key] = cnt + 1
354
+
355
+ def dumpkvs(self):
356
+ if self.comm is None:
357
+ d = self.name2val
358
+ else:
359
+ d = mpi_weighted_mean(
360
+ self.comm,
361
+ {
362
+ name: (val, self.name2cnt.get(name, 1))
363
+ for (name, val) in self.name2val.items()
364
+ },
365
+ )
366
+ if self.comm.rank != 0:
367
+ d["dummy"] = 1 # so we don't get a warning about empty dict
368
+ out = d.copy() # Return the dict for unit testing purposes
369
+ for fmt in self.output_formats:
370
+ if isinstance(fmt, KVWriter):
371
+ fmt.writekvs(d)
372
+ self.name2val.clear()
373
+ self.name2cnt.clear()
374
+ return out
375
+
376
+ def log(self, *args, level=INFO):
377
+ if self.level <= level:
378
+ self._do_log(args)
379
+
380
+ # Configuration
381
+ # ----------------------------------------
382
+ def set_level(self, level):
383
+ self.level = level
384
+
385
+ def set_comm(self, comm):
386
+ self.comm = comm
387
+
388
+ def get_dir(self):
389
+ return self.dir
390
+
391
+ def close(self):
392
+ for fmt in self.output_formats:
393
+ fmt.close()
394
+
395
+ # Misc
396
+ # ----------------------------------------
397
+ def _do_log(self, args):
398
+ for fmt in self.output_formats:
399
+ if isinstance(fmt, SeqWriter):
400
+ fmt.writeseq(map(str, args))
401
+
402
+
403
+ def get_rank_without_mpi_import():
404
+ # check environment variables here instead of importing mpi4py
405
+ # to avoid calling MPI_Init() when this module is imported
406
+ for varname in ["PMI_RANK", "OMPI_COMM_WORLD_RANK"]:
407
+ if varname in os.environ:
408
+ return int(os.environ[varname])
409
+ return 0
410
+
411
+
412
+ def mpi_weighted_mean(comm, local_name2valcount):
413
+ """
414
+ Copied from: https://github.com/openai/baselines/blob/ea25b9e8b234e6ee1bca43083f8f3cf974143998/baselines/common/mpi_util.py#L110
415
+ Perform a weighted average over dicts that are each on a different node
416
+ Input: local_name2valcount: dict mapping key -> (value, count)
417
+ Returns: key -> mean
418
+ """
419
+ all_name2valcount = comm.gather(local_name2valcount)
420
+ if comm.rank == 0:
421
+ name2sum = defaultdict(float)
422
+ name2count = defaultdict(float)
423
+ for n2vc in all_name2valcount:
424
+ for (name, (val, count)) in n2vc.items():
425
+ try:
426
+ val = float(val)
427
+ except ValueError:
428
+ if comm.rank == 0:
429
+ warnings.warn(
430
+ "WARNING: tried to compute mean on non-float {}={}".format(
431
+ name, val
432
+ )
433
+ )
434
+ else:
435
+ name2sum[name] += val * count
436
+ name2count[name] += count
437
+ return {name: name2sum[name] / name2count[name] for name in name2sum}
438
+ else:
439
+ return {}
440
+
441
+
442
+ def configure(dir=None, format_strs=None, comm=None, log_suffix=""):
443
+ """
444
+ If comm is provided, average all numerical stats across that comm
445
+ """
446
+ if dir is None:
447
+ dir = os.getenv("OPENAI_LOGDIR")
448
+ if dir is None:
449
+ dir = osp.join(
450
+ tempfile.gettempdir(),
451
+ datetime.datetime.now().strftime("openai-%Y-%m-%d-%H-%M-%S-%f"),
452
+ )
453
+ assert isinstance(dir, str)
454
+ dir = os.path.expanduser(dir)
455
+ os.makedirs(os.path.expanduser(dir), exist_ok=True)
456
+
457
+ rank = get_rank_without_mpi_import()
458
+ if rank > 0:
459
+ log_suffix = log_suffix + "-rank%03i" % rank
460
+
461
+ if format_strs is None:
462
+ if rank == 0:
463
+ format_strs = os.getenv("OPENAI_LOG_FORMAT", "stdout,log,csv").split(",")
464
+ else:
465
+ format_strs = os.getenv("OPENAI_LOG_FORMAT_MPI", "log").split(",")
466
+ format_strs = filter(None, format_strs)
467
+ output_formats = [make_output_format(f, dir, log_suffix) for f in format_strs]
468
+
469
+ Logger.CURRENT = Logger(dir=dir, output_formats=output_formats, comm=comm)
470
+ if output_formats:
471
+ log("Logging to %s" % dir)
472
+
473
+
474
+ def _configure_default_logger():
475
+ configure()
476
+ Logger.DEFAULT = Logger.CURRENT
477
+
478
+
479
+ def reset():
480
+ if Logger.CURRENT is not Logger.DEFAULT:
481
+ Logger.CURRENT.close()
482
+ Logger.CURRENT = Logger.DEFAULT
483
+ log("Reset logger")
484
+
485
+
486
+ @contextmanager
487
+ def scoped_configure(dir=None, format_strs=None, comm=None):
488
+ prevlogger = Logger.CURRENT
489
+ configure(dir=dir, format_strs=format_strs, comm=comm)
490
+ try:
491
+ yield
492
+ finally:
493
+ Logger.CURRENT.close()
494
+ Logger.CURRENT = prevlogger
495
+
motion_diffusion_model/diffusion/losses.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ """
3
+ Helpers for various likelihood-based losses. These are ported from the original
4
+ Ho et al. diffusion models codebase:
5
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/utils.py
6
+ """
7
+
8
+ import numpy as np
9
+ import torch as th
10
+
11
+
12
+ def normal_kl(mean1, logvar1, mean2, logvar2):
13
+ """
14
+ Compute the KL divergence between two gaussians.
15
+
16
+ Shapes are automatically broadcasted, so batches can be compared to
17
+ scalars, among other use cases.
18
+ """
19
+ tensor = None
20
+ for obj in (mean1, logvar1, mean2, logvar2):
21
+ if isinstance(obj, th.Tensor):
22
+ tensor = obj
23
+ break
24
+ assert tensor is not None, "at least one argument must be a Tensor"
25
+
26
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
27
+ # Tensors, but it does not work for th.exp().
28
+ logvar1, logvar2 = [
29
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
30
+ for x in (logvar1, logvar2)
31
+ ]
32
+
33
+ return 0.5 * (
34
+ -1.0
35
+ + logvar2
36
+ - logvar1
37
+ + th.exp(logvar1 - logvar2)
38
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
39
+ )
40
+
41
+
42
+ def approx_standard_normal_cdf(x):
43
+ """
44
+ A fast approximation of the cumulative distribution function of the
45
+ standard normal.
46
+ """
47
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
48
+
49
+
50
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
51
+ """
52
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
53
+ given image.
54
+
55
+ :param x: the target images. It is assumed that this was uint8 values,
56
+ rescaled to the range [-1, 1].
57
+ :param means: the Gaussian mean Tensor.
58
+ :param log_scales: the Gaussian log stddev Tensor.
59
+ :return: a tensor like x of log probabilities (in nats).
60
+ """
61
+ assert x.shape == means.shape == log_scales.shape
62
+ centered_x = x - means
63
+ inv_stdv = th.exp(-log_scales)
64
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
65
+ cdf_plus = approx_standard_normal_cdf(plus_in)
66
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
67
+ cdf_min = approx_standard_normal_cdf(min_in)
68
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
69
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
70
+ cdf_delta = cdf_plus - cdf_min
71
+ log_probs = th.where(
72
+ x < -0.999,
73
+ log_cdf_plus,
74
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
75
+ )
76
+ assert log_probs.shape == x.shape
77
+ return log_probs
motion_diffusion_model/diffusion/nn.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ """
3
+ Various utilities for neural networks.
4
+ """
5
+
6
+ import math
7
+
8
+ import torch as th
9
+ import torch.nn as nn
10
+
11
+
12
+ # PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
13
+ class SiLU(nn.Module):
14
+ def forward(self, x):
15
+ return x * th.sigmoid(x)
16
+
17
+
18
+ class GroupNorm32(nn.GroupNorm):
19
+ def forward(self, x):
20
+ return super().forward(x.float()).type(x.dtype)
21
+
22
+
23
+ def conv_nd(dims, *args, **kwargs):
24
+ """
25
+ Create a 1D, 2D, or 3D convolution module.
26
+ """
27
+ if dims == 1:
28
+ return nn.Conv1d(*args, **kwargs)
29
+ elif dims == 2:
30
+ return nn.Conv2d(*args, **kwargs)
31
+ elif dims == 3:
32
+ return nn.Conv3d(*args, **kwargs)
33
+ raise ValueError(f"unsupported dimensions: {dims}")
34
+
35
+
36
+ def linear(*args, **kwargs):
37
+ """
38
+ Create a linear module.
39
+ """
40
+ return nn.Linear(*args, **kwargs)
41
+
42
+
43
+ def avg_pool_nd(dims, *args, **kwargs):
44
+ """
45
+ Create a 1D, 2D, or 3D average pooling module.
46
+ """
47
+ if dims == 1:
48
+ return nn.AvgPool1d(*args, **kwargs)
49
+ elif dims == 2:
50
+ return nn.AvgPool2d(*args, **kwargs)
51
+ elif dims == 3:
52
+ return nn.AvgPool3d(*args, **kwargs)
53
+ raise ValueError(f"unsupported dimensions: {dims}")
54
+
55
+
56
+ def update_ema(target_params, source_params, rate=0.99):
57
+ """
58
+ Update target parameters to be closer to those of source parameters using
59
+ an exponential moving average.
60
+
61
+ :param target_params: the target parameter sequence.
62
+ :param source_params: the source parameter sequence.
63
+ :param rate: the EMA rate (closer to 1 means slower).
64
+ """
65
+ for targ, src in zip(target_params, source_params):
66
+ targ.detach().mul_(rate).add_(src, alpha=1 - rate)
67
+
68
+
69
+ def zero_module(module):
70
+ """
71
+ Zero out the parameters of a module and return it.
72
+ """
73
+ for p in module.parameters():
74
+ p.detach().zero_()
75
+ return module
76
+
77
+
78
+ def scale_module(module, scale):
79
+ """
80
+ Scale the parameters of a module and return it.
81
+ """
82
+ for p in module.parameters():
83
+ p.detach().mul_(scale)
84
+ return module
85
+
86
+
87
+ def mean_flat(tensor):
88
+ """
89
+ Take the mean over all non-batch dimensions.
90
+ """
91
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
92
+
93
+ def sum_flat(tensor):
94
+ """
95
+ Take the sum over all non-batch dimensions.
96
+ """
97
+ return tensor.sum(dim=list(range(1, len(tensor.shape))))
98
+
99
+
100
+ def normalization(channels):
101
+ """
102
+ Make a standard normalization layer.
103
+
104
+ :param channels: number of input channels.
105
+ :return: an nn.Module for normalization.
106
+ """
107
+ return GroupNorm32(32, channels)
108
+
109
+
110
+ def timestep_embedding(timesteps, dim, max_period=10000):
111
+ """
112
+ Create sinusoidal timestep embeddings.
113
+
114
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
115
+ These may be fractional.
116
+ :param dim: the dimension of the output.
117
+ :param max_period: controls the minimum frequency of the embeddings.
118
+ :return: an [N x dim] Tensor of positional embeddings.
119
+ """
120
+ half = dim // 2
121
+ freqs = th.exp(
122
+ -math.log(max_period) * th.arange(start=0, end=half, dtype=th.float32) / half
123
+ ).to(device=timesteps.device)
124
+ args = timesteps[:, None].float() * freqs[None]
125
+ embedding = th.cat([th.cos(args), th.sin(args)], dim=-1)
126
+ if dim % 2:
127
+ embedding = th.cat([embedding, th.zeros_like(embedding[:, :1])], dim=-1)
128
+ return embedding
129
+
130
+
131
+ def checkpoint(func, inputs, params, flag):
132
+ """
133
+ Evaluate a function without caching intermediate activations, allowing for
134
+ reduced memory at the expense of extra compute in the backward pass.
135
+ :param func: the function to evaluate.
136
+ :param inputs: the argument sequence to pass to `func`.
137
+ :param params: a sequence of parameters `func` depends on but does not
138
+ explicitly take as arguments.
139
+ :param flag: if False, disable gradient checkpointing.
140
+ """
141
+ if flag:
142
+ args = tuple(inputs) + tuple(params)
143
+ return CheckpointFunction.apply(func, len(inputs), *args)
144
+ else:
145
+ return func(*inputs)
146
+
147
+
148
+ class CheckpointFunction(th.autograd.Function):
149
+ @staticmethod
150
+ @th.cuda.amp.custom_fwd
151
+ def forward(ctx, run_function, length, *args):
152
+ ctx.run_function = run_function
153
+ ctx.input_length = length
154
+ ctx.save_for_backward(*args)
155
+ with th.no_grad():
156
+ output_tensors = ctx.run_function(*args[:length])
157
+ return output_tensors
158
+
159
+ @staticmethod
160
+ @th.cuda.amp.custom_bwd
161
+ def backward(ctx, *output_grads):
162
+ args = list(ctx.saved_tensors)
163
+
164
+ # Filter for inputs that require grad. If none, exit early.
165
+ input_indices = [i for (i, x) in enumerate(args) if x.requires_grad]
166
+ if not input_indices:
167
+ return (None, None) + tuple(None for _ in args)
168
+
169
+ with th.enable_grad():
170
+ for i in input_indices:
171
+ if i < ctx.input_length:
172
+ # Not sure why the OAI code does this little
173
+ # dance. It might not be necessary.
174
+ args[i] = args[i].detach().requires_grad_()
175
+ args[i] = args[i].view_as(args[i])
176
+ output_tensors = ctx.run_function(*args[:ctx.input_length])
177
+
178
+ if isinstance(output_tensors, th.Tensor):
179
+ output_tensors = [output_tensors]
180
+
181
+ # Filter for outputs that require grad. If none, exit early.
182
+ out_and_grads = [(o, g) for (o, g) in zip(output_tensors, output_grads) if o.requires_grad]
183
+ if not out_and_grads:
184
+ return (None, None) + tuple(None for _ in args)
185
+
186
+ # Compute gradients on the filtered tensors.
187
+ computed_grads = th.autograd.grad(
188
+ [o for (o, g) in out_and_grads],
189
+ [args[i] for i in input_indices],
190
+ [g for (o, g) in out_and_grads]
191
+ )
192
+
193
+ # Reassemble the complete gradient tuple.
194
+ input_grads = [None for _ in args]
195
+ for (i, g) in zip(input_indices, computed_grads):
196
+ input_grads[i] = g
197
+ return (None, None) + tuple(input_grads)
motion_diffusion_model/diffusion/resample.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+ import numpy as np
4
+ import torch as th
5
+ import torch.distributed as dist
6
+
7
+
8
+ def create_named_schedule_sampler(name, diffusion):
9
+ """
10
+ Create a ScheduleSampler from a library of pre-defined samplers.
11
+
12
+ :param name: the name of the sampler.
13
+ :param diffusion: the diffusion object to sample for.
14
+ """
15
+ if name == "uniform":
16
+ return UniformSampler(diffusion)
17
+ elif name == "loss-second-moment":
18
+ return LossSecondMomentResampler(diffusion)
19
+ else:
20
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
21
+
22
+
23
+ class ScheduleSampler(ABC):
24
+ """
25
+ A distribution over timesteps in the diffusion process, intended to reduce
26
+ variance of the objective.
27
+
28
+ By default, samplers perform unbiased importance sampling, in which the
29
+ objective's mean is unchanged.
30
+ However, subclasses may override sample() to change how the resampled
31
+ terms are reweighted, allowing for actual changes in the objective.
32
+ """
33
+
34
+ @abstractmethod
35
+ def weights(self):
36
+ """
37
+ Get a numpy array of weights, one per diffusion step.
38
+
39
+ The weights needn't be normalized, but must be positive.
40
+ """
41
+
42
+ def sample(self, batch_size, device):
43
+ """
44
+ Importance-sample timesteps for a batch.
45
+
46
+ :param batch_size: the number of timesteps.
47
+ :param device: the torch device to save to.
48
+ :return: a tuple (timesteps, weights):
49
+ - timesteps: a tensor of timestep indices.
50
+ - weights: a tensor of weights to scale the resulting losses.
51
+ """
52
+ w = self.weights()
53
+ p = w / np.sum(w)
54
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
55
+ indices = th.from_numpy(indices_np).long().to(device)
56
+ weights_np = 1 / (len(p) * p[indices_np])
57
+ weights = th.from_numpy(weights_np).float().to(device)
58
+ return indices, weights
59
+
60
+
61
+ class UniformSampler(ScheduleSampler):
62
+ def __init__(self, diffusion):
63
+ self.diffusion = diffusion
64
+ self._weights = np.ones([diffusion.num_timesteps])
65
+
66
+ def weights(self):
67
+ return self._weights
68
+
69
+
70
+ class LossAwareSampler(ScheduleSampler):
71
+ def update_with_local_losses(self, local_ts, local_losses):
72
+ """
73
+ Update the reweighting using losses from a model.
74
+
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+
80
+ :param local_ts: an integer Tensor of timesteps.
81
+ :param local_losses: a 1D Tensor of losses.
82
+ """
83
+ batch_sizes = [
84
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
85
+ for _ in range(dist.get_world_size())
86
+ ]
87
+ dist.all_gather(
88
+ batch_sizes,
89
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
90
+ )
91
+
92
+ # Pad all_gather batches to be the maximum batch size.
93
+ batch_sizes = [x.item() for x in batch_sizes]
94
+ max_bs = max(batch_sizes)
95
+
96
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
97
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
98
+ dist.all_gather(timestep_batches, local_ts)
99
+ dist.all_gather(loss_batches, local_losses)
100
+ timesteps = [
101
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
102
+ ]
103
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
104
+ self.update_with_all_losses(timesteps, losses)
105
+
106
+ @abstractmethod
107
+ def update_with_all_losses(self, ts, losses):
108
+ """
109
+ Update the reweighting using losses from a model.
110
+
111
+ Sub-classes should override this method to update the reweighting
112
+ using losses from the model.
113
+
114
+ This method directly updates the reweighting without synchronizing
115
+ between workers. It is called by update_with_local_losses from all
116
+ ranks with identical arguments. Thus, it should have deterministic
117
+ behavior to maintain state across workers.
118
+
119
+ :param ts: a list of int timesteps.
120
+ :param losses: a list of float losses, one per timestep.
121
+ """
122
+
123
+
124
+ class LossSecondMomentResampler(LossAwareSampler):
125
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
126
+ self.diffusion = diffusion
127
+ self.history_per_term = history_per_term
128
+ self.uniform_prob = uniform_prob
129
+ self._loss_history = np.zeros(
130
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
131
+ )
132
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
133
+
134
+ def weights(self):
135
+ if not self._warmed_up():
136
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
137
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
138
+ weights /= np.sum(weights)
139
+ weights *= 1 - self.uniform_prob
140
+ weights += self.uniform_prob / len(weights)
141
+ return weights
142
+
143
+ def update_with_all_losses(self, ts, losses):
144
+ for t, loss in zip(ts, losses):
145
+ if self._loss_counts[t] == self.history_per_term:
146
+ # Shift out the oldest loss term.
147
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
148
+ self._loss_history[t, -1] = loss
149
+ else:
150
+ self._loss_history[t, self._loss_counts[t]] = loss
151
+ self._loss_counts[t] += 1
152
+
153
+ def _warmed_up(self):
154
+ return (self._loss_counts == self.history_per_term).all()
motion_diffusion_model/diffusion/respace.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ import numpy as np
3
+ import torch as th
4
+
5
+ from .gaussian_diffusion import GaussianDiffusion
6
+ from utils.misc import wrapped_getattr
7
+
8
+
9
+ def space_timesteps(num_timesteps, section_counts):
10
+ """
11
+ Create a list of timesteps to use from an original diffusion process,
12
+ given the number of timesteps we want to take from equally-sized portions
13
+ of the original process.
14
+
15
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
16
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
17
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
18
+
19
+ If the stride is a string starting with "ddim", then the fixed striding
20
+ from the DDIM paper is used, and only one section is allowed.
21
+
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+
69
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
70
+ original diffusion process to retain.
71
+ :param kwargs: the kwargs to create the base diffusion process.
72
+ """
73
+
74
+ def __init__(self, use_timesteps, **kwargs):
75
+ self.use_timesteps = set(use_timesteps)
76
+ self.timestep_map = []
77
+ self.original_num_steps = len(kwargs["betas"])
78
+
79
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
80
+ last_alpha_cumprod = 1.0
81
+ new_betas = []
82
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
83
+ if i in self.use_timesteps:
84
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
85
+ last_alpha_cumprod = alpha_cumprod
86
+ self.timestep_map.append(i)
87
+ kwargs["betas"] = np.array(new_betas)
88
+ super().__init__(**kwargs)
89
+
90
+ def p_mean_variance(
91
+ self, model, *args, **kwargs
92
+ ): # pylint: disable=signature-differs
93
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
94
+
95
+ def training_losses(
96
+ self, model, *args, **kwargs
97
+ ): # pylint: disable=signature-differs
98
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
99
+
100
+ def condition_mean(self, cond_fn, *args, **kwargs):
101
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
102
+
103
+ def condition_score(self, cond_fn, *args, **kwargs):
104
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
105
+
106
+ def _wrap_model(self, model):
107
+ if isinstance(model, _WrappedModel):
108
+ return model
109
+ return _WrappedModel(
110
+ model, self.timestep_map, self.rescale_timesteps, self.original_num_steps
111
+ )
112
+
113
+ def _scale_timesteps(self, t):
114
+ # Scaling is done by the wrapped model.
115
+ return t
116
+
117
+
118
+ class _WrappedModel:
119
+ def __init__(self, model, timestep_map, rescale_timesteps, original_num_steps):
120
+ self.model = model
121
+ self.timestep_map = timestep_map
122
+ self.rescale_timesteps = rescale_timesteps
123
+ self.original_num_steps = original_num_steps
124
+
125
+ def __call__(self, x, ts, **kwargs):
126
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
127
+ new_ts = map_tensor[ts]
128
+ if self.rescale_timesteps:
129
+ new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
130
+ return self.model(x, new_ts, **kwargs)
131
+
132
+ def __getattr__(self, name, default=None):
133
+ # this method is reached only if name is not in self.__dict__.
134
+ return wrapped_getattr(self, name, default)
motion_diffusion_model/model/BERT/BERT_encoder.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import os
3
+
4
+ def load_bert(model_path):
5
+ bert = BERT(model_path)
6
+ bert.eval()
7
+ bert.text_model.training = False
8
+ for p in bert.parameters():
9
+ p.requires_grad = False
10
+ return bert
11
+
12
+ class BERT(nn.Module):
13
+ def __init__(self, modelpath: str):
14
+ super().__init__()
15
+
16
+ from transformers import AutoTokenizer, AutoModel
17
+ from transformers import logging
18
+ logging.set_verbosity_error()
19
+ # Tokenizer
20
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
21
+ # Tokenizer
22
+ self.tokenizer = AutoTokenizer.from_pretrained(modelpath)
23
+ # Text model
24
+ self.text_model = AutoModel.from_pretrained(modelpath)
25
+
26
+
27
+ def forward(self, texts):
28
+ encoded_inputs = self.tokenizer(texts, return_tensors="pt", padding=True)
29
+ output = self.text_model(**encoded_inputs.to(self.text_model.device)).last_hidden_state
30
+ mask = encoded_inputs.attention_mask.to(dtype=bool)
31
+ # output = output * mask.unsqueeze(-1)
32
+ return output, mask
motion_diffusion_model/model/cfg_sampler.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from copy import deepcopy
5
+
6
+ # A wrapper model for Classifier-free guidance **SAMPLING** only
7
+ # https://arxiv.org/abs/2207.12598
8
+ class ClassifierFreeSampleModel(nn.Module):
9
+
10
+ def __init__(self, model):
11
+ super().__init__()
12
+ self.model = model # model is the actual model to run
13
+
14
+ assert self.model.cond_mask_prob > 0, 'Cannot run a guided diffusion on a model that has not been trained with no conditions'
15
+
16
+ # pointers to inner model
17
+ self.rot2xyz = self.model.rot2xyz
18
+ self.translation = self.model.translation
19
+ self.njoints = self.model.njoints
20
+ self.nfeats = self.model.nfeats
21
+ self.data_rep = self.model.data_rep
22
+ self.cond_mode = self.model.cond_mode
23
+ self.encode_text = self.model.encode_text
24
+
25
+ def forward(self, x, timesteps, y=None):
26
+ cond_mode = self.model.cond_mode
27
+ assert cond_mode in ['text', 'action']
28
+ y_uncond = deepcopy(y)
29
+ y_uncond['uncond'] = True
30
+ out = self.model(x, timesteps, y)
31
+ out_uncond = self.model(x, timesteps, y_uncond)
32
+ return out_uncond + (y['scale'].view(-1, 1, 1, 1) * (out - out_uncond))
33
+
motion_diffusion_model/model/mdm.py ADDED
@@ -0,0 +1,480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import clip
6
+ from model.rotation2xyz import Rotation2xyz
7
+ from model.BERT.BERT_encoder import load_bert
8
+ from utils.misc import WeightedSum
9
+
10
+
11
+ class MDM(nn.Module):
12
+ def __init__(self, modeltype, njoints, nfeats, num_actions, translation, pose_rep, glob, glob_rot,
13
+ latent_dim=256, ff_size=1024, num_layers=8, num_heads=4, dropout=0.1,
14
+ ablation=None, activation="gelu", legacy=False, data_rep='rot6d', dataset='amass', clip_dim=512,
15
+ arch='trans_enc', emb_trans_dec=False, clip_version=None, **kargs):
16
+ super().__init__()
17
+
18
+ self.legacy = legacy
19
+ self.modeltype = modeltype
20
+ self.njoints = njoints
21
+ self.nfeats = nfeats
22
+ self.num_actions = num_actions
23
+ self.data_rep = data_rep
24
+ self.dataset = dataset
25
+
26
+ self.pose_rep = pose_rep
27
+ self.glob = glob
28
+ self.glob_rot = glob_rot
29
+ self.translation = translation
30
+
31
+ self.latent_dim = latent_dim
32
+
33
+ self.ff_size = ff_size
34
+ self.num_layers = num_layers
35
+ self.num_heads = num_heads
36
+ self.dropout = dropout
37
+
38
+ self.ablation = ablation
39
+ self.activation = activation
40
+ self.clip_dim = clip_dim
41
+ self.action_emb = kargs.get('action_emb', None)
42
+ self.input_feats = self.njoints * self.nfeats
43
+
44
+ self.normalize_output = kargs.get('normalize_encoder_output', False)
45
+
46
+ self.cond_mode = kargs.get('cond_mode', 'no_cond')
47
+ self.cond_mask_prob = kargs.get('cond_mask_prob', 0.)
48
+ self.mask_frames = kargs.get('mask_frames', False)
49
+ self.arch = arch
50
+ self.gru_emb_dim = self.latent_dim if self.arch == 'gru' else 0
51
+ self.input_process = InputProcess(self.data_rep, self.input_feats+self.gru_emb_dim, self.latent_dim)
52
+
53
+ self.emb_policy = kargs.get('emb_policy', 'add')
54
+
55
+ self.sequence_pos_encoder = PositionalEncoding(self.latent_dim, self.dropout, max_len=kargs.get('pos_embed_max_len', 5000))
56
+ self.emb_trans_dec = emb_trans_dec
57
+
58
+ self.pred_len = kargs.get('pred_len', 0)
59
+ self.context_len = kargs.get('context_len', 0)
60
+ self.total_len = self.pred_len + self.context_len
61
+ self.is_prefix_comp = self.total_len > 0
62
+ self.all_goal_joint_names = kargs.get('all_goal_joint_names', [])
63
+
64
+ self.multi_target_cond = kargs.get('multi_target_cond', False)
65
+ self.multi_encoder_type = kargs.get('multi_encoder_type', 'multi')
66
+ self.target_enc_layers = kargs.get('target_enc_layers', 1)
67
+ if self.multi_target_cond:
68
+ if self.multi_encoder_type == 'multi':
69
+ self.embed_target_cond = EmbedTargetLocMulti(self.all_goal_joint_names, self.latent_dim)
70
+ elif self.multi_encoder_type == 'single':
71
+ self.embed_target_cond = EmbedTargetLocSingle(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
72
+ elif self.multi_encoder_type == 'split':
73
+ self.embed_target_cond = EmbedTargetLocSplit(self.all_goal_joint_names, self.latent_dim, self.target_enc_layers)
74
+
75
+ if self.arch == 'trans_enc':
76
+ print("TRANS_ENC init")
77
+ seqTransEncoderLayer = nn.TransformerEncoderLayer(d_model=self.latent_dim,
78
+ nhead=self.num_heads,
79
+ dim_feedforward=self.ff_size,
80
+ dropout=self.dropout,
81
+ activation=self.activation)
82
+
83
+ self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer,
84
+ num_layers=self.num_layers)
85
+ elif self.arch == 'trans_dec':
86
+ print("TRANS_DEC init")
87
+ seqTransDecoderLayer = nn.TransformerDecoderLayer(d_model=self.latent_dim,
88
+ nhead=self.num_heads,
89
+ dim_feedforward=self.ff_size,
90
+ dropout=self.dropout,
91
+ activation=activation)
92
+ self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer,
93
+ num_layers=self.num_layers)
94
+ elif self.arch == 'gru':
95
+ print("GRU init")
96
+ self.gru = nn.GRU(self.latent_dim, self.latent_dim, num_layers=self.num_layers, batch_first=True)
97
+ else:
98
+ raise ValueError('Please choose correct architecture [trans_enc, trans_dec, gru]')
99
+
100
+ self.embed_timestep = TimestepEmbedder(self.latent_dim, self.sequence_pos_encoder)
101
+
102
+ if self.cond_mode != 'no_cond':
103
+ if 'text' in self.cond_mode:
104
+ # We support CLIP encoder and DistilBERT
105
+ print('EMBED TEXT')
106
+
107
+ self.text_encoder_type = kargs.get('text_encoder_type', 'clip')
108
+
109
+ if self.text_encoder_type == "clip":
110
+ print('Loading CLIP...')
111
+ self.clip_version = clip_version
112
+ self.clip_model = self.load_and_freeze_clip(clip_version)
113
+ self.encode_text = self.clip_encode_text
114
+ elif self.text_encoder_type == 'bert':
115
+ assert self.arch == 'trans_dec'
116
+ # assert self.emb_trans_dec == False # passing just the time embed so it's fine
117
+ print("Loading BERT...")
118
+ # bert_model_path = 'model/BERT/distilbert-base-uncased'
119
+ bert_model_path = 'distilbert/distilbert-base-uncased'
120
+ self.clip_model = load_bert(bert_model_path) # Sorry for that, the naming is for backward compatibility
121
+ self.encode_text = self.bert_encode_text
122
+ self.clip_dim = 768
123
+ else:
124
+ raise ValueError('We only support [CLIP, BERT] text encoders')
125
+
126
+ self.embed_text = nn.Linear(self.clip_dim, self.latent_dim)
127
+
128
+ if 'action' in self.cond_mode:
129
+ self.embed_action = EmbedAction(self.num_actions, self.latent_dim)
130
+ print('EMBED ACTION')
131
+
132
+ self.output_process = OutputProcess(self.data_rep, self.input_feats, self.latent_dim, self.njoints,
133
+ self.nfeats)
134
+
135
+ self.rot2xyz = Rotation2xyz(device='cpu', dataset=self.dataset)
136
+
137
+ def parameters_wo_clip(self):
138
+ return [p for name, p in self.named_parameters() if not name.startswith('clip_model.')]
139
+
140
+ def load_and_freeze_clip(self, clip_version):
141
+ clip_model, clip_preprocess = clip.load(clip_version, device='cpu',
142
+ jit=False) # Must set jit=False for training
143
+ clip.model.convert_weights(
144
+ clip_model) # Actually this line is unnecessary since clip by default already on float16
145
+
146
+ # Freeze CLIP weights
147
+ clip_model.eval()
148
+ for p in clip_model.parameters():
149
+ p.requires_grad = False
150
+
151
+ return clip_model
152
+
153
+ def mask_cond(self, cond, force_mask=False):
154
+ bs = cond.shape[-2]
155
+ if force_mask:
156
+ return torch.zeros_like(cond)
157
+ elif self.training and self.cond_mask_prob > 0.:
158
+ mask = torch.bernoulli(torch.ones(bs, device=cond.device) * self.cond_mask_prob).view(1, bs, 1) # 1-> use null_cond, 0-> use real cond
159
+ return cond * (1. - mask)
160
+ else:
161
+ return cond
162
+
163
+ def clip_encode_text(self, raw_text):
164
+ # raw_text - list (batch_size length) of strings with input text prompts
165
+ device = next(self.parameters()).device
166
+ max_text_len = 20 if self.dataset in ['humanml', 'kit'] else None # Specific hardcoding for humanml dataset
167
+ if max_text_len is not None:
168
+ default_context_length = 77
169
+ context_length = max_text_len + 2 # start_token + 20 + end_token
170
+ assert context_length < default_context_length
171
+ texts = clip.tokenize(raw_text, context_length=context_length, truncate=True).to(device) # [bs, context_length] # if n_tokens > context_length -> will truncate
172
+ # print('texts', texts.shape)
173
+ zero_pad = torch.zeros([texts.shape[0], default_context_length-context_length], dtype=texts.dtype, device=texts.device)
174
+ texts = torch.cat([texts, zero_pad], dim=1)
175
+ # print('texts after pad', texts.shape, texts)
176
+ else:
177
+ texts = clip.tokenize(raw_text, truncate=True).to(device) # [bs, context_length] # if n_tokens > 77 -> will truncate
178
+ return self.clip_model.encode_text(texts).float().unsqueeze(0)
179
+
180
+ def bert_encode_text(self, raw_text):
181
+ # enc_text = self.clip_model(raw_text)
182
+ # enc_text = enc_text.permute(1, 0, 2)
183
+ # return enc_text
184
+ enc_text, mask = self.clip_model(raw_text) # self.clip_model.get_last_hidden_state(raw_text, return_mask=True) # mask: False means no token there
185
+ enc_text = enc_text.permute(1, 0, 2)
186
+ mask = ~mask # mask: True means no token there, we invert since the meaning of mask for transformer is inverted https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
187
+ return enc_text, mask
188
+
189
+ def forward(self, x, timesteps, y=None):
190
+ """
191
+ x: [batch_size, njoints, nfeats, max_frames], denoted x_t in the paper
192
+ timesteps: [batch_size] (int)
193
+ """
194
+ bs, njoints, nfeats, nframes = x.shape
195
+ time_emb = self.embed_timestep(timesteps) # [1, bs, d]
196
+
197
+ if 'target_cond' in y.keys():
198
+ # NOTE: We don't use CFG for joints - but we do wat to support uncond sampling for generation and eval!
199
+ time_emb += self.mask_cond(self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None], force_mask=y.get('target_uncond', False)) # For uncond support and CFG
200
+ # time_emb += self.embed_target_cond(y['target_cond'], y['target_joint_names'], y['is_heading'])[None]
201
+
202
+ # Build input for prefix completion
203
+ if self.is_prefix_comp:
204
+ x = torch.cat([y['prefix'], x], dim=-1)
205
+ y['mask'] = torch.cat([torch.ones([bs, 1, 1, self.context_len], dtype=y['mask'].dtype, device=y['mask'].device),
206
+ y['mask']], dim=-1)
207
+
208
+ force_mask = y.get('uncond', False)
209
+ if 'text' in self.cond_mode:
210
+ if 'text_embed' in y.keys(): # caching option
211
+ enc_text = y['text_embed']
212
+ else:
213
+ enc_text = self.encode_text(y['text'])
214
+ if type(enc_text) == tuple:
215
+ enc_text, text_mask = enc_text
216
+ if text_mask.shape[0] == 1 and bs > 1: # casting mask for the single-prompt-for-all case
217
+ text_mask = torch.repeat_interleave(text_mask, bs, dim=0)
218
+ text_emb = self.embed_text(self.mask_cond(enc_text, force_mask=force_mask)) # casting mask for the single-prompt-for-all case
219
+ if self.emb_policy == 'add':
220
+ emb = text_emb + time_emb
221
+ else:
222
+ emb = torch.cat([time_emb, text_emb], dim=0)
223
+ text_mask = torch.cat([torch.zeros_like(text_mask[:, 0:1]), text_mask], dim=1)
224
+ if 'action' in self.cond_mode:
225
+ action_emb = self.embed_action(y['action'])
226
+ emb = time_emb + self.mask_cond(action_emb, force_mask=force_mask)
227
+ if self.cond_mode == 'no_cond':
228
+ # unconstrained
229
+ emb = time_emb
230
+
231
+ if self.arch == 'gru':
232
+ x_reshaped = x.reshape(bs, njoints*nfeats, 1, nframes)
233
+ emb_gru = emb.repeat(nframes, 1, 1) #[#frames, bs, d]
234
+ emb_gru = emb_gru.permute(1, 2, 0) #[bs, d, #frames]
235
+ emb_gru = emb_gru.reshape(bs, self.latent_dim, 1, nframes) #[bs, d, 1, #frames]
236
+ x = torch.cat((x_reshaped, emb_gru), axis=1) #[bs, d+joints*feat, 1, #frames]
237
+
238
+ x = self.input_process(x)
239
+
240
+ # TODO - move to collate
241
+ frames_mask = None
242
+ is_valid_mask = y['mask'].shape[-1] > 1 # Don't use mask with the generate script
243
+ if self.mask_frames and is_valid_mask:
244
+ frames_mask = torch.logical_not(y['mask'][..., :x.shape[0]].squeeze(1).squeeze(1)).to(device=x.device)
245
+ if self.emb_trans_dec or self.arch == 'trans_enc':
246
+ step_mask = torch.zeros((bs, 1), dtype=torch.bool, device=x.device)
247
+ frames_mask = torch.cat([step_mask, frames_mask], dim=1)
248
+
249
+ if self.arch == 'trans_enc':
250
+ # adding the timestep embed
251
+ xseq = torch.cat((emb, x), axis=0) # [seqlen+1, bs, d]
252
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
253
+ output = self.seqTransEncoder(xseq, src_key_padding_mask=frames_mask)[1:] # , src_key_padding_mask=~maskseq) # [seqlen, bs, d]
254
+
255
+ elif self.arch == 'trans_dec':
256
+ if self.emb_trans_dec:
257
+ xseq = torch.cat((time_emb, x), axis=0)
258
+ else:
259
+ xseq = x
260
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen+1, bs, d]
261
+
262
+ if self.text_encoder_type == 'clip':
263
+ output = self.seqTransDecoder(tgt=xseq, memory=emb, tgt_key_padding_mask=frames_mask)
264
+ elif self.text_encoder_type == 'bert':
265
+ output = self.seqTransDecoder(tgt=xseq, memory=emb, memory_key_padding_mask=text_mask, tgt_key_padding_mask=frames_mask) # Rotem's bug fix
266
+ else:
267
+ raise ValueError()
268
+
269
+ if self.emb_trans_dec:
270
+ output = output[1:] # [seqlen, bs, d]
271
+
272
+ elif self.arch == 'gru':
273
+ xseq = x
274
+ xseq = self.sequence_pos_encoder(xseq) # [seqlen, bs, d]
275
+ output, _ = self.gru(xseq)
276
+
277
+ # Extract completed suffix
278
+ if self.is_prefix_comp:
279
+ output = output[self.context_len:]
280
+ y['mask'] = y['mask'][..., self.context_len:]
281
+
282
+ output = self.output_process(output) # [bs, njoints, nfeats, nframes]
283
+ return output
284
+
285
+
286
+ def _apply(self, fn):
287
+ super()._apply(fn)
288
+ self.rot2xyz.smpl_model._apply(fn)
289
+
290
+
291
+ def train(self, *args, **kwargs):
292
+ super().train(*args, **kwargs)
293
+ self.rot2xyz.smpl_model.train(*args, **kwargs)
294
+
295
+
296
+ class PositionalEncoding(nn.Module):
297
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
298
+ super(PositionalEncoding, self).__init__()
299
+ self.dropout = nn.Dropout(p=dropout)
300
+
301
+ pe = torch.zeros(max_len, d_model)
302
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
303
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
304
+ pe[:, 0::2] = torch.sin(position * div_term)
305
+ pe[:, 1::2] = torch.cos(position * div_term)
306
+ pe = pe.unsqueeze(0).transpose(0, 1)
307
+
308
+ self.register_buffer('pe', pe)
309
+
310
+ def forward(self, x):
311
+ # not used in the final model
312
+ x = x + self.pe[:x.shape[0], :]
313
+ return self.dropout(x)
314
+
315
+
316
+ class TimestepEmbedder(nn.Module):
317
+ def __init__(self, latent_dim, sequence_pos_encoder):
318
+ super().__init__()
319
+ self.latent_dim = latent_dim
320
+ self.sequence_pos_encoder = sequence_pos_encoder
321
+
322
+ time_embed_dim = self.latent_dim
323
+ self.time_embed = nn.Sequential(
324
+ nn.Linear(self.latent_dim, time_embed_dim),
325
+ nn.SiLU(),
326
+ nn.Linear(time_embed_dim, time_embed_dim),
327
+ )
328
+
329
+ def forward(self, timesteps):
330
+ return self.time_embed(self.sequence_pos_encoder.pe[timesteps]).permute(1, 0, 2)
331
+
332
+
333
+ class InputProcess(nn.Module):
334
+ def __init__(self, data_rep, input_feats, latent_dim):
335
+ super().__init__()
336
+ self.data_rep = data_rep
337
+ self.input_feats = input_feats
338
+ self.latent_dim = latent_dim
339
+ self.poseEmbedding = nn.Linear(self.input_feats, self.latent_dim)
340
+ if self.data_rep == 'rot_vel':
341
+ self.velEmbedding = nn.Linear(self.input_feats, self.latent_dim)
342
+
343
+ def forward(self, x):
344
+ bs, njoints, nfeats, nframes = x.shape
345
+ x = x.permute((3, 0, 1, 2)).reshape(nframes, bs, njoints*nfeats)
346
+
347
+ if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
348
+ x = self.poseEmbedding(x) # [seqlen, bs, d]
349
+ return x
350
+ elif self.data_rep == 'rot_vel':
351
+ first_pose = x[[0]] # [1, bs, 150]
352
+ first_pose = self.poseEmbedding(first_pose) # [1, bs, d]
353
+ vel = x[1:] # [seqlen-1, bs, 150]
354
+ vel = self.velEmbedding(vel) # [seqlen-1, bs, d]
355
+ return torch.cat((first_pose, vel), axis=0) # [seqlen, bs, d]
356
+ else:
357
+ raise ValueError
358
+
359
+
360
+ class OutputProcess(nn.Module):
361
+ def __init__(self, data_rep, input_feats, latent_dim, njoints, nfeats):
362
+ super().__init__()
363
+ self.data_rep = data_rep
364
+ self.input_feats = input_feats
365
+ self.latent_dim = latent_dim
366
+ self.njoints = njoints
367
+ self.nfeats = nfeats
368
+ self.poseFinal = nn.Linear(self.latent_dim, self.input_feats)
369
+ if self.data_rep == 'rot_vel':
370
+ self.velFinal = nn.Linear(self.latent_dim, self.input_feats)
371
+
372
+ def forward(self, output):
373
+ nframes, bs, d = output.shape
374
+ if self.data_rep in ['rot6d', 'xyz', 'hml_vec']:
375
+ output = self.poseFinal(output) # [seqlen, bs, 150]
376
+ elif self.data_rep == 'rot_vel':
377
+ first_pose = output[[0]] # [1, bs, d]
378
+ first_pose = self.poseFinal(first_pose) # [1, bs, 150]
379
+ vel = output[1:] # [seqlen-1, bs, d]
380
+ vel = self.velFinal(vel) # [seqlen-1, bs, 150]
381
+ output = torch.cat((first_pose, vel), axis=0) # [seqlen, bs, 150]
382
+ else:
383
+ raise ValueError
384
+ output = output.reshape(nframes, bs, self.njoints, self.nfeats)
385
+ output = output.permute(1, 2, 3, 0) # [bs, njoints, nfeats, nframes]
386
+ return output
387
+
388
+
389
+ class EmbedAction(nn.Module):
390
+ def __init__(self, num_actions, latent_dim):
391
+ super().__init__()
392
+ self.action_embedding = nn.Parameter(torch.randn(num_actions, latent_dim))
393
+
394
+ def forward(self, input):
395
+ idx = input[:, 0].to(torch.long) # an index array must be long
396
+ output = self.action_embedding[idx]
397
+ return output
398
+
399
+ class EmbedTargetLocSingle(nn.Module):
400
+ def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
401
+ super().__init__()
402
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
403
+ self.target_cond_dim = len(self.extended_goal_joint_names) * 4 # 4 => (x,y,z,is_valid)
404
+ self.latent_dim = latent_dim
405
+ _layers = [nn.Linear(self.target_cond_dim, self.latent_dim)]
406
+ for _ in range(num_layers):
407
+ _layers += [nn.SiLU(), nn.Linear(self.latent_dim, self.latent_dim)]
408
+ self.mlp = nn.Sequential(*_layers)
409
+
410
+ def forward(self, input, target_joint_names, target_heading):
411
+ # TODO - generate validity from outside the model
412
+ validity = torch.zeros_like(input)[..., :1]
413
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
414
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
415
+ for j in sample_joint_names_w_heading:
416
+ validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
417
+
418
+ mlp_input = torch.cat([input, validity], dim=-1).view(input.shape[0], -1)
419
+ return self.mlp(mlp_input)
420
+
421
+
422
+ class EmbedTargetLocSplit(nn.Module):
423
+ def __init__(self, all_goal_joint_names, latent_dim, num_layers=1):
424
+ super().__init__()
425
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
426
+ self.target_cond_dim = 4
427
+ self.latent_dim = latent_dim
428
+ self.splited_dim = self.latent_dim // len(self.extended_goal_joint_names)
429
+ assert self.latent_dim % len(self.extended_goal_joint_names) == 0
430
+ self.mini_mlps = nn.ModuleList()
431
+ for _ in self.extended_goal_joint_names:
432
+ _layers = [nn.Linear(self.target_cond_dim, self.splited_dim)]
433
+ for _ in range(num_layers):
434
+ _layers += [nn.SiLU(), nn.Linear(self.splited_dim, self.splited_dim)]
435
+ self.mini_mlps.append(nn.Sequential(*_layers))
436
+
437
+ def forward(self, input, target_joint_names, target_heading):
438
+ # TODO - generate validity from outside the model
439
+ validity = torch.zeros_like(input)[..., :1]
440
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
441
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
442
+ for j in sample_joint_names_w_heading:
443
+ validity[sample_idx, self.extended_goal_joint_names.index(j)] = 1.
444
+
445
+ mlp_input = torch.cat([input, validity], dim=-1)
446
+ mlp_splits = [self.mini_mlps[i](mlp_input[:, i]) for i in range(mlp_input.shape[1])]
447
+ return torch.cat(mlp_splits, dim=-1)
448
+
449
+ class EmbedTargetLocMulti(nn.Module):
450
+ def __init__(self, all_goal_joint_names, latent_dim):
451
+ super().__init__()
452
+
453
+ # todo: use a tensor of weight per joint, and another one for biases, then apply a selection in one go like we to for actions
454
+ self.extended_goal_joint_names = all_goal_joint_names + ['traj', 'heading']
455
+ self.extended_goal_joint_idx = {joint_name: idx for idx, joint_name in enumerate(self.extended_goal_joint_names)}
456
+ self.n_extended_goal_joints = len(self.extended_goal_joint_names)
457
+ self.target_loc_emb = nn.ParameterDict({joint_name:
458
+ nn.Sequential(
459
+ nn.Linear(3, latent_dim),
460
+ nn.SiLU(),
461
+ nn.Linear(latent_dim, latent_dim))
462
+ for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
463
+ # nn.Linear(3, latent_dim) for joint_name in self.extended_goal_joint_names}) # todo: check if 3 works for heading and traj
464
+ self.target_all_loc_emb = WeightedSum(self.n_extended_goal_joints) # nn.Linear(self.n_extended_goal_joints, latent_dim)
465
+ self.latent_dim = latent_dim
466
+
467
+ def forward(self, input, target_joint_names, target_heading):
468
+ output = torch.zeros((input.shape[0], self.latent_dim), dtype=input.dtype, device=input.device)
469
+
470
+ # Iterate over the batch and apply the appropriate filter for each joint
471
+ for sample_idx, sample_joint_names in enumerate(target_joint_names):
472
+ sample_joint_names_w_heading = np.append(sample_joint_names, 'heading') if target_heading[sample_idx] else sample_joint_names
473
+ output_one_sample = torch.zeros((self.n_extended_goal_joints, self.latent_dim), dtype=input.dtype, device=input.device)
474
+ for joint_name in sample_joint_names_w_heading:
475
+ layer = self.target_loc_emb[joint_name]
476
+ output_one_sample[self.extended_goal_joint_idx[joint_name]] = layer(input[sample_idx, self.extended_goal_joint_idx[joint_name]])
477
+ output[sample_idx] = self.target_all_loc_emb(output_one_sample)
478
+ # print(torch.where(output_one_sample.sum(axis=1)!=0)[0].cpu().numpy())
479
+
480
+ return output
motion_diffusion_model/model/rotation2xyz.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import torch
3
+ import utils.rotation_conversions as geometry
4
+
5
+
6
+ from model.smpl import SMPL, JOINTSTYPE_ROOT
7
+ # from .get_model import JOINTSTYPES
8
+ JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]
9
+
10
+
11
+ class Rotation2xyz:
12
+ def __init__(self, device, dataset='amass'):
13
+ self.device = device
14
+ self.dataset = dataset
15
+ self.smpl_model = SMPL().eval().to(device)
16
+
17
+ def __call__(self, x, mask, pose_rep, translation, glob,
18
+ jointstype, vertstrans, betas=None, beta=0,
19
+ glob_rot=None, get_rotations_back=False, **kwargs):
20
+ if pose_rep == "xyz":
21
+ return x
22
+
23
+ if mask is None:
24
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=bool, device=x.device)
25
+
26
+ if not glob and glob_rot is None:
27
+ raise TypeError("You must specify global rotation if glob is False")
28
+
29
+ if jointstype not in JOINTSTYPES:
30
+ raise NotImplementedError("This jointstype is not implemented.")
31
+
32
+ if translation:
33
+ x_translations = x[:, -1, :3]
34
+ x_rotations = x[:, :-1]
35
+ else:
36
+ x_rotations = x
37
+
38
+ x_rotations = x_rotations.permute(0, 3, 1, 2)
39
+ nsamples, time, njoints, feats = x_rotations.shape
40
+
41
+ # Compute rotations (convert only masked sequences output)
42
+ if pose_rep == "rotvec":
43
+ rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
44
+ elif pose_rep == "rotmat":
45
+ rotations = x_rotations[mask].view(-1, njoints, 3, 3)
46
+ elif pose_rep == "rotquat":
47
+ rotations = geometry.quaternion_to_matrix(x_rotations[mask])
48
+ elif pose_rep == "rot6d":
49
+ rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
50
+ else:
51
+ raise NotImplementedError("No geometry for this one.")
52
+
53
+ if not glob:
54
+ global_orient = torch.tensor(glob_rot, device=x.device)
55
+ global_orient = geometry.axis_angle_to_matrix(global_orient).view(1, 1, 3, 3)
56
+ global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
57
+ else:
58
+ global_orient = rotations[:, 0]
59
+ rotations = rotations[:, 1:]
60
+
61
+ if betas is None:
62
+ betas = torch.zeros([rotations.shape[0], self.smpl_model.num_betas],
63
+ dtype=rotations.dtype, device=rotations.device)
64
+ betas[:, 1] = beta
65
+ # import ipdb; ipdb.set_trace()
66
+ out = self.smpl_model(body_pose=rotations, global_orient=global_orient, betas=betas)
67
+
68
+ # get the desirable joints
69
+ joints = out[jointstype]
70
+
71
+ x_xyz = torch.empty(nsamples, time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
72
+ x_xyz[~mask] = 0
73
+ x_xyz[mask] = joints
74
+
75
+ x_xyz = x_xyz.permute(0, 2, 3, 1).contiguous()
76
+
77
+ # the first translation root at the origin on the prediction
78
+ if jointstype != "vertices":
79
+ rootindex = JOINTSTYPE_ROOT[jointstype]
80
+ x_xyz = x_xyz - x_xyz[:, [rootindex], :, :]
81
+
82
+ if translation and vertstrans:
83
+ # the first translation root at the origin
84
+ x_translations = x_translations - x_translations[:, :, [0]]
85
+
86
+ # add the translation to all the joints
87
+ x_xyz = x_xyz + x_translations[:, None, :, :]
88
+
89
+ if get_rotations_back:
90
+ return x_xyz, rotations, global_orient
91
+ else:
92
+ return x_xyz
motion_diffusion_model/model/smpl.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ import numpy as np
3
+ import torch
4
+
5
+ import contextlib
6
+
7
+ from smplx import SMPLLayer as _SMPLLayer
8
+ from smplx.lbs import vertices2joints
9
+
10
+
11
+ # action2motion_joints = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 21, 24, 38]
12
+ # change 0 and 8
13
+ action2motion_joints = [8, 1, 2, 3, 4, 5, 6, 7, 0, 9, 10, 11, 12, 13, 14, 21, 24, 38]
14
+
15
+ from utils.config import SMPL_MODEL_PATH, JOINT_REGRESSOR_TRAIN_EXTRA
16
+
17
+ JOINTSTYPE_ROOT = {"a2m": 0, # action2motion
18
+ "smpl": 0,
19
+ "a2mpl": 0, # set(smpl, a2m)
20
+ "vibe": 8} # 0 is the 8 position: OP MidHip below
21
+
22
+ JOINT_MAP = {
23
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
24
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
25
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
26
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
27
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
28
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
29
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
30
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
31
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
32
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
33
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
34
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
35
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
36
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
37
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
38
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
39
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
40
+ }
41
+
42
+ JOINT_NAMES = [
43
+ 'OP Nose', 'OP Neck', 'OP RShoulder',
44
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder',
45
+ 'OP LElbow', 'OP LWrist', 'OP MidHip',
46
+ 'OP RHip', 'OP RKnee', 'OP RAnkle',
47
+ 'OP LHip', 'OP LKnee', 'OP LAnkle',
48
+ 'OP REye', 'OP LEye', 'OP REar',
49
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe',
50
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
51
+ 'Right Ankle', 'Right Knee', 'Right Hip',
52
+ 'Left Hip', 'Left Knee', 'Left Ankle',
53
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder',
54
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist',
55
+ 'Neck (LSP)', 'Top of Head (LSP)',
56
+ 'Pelvis (MPII)', 'Thorax (MPII)',
57
+ 'Spine (H36M)', 'Jaw (H36M)',
58
+ 'Head (H36M)', 'Nose', 'Left Eye',
59
+ 'Right Eye', 'Left Ear', 'Right Ear'
60
+ ]
61
+
62
+
63
+ # adapted from VIBE/SPIN to output smpl_joints, vibe joints and action2motion joints
64
+ class SMPL(_SMPLLayer):
65
+ """ Extension of the official SMPL implementation to support more joints """
66
+
67
+ def __init__(self, model_path=SMPL_MODEL_PATH, **kwargs):
68
+ kwargs["model_path"] = model_path
69
+
70
+ # remove the verbosity for the 10-shapes beta parameters
71
+ with contextlib.redirect_stdout(None):
72
+ super(SMPL, self).__init__(**kwargs)
73
+
74
+ J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
75
+ self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
76
+ vibe_indexes = np.array([JOINT_MAP[i] for i in JOINT_NAMES])
77
+ a2m_indexes = vibe_indexes[action2motion_joints]
78
+ smpl_indexes = np.arange(24)
79
+ a2mpl_indexes = np.unique(np.r_[smpl_indexes, a2m_indexes])
80
+
81
+ self.maps = {"vibe": vibe_indexes,
82
+ "a2m": a2m_indexes,
83
+ "smpl": smpl_indexes,
84
+ "a2mpl": a2mpl_indexes}
85
+
86
+ def forward(self, *args, **kwargs):
87
+ smpl_output = super(SMPL, self).forward(*args, **kwargs)
88
+
89
+ extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
90
+ all_joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
91
+
92
+ output = {"vertices": smpl_output.vertices}
93
+
94
+ for joinstype, indexes in self.maps.items():
95
+ output[joinstype] = all_joints[:, indexes]
96
+
97
+ return output
motion_diffusion_model/sample/edit.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ """
3
+ Generate a large batch of image samples from a model and save them as a large
4
+ numpy array. This can be used to produce samples for FID evaluation.
5
+ """
6
+ from utils.fixseed import fixseed
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ from utils.parser_util import edit_args
11
+ from sample.generate import save_multiple_samples, construct_template_variables
12
+ from utils.model_util import create_model_and_diffusion, load_saved_model
13
+ from utils import dist_util
14
+ from utils.sampler_util import ClassifierFreeSampleModel
15
+ from data_loaders.get_data import get_dataset_loader
16
+ from data_loaders.humanml.scripts.motion_process import recover_from_ric
17
+ from data_loaders import humanml_utils
18
+ import data_loaders.humanml.utils.paramUtil as paramUtil
19
+ from data_loaders.humanml.utils.plot_script import plot_3d_motion
20
+ import shutil
21
+
22
+
23
+ def main():
24
+ args = edit_args()
25
+ fixseed(args.seed)
26
+ out_path = args.output_dir
27
+ name = os.path.basename(os.path.dirname(args.model_path))
28
+ niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
29
+ max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
30
+ fps = 12.5 if args.dataset == 'kit' else 20
31
+ n_frames = 120 # min(max_frames, int(args.motion_length*fps))
32
+
33
+ dist_util.setup_dist(args.device)
34
+ if out_path == '':
35
+ out_path = os.path.join(os.path.dirname(args.model_path),
36
+ 'edit_{}_{}_{}_seed{}'.format(name, niter, args.edit_mode, args.seed))
37
+ if args.text_condition != '':
38
+ out_path += '_' + args.text_condition.replace(' ', '_').replace('.', '')
39
+
40
+ print('Loading dataset...')
41
+ assert args.num_samples <= args.batch_size, \
42
+ f'Please either increase batch_size({args.batch_size}) or reduce num_samples({args.num_samples})'
43
+ # So why do we need this check? In order to protect GPU from a memory overload in the following line.
44
+ # If your GPU can handle batch size larger then default, you can specify it through --batch_size flag.
45
+ # If it doesn't, and you still want to sample more prompts, run this script with different seeds
46
+ # (specify through the --seed flag)
47
+ args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples
48
+ data = get_dataset_loader(name=args.dataset,
49
+ batch_size=args.batch_size,
50
+ num_frames=max_frames,
51
+ split='test',
52
+ hml_mode='train') # in train mode, you get both text and motion.
53
+ # data.fixed_length = n_frames
54
+ total_num_samples = args.num_samples * args.num_repetitions
55
+
56
+ print("Creating model and diffusion...")
57
+ model, diffusion = create_model_and_diffusion(args, data)
58
+
59
+ print(f"Loading checkpoints from [{args.model_path}]...")
60
+ load_saved_model(model, args.model_path, use_avg=args.use_ema)
61
+
62
+ model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler
63
+ model.to(dist_util.dev())
64
+ model.eval() # disable random masking
65
+
66
+ iterator = iter(data)
67
+ input_motions, model_kwargs = next(iterator)
68
+ input_motions = input_motions.to(dist_util.dev())
69
+ texts = [args.text_condition] * args.num_samples
70
+ model_kwargs['y']['text'] = texts
71
+ if args.text_condition == '':
72
+ args.guidance_param = 0. # Force unconditioned generation
73
+
74
+ # add inpainting mask according to args
75
+ assert max_frames == input_motions.shape[-1]
76
+ gt_frames_per_sample = {}
77
+ model_kwargs['y']['inpainted_motion'] = input_motions
78
+ if args.edit_mode == 'in_between':
79
+ model_kwargs['y']['inpainting_mask'] = torch.ones_like(input_motions, dtype=torch.bool,
80
+ device=input_motions.device) # True means use gt motion
81
+ for i, length in enumerate(model_kwargs['y']['lengths'].cpu().numpy()):
82
+ start_idx, end_idx = int(args.prefix_end * length), int(args.suffix_start * length)
83
+ gt_frames_per_sample[i] = list(range(0, start_idx)) + list(range(end_idx, max_frames))
84
+ model_kwargs['y']['inpainting_mask'][i, :, :,
85
+ start_idx: end_idx] = False # do inpainting in those frames
86
+ elif args.edit_mode == 'upper_body':
87
+ model_kwargs['y']['inpainting_mask'] = torch.tensor(humanml_utils.HML_LOWER_BODY_MASK, dtype=torch.bool,
88
+ device=input_motions.device) # True is lower body data
89
+ model_kwargs['y']['inpainting_mask'] = model_kwargs['y']['inpainting_mask'].unsqueeze(0).unsqueeze(
90
+ -1).unsqueeze(-1).repeat(input_motions.shape[0], 1, input_motions.shape[2], input_motions.shape[3])
91
+
92
+ all_motions = []
93
+ all_lengths = []
94
+ all_text = []
95
+
96
+ for rep_i in range(args.num_repetitions):
97
+ print(f'### Start sampling [repetitions #{rep_i}]')
98
+
99
+ # add CFG scale to batch
100
+ model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
101
+
102
+ sample_fn = diffusion.p_sample_loop
103
+
104
+ sample = sample_fn(
105
+ model,
106
+ (args.batch_size, model.njoints, model.nfeats, max_frames),
107
+ clip_denoised=False,
108
+ model_kwargs=model_kwargs,
109
+ skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
110
+ init_image=None,
111
+ progress=True,
112
+ dump_steps=None,
113
+ noise=None,
114
+ const_noise=False,
115
+ )
116
+
117
+
118
+ # Recover XYZ *positions* from HumanML3D vector representation
119
+ if model.data_rep == 'hml_vec':
120
+ n_joints = 22 if sample.shape[1] == 263 else 21
121
+ sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
122
+ sample = recover_from_ric(sample, n_joints)
123
+ sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
124
+
125
+ all_text += model_kwargs['y']['text']
126
+ all_motions.append(sample.cpu().numpy())
127
+ all_lengths.append(model_kwargs['y']['lengths'].cpu().numpy())
128
+
129
+ print(f"created {len(all_motions) * args.batch_size} samples")
130
+
131
+
132
+ all_motions = np.concatenate(all_motions, axis=0)
133
+ all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen]
134
+ all_text = all_text[:total_num_samples]
135
+ all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples]
136
+
137
+ if os.path.exists(out_path):
138
+ shutil.rmtree(out_path)
139
+ os.makedirs(out_path)
140
+
141
+ npy_path = os.path.join(out_path, 'results.npy')
142
+ print(f"saving results file to [{npy_path}]")
143
+ np.save(npy_path,
144
+ {'motion': all_motions, 'text': all_text, 'lengths': all_lengths,
145
+ 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
146
+ with open(npy_path.replace('.npy', '.txt'), 'w') as fw:
147
+ fw.write('\n'.join(all_text))
148
+ with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw:
149
+ fw.write('\n'.join([str(l) for l in all_lengths]))
150
+
151
+ print(f"saving visualizations to [{out_path}]...")
152
+ skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain
153
+
154
+ # Recover XYZ *positions* from HumanML3D vector representation
155
+ if model.data_rep == 'hml_vec':
156
+ input_motions = data.dataset.t2m_dataset.inv_transform(input_motions.cpu().permute(0, 2, 3, 1)).float()
157
+ input_motions = recover_from_ric(input_motions, n_joints)
158
+ input_motions = input_motions.view(-1, *input_motions.shape[2:]).permute(0, 2, 3, 1).cpu().numpy()
159
+
160
+
161
+ sample_print_template, row_print_template, all_print_template, \
162
+ sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained)
163
+ max_vis_samples = 6
164
+ num_vis_samples = min(args.num_samples, max_vis_samples)
165
+ animations = np.empty(shape=(args.num_samples, args.num_repetitions), dtype=object)
166
+ max_length = max(all_lengths)
167
+
168
+ for sample_i in range(args.num_samples):
169
+ caption = 'Input Motion'
170
+ length = model_kwargs['y']['lengths'][sample_i]
171
+ motion = input_motions[sample_i].transpose(2, 0, 1)[:length]
172
+ save_file = 'input_motion{:02d}.mp4'.format(sample_i)
173
+ animation_save_path = os.path.join(out_path, save_file)
174
+ rep_files = [animation_save_path]
175
+ # FIXME - fix and bring back the following:
176
+ # print(f'[({sample_i}) "{caption}" | -> {save_file}]')
177
+ # plot_3d_motion(animation_save_path, skeleton, motion, title=caption,
178
+ # dataset=args.dataset, fps=fps, vis_mode='gt',
179
+ # gt_frames=gt_frames_per_sample.get(sample_i, []))
180
+ for rep_i in range(args.num_repetitions):
181
+ caption = all_text[rep_i*args.batch_size + sample_i]
182
+ if caption == '':
183
+ caption = 'Edit [{}] unconditioned'.format(args.edit_mode)
184
+ else:
185
+ caption = 'Edit [{}]: {}'.format(args.edit_mode, caption)
186
+ length = all_lengths[rep_i*args.batch_size + sample_i]
187
+ motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:length]
188
+ save_file = 'sample{:02d}_rep{:02d}.mp4'.format(sample_i, rep_i)
189
+ animation_save_path = os.path.join(out_path, save_file)
190
+ rep_files.append(animation_save_path)
191
+ gt_frames = gt_frames_per_sample.get(sample_i, [])
192
+ print(f'[({sample_i}) "{caption}" | Rep #{rep_i} | -> {save_file}]')
193
+ animations[sample_i, rep_i] = plot_3d_motion(animation_save_path,
194
+ skeleton, motion, dataset=args.dataset, title=caption,
195
+ fps=fps, gt_frames=gt_frames)
196
+ # Credit for visualization: https://github.com/EricGuo5513/text-to-motion
197
+
198
+ all_rep_save_file = os.path.join(out_path, 'sample{:02d}.mp4'.format(sample_i))
199
+ ffmpeg_rep_files = [f' -i {f} ' for f in rep_files]
200
+ hstack_args = f' -filter_complex hstack=inputs={args.num_repetitions+1}'
201
+ ffmpeg_rep_cmd = f'ffmpeg -y -loglevel warning ' + ''.join(ffmpeg_rep_files) + f'{hstack_args} {all_rep_save_file}'
202
+ os.system(ffmpeg_rep_cmd)
203
+ print(f'[({sample_i}) "{caption}" | all repetitions | -> {all_rep_save_file}]')
204
+
205
+ save_multiple_samples(out_path, {'all': all_file_template}, animations, fps, max(list(all_lengths) + [n_frames]))
206
+
207
+ abs_path = os.path.abspath(out_path)
208
+ print(f'[Done] Results are at [{abs_path}]')
209
+
210
+
211
+ if __name__ == "__main__":
212
+ main()
motion_diffusion_model/sample/generate.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/openai/guided-diffusion
2
+ """
3
+ Generate a large batch of image samples from a model and save them as a large
4
+ numpy array. This can be used to produce samples for FID evaluation.
5
+ """
6
+ from utils.fixseed import fixseed
7
+ import os
8
+ import numpy as np
9
+ import torch
10
+ from utils.parser_util import generate_args
11
+ from utils.model_util import create_model_and_diffusion, load_saved_model
12
+ from utils import dist_util
13
+ from utils.sampler_util import ClassifierFreeSampleModel, AutoRegressiveSampler
14
+ from data_loaders.get_data import get_dataset_loader
15
+ from data_loaders.humanml.scripts.motion_process import recover_from_ric, get_target_location, sample_goal
16
+ import data_loaders.humanml.utils.paramUtil as paramUtil
17
+ from data_loaders.humanml.utils.plot_script import plot_3d_motion
18
+ import shutil
19
+ from data_loaders.tensors import collate
20
+ from moviepy.editor import clips_array
21
+
22
+
23
+ def main(args=None):
24
+ if args is None:
25
+ # args is None unless this method is called from another function (e.g. during training)
26
+ args = generate_args()
27
+ fixseed(args.seed)
28
+ out_path = args.output_dir
29
+ n_joints = 22 if args.dataset == 'humanml' else 21
30
+ name = os.path.basename(os.path.dirname(args.model_path))
31
+ niter = os.path.basename(args.model_path).replace('model', '').replace('.pt', '')
32
+ max_frames = 196 if args.dataset in ['kit', 'humanml'] else 60
33
+ fps = 12.5 if args.dataset == 'kit' else 20
34
+ n_frames = min(max_frames, int(args.motion_length*fps))
35
+ is_using_data = not any([args.input_text, args.text_prompt, args.action_file, args.action_name])
36
+ if args.context_len > 0:
37
+ is_using_data = True # For prefix completion, we need to sample a prefix
38
+ dist_util.setup_dist(args.device)
39
+ if out_path == '':
40
+ out_path = os.path.join(os.path.dirname(args.model_path),
41
+ 'samples_{}_{}_seed{}'.format(name, niter, args.seed))
42
+ if args.text_prompt != '':
43
+ out_path += '_' + args.text_prompt.replace(' ', '_').replace('.', '')
44
+ elif args.input_text != '':
45
+ out_path += '_' + os.path.basename(args.input_text).replace('.txt', '').replace(' ', '_').replace('.', '')
46
+ elif args.dynamic_text_path != '':
47
+ out_path += '_' + os.path.basename(args.dynamic_text_path).replace('.txt', '').replace(' ', '_').replace('.', '')
48
+
49
+ # this block must be called BEFORE the dataset is loaded
50
+ texts = None
51
+ if args.text_prompt != '':
52
+ texts = [args.text_prompt] * args.num_samples
53
+ elif args.input_text != '':
54
+ assert os.path.exists(args.input_text)
55
+ with open(args.input_text, 'r') as fr:
56
+ texts = fr.readlines()
57
+ texts = [s.replace('\n', '') for s in texts]
58
+ args.num_samples = len(texts)
59
+ elif args.dynamic_text_path != '':
60
+ assert os.path.exists(args.dynamic_text_path)
61
+ assert args.autoregressive, "Dynamic text sampling is only supported with autoregressive sampling."
62
+ with open(args.dynamic_text_path, 'r') as fr:
63
+ texts = fr.readlines()
64
+ texts = [s.replace('\n', '') for s in texts]
65
+ n_frames = len(texts) * args.pred_len # each text prompt is for a single prediction
66
+ elif args.action_name:
67
+ action_text = [args.action_name]
68
+ args.num_samples = 1
69
+ elif args.action_file != '':
70
+ assert os.path.exists(args.action_file)
71
+ with open(args.action_file, 'r') as fr:
72
+ action_text = fr.readlines()
73
+ action_text = [s.replace('\n', '') for s in action_text]
74
+ args.num_samples = len(action_text)
75
+
76
+ args.batch_size = args.num_samples # Sampling a single batch from the testset, with exactly args.num_samples
77
+
78
+ print('Loading dataset...')
79
+ data = load_dataset(args, max_frames, n_frames)
80
+ total_num_samples = args.num_samples * args.num_repetitions
81
+
82
+ print("Creating model and diffusion...")
83
+ model, diffusion = create_model_and_diffusion(args, data)
84
+
85
+ sample_fn = diffusion.p_sample_loop
86
+ if args.autoregressive:
87
+ sample_cls = AutoRegressiveSampler(args, sample_fn, n_frames)
88
+ sample_fn = sample_cls.sample
89
+
90
+ print(f"Loading checkpoints from [{args.model_path}]...")
91
+ load_saved_model(model, args.model_path, use_avg=args.use_ema)
92
+
93
+ if args.guidance_param != 1:
94
+ model = ClassifierFreeSampleModel(model) # wrapping model with the classifier-free sampler
95
+ model.to(dist_util.dev())
96
+ model.eval() # disable random masking
97
+
98
+ motion_shape = (args.batch_size, model.njoints, model.nfeats, n_frames)
99
+
100
+ if is_using_data:
101
+ iterator = iter(data)
102
+ input_motion, model_kwargs = next(iterator)
103
+ input_motion = input_motion.to(dist_util.dev())
104
+ if texts is not None:
105
+ model_kwargs['y']['text'] = texts
106
+ else:
107
+ collate_args = [{'inp': torch.zeros(n_frames), 'tokens': None, 'lengths': n_frames}] * args.num_samples
108
+ is_t2m = any([args.input_text, args.text_prompt])
109
+ if is_t2m:
110
+ # t2m
111
+ collate_args = [dict(arg, text=txt) for arg, txt in zip(collate_args, texts)]
112
+ else:
113
+ # a2m
114
+ action = data.dataset.action_name_to_action(action_text)
115
+ collate_args = [dict(arg, action=one_action, action_text=one_action_text) for
116
+ arg, one_action, one_action_text in zip(collate_args, action, action_text)]
117
+ _, model_kwargs = collate(collate_args)
118
+
119
+ model_kwargs['y'] = {key: val.to(dist_util.dev()) if torch.is_tensor(val) else val for key, val in model_kwargs['y'].items()}
120
+ init_image = None
121
+
122
+ all_motions = []
123
+ all_lengths = []
124
+ all_text = []
125
+
126
+ # add CFG scale to batch
127
+ if args.guidance_param != 1:
128
+ model_kwargs['y']['scale'] = torch.ones(args.batch_size, device=dist_util.dev()) * args.guidance_param
129
+
130
+ if 'text' in model_kwargs['y'].keys():
131
+ # encoding once instead of each iteration saves lots of time
132
+ model_kwargs['y']['text_embed'] = model.encode_text(model_kwargs['y']['text'])
133
+
134
+ if args.dynamic_text_path != '':
135
+ # Rearange the text to match the autoregressive sampling - each prompt fits to a single prediction
136
+ # Which is 2 seconds of motion by default
137
+ model_kwargs['y']['text'] = [model_kwargs['y']['text']] * args.num_samples
138
+ if args.text_encoder_type == 'bert':
139
+ model_kwargs['y']['text_embed'] = (model_kwargs['y']['text_embed'][0].unsqueeze(0).repeat(args.num_samples, 1, 1, 1),
140
+ model_kwargs['y']['text_embed'][1].unsqueeze(0).repeat(args.num_samples, 1, 1))
141
+ else:
142
+ raise NotImplementedError('DiP model only supports BERT text encoder at the moment. If you implement this, please send a PR!')
143
+
144
+ for rep_i in range(args.num_repetitions):
145
+ print(f'### Sampling [repetitions #{rep_i}]')
146
+
147
+ sample = sample_fn(
148
+ model,
149
+ motion_shape,
150
+ clip_denoised=False,
151
+ model_kwargs=model_kwargs,
152
+ skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
153
+ init_image=init_image,
154
+ progress=True,
155
+ dump_steps=None,
156
+ noise=None,
157
+ const_noise=False,
158
+ )
159
+
160
+ # Recover XYZ *positions* from HumanML3D vector representation
161
+ if model.data_rep == 'hml_vec':
162
+ n_joints = 22 if sample.shape[1] == 263 else 21
163
+ sample = data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
164
+ sample = recover_from_ric(sample, n_joints)
165
+ sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
166
+
167
+ rot2xyz_pose_rep = 'xyz' if model.data_rep in ['xyz', 'hml_vec'] else model.data_rep
168
+ rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.batch_size, n_frames).bool()
169
+ sample = model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True,
170
+ jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
171
+ get_rotations_back=False)
172
+
173
+ if args.unconstrained:
174
+ all_text += ['unconstrained'] * args.num_samples
175
+ else:
176
+ text_key = 'text' if 'text' in model_kwargs['y'] else 'action_text'
177
+ all_text += model_kwargs['y'][text_key]
178
+
179
+ all_motions.append(sample.cpu().numpy())
180
+ _len = model_kwargs['y']['lengths'].cpu().numpy()
181
+ if 'prefix' in model_kwargs['y'].keys():
182
+ _len[:] = sample.shape[-1]
183
+ all_lengths.append(_len)
184
+
185
+ print(f"created {len(all_motions) * args.batch_size} samples")
186
+
187
+
188
+ all_motions = np.concatenate(all_motions, axis=0)
189
+ all_motions = all_motions[:total_num_samples] # [bs, njoints, 6, seqlen]
190
+ all_text = all_text[:total_num_samples]
191
+ all_lengths = np.concatenate(all_lengths, axis=0)[:total_num_samples]
192
+
193
+ if os.path.exists(out_path):
194
+ shutil.rmtree(out_path)
195
+ os.makedirs(out_path)
196
+
197
+ npy_path = os.path.join(out_path, 'results.npy')
198
+ print(f"saving results file to [{npy_path}]")
199
+ np.save(npy_path,
200
+ {'motion': all_motions, 'text': all_text, 'lengths': all_lengths,
201
+ 'num_samples': args.num_samples, 'num_repetitions': args.num_repetitions})
202
+ if args.dynamic_text_path != '':
203
+ text_file_content = '\n'.join(['#'.join(s) for s in all_text])
204
+ else:
205
+ text_file_content = '\n'.join(all_text)
206
+ with open(npy_path.replace('.npy', '.txt'), 'w') as fw:
207
+ fw.write(text_file_content)
208
+ with open(npy_path.replace('.npy', '_len.txt'), 'w') as fw:
209
+ fw.write('\n'.join([str(l) for l in all_lengths]))
210
+
211
+ print(f"saving visualizations to [{out_path}]...")
212
+ skeleton = paramUtil.kit_kinematic_chain if args.dataset == 'kit' else paramUtil.t2m_kinematic_chain
213
+
214
+ sample_print_template, row_print_template, all_print_template, \
215
+ sample_file_template, row_file_template, all_file_template = construct_template_variables(args.unconstrained)
216
+ max_vis_samples = 6
217
+ num_vis_samples = min(args.num_samples, max_vis_samples)
218
+ animations = np.empty(shape=(args.num_samples, args.num_repetitions), dtype=object)
219
+ max_length = max(all_lengths)
220
+
221
+ for sample_i in range(args.num_samples):
222
+ rep_files = []
223
+ for rep_i in range(args.num_repetitions):
224
+ caption = all_text[rep_i*args.batch_size + sample_i]
225
+ if args.dynamic_text_path != '': # caption per frame
226
+ assert type(caption) == list
227
+ caption_per_frame = []
228
+ for c in caption:
229
+ caption_per_frame += [c] * args.pred_len
230
+ caption = caption_per_frame
231
+
232
+
233
+ # Trim / freeze motion if needed
234
+ length = all_lengths[rep_i*args.batch_size + sample_i]
235
+ motion = all_motions[rep_i*args.batch_size + sample_i].transpose(2, 0, 1)[:max_length]
236
+ if motion.shape[0] > length:
237
+ motion[length:-1] = motion[length-1] # duplicate the last frame to end of motion, so all motions will be in equal length
238
+
239
+ save_file = sample_file_template.format(sample_i, rep_i)
240
+ animation_save_path = os.path.join(out_path, save_file)
241
+ gt_frames = np.arange(args.context_len) if args.context_len > 0 and not args.autoregressive else []
242
+ animations[sample_i, rep_i] = plot_3d_motion(animation_save_path,
243
+ skeleton, motion, dataset=args.dataset, title=caption,
244
+ fps=fps, gt_frames=gt_frames)
245
+ rep_files.append(animation_save_path)
246
+
247
+ save_multiple_samples(out_path, {'all': all_file_template}, animations, fps, max(list(all_lengths) + [n_frames]))
248
+
249
+ abs_path = os.path.abspath(out_path)
250
+ print(f'[Done] Results are at [{abs_path}]')
251
+
252
+ return out_path
253
+
254
+
255
+ def save_multiple_samples(out_path, file_templates, animations, fps, max_frames, no_dir=False):
256
+
257
+ num_samples_in_out_file = 3
258
+ n_samples = animations.shape[0]
259
+
260
+ for sample_i in range(0,n_samples,num_samples_in_out_file):
261
+ last_sample_i = min(sample_i+num_samples_in_out_file, n_samples)
262
+ all_sample_save_file = file_templates['all'].format(sample_i, last_sample_i-1)
263
+ if no_dir and n_samples <= num_samples_in_out_file:
264
+ all_sample_save_path = out_path
265
+ else:
266
+ all_sample_save_path = os.path.join(out_path, all_sample_save_file)
267
+ print(f'saving {os.path.split(out_path)[1]}/{all_sample_save_file}')
268
+
269
+ clips = clips_array(animations[sample_i:last_sample_i])
270
+ clips.duration = max_frames/fps
271
+
272
+ # import time
273
+ # start = time.time()
274
+ clips.write_videofile(all_sample_save_path, fps=fps, threads=4, logger=None)
275
+ # print(f'duration = {time.time()-start}')
276
+
277
+ for clip in clips.clips:
278
+ # close internal clips. Does nothing but better use in case one day it will do something
279
+ clip.close()
280
+ clips.close() # important
281
+
282
+
283
+ def construct_template_variables(unconstrained):
284
+ row_file_template = 'sample{:02d}.mp4'
285
+ all_file_template = 'samples_{:02d}_to_{:02d}.mp4'
286
+ if unconstrained:
287
+ sample_file_template = 'row{:02d}_col{:02d}.mp4'
288
+ sample_print_template = '[{} row #{:02d} column #{:02d} | -> {}]'
289
+ row_file_template = row_file_template.replace('sample', 'row')
290
+ row_print_template = '[{} row #{:02d} | all columns | -> {}]'
291
+ all_file_template = all_file_template.replace('samples', 'rows')
292
+ all_print_template = '[rows {:02d} to {:02d} | -> {}]'
293
+ else:
294
+ sample_file_template = 'sample{:02d}_rep{:02d}.mp4'
295
+ sample_print_template = '["{}" ({:02d}) | Rep #{:02d} | -> {}]'
296
+ row_print_template = '[ "{}" ({:02d}) | all repetitions | -> {}]'
297
+ all_print_template = '[samples {:02d} to {:02d} | all repetitions | -> {}]'
298
+
299
+ return sample_print_template, row_print_template, all_print_template, \
300
+ sample_file_template, row_file_template, all_file_template
301
+
302
+
303
+ def load_dataset(args, max_frames, n_frames):
304
+ data = get_dataset_loader(name=args.dataset,
305
+ batch_size=args.batch_size,
306
+ num_frames=max_frames,
307
+ split='test',
308
+ hml_mode='train' if args.pred_len > 0 else 'text_only', # We need to sample a prefix from the dataset
309
+ fixed_len=args.pred_len + args.context_len, pred_len=args.pred_len, device=dist_util.dev())
310
+ data.fixed_length = n_frames
311
+ return data
312
+
313
+
314
+ def is_substr_in_list(substr, list_of_strs):
315
+ return np.char.find(list_of_strs, substr) != -1 # [substr in string for string in list_of_strs]
316
+
317
+ if __name__ == "__main__":
318
+ main()
motion_diffusion_model/sample/predict.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import subprocess
3
+ from typing import Any, List, Optional
4
+ from argparse import Namespace
5
+
6
+ import torch
7
+ from cog import BasePredictor, Input, Path, BaseModel
8
+
9
+ import data_loaders.humanml.utils.paramUtil as paramUtil
10
+ from data_loaders.get_data import get_dataset_loader
11
+ from data_loaders.humanml.scripts.motion_process import recover_from_ric
12
+ from data_loaders.humanml.utils.plot_script import plot_3d_motion
13
+ from data_loaders.tensors import collate
14
+ from utils.sampler_util import ClassifierFreeSampleModel
15
+ from utils import dist_util
16
+ from utils.model_util import create_model_and_diffusion, load_model_wo_clip
17
+ from visualize.motions2hik import motions2hik
18
+ from sample.generate import construct_template_variables
19
+
20
+ """
21
+ In case of matplot lib issues it may be needed to delete model/data_loaders/humanml/utils/plot_script.py" in lines 89~92 as
22
+ suggested in https://github.com/GuyTevet/motion-diffusion-model/issues/6
23
+ """
24
+
25
+
26
+ class ModelOutput(BaseModel):
27
+ json_file: Optional[Any]
28
+ animation: Optional[List[Path]]
29
+
30
+
31
+ def get_args():
32
+ args = Namespace()
33
+ args.fps = 20
34
+ args.model_path = './save/humanml_trans_enc_512/model000200000.pt'
35
+ args.guidance_param = 2.5
36
+ args.unconstrained = False
37
+ args.dataset = 'humanml'
38
+
39
+ args.cond_mask_prob = 1
40
+ args.emb_trans_dec = False
41
+ args.latent_dim = 512
42
+ args.layers = 8
43
+ args.arch = 'trans_enc'
44
+
45
+ args.noise_schedule = 'cosine'
46
+ args.sigma_small = True
47
+ args.lambda_vel = 0.0
48
+ args.lambda_rcxyz = 0.0
49
+ args.lambda_fc = 0.0
50
+ return args
51
+
52
+
53
+ class Predictor(BasePredictor):
54
+ def setup(self):
55
+ subprocess.run(["mkdir", "/root/.cache/clip"])
56
+ subprocess.run(["cp", "-r", "ViT-B-32.pt", "/root/.cache/clip"])
57
+
58
+ self.args = get_args()
59
+ self.num_frames = self.args.fps * 6
60
+ print('Loading dataset...')
61
+
62
+ # temporary data
63
+ self.data = get_dataset_loader(name=self.args.dataset,
64
+ batch_size=1,
65
+ num_frames=196,
66
+ split='test',
67
+ hml_mode='text_only')
68
+
69
+ self.data.fixed_length = float(self.num_frames)
70
+
71
+ print("Creating model and diffusion...")
72
+ self.model, self.diffusion = create_model_and_diffusion(self.args, self.data)
73
+
74
+ print(f"Loading checkpoints from...")
75
+ state_dict = torch.load(self.args.model_path, map_location='cpu')
76
+ load_model_wo_clip(self.model, state_dict)
77
+
78
+ if self.args.guidance_param != 1:
79
+ self.model = ClassifierFreeSampleModel(self.model) # wrapping model with the classifier-free sampler
80
+ self.model.to(dist_util.dev())
81
+ self.model.eval() # disable random masking
82
+
83
+ def predict(
84
+ self,
85
+ prompt: str = Input(default="the person walked forward and is picking up his toolbox."),
86
+ num_repetitions: int = Input(default=3, description="How many"),
87
+ output_format: str = Input(
88
+ description='Choose the format of the output, either an animation or a json file of the animation data.\
89
+ The json format is: {"thetas": [...], "root_translation": [...], "joint_map": [...]}, where "thetas" \
90
+ is an [nframes x njoints x 3] array of joint rotations in degrees, "root_translation" is an [nframes x 3] \
91
+ array of (X, Y, Z) positions of the root, and "joint_map" is a list mapping the SMPL joint index to the\
92
+ corresponding HumanIK joint name',
93
+ default="animation",
94
+ choices=["animation", "json_file"],
95
+ ),
96
+ ) -> ModelOutput:
97
+ args = self.args
98
+ args.num_repetitions = int(num_repetitions)
99
+
100
+ self.data = get_dataset_loader(name=self.args.dataset,
101
+ batch_size=args.num_repetitions,
102
+ num_frames=self.num_frames,
103
+ split='test',
104
+ hml_mode='text_only')
105
+
106
+ collate_args = [{'inp': torch.zeros(self.num_frames), 'tokens': None, 'lengths': self.num_frames, 'text': str(prompt)}]
107
+ _, model_kwargs = collate(collate_args)
108
+
109
+ # add CFG scale to batch
110
+ if args.guidance_param != 1:
111
+ model_kwargs['y']['scale'] = torch.ones(args.num_repetitions, device=dist_util.dev()) * args.guidance_param
112
+
113
+ sample_fn = self.diffusion.p_sample_loop
114
+ sample = sample_fn(
115
+ self.model,
116
+ (args.num_repetitions, self.model.njoints, self.model.nfeats, self.num_frames),
117
+ clip_denoised=False,
118
+ model_kwargs=model_kwargs,
119
+ skip_timesteps=0, # 0 is the default value - i.e. don't skip any step
120
+ init_image=None,
121
+ progress=True,
122
+ dump_steps=None,
123
+ noise=None,
124
+ const_noise=False,
125
+ )
126
+
127
+ # Recover XYZ *positions* from HumanML3D vector representation
128
+ if self.model.data_rep == 'hml_vec':
129
+ n_joints = 22 if sample.shape[1] == 263 else 21
130
+ sample = self.data.dataset.t2m_dataset.inv_transform(sample.cpu().permute(0, 2, 3, 1)).float()
131
+ sample = recover_from_ric(sample, n_joints)
132
+ sample = sample.view(-1, *sample.shape[2:]).permute(0, 2, 3, 1)
133
+
134
+ rot2xyz_pose_rep = 'xyz' if self.model.data_rep in ['xyz', 'hml_vec'] else self.model.data_rep
135
+ rot2xyz_mask = None if rot2xyz_pose_rep == 'xyz' else model_kwargs['y']['mask'].reshape(args.num_repetitions,
136
+ self.num_frames).bool()
137
+ sample = self.model.rot2xyz(x=sample, mask=rot2xyz_mask, pose_rep=rot2xyz_pose_rep, glob=True, translation=True,
138
+ jointstype='smpl', vertstrans=True, betas=None, beta=0, glob_rot=None,
139
+ get_rotations_back=False)
140
+
141
+ all_motions = sample.cpu().numpy()
142
+
143
+ if output_format == 'json_file':
144
+ data_dict = motions2hik(all_motions)
145
+ return ModelOutput(json_file=data_dict)
146
+
147
+ caption = str(prompt)
148
+
149
+ skeleton = paramUtil.t2m_kinematic_chain
150
+
151
+ sample_print_template, row_print_template, all_print_template, \
152
+ sample_file_template, row_file_template, all_file_template = construct_template_variables(
153
+ args.unconstrained)
154
+
155
+ rep_files = []
156
+ replicate_fnames = []
157
+ for rep_i in range(args.num_repetitions):
158
+ motion = all_motions[rep_i].transpose(2, 0, 1)[:self.num_frames]
159
+ save_file = sample_file_template.format(1, rep_i)
160
+ print(sample_print_template.format(caption, 1, rep_i, save_file))
161
+ plot_3d_motion(save_file, skeleton, motion, dataset=args.dataset, title=caption, fps=args.fps)
162
+ # Credit for visualization: https://github.com/EricGuo5513/text-to-motion
163
+ rep_files.append(save_file)
164
+
165
+ replicate_fnames.append(Path(save_file))
166
+
167
+ return ModelOutput(animation=replicate_fnames)
motion_diffusion_model/utils/PYTORCH3D_LICENSE ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD License
2
+
3
+ For PyTorch3D software
4
+
5
+ Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
6
+
7
+ Redistribution and use in source and binary forms, with or without modification,
8
+ are permitted provided that the following conditions are met:
9
+
10
+ * Redistributions of source code must retain the above copyright notice, this
11
+ list of conditions and the following disclaimer.
12
+
13
+ * Redistributions in binary form must reproduce the above copyright notice,
14
+ this list of conditions and the following disclaimer in the documentation
15
+ and/or other materials provided with the distribution.
16
+
17
+ * Neither the name Facebook nor the names of its contributors may be used to
18
+ endorse or promote products derived from this software without specific
19
+ prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
28
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
30
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
motion_diffusion_model/utils/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ SMPL_DATA_PATH = "./body_models/smpl"
4
+
5
+ SMPL_KINTREE_PATH = os.path.join(SMPL_DATA_PATH, "kintree_table.pkl")
6
+ SMPL_MODEL_PATH = os.path.join(SMPL_DATA_PATH, "SMPL_NEUTRAL.pkl")
7
+ JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(SMPL_DATA_PATH, 'J_regressor_extra.npy')
8
+
9
+ ROT_CONVENTION_TO_ROT_NUMBER = {
10
+ 'legacy': 23,
11
+ 'no_hands': 21,
12
+ 'full_hands': 51,
13
+ 'mitten_hands': 33,
14
+ }
15
+
16
+ GENDERS = ['neutral', 'male', 'female']
17
+ NUM_BETAS = 10
motion_diffusion_model/utils/dist_util.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helpers for distributed training.
3
+ """
4
+
5
+ import socket
6
+
7
+ import torch as th
8
+ import torch.distributed as dist
9
+
10
+ # Change this to reflect your cluster layout.
11
+ # The GPU for a given rank is (rank % GPUS_PER_NODE).
12
+ GPUS_PER_NODE = 8
13
+
14
+ SETUP_RETRY_COUNT = 3
15
+
16
+ used_device = 0
17
+
18
+ def setup_dist(device=0):
19
+ """
20
+ Setup a distributed process group.
21
+ """
22
+ global used_device
23
+ used_device = device
24
+ if dist.is_initialized():
25
+ return
26
+ # os.environ["CUDA_VISIBLE_DEVICES"] = str(device) # f"{MPI.COMM_WORLD.Get_rank() % GPUS_PER_NODE}"
27
+
28
+ # comm = MPI.COMM_WORLD
29
+ # backend = "gloo" if not th.cuda.is_available() else "nccl"
30
+
31
+ # if backend == "gloo":
32
+ # hostname = "localhost"
33
+ # else:
34
+ # hostname = socket.gethostbyname(socket.getfqdn())
35
+ # os.environ["MASTER_ADDR"] = comm.bcast(hostname, root=0)
36
+ # os.environ["RANK"] = str(comm.rank)
37
+ # os.environ["WORLD_SIZE"] = str(comm.size)
38
+
39
+ # port = comm.bcast(_find_free_port(), root=used_device)
40
+ # os.environ["MASTER_PORT"] = str(port)
41
+ # dist.init_process_group(backend=backend, init_method="env://")
42
+
43
+
44
+ def dev():
45
+ """
46
+ Get the device to use for torch.distributed.
47
+ """
48
+ global used_device
49
+ if th.cuda.is_available() and used_device>=0:
50
+ return th.device(f"cuda:{used_device}")
51
+ return th.device("cpu")
52
+
53
+
54
+ def load_state_dict(path, **kwargs):
55
+ """
56
+ Load a PyTorch file without redundant fetches across MPI ranks.
57
+ """
58
+ return th.load(path, **kwargs)
59
+
60
+
61
+ def sync_params(params):
62
+ """
63
+ Synchronize a sequence of Tensors across ranks from rank 0.
64
+ """
65
+ for p in params:
66
+ with th.no_grad():
67
+ dist.broadcast(p, 0)
68
+
69
+
70
+ def _find_free_port():
71
+ try:
72
+ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
73
+ s.bind(("", 0))
74
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
75
+ return s.getsockname()[1]
76
+ finally:
77
+ s.close()
motion_diffusion_model/utils/fixseed.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import random
4
+
5
+
6
+ def fixseed(seed):
7
+ torch.backends.cudnn.benchmark = False
8
+ random.seed(seed)
9
+ np.random.seed(seed)
10
+ torch.manual_seed(seed)
11
+
12
+
13
+ # SEED = 10
14
+ # EVALSEED = 0
15
+ # # Provoc warning: not fully functionnal yet
16
+ # # torch.set_deterministic(True)
17
+ # torch.backends.cudnn.benchmark = False
18
+ # fixseed(SEED)
motion_diffusion_model/utils/loss_util.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusion.nn import mean_flat, sum_flat
2
+ import torch
3
+ import numpy as np
4
+
5
+ def angle_l2(angle1, angle2):
6
+ a = angle1 - angle2
7
+ a = (a + (torch.pi/2)) % torch.pi - (torch.pi/2)
8
+ return a ** 2
9
+
10
+ def diff_l2(a, b):
11
+ return (a - b) ** 2
12
+
13
+ def masked_l2(a, b, mask, loss_fn=diff_l2, epsilon=1e-8, entries_norm=True):
14
+ # assuming a.shape == b.shape == bs, J, Jdim, seqlen
15
+ # assuming mask.shape == bs, 1, 1, seqlen
16
+ loss = loss_fn(a, b)
17
+ loss = sum_flat(loss * mask.float()) # gives \sigma_euclidean over unmasked elements
18
+ n_entries = a.shape[1]
19
+ if len(a.shape) > 3:
20
+ n_entries *= a.shape[2]
21
+ non_zero_elements = sum_flat(mask)
22
+ if entries_norm:
23
+ # In cases the mask is per frame, and not specifying the number of entries per frame, this normalization is needed,
24
+ # Otherwise set it to False
25
+ non_zero_elements *= n_entries
26
+ # print('mask', mask.shape)
27
+ # print('non_zero_elements', non_zero_elements)
28
+ # print('loss', loss)
29
+ mse_loss_val = loss / (non_zero_elements + epsilon) # Add epsilon to avoid division by zero
30
+ # print('mse_loss_val', mse_loss_val)
31
+ return mse_loss_val
32
+
33
+
34
+ def masked_goal_l2(pred_goal, ref_goal, cond, all_goal_joint_names):
35
+ all_goal_joint_names_w_traj = np.append(all_goal_joint_names, 'traj')
36
+ target_joint_idx = [[np.where(all_goal_joint_names_w_traj == j)[0][0] for j in sample_joints] for sample_joints in cond['target_joint_names']]
37
+ loc_mask = torch.zeros_like(pred_goal[:,:-1], dtype=torch.bool)
38
+ for sample_idx in range(loc_mask.shape[0]):
39
+ loc_mask[sample_idx, target_joint_idx[sample_idx]] = True
40
+ loc_mask[:, -1, 1] = False # vertical joint of 'traj' is always masked out
41
+ loc_loss = masked_l2(pred_goal[:,:-1], ref_goal[:,:-1], loc_mask, entries_norm=False)
42
+
43
+ heading_loss = masked_l2(pred_goal[:,-1:, :1], ref_goal[:,-1:, :1], cond['is_heading'].unsqueeze(1).unsqueeze(1), loss_fn=angle_l2, entries_norm=False)
44
+
45
+ loss = loc_loss + heading_loss
46
+ return loss
motion_diffusion_model/utils/misc.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class WeightedSum(nn.Module):
6
+ def __init__(self, num_rows):
7
+ super(WeightedSum, self).__init__()
8
+ # Initialize learnable weights
9
+ self.weights = nn.Parameter(torch.randn(num_rows))
10
+
11
+ def forward(self, x):
12
+ # Ensure weights are normalized (optional)
13
+ normalized_weights = self.weights / self.weights.sum() # torch.softmax(self.weights, dim=0)
14
+ # Compute the weighted sum of the rows
15
+ weighted_sum = torch.matmul(normalized_weights, x)
16
+ return weighted_sum
17
+
18
+
19
+ def wrapped_getattr(self, name, default=None, wrapped_member_name='model'):
20
+ ''' should be called from wrappers of model classes such as ClassifierFreeSampleModel'''
21
+
22
+ if isinstance(self, torch.nn.Module):
23
+ # for descendants of nn.Module, name may be in self.__dict__[_parameters/_buffers/_modules]
24
+ # so we activate nn.Module.__getattr__ first.
25
+ # Otherwise, we might encounter an infinite loop
26
+ try:
27
+ attr = torch.nn.Module.__getattr__(self, name)
28
+ except AttributeError:
29
+ wrapped_member = torch.nn.Module.__getattr__(self, wrapped_member_name)
30
+ attr = getattr(wrapped_member, name, default)
31
+ else:
32
+ # the easy case, where self is not derived from nn.Module
33
+ wrapped_member = getattr(self, wrapped_member_name)
34
+ attr = getattr(wrapped_member, name, default)
35
+ return attr
36
+
37
+
38
+ def to_numpy(tensor):
39
+ if torch.is_tensor(tensor):
40
+ return tensor.cpu().numpy()
41
+ elif type(tensor).__module__ != 'numpy':
42
+ raise ValueError("Cannot convert {} to numpy array".format(
43
+ type(tensor)))
44
+ return tensor
45
+
46
+
47
+ def to_torch(ndarray):
48
+ if type(ndarray).__module__ == 'numpy':
49
+ return torch.from_numpy(ndarray)
50
+ elif not torch.is_tensor(ndarray):
51
+ raise ValueError("Cannot convert {} to torch tensor".format(
52
+ type(ndarray)))
53
+ return ndarray
54
+
55
+
56
+ def cleanexit():
57
+ import sys
58
+ import os
59
+ try:
60
+ sys.exit(0)
61
+ except SystemExit:
62
+ os._exit(0)
63
+
64
+ def load_model_wo_clip(model, state_dict):
65
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
66
+ assert len(unexpected_keys) == 0
67
+ assert all([k.startswith('clip_model.') for k in missing_keys])
68
+
69
+ def freeze_joints(x, joints_to_freeze):
70
+ # Freezes selected joint *rotations* as they appear in the first frame
71
+ # x [bs, [root+n_joints], joint_dim(6), seqlen]
72
+ frozen = x.detach().clone()
73
+ frozen[:, joints_to_freeze, :, :] = frozen[:, joints_to_freeze, :, :1]
74
+ return frozen
motion_diffusion_model/utils/model_util.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model.mdm import MDM
3
+ from diffusion import gaussian_diffusion as gd
4
+ from diffusion.respace import SpacedDiffusion, space_timesteps
5
+ from utils.parser_util import get_cond_mode
6
+ from data_loaders.humanml_utils import HML_EE_JOINT_NAMES
7
+
8
+ def load_model_wo_clip(model, state_dict):
9
+ # assert (state_dict['sequence_pos_encoder.pe'][:model.sequence_pos_encoder.pe.shape[0]] == model.sequence_pos_encoder.pe).all() # TEST
10
+ # assert (state_dict['embed_timestep.sequence_pos_encoder.pe'][:model.embed_timestep.sequence_pos_encoder.pe.shape[0]] == model.embed_timestep.sequence_pos_encoder.pe).all() # TEST
11
+ del state_dict['sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
12
+ del state_dict['embed_timestep.sequence_pos_encoder.pe'] # no need to load it (fixed), and causes size mismatch for older models
13
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
14
+ assert len(unexpected_keys) == 0
15
+ assert all([k.startswith('clip_model.') or 'sequence_pos_encoder' in k for k in missing_keys])
16
+
17
+
18
+ def create_model_and_diffusion(args, data):
19
+ model = MDM(**get_model_args(args, data))
20
+ diffusion = create_gaussian_diffusion(args)
21
+ return model, diffusion
22
+
23
+
24
+ def get_model_args(args, data):
25
+
26
+ # default args
27
+ clip_version = 'ViT-B/32'
28
+ action_emb = 'tensor'
29
+ cond_mode = get_cond_mode(args)
30
+ if hasattr(data.dataset, 'num_actions'):
31
+ num_actions = data.dataset.num_actions
32
+ else:
33
+ num_actions = 1
34
+
35
+ # SMPL defaults
36
+ data_rep = 'rot6d'
37
+ njoints = 25
38
+ nfeats = 6
39
+ all_goal_joint_names = []
40
+
41
+ if args.dataset == 'humanml':
42
+ data_rep = 'hml_vec'
43
+ njoints = 263
44
+ nfeats = 1
45
+ all_goal_joint_names = ['pelvis'] + HML_EE_JOINT_NAMES
46
+ elif args.dataset == 'kit':
47
+ data_rep = 'hml_vec'
48
+ njoints = 251
49
+ nfeats = 1
50
+
51
+ # Compatibility with old models
52
+ if not hasattr(args, 'pred_len'):
53
+ args.pred_len = 0
54
+ args.context_len = 0
55
+
56
+ emb_policy = args.__dict__.get('emb_policy', 'add')
57
+ multi_target_cond = args.__dict__.get('multi_target_cond', False)
58
+ multi_encoder_type = args.__dict__.get('multi_encoder_type', 'multi')
59
+ target_enc_layers = args.__dict__.get('target_enc_layers', 1)
60
+
61
+ return {'modeltype': '', 'njoints': njoints, 'nfeats': nfeats, 'num_actions': num_actions,
62
+ 'translation': True, 'pose_rep': 'rot6d', 'glob': True, 'glob_rot': True,
63
+ 'latent_dim': args.latent_dim, 'ff_size': 1024, 'num_layers': args.layers, 'num_heads': 4,
64
+ 'dropout': 0.1, 'activation': "gelu", 'data_rep': data_rep, 'cond_mode': cond_mode,
65
+ 'cond_mask_prob': args.cond_mask_prob, 'action_emb': action_emb, 'arch': args.arch,
66
+ 'emb_trans_dec': args.emb_trans_dec, 'clip_version': clip_version, 'dataset': args.dataset,
67
+ 'text_encoder_type': args.text_encoder_type,
68
+ 'pos_embed_max_len': args.pos_embed_max_len, 'mask_frames': args.mask_frames,
69
+ 'pred_len': args.pred_len, 'context_len': args.context_len, 'emb_policy': emb_policy,
70
+ 'all_goal_joint_names': all_goal_joint_names, 'multi_target_cond': multi_target_cond, 'multi_encoder_type': multi_encoder_type, 'target_enc_layers': target_enc_layers,
71
+ }
72
+
73
+
74
+
75
+ def create_gaussian_diffusion(args):
76
+ # default params
77
+ predict_xstart = True # we always predict x_start (a.k.a. x0), that's our deal!
78
+ steps = args.diffusion_steps
79
+ scale_beta = 1. # no scaling
80
+ timestep_respacing = '' # can be used for ddim sampling, we don't use it.
81
+ learn_sigma = False
82
+ rescale_timesteps = False
83
+
84
+ betas = gd.get_named_beta_schedule(args.noise_schedule, steps, scale_beta)
85
+ loss_type = gd.LossType.MSE
86
+
87
+ if not timestep_respacing:
88
+ timestep_respacing = [steps]
89
+
90
+ if hasattr(args, 'lambda_target_loc'):
91
+ lambda_target_loc = args.lambda_target_loc
92
+ else:
93
+ lambda_target_loc = 0.
94
+
95
+ return SpacedDiffusion(
96
+ use_timesteps=space_timesteps(steps, timestep_respacing),
97
+ betas=betas,
98
+ model_mean_type=(
99
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
100
+ ),
101
+ model_var_type=(
102
+ (
103
+ gd.ModelVarType.FIXED_LARGE
104
+ if not args.sigma_small
105
+ else gd.ModelVarType.FIXED_SMALL
106
+ )
107
+ if not learn_sigma
108
+ else gd.ModelVarType.LEARNED_RANGE
109
+ ),
110
+ loss_type=loss_type,
111
+ rescale_timesteps=rescale_timesteps,
112
+ lambda_vel=args.lambda_vel,
113
+ lambda_rcxyz=args.lambda_rcxyz,
114
+ lambda_fc=args.lambda_fc,
115
+ lambda_target_loc=lambda_target_loc,
116
+ )
117
+
118
+ def load_saved_model(model, model_path, use_avg: bool=False): # use_avg_model
119
+ state_dict = torch.load(model_path, map_location='cpu')
120
+ # Use average model when possible
121
+ if use_avg and 'model_avg' in state_dict.keys():
122
+ # if use_avg_model:
123
+ print('loading avg model')
124
+ state_dict = state_dict['model_avg']
125
+ else:
126
+ if 'model' in state_dict:
127
+ print('loading model without avg')
128
+ state_dict = state_dict['model']
129
+ else:
130
+ print('checkpoint has no avg model, loading as usual.')
131
+ load_model_wo_clip(model, state_dict)
132
+ return model
motion_diffusion_model/utils/parser_util.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+ import argparse
3
+ import os
4
+ import json
5
+
6
+
7
+ def parse_and_load_from_model(parser):
8
+ # args according to the loaded model
9
+ # do not try to specify them from cmd line since they will be overwritten
10
+ add_data_options(parser)
11
+ add_model_options(parser)
12
+ add_diffusion_options(parser)
13
+ args = parser.parse_args()
14
+ args_to_overwrite = []
15
+ for group_name in ['dataset', 'model', 'diffusion']:
16
+ args_to_overwrite += get_args_per_group_name(parser, args, group_name)
17
+
18
+ # load args from model
19
+ if args.model_path != '': # if not using external results file
20
+ args = load_args_from_model(args, args_to_overwrite)
21
+
22
+ if args.cond_mask_prob == 0:
23
+ args.guidance_param = 1
24
+
25
+ return apply_rules(args)
26
+
27
+ def load_args_from_model(args, args_to_overwrite):
28
+ model_path = get_model_path_from_args()
29
+ args_path = os.path.join(os.path.dirname(model_path), 'args.json')
30
+ assert os.path.exists(args_path), 'Arguments json file was not found!'
31
+ with open(args_path, 'r') as fr:
32
+ model_args = json.load(fr)
33
+
34
+ for a in args_to_overwrite:
35
+ if a in model_args.keys():
36
+ setattr(args, a, model_args[a])
37
+
38
+ elif 'cond_mode' in model_args: # backward compitability
39
+ unconstrained = (model_args['cond_mode'] == 'no_cond')
40
+ setattr(args, 'unconstrained', unconstrained)
41
+
42
+ else:
43
+ print('Warning: was not able to load [{}], using default value [{}] instead.'.format(a, args.__dict__[a]))
44
+ return args
45
+
46
+ def apply_rules(args):
47
+ # For prefix completion
48
+ if args.pred_len == 0:
49
+ args.pred_len = args.context_len
50
+
51
+ # For target conditioning
52
+ if args.lambda_target_loc > 0.:
53
+ args.multi_target_cond = True
54
+ return args
55
+
56
+
57
+ def get_args_per_group_name(parser, args, group_name):
58
+ for group in parser._action_groups:
59
+ if group.title == group_name:
60
+ group_dict = {a.dest: getattr(args, a.dest, None) for a in group._group_actions}
61
+ return list(argparse.Namespace(**group_dict).__dict__.keys())
62
+ return ValueError('group_name was not found.')
63
+
64
+ def get_model_path_from_args():
65
+ try:
66
+ dummy_parser = ArgumentParser()
67
+ dummy_parser.add_argument('--model_path')
68
+ dummy_args, _ = dummy_parser.parse_known_args()
69
+ return dummy_args.model_path
70
+ except:
71
+ raise ValueError('model_path argument must be specified.')
72
+
73
+
74
+ def add_base_options(parser):
75
+ group = parser.add_argument_group('base')
76
+ group.add_argument("--cuda", default=True, type=bool, help="Use cuda device, otherwise use CPU.")
77
+ group.add_argument("--device", default=0, type=int, help="Device id to use.")
78
+ group.add_argument("--seed", default=10, type=int, help="For fixing random seed.")
79
+ group.add_argument("--batch_size", default=64, type=int, help="Batch size during training.")
80
+ group.add_argument("--train_platform_type", default='NoPlatform', choices=['NoPlatform', 'ClearmlPlatform', 'TensorboardPlatform', 'WandBPlatform'], type=str,
81
+ help="Choose platform to log results. NoPlatform means no logging.")
82
+ group.add_argument("--external_mode", default=False, type=bool, help="For backward cometability, do not change or delete.")
83
+
84
+
85
+ def add_diffusion_options(parser):
86
+ group = parser.add_argument_group('diffusion')
87
+ group.add_argument("--noise_schedule", default='cosine', choices=['linear', 'cosine'], type=str,
88
+ help="Noise schedule type")
89
+ group.add_argument("--diffusion_steps", default=1000, type=int,
90
+ help="Number of diffusion steps (denoted T in the paper)")
91
+ group.add_argument("--sigma_small", default=True, type=bool, help="Use smaller sigma values.")
92
+
93
+
94
+ def add_model_options(parser):
95
+ group = parser.add_argument_group('model')
96
+ group.add_argument("--arch", default='trans_enc',
97
+ choices=['trans_enc', 'trans_dec', 'gru'], type=str,
98
+ help="Architecture types as reported in the paper.")
99
+ group.add_argument("--text_encoder_type", default='clip',
100
+ choices=['clip', 'bert'], type=str, help="Text encoder type.")
101
+ group.add_argument("--emb_trans_dec", action='store_true',
102
+ help="For trans_dec architecture only, if true, will inject condition as a class token"
103
+ " (in addition to cross-attention).")
104
+ group.add_argument("--layers", default=8, type=int,
105
+ help="Number of layers.")
106
+ group.add_argument("--latent_dim", default=512, type=int,
107
+ help="Transformer/GRU width.")
108
+ group.add_argument("--cond_mask_prob", default=.1, type=float,
109
+ help="The probability of masking the condition during training."
110
+ " For classifier-free guidance learning.")
111
+ group.add_argument("--mask_frames", action='store_true', help="If true, will fix Rotem's bug and mask invalid frames.")
112
+ group.add_argument("--lambda_rcxyz", default=0.0, type=float, help="Joint positions loss.")
113
+ group.add_argument("--lambda_vel", default=0.0, type=float, help="Joint velocity loss.")
114
+ group.add_argument("--lambda_fc", default=0.0, type=float, help="Foot contact loss.")
115
+ group.add_argument("--lambda_target_loc", default=0.0, type=float, help="For HumanML only, when . L2 with target location.")
116
+ group.add_argument("--unconstrained", action='store_true',
117
+ help="Model is trained unconditionally. That is, it is constrained by neither text nor action. "
118
+ "Currently tested on HumanAct12 only.")
119
+ group.add_argument("--pos_embed_max_len", default=5000, type=int,
120
+ help="Pose embedding max length.")
121
+ group.add_argument("--use_ema", action='store_true',
122
+ help="If True, will use EMA model averaging.")
123
+
124
+
125
+ group.add_argument("--multi_target_cond", action='store_true', help="If true, enable multi-target conditioning (aka Sigal's model).")
126
+ group.add_argument("--multi_encoder_type", default='single', choices=['single', 'multi', 'split'], type=str, help="Specifies the encoder type to be used for the multi joint condition.")
127
+ group.add_argument("--target_enc_layers", default=1, type=int, help="Num target encoder layers")
128
+
129
+
130
+ # Prefix completion model
131
+ group.add_argument("--context_len", default=0, type=int, help="If larger than 0, will do prefix completion.")
132
+ group.add_argument("--pred_len", default=0, type=int, help="If context_len larger than 0, will do prefix completion. If pred_len will not be specified - will use the same length as context_len")
133
+
134
+
135
+
136
+
137
+ def add_data_options(parser):
138
+ group = parser.add_argument_group('dataset')
139
+ group.add_argument("--dataset", default='humanml', choices=['humanml', 'kit', 'humanact12', 'uestc'], type=str,
140
+ help="Dataset name (choose from list).")
141
+ group.add_argument("--data_dir", default="", type=str,
142
+ help="If empty, will use defaults according to the specified dataset.")
143
+
144
+
145
+ def add_training_options(parser):
146
+ group = parser.add_argument_group('training')
147
+ group.add_argument("--save_dir", required=True, type=str,
148
+ help="Path to save checkpoints and results.")
149
+ group.add_argument("--overwrite", action='store_true',
150
+ help="If True, will enable to use an already existing save_dir.")
151
+ group.add_argument("--lr", default=1e-4, type=float, help="Learning rate.")
152
+ group.add_argument("--weight_decay", default=0.0, type=float, help="Optimizer weight decay.")
153
+ group.add_argument("--lr_anneal_steps", default=0, type=int, help="Number of learning rate anneal steps.")
154
+ group.add_argument("--eval_batch_size", default=32, type=int,
155
+ help="Batch size during evaluation loop. Do not change this unless you know what you are doing. "
156
+ "T2m precision calculation is based on fixed batch size 32.")
157
+ group.add_argument("--eval_split", default='test', choices=['val', 'test'], type=str,
158
+ help="Which split to evaluate on during training.")
159
+ group.add_argument("--eval_during_training", action='store_true',
160
+ help="If True, will run evaluation during training.")
161
+ group.add_argument("--eval_rep_times", default=3, type=int,
162
+ help="Number of repetitions for evaluation loop during training.")
163
+ group.add_argument("--eval_num_samples", default=1_000, type=int,
164
+ help="If -1, will use all samples in the specified split.")
165
+ group.add_argument("--log_interval", default=1_000, type=int,
166
+ help="Log losses each N steps")
167
+ group.add_argument("--save_interval", default=50_000, type=int,
168
+ help="Save checkpoints and run evaluation each N steps")
169
+ group.add_argument("--num_steps", default=600_000, type=int,
170
+ help="Training will stop after the specified number of steps.")
171
+ group.add_argument("--num_frames", default=60, type=int,
172
+ help="Limit for the maximal number of frames. In HumanML3D and KIT this field is ignored.")
173
+ group.add_argument("--resume_checkpoint", default="", type=str,
174
+ help="If not empty, will start from the specified checkpoint (path to model###.pt file).")
175
+
176
+ group.add_argument("--gen_during_training", action='store_true',
177
+ help="If True, will generate motions during training, on each save interval.")
178
+ group.add_argument("--gen_num_samples", default=3, type=int,
179
+ help="Number of samples to sample while generating")
180
+ group.add_argument("--gen_num_repetitions", default=2, type=int,
181
+ help="Number of repetitions, per sample (text prompt/action)")
182
+ group.add_argument("--gen_guidance_param", default=2.5, type=float,
183
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
184
+
185
+ group.add_argument("--avg_model_beta", default=0.9999, type=float, help="Average model beta (for EMA).")
186
+ group.add_argument("--adam_beta2", default=0.999, type=float, help="Adam beta2.")
187
+
188
+ group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
189
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
190
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
191
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
192
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
193
+
194
+
195
+ def add_sampling_options(parser):
196
+ group = parser.add_argument_group('sampling')
197
+ group.add_argument("--model_path", required=True, type=str,
198
+ help="Path to model####.pt file to be sampled.")
199
+ group.add_argument("--output_dir", default='', type=str,
200
+ help="Path to results dir (auto created by the script). "
201
+ "If empty, will create dir in parallel to checkpoint.")
202
+ group.add_argument("--num_samples", default=6, type=int,
203
+ help="Maximal number of prompts to sample, "
204
+ "if loading dataset from file, this field will be ignored.")
205
+ group.add_argument("--num_repetitions", default=3, type=int,
206
+ help="Number of repetitions, per sample (text prompt/action)")
207
+ group.add_argument("--guidance_param", default=2.5, type=float,
208
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
209
+
210
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
211
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
212
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
213
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
214
+
215
+ def add_generate_options(parser):
216
+ group = parser.add_argument_group('generate')
217
+ group.add_argument("--motion_length", default=6.0, type=float,
218
+ help="The length of the sampled motion [in seconds]. "
219
+ "Maximum is 9.8 for HumanML3D (text-to-motion), and 2.0 for HumanAct12 (action-to-motion)")
220
+ group.add_argument("--input_text", default='', type=str,
221
+ help="Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
222
+ group.add_argument("--dynamic_text_path", default='', type=str,
223
+ help="For the autoregressive mode only! Path to a text file lists text prompts to be synthesized. If empty, will take text prompts from dataset.")
224
+ group.add_argument("--action_file", default='', type=str,
225
+ help="Path to a text file that lists names of actions to be synthesized. Names must be a subset of dataset/uestc/info/action_classes.txt if sampling from uestc, "
226
+ "or a subset of [warm_up,walk,run,jump,drink,lift_dumbbell,sit,eat,turn steering wheel,phone,boxing,throw] if sampling from humanact12. "
227
+ "If no file is specified, will take action names from dataset.")
228
+ group.add_argument("--text_prompt", default='', type=str,
229
+ help="A text prompt to be generated. If empty, will take text prompts from dataset.")
230
+ group.add_argument("--action_name", default='', type=str,
231
+ help="An action name to be generated. If empty, will take text prompts from dataset.")
232
+ group.add_argument("--target_joint_names", default='DIMP_FINAL', type=str, help="Force single joint configuration by specifing the joints (coma separated). If None - will use the random mode for all end effectors.")
233
+
234
+
235
+ def add_edit_options(parser):
236
+ group = parser.add_argument_group('edit')
237
+ group.add_argument("--edit_mode", default='in_between', choices=['in_between', 'upper_body'], type=str,
238
+ help="Defines which parts of the input motion will be edited.\n"
239
+ "(1) in_between - suffix and prefix motion taken from input motion, "
240
+ "middle motion is generated.\n"
241
+ "(2) upper_body - lower body joints taken from input motion, "
242
+ "upper body is generated.")
243
+ group.add_argument("--text_condition", default='', type=str,
244
+ help="Editing will be conditioned on this text prompt. "
245
+ "If empty, will perform unconditioned editing.")
246
+ group.add_argument("--prefix_end", default=0.25, type=float,
247
+ help="For in_between editing - Defines the end of input prefix (ratio from all frames).")
248
+ group.add_argument("--suffix_start", default=0.75, type=float,
249
+ help="For in_between editing - Defines the start of input suffix (ratio from all frames).")
250
+
251
+
252
+ def add_evaluation_options(parser):
253
+ group = parser.add_argument_group('eval')
254
+ group.add_argument("--model_path", required=True, type=str,
255
+ help="Path to model####.pt file to be sampled.")
256
+ group.add_argument("--eval_mode", default='wo_mm', choices=['wo_mm', 'mm_short', 'debug', 'full'], type=str,
257
+ help="wo_mm (t2m only) - 20 repetitions without multi-modality metric; "
258
+ "mm_short (t2m only) - 5 repetitions with multi-modality metric; "
259
+ "debug - short run, less accurate results."
260
+ "full (a2m only) - 20 repetitions.")
261
+ group.add_argument("--autoregressive", action='store_true', help="If true, and we use a prefix model will generate motions in an autoregressive loop.")
262
+ group.add_argument("--autoregressive_include_prefix", action='store_true', help="If true, include the init prefix in the output, otherwise, will drop it.")
263
+ group.add_argument("--autoregressive_init", default='data', type=str, choices=['data', 'isaac'],
264
+ help="Sets the source of the init frames, either from the dataset or isaac init poses.")
265
+ group.add_argument("--guidance_param", default=2.5, type=float,
266
+ help="For classifier-free sampling - specifies the s parameter, as defined in the paper.")
267
+
268
+
269
+ def get_cond_mode(args):
270
+ if args.unconstrained:
271
+ cond_mode = 'no_cond'
272
+ elif args.dataset in ['kit', 'humanml']:
273
+ cond_mode = 'text'
274
+ else:
275
+ cond_mode = 'action'
276
+ return cond_mode
277
+
278
+
279
+ def train_args():
280
+ parser = ArgumentParser()
281
+ add_base_options(parser)
282
+ add_data_options(parser)
283
+ add_model_options(parser)
284
+ add_diffusion_options(parser)
285
+ add_training_options(parser)
286
+ return apply_rules(parser.parse_args())
287
+
288
+
289
+ def generate_args():
290
+ parser = ArgumentParser()
291
+ # args specified by the user: (all other will be loaded from the model)
292
+ add_base_options(parser)
293
+ add_sampling_options(parser)
294
+ add_generate_options(parser)
295
+ args = parse_and_load_from_model(parser)
296
+ cond_mode = get_cond_mode(args)
297
+
298
+ if (args.input_text or args.text_prompt) and cond_mode != 'text':
299
+ raise Exception('Arguments input_text and text_prompt should not be used for an action condition. Please use action_file or action_name.')
300
+ elif (args.action_file or args.action_name) and cond_mode != 'action':
301
+ raise Exception('Arguments action_file and action_name should not be used for a text condition. Please use input_text or text_prompt.')
302
+
303
+ return args
304
+
305
+
306
+ def edit_args():
307
+ parser = ArgumentParser()
308
+ # args specified by the user: (all other will be loaded from the model)
309
+ add_base_options(parser)
310
+ add_sampling_options(parser)
311
+ add_edit_options(parser)
312
+ return parse_and_load_from_model(parser)
313
+
314
+
315
+ def evaluation_parser():
316
+ parser = ArgumentParser()
317
+ # args specified by the user: (all other will be loaded from the model)
318
+ add_base_options(parser)
319
+ add_evaluation_options(parser)
320
+ return parse_and_load_from_model(parser)
motion_diffusion_model/utils/rotation_conversions.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
+ # Check PYTORCH3D_LICENCE before use
4
+
5
+ import functools
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ """
13
+ The transformation matrices returned from the functions in this file assume
14
+ the points on which the transformation will be applied are column vectors.
15
+ i.e. the R matrix is structured as
16
+
17
+ R = [
18
+ [Rxx, Rxy, Rxz],
19
+ [Ryx, Ryy, Ryz],
20
+ [Rzx, Rzy, Rzz],
21
+ ] # (3, 3)
22
+
23
+ This matrix can be applied to column vectors by post multiplication
24
+ by the points e.g.
25
+
26
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
27
+ transformed_points = R * points
28
+
29
+ To apply the same matrix to points which are row vectors, the R matrix
30
+ can be transposed and pre multiplied by the points:
31
+
32
+ e.g.
33
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
34
+ transformed_points = points * R.transpose(1, 0)
35
+ """
36
+
37
+
38
+ def quaternion_to_matrix(quaternions):
39
+ """
40
+ Convert rotations given as quaternions to rotation matrices.
41
+
42
+ Args:
43
+ quaternions: quaternions with real part first,
44
+ as tensor of shape (..., 4).
45
+
46
+ Returns:
47
+ Rotation matrices as tensor of shape (..., 3, 3).
48
+ """
49
+ r, i, j, k = torch.unbind(quaternions, -1)
50
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
51
+
52
+ o = torch.stack(
53
+ (
54
+ 1 - two_s * (j * j + k * k),
55
+ two_s * (i * j - k * r),
56
+ two_s * (i * k + j * r),
57
+ two_s * (i * j + k * r),
58
+ 1 - two_s * (i * i + k * k),
59
+ two_s * (j * k - i * r),
60
+ two_s * (i * k - j * r),
61
+ two_s * (j * k + i * r),
62
+ 1 - two_s * (i * i + j * j),
63
+ ),
64
+ -1,
65
+ )
66
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
67
+
68
+
69
+ def _copysign(a, b):
70
+ """
71
+ Return a tensor where each element has the absolute value taken from the,
72
+ corresponding element of a, with sign taken from the corresponding
73
+ element of b. This is like the standard copysign floating-point operation,
74
+ but is not careful about negative 0 and NaN.
75
+
76
+ Args:
77
+ a: source tensor.
78
+ b: tensor whose signs will be used, of the same shape as a.
79
+
80
+ Returns:
81
+ Tensor of the same shape as a with the signs of b.
82
+ """
83
+ signs_differ = (a < 0) != (b < 0)
84
+ return torch.where(signs_differ, -a, a)
85
+
86
+
87
+ def _sqrt_positive_part(x):
88
+ """
89
+ Returns torch.sqrt(torch.max(0, x))
90
+ but with a zero subgradient where x is 0.
91
+ """
92
+ ret = torch.zeros_like(x)
93
+ positive_mask = x > 0
94
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
95
+ return ret
96
+
97
+
98
+ def matrix_to_quaternion(matrix):
99
+ """
100
+ Convert rotations given as rotation matrices to quaternions.
101
+
102
+ Args:
103
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
104
+
105
+ Returns:
106
+ quaternions with real part first, as tensor of shape (..., 4).
107
+ """
108
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
109
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
110
+ m00 = matrix[..., 0, 0]
111
+ m11 = matrix[..., 1, 1]
112
+ m22 = matrix[..., 2, 2]
113
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
114
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
115
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
116
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
117
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
118
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
119
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
120
+ return torch.stack((o0, o1, o2, o3), -1)
121
+
122
+
123
+ def _axis_angle_rotation(axis: str, angle):
124
+ """
125
+ Return the rotation matrices for one of the rotations about an axis
126
+ of which Euler angles describe, for each value of the angle given.
127
+
128
+ Args:
129
+ axis: Axis label "X" or "Y or "Z".
130
+ angle: any shape tensor of Euler angles in radians
131
+
132
+ Returns:
133
+ Rotation matrices as tensor of shape (..., 3, 3).
134
+ """
135
+
136
+ cos = torch.cos(angle)
137
+ sin = torch.sin(angle)
138
+ one = torch.ones_like(angle)
139
+ zero = torch.zeros_like(angle)
140
+
141
+ if axis == "X":
142
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
143
+ if axis == "Y":
144
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
145
+ if axis == "Z":
146
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
147
+
148
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
149
+
150
+
151
+ def euler_angles_to_matrix(euler_angles, convention: str):
152
+ """
153
+ Convert rotations given as Euler angles in radians to rotation matrices.
154
+
155
+ Args:
156
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
157
+ convention: Convention string of three uppercase letters from
158
+ {"X", "Y", and "Z"}.
159
+
160
+ Returns:
161
+ Rotation matrices as tensor of shape (..., 3, 3).
162
+ """
163
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
164
+ raise ValueError("Invalid input euler angles.")
165
+ if len(convention) != 3:
166
+ raise ValueError("Convention must have 3 letters.")
167
+ if convention[1] in (convention[0], convention[2]):
168
+ raise ValueError(f"Invalid convention {convention}.")
169
+ for letter in convention:
170
+ if letter not in ("X", "Y", "Z"):
171
+ raise ValueError(f"Invalid letter {letter} in convention string.")
172
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
173
+ return functools.reduce(torch.matmul, matrices)
174
+
175
+
176
+ def _angle_from_tan(
177
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
178
+ ):
179
+ """
180
+ Extract the first or third Euler angle from the two members of
181
+ the matrix which are positive constant times its sine and cosine.
182
+
183
+ Args:
184
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
185
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
186
+ convention.
187
+ data: Rotation matrices as tensor of shape (..., 3, 3).
188
+ horizontal: Whether we are looking for the angle for the third axis,
189
+ which means the relevant entries are in the same row of the
190
+ rotation matrix. If not, they are in the same column.
191
+ tait_bryan: Whether the first and third axes in the convention differ.
192
+
193
+ Returns:
194
+ Euler Angles in radians for each matrix in dataset as a tensor
195
+ of shape (...).
196
+ """
197
+
198
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
199
+ if horizontal:
200
+ i2, i1 = i1, i2
201
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
202
+ if horizontal == even:
203
+ return torch.atan2(data[..., i1], data[..., i2])
204
+ if tait_bryan:
205
+ return torch.atan2(-data[..., i2], data[..., i1])
206
+ return torch.atan2(data[..., i2], -data[..., i1])
207
+
208
+
209
+ def _index_from_letter(letter: str):
210
+ if letter == "X":
211
+ return 0
212
+ if letter == "Y":
213
+ return 1
214
+ if letter == "Z":
215
+ return 2
216
+
217
+
218
+ def matrix_to_euler_angles(matrix, convention: str):
219
+ """
220
+ Convert rotations given as rotation matrices to Euler angles in radians.
221
+
222
+ Args:
223
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
224
+ convention: Convention string of three uppercase letters.
225
+
226
+ Returns:
227
+ Euler angles in radians as tensor of shape (..., 3).
228
+ """
229
+ if len(convention) != 3:
230
+ raise ValueError("Convention must have 3 letters.")
231
+ if convention[1] in (convention[0], convention[2]):
232
+ raise ValueError(f"Invalid convention {convention}.")
233
+ for letter in convention:
234
+ if letter not in ("X", "Y", "Z"):
235
+ raise ValueError(f"Invalid letter {letter} in convention string.")
236
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
237
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
238
+ i0 = _index_from_letter(convention[0])
239
+ i2 = _index_from_letter(convention[2])
240
+ tait_bryan = i0 != i2
241
+ if tait_bryan:
242
+ central_angle = torch.asin(
243
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
244
+ )
245
+ else:
246
+ central_angle = torch.acos(matrix[..., i0, i0])
247
+
248
+ o = (
249
+ _angle_from_tan(
250
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
251
+ ),
252
+ central_angle,
253
+ _angle_from_tan(
254
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
255
+ ),
256
+ )
257
+ return torch.stack(o, -1)
258
+
259
+
260
+ def random_quaternions(
261
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
262
+ ):
263
+ """
264
+ Generate random quaternions representing rotations,
265
+ i.e. versors with nonnegative real part.
266
+
267
+ Args:
268
+ n: Number of quaternions in a batch to return.
269
+ dtype: Type to return.
270
+ device: Desired device of returned tensor. Default:
271
+ uses the current device for the default tensor type.
272
+ requires_grad: Whether the resulting tensor should have the gradient
273
+ flag set.
274
+
275
+ Returns:
276
+ Quaternions as tensor of shape (N, 4).
277
+ """
278
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
279
+ s = (o * o).sum(1)
280
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
281
+ return o
282
+
283
+
284
+ def random_rotations(
285
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
286
+ ):
287
+ """
288
+ Generate random rotations as 3x3 rotation matrices.
289
+
290
+ Args:
291
+ n: Number of rotation matrices in a batch to return.
292
+ dtype: Type to return.
293
+ device: Device of returned tensor. Default: if None,
294
+ uses the current device for the default tensor type.
295
+ requires_grad: Whether the resulting tensor should have the gradient
296
+ flag set.
297
+
298
+ Returns:
299
+ Rotation matrices as tensor of shape (n, 3, 3).
300
+ """
301
+ quaternions = random_quaternions(
302
+ n, dtype=dtype, device=device, requires_grad=requires_grad
303
+ )
304
+ return quaternion_to_matrix(quaternions)
305
+
306
+
307
+ def random_rotation(
308
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
309
+ ):
310
+ """
311
+ Generate a single random 3x3 rotation matrix.
312
+
313
+ Args:
314
+ dtype: Type to return
315
+ device: Device of returned tensor. Default: if None,
316
+ uses the current device for the default tensor type
317
+ requires_grad: Whether the resulting tensor should have the gradient
318
+ flag set
319
+
320
+ Returns:
321
+ Rotation matrix as tensor of shape (3, 3).
322
+ """
323
+ return random_rotations(1, dtype, device, requires_grad)[0]
324
+
325
+
326
+ def standardize_quaternion(quaternions):
327
+ """
328
+ Convert a unit quaternion to a standard form: one in which the real
329
+ part is non negative.
330
+
331
+ Args:
332
+ quaternions: Quaternions with real part first,
333
+ as tensor of shape (..., 4).
334
+
335
+ Returns:
336
+ Standardized quaternions as tensor of shape (..., 4).
337
+ """
338
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
339
+
340
+
341
+ def quaternion_raw_multiply(a, b):
342
+ """
343
+ Multiply two quaternions.
344
+ Usual torch rules for broadcasting apply.
345
+
346
+ Args:
347
+ a: Quaternions as tensor of shape (..., 4), real part first.
348
+ b: Quaternions as tensor of shape (..., 4), real part first.
349
+
350
+ Returns:
351
+ The product of a and b, a tensor of quaternions shape (..., 4).
352
+ """
353
+ aw, ax, ay, az = torch.unbind(a, -1)
354
+ bw, bx, by, bz = torch.unbind(b, -1)
355
+ ow = aw * bw - ax * bx - ay * by - az * bz
356
+ ox = aw * bx + ax * bw + ay * bz - az * by
357
+ oy = aw * by - ax * bz + ay * bw + az * bx
358
+ oz = aw * bz + ax * by - ay * bx + az * bw
359
+ return torch.stack((ow, ox, oy, oz), -1)
360
+
361
+
362
+ def quaternion_multiply(a, b):
363
+ """
364
+ Multiply two quaternions representing rotations, returning the quaternion
365
+ representing their composition, i.e. the versor with nonnegative real part.
366
+ Usual torch rules for broadcasting apply.
367
+
368
+ Args:
369
+ a: Quaternions as tensor of shape (..., 4), real part first.
370
+ b: Quaternions as tensor of shape (..., 4), real part first.
371
+
372
+ Returns:
373
+ The product of a and b, a tensor of quaternions of shape (..., 4).
374
+ """
375
+ ab = quaternion_raw_multiply(a, b)
376
+ return standardize_quaternion(ab)
377
+
378
+
379
+ def quaternion_invert(quaternion):
380
+ """
381
+ Given a quaternion representing rotation, get the quaternion representing
382
+ its inverse.
383
+
384
+ Args:
385
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
386
+ first, which must be versors (unit quaternions).
387
+
388
+ Returns:
389
+ The inverse, a tensor of quaternions of shape (..., 4).
390
+ """
391
+
392
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
393
+
394
+
395
+ def quaternion_apply(quaternion, point):
396
+ """
397
+ Apply the rotation given by a quaternion to a 3D point.
398
+ Usual torch rules for broadcasting apply.
399
+
400
+ Args:
401
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
402
+ point: Tensor of 3D points of shape (..., 3).
403
+
404
+ Returns:
405
+ Tensor of rotated points of shape (..., 3).
406
+ """
407
+ if point.size(-1) != 3:
408
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
409
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
410
+ point_as_quaternion = torch.cat((real_parts, point), -1)
411
+ out = quaternion_raw_multiply(
412
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
413
+ quaternion_invert(quaternion),
414
+ )
415
+ return out[..., 1:]
416
+
417
+
418
+ def axis_angle_to_matrix(axis_angle):
419
+ """
420
+ Convert rotations given as axis/angle to rotation matrices.
421
+
422
+ Args:
423
+ axis_angle: Rotations given as a vector in axis angle form,
424
+ as a tensor of shape (..., 3), where the magnitude is
425
+ the angle turned anticlockwise in radians around the
426
+ vector's direction.
427
+
428
+ Returns:
429
+ Rotation matrices as tensor of shape (..., 3, 3).
430
+ """
431
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
432
+
433
+
434
+ def matrix_to_axis_angle(matrix):
435
+ """
436
+ Convert rotations given as rotation matrices to axis/angle.
437
+
438
+ Args:
439
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
440
+
441
+ Returns:
442
+ Rotations given as a vector in axis angle form, as a tensor
443
+ of shape (..., 3), where the magnitude is the angle
444
+ turned anticlockwise in radians around the vector's
445
+ direction.
446
+ """
447
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
448
+
449
+
450
+ def axis_angle_to_quaternion(axis_angle):
451
+ """
452
+ Convert rotations given as axis/angle to quaternions.
453
+
454
+ Args:
455
+ axis_angle: Rotations given as a vector in axis angle form,
456
+ as a tensor of shape (..., 3), where the magnitude is
457
+ the angle turned anticlockwise in radians around the
458
+ vector's direction.
459
+
460
+ Returns:
461
+ quaternions with real part first, as tensor of shape (..., 4).
462
+ """
463
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
464
+ half_angles = 0.5 * angles
465
+ eps = 1e-6
466
+ small_angles = angles.abs() < eps
467
+ sin_half_angles_over_angles = torch.empty_like(angles)
468
+ sin_half_angles_over_angles[~small_angles] = (
469
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
470
+ )
471
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
472
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
473
+ sin_half_angles_over_angles[small_angles] = (
474
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
475
+ )
476
+ quaternions = torch.cat(
477
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
478
+ )
479
+ return quaternions
480
+
481
+
482
+ def quaternion_to_axis_angle(quaternions):
483
+ """
484
+ Convert rotations given as quaternions to axis/angle.
485
+
486
+ Args:
487
+ quaternions: quaternions with real part first,
488
+ as tensor of shape (..., 4).
489
+
490
+ Returns:
491
+ Rotations given as a vector in axis angle form, as a tensor
492
+ of shape (..., 3), where the magnitude is the angle
493
+ turned anticlockwise in radians around the vector's
494
+ direction.
495
+ """
496
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
497
+ half_angles = torch.atan2(norms, quaternions[..., :1])
498
+ angles = 2 * half_angles
499
+ eps = 1e-6
500
+ small_angles = angles.abs() < eps
501
+ sin_half_angles_over_angles = torch.empty_like(angles)
502
+ sin_half_angles_over_angles[~small_angles] = (
503
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
504
+ )
505
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
506
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
507
+ sin_half_angles_over_angles[small_angles] = (
508
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
509
+ )
510
+ return quaternions[..., 1:] / sin_half_angles_over_angles
511
+
512
+
513
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
514
+ """
515
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
516
+ using Gram--Schmidt orthogonalisation per Section B of [1].
517
+ Args:
518
+ d6: 6D rotation representation, of size (*, 6)
519
+
520
+ Returns:
521
+ batch of rotation matrices of size (*, 3, 3)
522
+
523
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
524
+ On the Continuity of Rotation Representations in Neural Networks.
525
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
526
+ Retrieved from http://arxiv.org/abs/1812.07035
527
+ """
528
+
529
+ a1, a2 = d6[..., :3], d6[..., 3:]
530
+ b1 = F.normalize(a1, dim=-1)
531
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
532
+ b2 = F.normalize(b2, dim=-1)
533
+ b3 = torch.cross(b1, b2, dim=-1)
534
+ return torch.stack((b1, b2, b3), dim=-2)
535
+
536
+
537
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
538
+ """
539
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
540
+ by dropping the last row. Note that 6D representation is not unique.
541
+ Args:
542
+ matrix: batch of rotation matrices of size (*, 3, 3)
543
+
544
+ Returns:
545
+ 6D rotation representation, of size (*, 6)
546
+
547
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
548
+ On the Continuity of Rotation Representations in Neural Networks.
549
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
550
+ Retrieved from http://arxiv.org/abs/1812.07035
551
+ """
552
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)