TunaDance / render.py
NikhilMarisetty's picture
Upload folder using huggingface_hub
eb71a72 verified
import pickle
import numpy as np
import torch
import cv2
import os
# os.environ["PYOPENGL_PLATFORM"] = "osmesa" # Not available on macOS
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from tqdm import tqdm
from smplx import SMPL, SMPLX, SMPLH
import pyrender
import trimesh
import subprocess
import pickle
from pytorch3d.transforms import (axis_angle_to_matrix, matrix_to_axis_angle,
matrix_to_quaternion, matrix_to_rotation_6d,
quaternion_to_matrix, rotation_6d_to_matrix)
import sys
sys.path.append('.')
import argparse
def quat_to_6v(q):
assert q.shape[-1] == 4
mat = quaternion_to_matrix(q)
mat = matrix_to_rotation_6d(mat)
return mat
def quat_from_6v(q):
assert q.shape[-1] == 6
mat = rotation_6d_to_matrix(q)
quat = matrix_to_quaternion(mat)
return quat
def ax_to_6v(q):
assert q.shape[-1] == 3
mat = axis_angle_to_matrix(q)
mat = matrix_to_rotation_6d(mat)
return mat
def ax_from_6v(q):
assert q.shape[-1] == 6
mat = rotation_6d_to_matrix(q)
ax = matrix_to_axis_angle(mat)
return ax
class MovieMaker():
def __init__(self, save_path) -> None:
self.mag = 2
self.eyes = np.array([[3,-3,2], [0,0,-2], [0,0,4], [-8,-8,1], [0,-2,4], [0,2,4]])
self.centers = np.array([[0,0,0],[0,0,0],[0,0.5,0],[0,0,-1], [0,0.5,0], [0,0.5,0]])
self.ups = np.array([[0,0,1],[0,1,0],[0,1,0],[0,0,-1], [0,1,0], [0,1,0]])
self.save_path = save_path
self.fps = args.fps
self.img_size = (1200,1200)
# SMPLH_path = "assets/smpl_model/smplh/SMPLH_MALE.pkl"
# SMPL_path = "assets/smpl_model/smpl/SMPL_MALE.pkl"
SMPLX_path = "assets/smpl_model/smplx/SMPLX_NEUTRAL.npz"
trimesh_path = 'assets/NORMAL_new.obj'
# self.smplh = SMPLH(SMPLH_path, use_pca=False, flat_hand_mean=True)
# self.smplh.to(f'cuda:{args.gpu}').eval()
# self.smpl = SMPL(SMPL_path)
# self.smpl.to(f'cuda:{args.gpu}').eval()
self.smplx = SMPLX(SMPLX_path, use_pca=False, flat_hand_mean=True).eval()
_device = "mps" if torch.backends.mps.is_available() else "cpu"
self.smplx.to(_device).eval()
self.scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 1.0])
camera = pyrender.PerspectiveCamera(yfov=np.pi / 3.0)
camera_pose = look_at(self.eyes[5], self.centers[5], self.ups[5]) # 2
self.scene.add(camera, pose=camera_pose)
light = pyrender.DirectionalLight(color=np.ones(3), intensity=3.0)
self.scene.add(light, pose=camera_pose)
self.r = pyrender.OffscreenRenderer(self.img_size[0], self.img_size[1])
# self.mesh = trimesh.load(trimesh_path)
# floor_mesh = pyrender.Mesh.from_trimesh(self.mesh)
# floor_node = self.scene.add(floor_mesh)
def save_video(self, save_path, color_list):
# save_path = os.path.join(save_path,'move.mp4')
f = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
videowriter = cv2.VideoWriter(save_path,f,self.fps,self.img_size)
for i in range(len(color_list)):
videowriter.write(color_list[i][:,:,::-1])
videowriter.release()
def get_imgs(self, motion):
meshes = self.motion2mesh(motion)
imgs = self.render_imgs(meshes)
return np.concatenate(imgs, axis=1)
def motion2mesh(self, motion):
if args.mode == "smpl":
output = self.smpl.forward(
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
transl = motion[:,:3],
global_orient = motion[:,3:6],
body_pose = torch.cat([motion[:,6:69], motion[:,69:72], motion[:,114:117]], dim=1)
)
elif args.mode == "smplh":
output = self.smplh.forward(
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
# transl = motion[:,:3],
transl = torch.tensor([[0,0,-1]]).expand(motion.shape[0],-1).to(motion.device) ,
global_orient = motion[:,3:6],
body_pose = motion[:,6:69],
left_hand_pose = motion[:,69:114],
right_hand_pose = motion[:,114:159],
)
elif args.mode == "smplx":
output = self.smplx.forward(
betas = torch.zeros([motion.shape[0], 10]).to(motion.device),
# transl = motion[:,:3],
transl = motion[:,:3],
global_orient = motion[:,3:6],
body_pose = motion[:,6:69],
jaw_pose = torch.zeros([motion.shape[0], 3]).to(motion),
leye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
reye_pose = torch.zeros([motion.shape[0], 3]).to(motion),
left_hand_pose = motion[:,69:69+45],
right_hand_pose = motion[:,69+45:],
expression= torch.zeros([motion.shape[0], 10]).to(motion),
)
meshes = []
for i in range(output.vertices.shape[0]):
if args.mode == 'smplh':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
elif args.mode == 'smplx':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
elif args.mode == 'smpl':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
# mesh.export(os.path.join(self.save_path, f'{i}.obj'))
meshes.append(mesh)
return meshes
def render_multi_view(self, meshes, music_file, tab='', eyes=None, centers=None, ups=None, views=1):
if eyes and centers and ups:
assert eyes.shape == centers.shape == ups.shape
else:
eyes = self.eyes
centers = self.centers
ups = self.ups
for i in range(views):
color_list = self.render_single_view(meshes, eyes[1], centers[1], ups[1])
movie_file = os.path.join(self.save_path, tab + '-' + str(i) + '.mp4')
output_file = os.path.join(self.save_path, tab + '-' + str(i) + '-music.mp4')
self.save_video(movie_file, color_list)
if music_file is not None:
subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
else:
subprocess.run(['ffmpeg','-i',movie_file,output_file])
# if music_file is not None:
# subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
# else:
# subprocess.run(['ffmpeg','-i',movie_file,output_file])
os.remove(movie_file)
def render_single_view(self, meshes):
num = len(meshes)
color_list = []
for i in tqdm(range(num)):
mesh_nodes = []
for mesh in meshes[i]:
render_mesh = pyrender.Mesh.from_trimesh(mesh)
mesh_node = self.scene.add(render_mesh)
mesh_nodes.append(mesh_node)
color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
color = color.copy()
color_list.append(color)
for mesh_node in mesh_nodes:
self.scene.remove_node(mesh_node)
return color_list
def render_imgs(self, meshes):
colors = []
for mesh in meshes:
render_mesh = pyrender.Mesh.from_trimesh(mesh)
mesh_node = self.scene.add(render_mesh)
color, _ = self.r.render(self.scene, flags=pyrender.RenderFlags.SHADOWS_DIRECTIONAL)
colors.append(color)
self.scene.remove_node(mesh_node)
return colors
# cv2.imwrite(os.path.join(self.save_path, 'test.jpg'), color[:,:,::-1])
def run(self, seq_rot, music_file=None, tab='', save_pt=False):
if isinstance(seq_rot, np.ndarray):
seq_rot = torch.tensor(seq_rot, dtype=torch.float32, device="mps" if torch.backends.mps.is_available() else "cpu")
if save_pt:
torch.save(seq_rot.detach().cpu(), os.path.join(self.save_path, tab +'_pose.pt'))
B, D = seq_rot.shape
if args.mode == "smpl":
print("using smpl!!!")
output = self.smpl.forward(
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
transl = seq_rot[:,:3],
global_orient = seq_rot[:,3:6],
body_pose = torch.cat([seq_rot[:,6:69], seq_rot[:,69:72], seq_rot[:,114:117]], dim=1)
)
elif args.mode == "smplh":
print("using smplh!!!")
output = self.smplh.forward(
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
transl = seq_rot[:,:3],
global_orient = seq_rot[:,3:6],
body_pose = seq_rot[:,6:69],
left_hand_pose = seq_rot[:,69:114], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), # seq_rot[:,69:114],
right_hand_pose = seq_rot[:,114:], # torch.zeros([seq_rot.shape[0], 45]).to(seq_rot.device), #
expression = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
)
elif args.mode == "smplx":
output = self.smplx.forward(
betas = torch.zeros([seq_rot.shape[0], 10]).to(seq_rot.device),
# transl = motion[:,:3],
transl = seq_rot[:,:3],
global_orient = seq_rot[:,3:6],
body_pose = seq_rot[:,6:69],
jaw_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
leye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
reye_pose = torch.zeros([seq_rot.shape[0], 3]).to(seq_rot),
left_hand_pose = seq_rot[:,69:69+45],
right_hand_pose = seq_rot[:,69+45:],
expression= torch.zeros([seq_rot.shape[0], 10]).to(seq_rot),
)
N, V, DD = output.vertices.shape # 150, 6890, 3
vertices = output.vertices.reshape((B, -1, V, DD)) # # 150, 1, 6890, 3
meshes = []
for i in range(B):
# if int(i) > 20:
# break
view = []
for v in vertices[i]:
# vertices[:,2] *= -1
if args.mode == 'smplh':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplh.faces)
elif args.mode == 'smplx':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smplx.faces)
elif args.mode == 'smpl':
mesh = trimesh.Trimesh(output.vertices[i].cpu(), self.smpl.faces)
# mesh.export(os.path.join(self.save_path, 'test.obj'))
view.append(mesh)
meshes.append(view)
color_list = self.render_single_view(meshes)
movie_file = os.path.join(self.save_path, tab + 'tmp.mp4')
output_file = os.path.join(self.save_path, tab + 'z.mp4')
self.save_video(movie_file, color_list)
if music_file is not None:
subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
else:
subprocess.run(['ffmpeg','-i',movie_file,output_file])
# if music_file is not None:
# subprocess.run(['ffmpeg','-i',movie_file,'-i',music_file,'-shortest',output_file])
# else:
# subprocess.run(['ffmpeg','-i',movie_file,output_file])
os.remove(movie_file)
def look_at(eye, center, up):
front = eye - center
front = front / np.linalg.norm(front)
right = np.cross(up, front)
right = right/ np.linalg.norm(right)
up_new = np.cross(front, right)
camera_pose = np.eye(4)
camera_pose[:3,:3] = np.stack([right, up_new, front]).transpose()
camera_pose[:3,3] = eye
return camera_pose
def motion_data_load_process(motionfile):
if motionfile.split(".")[-1] == "pkl":
pkl_data = pickle.load(open(motionfile, "rb"))
smpl_poses = pkl_data["smpl_poses"]
modata = np.concatenate((pkl_data["smpl_trans"], smpl_poses), axis=1)
if modata.shape[1] == 69:
hand_zeros = np.zeros([modata.shape[0], 90], dtype=np.float32)
modata = np.concatenate((modata, hand_zeros), axis=1)
assert modata.shape[1] == 159
modata[:, 1] = modata[:, 1] + 1.3
return modata
elif motionfile.split(".")[-1] == "npy":
modata = np.load(motionfile)
print("modata.shape", modata.shape)
if modata.shape[-1] == 315: # first 3-dim is root translation
print("modata.shape is:", modata.shape)
rot6d = torch.from_numpy(modata[:,3:])
T,C = rot6d.shape
rot6d = rot6d.reshape(-1,6)
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
modata = np.concatenate((modata[:,:3], axis), axis=1)
print("modata.shape is:", modata.shape)
elif modata.shape[-1] == 319:
print("modata.shape is:", modata.shape)
modata = modata[:,4:]
rot6d = torch.from_numpy(modata[:,3:])
T,C = rot6d.shape
rot6d = rot6d.reshape(-1,6)
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
modata = np.concatenate((modata[:,:3], axis), axis=1)
print("modata.shape is:", modata.shape)
elif modata.shape[-1] == 168:
modata = np.concatenate( [modata[:,:21*3+1], modata[:,25*3:]] , axis=1)
elif modata.shape[-1] == 159:
print("modata.shape is:", modata.shape)
print("modata.shape is:", modata.shape)
elif modata.shape[-1] == 135:
print("modata.shape is:", modata.shape)
if len(modata.shape) == 3 and modata.shape[0] ==1:
modata = modata.squeeze(0)
rot6d = torch.from_numpy(modata[:,3:])
T,C = rot6d.shape
rot6d = rot6d.reshape(-1,6)
axis = ax_from_6v(rot6d).view(T,-1).detach().cpu().numpy()
hand_zeros = torch.zeros([T, 90]).to(rot6d).detach().cpu().numpy()
modata = np.concatenate((modata[:,:3], axis, hand_zeros), axis=1)
print("modata.shape is:", modata.shape)
else:
raise("shape error!")
modata[:, 1] = modata[:, 1] + 1.3
return modata
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--gpu", type=str, default="2")
parser.add_argument("--modir", type=str, default="")
parser.add_argument("--mode", type=str, default="smplx", choices=['smpl','smplh','smplx'])
parser.add_argument("--fps", type=int, default=30)
parser.add_argument("--save_path", type=str, default=None)
args = parser.parse_args()
print(args.gpu)
motion_dir = args.modir
if args.save_path is not None:
save_path = args.save_path
if not os.path.exists(save_path):
os.makedirs(save_path)
else:
save_path = os.path.join(motion_dir, 'video')
os.makedirs(save_path, exist_ok=True)
music_dir = "experiments/DanceDiffuse_module/debug--0517_Norm_512len_315_transloss/val1640/samples_2023-05-17-20-54-05"
for file in os.listdir(motion_dir):
if file[-3:] in ["npy", "pkl"]:
# if there have exist rendered video, continue
flag = False
for exists_file in os.listdir(save_path):
if file[:-4] in exists_file:
flag = True
break
else:
flag = False
if flag:
print("exist", file)
continue
print(file)
motion_file = os.path.join(motion_dir, file)
visualizer = MovieMaker(save_path=save_path)
modata = motion_data_load_process(motion_file)
visualizer.run(modata, tab=os.path.basename(motion_file).split(".")[0], music_file=None)
print('done')