File size: 10,350 Bytes
f7400bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import smplx
import torch
import numpy as np
from . import rotation_conversions as rc
import os
import wget 

download_path = "./datasets/hub"
smplx_model_dir = os.path.join(download_path, "smplx_models", "smplx")
if not os.path.exists(smplx_model_dir):
    smplx_model_file_path = os.path.join(smplx_model_dir, "SMPLX_NEUTRAL_2020.npz")
    os.makedirs(smplx_model_dir, exist_ok=True)
    if not os.path.exists(smplx_model_file_path):
        print(f"Downloading {smplx_model_file_path}")
        wget.download(
            "https://huggingface.co/spaces/H-Liu1997/EMAGE/resolve/main/EMAGE/smplx_models/smplx/SMPLX_NEUTRAL_2020.npz",
            smplx_model_file_path,
        )

smplx_model = smplx.create(
    "./datasets/hub/smplx_models/",
    model_type='smplx',
    gender='NEUTRAL_2020',
    use_face_contour=False,
    num_betas=300,
    num_expression_coeffs=100,
    ext='npz',
    use_pca=False,
).eval()

def get_motion_rep_tensor(motion_tensor, pose_fps=30, device="cuda", betas=None):
    global smplx_model
    smplx_model = smplx_model.to(device)
    bs, n, _ = motion_tensor.shape
    motion_tensor = motion_tensor.float().to(device)
    motion_tensor_reshaped = motion_tensor.reshape(bs * n, 165)
    betas = torch.zeros(n, 300, device=device) if betas is None else betas.to(device).unsqueeze(0).repeat(n, 1)
    output = smplx_model(
        betas=torch.zeros(bs * n, 300, device=device),
        transl=torch.zeros(bs * n, 3, device=device),
        expression=torch.zeros(bs * n, 100, device=device),
        jaw_pose=torch.zeros(bs * n, 3, device=device),
        global_orient=torch.zeros(bs * n, 3, device=device),
        body_pose=motion_tensor_reshaped[:, 3:21 * 3 + 3],
        left_hand_pose=motion_tensor_reshaped[:, 25 * 3:40 * 3],
        right_hand_pose=motion_tensor_reshaped[:, 40 * 3:55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(bs * n, 3, device=device),
        reye_pose=torch.zeros(bs * n, 3, device=device),
    )
    joints = output['joints'].reshape(bs, n, 127, 3)[:, :, :55, :]
    dt = 1 / pose_fps
    init_vel = (joints[:, 1:2] - joints[:, 0:1]) / dt
    middle_vel = (joints[:, 2:] - joints[:, :-2]) / (2 * dt)
    final_vel = (joints[:, -1:] - joints[:, -2:-1]) / dt
    vel = torch.cat([init_vel, middle_vel, final_vel], dim=1)
    position = joints
    rot_matrices = rc.axis_angle_to_matrix(motion_tensor.reshape(bs, n, 55, 3))
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(bs, n, 55, 6)
    init_vel_ang = (motion_tensor[:, 1:2] - motion_tensor[:, 0:1]) / dt
    middle_vel_ang = (motion_tensor[:, 2:] - motion_tensor[:, :-2]) / (2 * dt)
    final_vel_ang = (motion_tensor[:, -1:] - motion_tensor[:, -2:-1]) / dt
    angular_velocity = torch.cat([init_vel_ang, middle_vel_ang, final_vel_ang], dim=1).reshape(bs, n, 55, 3)
    rep15d = torch.cat([position, vel, rot6d, angular_velocity], dim=3).reshape(bs, n, 55 * 15)
    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": motion_tensor,
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
    }

def get_motion_rep_numpy(poses_np, pose_fps=30, device="cuda", expressions=None, expression_only=False, betas=None):
    # motion["poses"] is expected to be numpy array of shape (n, 165)
    # (n, 55*3), axis-angle for 55 joints
    global smplx_model
    smplx_model = smplx_model.to(device)
    n = poses_np.shape[0]

    # Convert numpy to torch tensor for SMPL-X forward pass
    poses_ts = torch.from_numpy(poses_np).float().to(device).unsqueeze(0)  # (1, n, 165)
    poses_ts_reshaped = poses_ts.reshape(-1, 165)  # (n, 165)
    betas = torch.zeros(n, 300, device=device) if betas is None else torch.from_numpy(betas).to(device).unsqueeze(0).repeat(n, 1)
    if expressions is not None and expression_only:
        # print("xx")
        expressions = torch.from_numpy(expressions).float().to(device)
        output = smplx_model(
            betas=betas,
            transl=torch.zeros(n, 3, device=device),
            expression=expressions,
            jaw_pose=poses_ts_reshaped[:, 22 * 3:23 * 3],
            global_orient=torch.zeros(n, 3, device=device),
            body_pose=torch.zeros(n, 21*3, device=device),
            left_hand_pose=torch.zeros(n, 15*3, device=device),
            right_hand_pose=torch.zeros(n, 15*3, device=device),
            return_joints=True,
            leye_pose=torch.zeros(n, 3, device=device),
            reye_pose=torch.zeros(n, 3, device=device),
            )
        joints = output["vertices"].detach().cpu().numpy().reshape(n, -1)
        return {"vertices": joints}

    # Run smplx model to get joints
    output = smplx_model(
        betas=betas,
        transl=torch.zeros(n, 3, device=device),
        expression=torch.zeros(n, 100, device=device),
        jaw_pose=torch.zeros(n, 3, device=device),
        global_orient=torch.zeros(n, 3, device=device),
        body_pose=poses_ts_reshaped[:, 3:21 * 3 + 3],
        left_hand_pose=poses_ts_reshaped[:, 25 * 3:40 * 3],
        right_hand_pose=poses_ts_reshaped[:, 40 * 3:55 * 3],
        return_joints=True,
        leye_pose=torch.zeros(n, 3, device=device),
        reye_pose=torch.zeros(n, 3, device=device),
    )
    joints = output["joints"].detach().cpu().numpy().reshape(n, 127, 3)[:, :55, :]

    dt = 1 / pose_fps
    # Compute linear velocity
    init_vel = (joints[1:2] - joints[0:1]) / dt
    middle_vel = (joints[2:] - joints[:-2]) / (2 * dt)
    final_vel = (joints[-1:] - joints[-2:-1]) / dt
    vel = np.concatenate([init_vel, middle_vel, final_vel], axis=0)

    position = joints

    # Compute rotation 6D from axis-angle
    poses_ts_reshaped_aa = poses_ts.reshape(1, n, 55, 3)
    rot_matrices = rc.axis_angle_to_matrix(poses_ts_reshaped_aa)[0]  # (n, 55, 3, 3)
    rot6d = rc.matrix_to_rotation_6d(rot_matrices).reshape(n, 55, 6).cpu().numpy()

    # Compute angular velocity
    init_vel_ang = (poses_np[1:2] - poses_np[0:1]) / dt
    middle_vel_ang = (poses_np[2:] - poses_np[:-2]) / (2 * dt)
    final_vel_ang = (poses_np[-1:] - poses_np[-2:-1]) / dt
    angular_velocity = np.concatenate([init_vel_ang, middle_vel_ang, final_vel_ang], axis=0).reshape(n, 55, 3)

    # rep15d: position(55*3), vel(55*3), rot6d(55*6), angular_velocity(55*3) => total 55*(3+3+6+3)=55*15
    rep15d = np.concatenate([position, vel, rot6d, angular_velocity], axis=2).reshape(n, 55 * 15)

    return {
        "position": position,
        "velocity": vel,
        "rotation": rot6d,
        "axis_angle": poses_np,
        "angular_velocity": angular_velocity,
        "rep15d": rep15d,
    }

