atnikos commited on
Commit
78c7556
·
1 Parent(s): 0edf584

omegaconf req

Browse files
Files changed (1) hide show
  1. feature_extractor.py +76 -0
feature_extractor.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # - "body_transl_delta_pelv"
2
+ # - "body_orient_xy"
3
+ # - "z_orient_delta"
4
+ # - "body_pose"
5
+ # - "body_joints_local_wo_z_rot"
6
+ from transform3d import transform_body_pose, change_for, remove_z_rot, get_z_rot, rot_diff
7
+ from einops import rearrange
8
+ import torch
9
+
10
+ def to_tensor(array):
11
+ if torch.is_tensor(array):
12
+ return array
13
+ else:
14
+ return torch.tensor(array)
15
+
16
+ def _get_body_transl_delta_pelv(data):
17
+ """
18
+ get body pelvis tranlation delta relative to pelvis coord.frame
19
+ v_i = t_i - t_{i-1} relative to R_{i-1}
20
+ """
21
+ trans = to_tensor(data['trans'])
22
+ trans_vel = trans - trans.roll(1, 0) # shift one right and subtract
23
+ pelvis_orient = transform_body_pose(to_tensor(data['rots'][..., :3]), "aa->rot")
24
+ trans_vel_pelv = change_for(trans_vel, pelvis_orient.roll(1, 0))
25
+ trans_vel_pelv[0] = 0 # zero out velocity of first frame
26
+ return trans_vel_pelv
27
+
28
+ def _get_body_orient_xy(data):
29
+ """get body global orientation"""
30
+ # default is axis-angle representation
31
+ pelvis_orient = to_tensor(data['rots'][..., :3])
32
+ # if rot_repr == "6d":
33
+ # axis-angle to rotation matrix & drop last row
34
+ pelvis_orient_xy = remove_z_rot(pelvis_orient, in_format="aa")
35
+ return pelvis_orient_xy
36
+
37
+ def _get_body_pose(data):
38
+ """get body pose"""
39
+ # default is axis-angle representation: Frames x (Jx3) (J=21)
40
+ pose = to_tensor(data['rots'][..., 3:3 + 21*3]) # drop pelvis orientation
41
+ pose = transform_body_pose(pose, f"aa->6d")
42
+ return pose
43
+
44
+ def _get_body_joints_local_wo_z_rot(data):
45
+ """get body joint coordinates relative to the pelvis"""
46
+ joints = to_tensor(data['joint_positions'][:, :22, :])
47
+ pelvis_transl = to_tensor(joints[:, 0, :])
48
+ joints_glob = to_tensor(joints[:, :22, :])
49
+ pelvis_orient = to_tensor(data['rots'][..., :3])
50
+
51
+ pelvis_orient_z = get_z_rot(pelvis_orient, in_format="aa")
52
+ # pelvis_orient_z = transform_body_pose(pelvis_orient_z, "aa->rot").float()
53
+ # relative_joints = R.T @ (p_global - pelvis_translation)
54
+ rel_joints = torch.einsum('fdi,fjd->fji',
55
+ pelvis_orient_z,
56
+ joints_glob - pelvis_transl[:, None, :])
57
+
58
+ return rearrange(rel_joints, '... j c -> ... (j c)')
59
+
60
+ def _get_z_orient_delta(data):
61
+ """get global body orientation delta"""
62
+ # default is axis-angle representation
63
+ pelvis_orient = to_tensor(data['rots'][..., :3])
64
+ pelvis_orient_z = get_z_rot(pelvis_orient, in_format="aa")
65
+ pelvis_orient_z = transform_body_pose(pelvis_orient_z, "rot->aa")
66
+ z_orient_delta = rot_diff(pelvis_orient_z, in_format="aa",
67
+ out_format='6d')
68
+ return z_orient_delta
69
+
70
+ FEAT_GET_METHODS = {
71
+ "body_transl_delta_pelv": _get_body_transl_delta_pelv,
72
+ "body_orient_xy": _get_body_orient_xy,
73
+ "z_orient_delta": _get_z_orient_delta,
74
+ "body_pose": _get_body_pose,
75
+ "body_joints_local_wo_z_rot": _get_body_joints_local_wo_z_rot,
76
+ }