import os from pathlib import Path import sys from tempfile import TemporaryDirectory import librosa as lr import matplotlib.animation as animation import matplotlib.pyplot as plt from mpl_toolkits.mplot3d import axes3d import numpy as np import soundfile as sf import torch from matplotlib import cm from matplotlib.colors import ListedColormap from pytorch3d.transforms import (axis_angle_to_quaternion, quaternion_apply, quaternion_multiply) from tqdm import tqdm from typing import NewType Tensor = NewType('Tensor', torch.Tensor) import torch.nn.functional as F try: import pickle5 as pickle except ImportError: import pickle smpl_joints = [ "root", # 0 "lhip", # 1 "rhip", # 2 "belly", # 3 "lknee", # 4 "rknee", # 5 "spine", # 6 "lankle",# 7 "rankle",# 8 "chest", # 9 "ltoes", # 10 "rtoes", # 11 "neck", # 12 "linshoulder", # 13 "rinshoulder", # 14 "head", # 15 "lshoulder", # 16 "rshoulder", # 17 "lelbow", # 18 "relbow", # 19 "lwrist", # 20 "rwrist", # 21 "lhand", # 22 "rhand", # 23 ] smplh_joints = [ 'pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3' ] smplx_joints = [ 'pelvis', 'left_hip', 'right_hip', 'spine1', 'left_knee', 'right_knee', 'spine2', 'left_ankle', 'right_ankle', 'spine3', 'left_foot', 'right_foot', 'neck', 'left_collar', 'right_collar', 'head', 'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow', 'left_wrist', 'right_wrist', 'jaw', 'left_eye_smplhf', 'right_eye_smplhf', 'left_index1', 'left_index2', 'left_index3', 'left_middle1', 'left_middle2', 'left_middle3', 'left_pinky1', 'left_pinky2', 'left_pinky3', 'left_ring1', 'left_ring2', 'left_ring3', 'left_thumb1', 'left_thumb2', 'left_thumb3', 'right_index1', 'right_index2', 'right_index3', 'right_middle1', 'right_middle2', 'right_middle3', 'right_pinky1', 'right_pinky2', 'right_pinky3', 'right_ring1', 'right_ring2', 'right_ring3', 'right_thumb1', 'right_thumb2', 'right_thumb3' ] smpl_parents = [ -1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 21, ] smplh_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 20, 22, 23, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 21, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50] smplx_parents = [-1, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19, 15, 15, 15, 20, 25, 26, 20, 28, 29, 20, 31, 32, 20, 34, 35, 20, 37, 38, 21, 40, 41, 21, 43, 44, 21, 46, 47, 21, 49, 50, 21, 52, 53] smpl_offsets = [ [0.0, 0.0, 0.0], [0.05858135, -0.08228004, -0.01766408], [-0.06030973, -0.09051332, -0.01354254], [0.00443945, 0.12440352, -0.03838522], [0.04345142, -0.38646945, 0.008037], [-0.04325663, -0.38368791, -0.00484304], [0.00448844, 0.1379564, 0.02682033], [-0.01479032, -0.42687458, -0.037428], [0.01905555, -0.4200455, -0.03456167], [-0.00226458, 0.05603239, 0.00285505], [0.04105436, -0.06028581, 0.12204243], [-0.03483987, -0.06210566, 0.13032329], [-0.0133902, 0.21163553, -0.03346758], [0.07170245, 0.11399969, -0.01889817], [-0.08295366, 0.11247234, -0.02370739], [0.01011321, 0.08893734, 0.05040987], [0.12292141, 0.04520509, -0.019046], [-0.11322832, 0.04685326, -0.00847207], [0.2553319, -0.01564902, -0.02294649], [-0.26012748, -0.01436928, -0.03126873], [0.26570925, 0.01269811, -0.00737473], [-0.26910836, 0.00679372, -0.00602676], [0.08669055, -0.01063603, -0.01559429], [-0.0887537, -0.00865157, -0.01010708], ] def set_line_data_3d(line, x): line.set_data(x[:, :2].T) line.set_3d_properties(x[:, 2]) def set_scatter_data_3d(scat, x, c): scat.set_offsets(x[:, :2]) scat.set_3d_properties(x[:, 2], "z") scat.set_facecolors([c]) def get_axrange(poses): pose = poses[0] x_min = pose[:, 0].min() x_max = pose[:, 0].max() y_min = pose[:, 1].min() y_max = pose[:, 1].max() z_min = pose[:, 2].min() z_max = pose[:, 2].max() xdiff = x_max - x_min ydiff = y_max - y_min zdiff = z_max - z_min biggestdiff = max([xdiff, ydiff, zdiff]) return biggestdiff def plot_single_pose(num, poses, lines, ax, axrange, scat, contact, ske_parents): pose = poses[num] static = contact[num] indices = [7, 8, 10, 11] for i, (point, idx) in enumerate(zip(scat, indices)): position = pose[idx : idx + 1] color = "r" if static[i] else "g" set_scatter_data_3d(point, position, color) for i, (p, line) in enumerate(zip(ske_parents, lines)): # don't plot root if i == 0: continue # stack to create a line data = np.stack((pose[i], pose[p]), axis=0) set_line_data_3d(line, data) if num == 0: if isinstance(axrange, int): axrange = (axrange, axrange, axrange) xcenter, ycenter, zcenter = 0, 0, 2.5 stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2 x_min, x_max = xcenter - stepx, xcenter + stepx y_min, y_max = ycenter - stepy, ycenter + stepy z_min, z_max = zcenter - stepz, zcenter + stepz ax.set_xlim(x_min, x_max) ax.set_ylim(y_min, y_max) ax.set_zlim(z_min, z_max) def skeleton_render( poses, epoch=0, out="renders", name="", sound=True, stitch=False, sound_folder="ood_sliced", contact=None, render=True, smpl_mode="smpl", # 是否渲染双手 ): if render: if smpl_mode=="smpl": poses = np.concatenate((poses[:, :23, :], np.expand_dims(poses[:, 37, :], axis=1)), axis=1) ske_parents = smpl_parents elif smpl_mode == "smplx": ske_parents = smplx_parents # generate the pose with FK Path(out).mkdir(parents=True, exist_ok=True) num_steps = poses.shape[0] # fig = plt.figure() ax = fig.add_subplot(projection="3d") point = np.array([0, 0, 1]) normal = np.array([0, 0, 1]) d = -point.dot(normal) xx, yy = np.meshgrid(np.linspace(-1.5, 1.5, 2), np.linspace(-1.5, 1.5, 2)) z = (-normal[0] * xx - normal[1] * yy - d) * 1.0 / normal[2] # plot the plane ax.plot_surface(xx, yy, z, zorder=-11, cmap=cm.twilight) # Create lines initially without data lines = [ ax.plot([], [], [], zorder=10, linewidth=1.5)[0] for _ in ske_parents ] scat = [ ax.scatter([], [], [], zorder=10, s=0, cmap=ListedColormap(["r", "g", "b"])) for _ in range(4) ] axrange = 3 # create contact labels feet = poses[:, (7, 8, 10, 11)] feetv = np.zeros(feet.shape[:2]) feetv[:-1] = np.linalg.norm(feet[1:] - feet[:-1], axis=-1) if contact is None: contact = feetv < 0.01 else: contact = contact > 0.95 # Creating the Animation object anim = animation.FuncAnimation( fig, plot_single_pose, num_steps, fargs=(poses, lines, ax, axrange, scat, contact, ske_parents), interval=1000 // 30, ) if sound: # make a temporary directory to save the intermediate gif in if render: temp_dir = TemporaryDirectory() gifname = os.path.join(temp_dir.name, f"{epoch}.gif") anim.save(gifname) # stitch wavs if stitch: assert type(name) == list # must be a list of names to do stitching name_ = [os.path.splitext(x)[0] + ".wav" for x in name] audio, sr = lr.load(name_[0], sr=None) ll, half = len(audio), len(audio) // 2 total_wav = np.zeros(ll + half * (len(name_) - 1)) total_wav[:ll] = audio idx = ll for n_ in name_[1:]: audio, sr = lr.load(n_, sr=None) total_wav[idx : idx + half] = audio[half:] idx += half # save a dummy spliced audio audioname = f"{temp_dir.name}/tempsound.wav" if render else os.path.join(out, f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.wav') sf.write(audioname, total_wav, sr) outname = os.path.join( out, f'{epoch}_{"_".join(os.path.splitext(os.path.basename(name[0]))[0].split("_")[:-1])}.mp4', ) else: assert type(name) == str assert name != "", "Must provide an audio filename" audioname = name outname = os.path.join( out, f"{epoch}_{os.path.splitext(os.path.basename(name))[0]}.mp4" ) if render: print(f"ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}") out = os.system( f"/home/lrh/Documents/ffmpeg-6.0-amd64-static/ffmpeg -loglevel error -stream_loop 0 -y -i {gifname} -i {audioname} -shortest -c:v libx264 -crf 26 -c:a aac -q:a 4 {outname}" ) else: if render: # actually save the gif path = os.path.normpath(name) pathparts = path.split(os.sep) gifname = os.path.join(out, f"{pathparts[-1][:-4]}.gif") anim.save(gifname, savefig_kwargs={"transparent": True, "facecolor": "none"},) plt.close() class SMPLSkeleton: def __init__( self, device=None, ): offsets = smpl_offsets parents = smpl_parents assert len(offsets) == len(parents) self._offsets = torch.Tensor(offsets) #.to(device) self._parents = np.array(parents) self._compute_metadata() def _compute_metadata(self): self._has_children = np.zeros(len(self._parents)).astype(bool) for i, parent in enumerate(self._parents): if parent != -1: self._has_children[parent] = True self._children = [] for i, parent in enumerate(self._parents): self._children.append([]) for i, parent in enumerate(self._parents): if parent != -1: self._children[parent].append(i) def forward(self, rotations, root_positions): """ Perform forward kinematics using the given trajectory and local rotations. Arguments (where N = batch size, L = sequence length, J = number of joints): -- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint. -- root_positions: (N, L, 3) tensor describing the root joint positions. """ assert len(rotations.shape) == 4 assert len(root_positions.shape) == 3 # transform from axis angle to quaternion fk_device = rotations.device self._offsets.to(fk_device) rotations = axis_angle_to_quaternion(rotations) positions_world = [] rotations_world = [] expanded_offsets = self._offsets.expand( rotations.shape[0], rotations.shape[1], self._offsets.shape[0], self._offsets.shape[1], ).to(fk_device) # Parallelize along the batch and time dimensions for i in range(self._offsets.shape[0]): if self._parents[i] == -1: positions_world.append(root_positions) rotations_world.append(rotations[:, :, 0]) else: positions_world.append( quaternion_apply( rotations_world[self._parents[i]], expanded_offsets[:, :, i] ) + positions_world[self._parents[i]] ) if self._has_children[i]: rotations_world.append( quaternion_multiply( rotations_world[self._parents[i]], rotations[:, :, i] ) ) else: # This joint is a terminal node -> it would be useless to compute the transformation rotations_world.append(None) return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2) @torch.no_grad() class SMPLX_Skeleton: def __init__( self, device=None, batch=64, ): # offsets = smpl_offsets self.device = device self.parents = smplx_parents self.J = np.load(os.path.join(os.path.dirname(__file__), "smplx_neu_J_1.npy")) self.J = torch.from_numpy(self.J).to(device).unsqueeze(dim=0).repeat(batch, 1, 1) def batch_rodrigues(self, rot_vecs: Tensor, epsilon: float = 1e-8,) -> Tensor: ''' Calculates the rotation matrices for a batch of rotation vectors Parameters ---------- rot_vecs: torch.tensor Nx3 array of N axis-angle vectors Returns ------- R: torch.tensor Nx3x3 The rotation matrices for the given axis-angle parameters ''' batch_size = rot_vecs.shape[0] device, dtype = rot_vecs.device, rot_vecs.dtype angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True) rot_dir = rot_vecs / angle cos = torch.unsqueeze(torch.cos(angle), dim=1) sin = torch.unsqueeze(torch.sin(angle), dim=1) # Bx1 arrays rx, ry, rz = torch.split(rot_dir, 1, dim=1) K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device) zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device) K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \ .view((batch_size, 3, 3)) ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0) rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K) return rot_mat def batch_rigid_transform(self, rot_mats: Tensor, joints: Tensor, parents: Tensor, dtype=torch.float32 ) -> Tensor: """ Applies a batch of rigid transformations to the joints Parameters ---------- rot_mats : torch.tensor BxNx3x3 Tensor of rotation matrices joints : torch.tensor BxNx3 Locations of joints parents : torch.tensor BxN The kinematic tree of each object dtype : torch.dtype, optional: The data type of the created tensors, the default is torch.float32 Returns ------- posed_joints : torch.tensor BxNx3 The locations of the joints after applying the pose rotations rel_transforms : torch.tensor BxNx4x4 The relative (with respect to the root joint) rigid transformations for all the joints """ joints = torch.unsqueeze(joints, dim=-1) # joints_check = joints.detach().cpu().numpy() rel_joints = joints.clone() rel_joints[:, 1:] -= joints[:, parents[1:]] transforms_mat = self.transform_mat( rot_mats.reshape(-1, 3, 3), rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4) transform_chain = [transforms_mat[:, 0]] for i in range(1, parents.shape[0]): # Subtract the joint location at the rest pose # No need for rotation, since it's identity when at rest curr_res = torch.matmul(transform_chain[parents[i]], transforms_mat[:, i]) transform_chain.append(curr_res) transforms = torch.stack(transform_chain, dim=1) # The last column of the transformations contains the posed joints posed_joints = transforms[:, :, :3, 3] # joints_homogen = F.pad(joints, [0, 0, 0, 1]) # rel_transforms = transforms - F.pad( # torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0]) return posed_joints #, rel_transforms def transform_mat(self, R: Tensor, t: Tensor) -> Tensor: ''' Creates a batch of transformation matrices Args: - R: Bx3x3 array of a batch of rotation matrices - t: Bx3x1 array of a batch of translation vectors Returns: - T: Bx4x4 Transformation matrix ''' # No padding left or right, only add an extra row return torch.cat([F.pad(R, [0, 0, 0, 1]), F.pad(t, [0, 0, 0, 1], value=1)], dim=2) def motion_data_load_process(self, motionfile): if motionfile.split(".")[-1] == "pkl": pkl_data = pickle.load(open(motionfile, "rb")) if "pos" in pkl_data.keys(): local_q_165 = torch.from_numpy(pkl_data["q"]).to(self.device).float() root_pos = torch.from_numpy(pkl_data["pos"]).to(self.device).float() root_pos = root_pos[:, :] - root_pos[0, :] return local_q_165, root_pos else: smpl_poses = pkl_data["smpl_poses"] if smpl_poses.shape[0] != 150 and smpl_poses.shape[0] != 300: smpl_poses = smpl_poses.reshape(150, -1) # modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1) # assert modata.shape[1] == 159 # modata = torch.from_numpy(modata).to(f'cuda:{args.gpu}') root_pos = pkl_data["smpl_trans"] local_q = torch.from_numpy(smpl_poses).to(self.device).float() root_pos = torch.from_numpy(root_pos).to(self.device).float() local_q_165 = torch.cat([local_q[:, :66], torch.zeros([local_q.shape[0], 9], device=local_q.device, dtype=torch.float32), local_q[:, 66:]], dim=1).to(self.device).float() root_pos = root_pos[:, :] - root_pos[0, :] return local_q_165, root_pos def forward(self, rotations, root_positions): """ Perform forward kinematics using the given trajectory and local rotations. Arguments (where N = batch size, L = sequence length, J = number of joints): -- rotations: (N, 156) 或 (N, 165) -- root_positions: (N, 3) 输出: N, 55, 3 关节点全局坐标 """ # assert len(rotations.shape) == 4 # assert len(root_positions.shape) == 3 # print(fk_device) fk_device = rotations.device if rotations.shape[1] == 156: local_q_165 = torch.cat([rotations[:, :66], torch.zeros([rotations.shape[0], 9], device=fk_device, dtype=torch.float32), rotations[:, 66:]], dim=1).to(fk_device).float() elif rotations.shape[1] == 165: local_q_165 = rotations.to(fk_device).float() else: print("rotations shape error", rotations.shape) sys.exit(0) root_pos = root_positions.to(fk_device).float() assert local_q_165.shape[1] == 165 B, C = local_q_165.shape # print("local_q shape is:", local_q_165.shape) rot_mats = self.batch_rodrigues(local_q_165.view(-1, 3)).view( [B, -1, 3, 3]) # J = np.load("/data/lrh/project/Dance/mdm_v2/model/smplx_neu_J_1.npy") if self.J.shape[0] >= B: J_temp = self.J[:B,:,:] #self.J = self.J[:B,:,:] else: J_temp = self.J[:1,:,:].repeat(B, 1, 1) print("warning: self.J size 0 is lower than batchsize x seq_len") parents = torch.Tensor(self.parents).long() # if self.parents is None else self.parents J_transformed = self.batch_rigid_transform(rot_mats, J_temp, parents, dtype=torch.float32) J_transformed += root_pos.unsqueeze(dim=1) # J_transformed = J_transformed.detach().cpu().numpy() return J_transformed if __name__ == "__main__": print("1") device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") smplx_fk = SMPLX_Skeleton(device = device, batch=150) motion_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/dances/012_slice0.pkl" # music_file = "/home/data/lrh/datasets/fine_dance/magicsmpl/sliced/test/wavs/012_slice0.wav" local_q_165, root_pos = smplx_fk.motion_data_load_process(motion_file) print("local_q_165.shape", local_q_165.shape) print("root_pos.shape", root_pos.shape) joints = smplx_fk.forward(local_q_165, root_pos).detach().cpu().numpy() # 150, 165 150, 3 print("joints.shape", joints.shape) # skeleton_render( # joints, # epoch=f"e{1}_b{1}", # out="./output/temp", # name=music_file, # render=True, # stitch=False, # sound=True, # smpl_mode="smplx" # )