def process_smplx_motion(pose_file, smplx_model, pose_fps, facial_rep=None):
    """Process SMPLX pose and facial data together."""
    pose_data = np.load(pose_file, allow_pickle=True)
    stride = int(30/pose_fps)
    
    # Extract pose and facial data with same stride
    pose_frames = pose_data["poses"][::stride]
    facial_frames = pose_data["expressions"][::stride] if facial_rep is not None else None
    
    # Process translations
    trans = pose_data["trans"][::stride]
    trans[:,0] = trans[:,0] - trans[0,0]
    trans[:,2] = trans[:,2] - trans[0,2]
    
    # Calculate translation velocities
    trans_v = np.zeros_like(trans)
    trans_v[1:,0] = trans[1:,0] - trans[:-1,0]
    trans_v[0,0] = trans_v[1,0]
    trans_v[1:,2] = trans[1:,2] - trans[:-1,2]
    trans_v[0,2] = trans_v[1,2]
    trans_v[:,1] = trans[:,1]
    
    # Process shape data
    shape = np.repeat(pose_data["betas"].reshape(1, 300), pose_frames.shape[0], axis=0)
    
    # # Calculate contacts
    # contacts = calculate_foot_contacts(pose_data, smplx_model)
    
    # if contacts is not None:
    #     pose_data = np.concatenate([pose_data, contacts], axis=1)
    
    return {
        'pose': pose_frames,
        'trans': trans,
        'trans_v': trans_v,
        'shape': shape,
        'facial': facial_frames if facial_frames is not None else np.array([-1])
    }

def calculate_foot_contacts(pose_data, smplx_model):
    """Calculate foot contacts from pose data."""
    max_length = 128
    all_tensor = []
    n = pose_data["poses"].shape[0]
    
    # Process in batches
    for i in range(n // max_length):
        joints = process_joints_batch(pose_data, i, max_length, smplx_model)
        all_tensor.append(joints)
    
    # Process remaining frames
    if n % max_length != 0:
        r = n % max_length
        joints = process_joints_batch(pose_data, n // max_length, r, smplx_model, remainder=True)
        all_tensor.append(joints)
    
    # Calculate velocities and contacts
    joints = torch.cat(all_tensor, axis=0)
    feetv = torch.zeros(joints.shape[1], joints.shape[0])
    joints = joints.permute(1, 0, 2)
    feetv[:, :-1] = (joints[:, 1:] - joints[:, :-1]).norm(dim=-1)
    contacts = (feetv < 0.01).numpy().astype(float)
    
    return contacts.transpose(1, 0)

def process_joints_batch(pose_data, batch_idx, batch_size, smplx_model, remainder=False):
    """Process a batch of joints for contact calculation."""
    start_idx = batch_idx * batch_size
    end_idx = start_idx + batch_size
    
    with torch.no_grad():
        return smplx_model(
            betas=torch.from_numpy(pose_data["betas"]).cuda().float().repeat(batch_size, 1),
            transl=torch.from_numpy(pose_data["trans"][start_idx:end_idx]).cuda().float(),
            expression=torch.from_numpy(pose_data["expressions"][start_idx:end_idx]).cuda().float(),
            jaw_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 66:69]).cuda().float(),
            global_orient=torch.from_numpy(pose_data["poses"][start_idx:end_idx, :3]).cuda().float(),
            body_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 3:21*3+3]).cuda().float(),
            left_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 25*3:40*3]).cuda().float(),
            right_hand_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 40*3:55*3]).cuda().float(),
            leye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 69:72]).cuda().float(),
            reye_pose=torch.from_numpy(pose_data["poses"][start_idx:end_idx, 72:75]).cuda().float(),
            return_verts=True,
            return_joints=True
        )['joints'][:, (7,8,10,11), :].reshape(batch_size, 4, 3).cpu()