import shutil import mediapy from PIL import Image, ImageDraw import os.path from enum import Enum from pathlib import Path import wandb import time import cv2 import numpy as np import torch import torch.backends.cudnn as cudnn import torch.nn as nn import trimesh from pytorch3d.io import load_obj from pytorch3d.ops import knn_points, knn_gather from torch.utils.tensorboard import SummaryWriter from tqdm import tqdm from torchvision.transforms.functional import gaussian_blur from time import time import pyvista as pv import dreifus from dreifus.matrix import Pose, Intrinsics, CameraCoordinateConvention, PoseType from dreifus.pyvista import add_camera_frustum, render_from_camera from pixel3dmm import env_paths from pixel3dmm.tracking import util from pixel3dmm.tracking.losses import UVLoss from pixel3dmm.tracking import nvdiffrast_util from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer from pixel3dmm import env_paths from pixel3dmm.tracking.flame.FLAME import FLAME from pixel3dmm.utils.misc import tensor2im from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix, matrix_to_rotation_6d, euler_angles_to_matrix from pixel3dmm.utils.drawing import plot_points def timeit(t0, tag): t1 = time() #print(f'[PROFILER]: {tag} took {t1-t0} seconds') return t1 os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" rank = 42 torch.manual_seed(rank) torch.cuda.manual_seed(rank) cudnn.benchmark = True np.random.seed(rank) I = torch.eye(3)[None].cuda().detach() I6D = matrix_to_rotation_6d(I) left_iris_flame = [4597, 4542, 4510, 4603, 4570] right_iris_flame = [4051, 3996, 3964, 3932, 4028] left_iris_mp = [468, 469, 470, 471, 472] right_iris_mp = [473, 474, 475, 476, 477] torch.set_float32_matmul_precision('high') class View(Enum): GROUND_TRUTH = 1 COLOR_OVERLAY = 2 SHAPE_OVERLAY = 4 SHAPE = 8 LANDMARKS = 16 HEATMAP = 32 DEPTH = 64 def get_intrinsics(focal_length, principal_point, use_hack : bool = True, size : int = 512): intrinsics = torch.eye(3)[None, ...].float().cuda().repeat(focal_length.shape[0], 1,1 ) intrinsics[:, 0, 0] = focal_length.squeeze() * size intrinsics[:, 1, 1] = focal_length.squeeze() * size intrinsics[:, :2, 2] = size/2+0.5 + principal_point * (size/2+0.5) if use_hack: intrinsics[:, 0:1, 2:3] = size - intrinsics[:, 0:1, 2:3] # TODO fix this hack return intrinsics def get_extrinsics(R_base, t_base): timestep = 0 w2c_openGL = torch.eye(4)[None, ...].float().cuda() w2c_openGL[:, :3, :3] = R_base[timestep] w2c_openGL[:, :3, 3] = t_base[timestep] return w2c_openGL def project_points_screen_space(points3d, focal_length, principal_point, R_base, t_base, size : int = 512): # construct camera matrices intrinsics = get_intrinsics(focal_length, principal_point, size=size) w2c_openGL = get_extrinsics(R_base, t_base).repeat(focal_length.shape[0], 1, 1) B = points3d.shape[0] reps_extr = B if w2c_openGL.shape[0] == 1 else 1 reps_intr = B if intrinsics.shape[0] == 1 else 1 # apply w2c transformation lmk68_cam_space = torch.bmm( torch.cat([points3d, torch.ones_like(points3d[..., :1])], dim=-1), w2c_openGL.permute(0, 2, 1).repeat(reps_extr, 1, 1)) # project from cam_space to screen_space lmk68_cam_space_prime = lmk68_cam_space[..., :3] / -lmk68_cam_space[..., [2]] lmk68_screen_space = (-1) * torch.bmm(lmk68_cam_space_prime, intrinsics.permute(0, 2, 1).repeat(reps_intr, 1, 1))[..., :2] lmk68_screen_space = torch.stack([size - 1 - lmk68_screen_space[..., 0], lmk68_screen_space[..., 1], lmk68_cam_space[..., 2]], dim=-1) return lmk68_screen_space WFLW_2_iBUG68 = np.array( [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 33, 34, 35, 36, 37, 42, 43, 44, 45, 46, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 64, 65, 67, 68, 69, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95]) WFLW_2_iBUG68 = torch.from_numpy(WFLW_2_iBUG68).cuda() COMPILE = True if COMPILE: project_points_screen_space = torch.compile(project_points_screen_space) class Tracker(object): def __init__(self, config, flame_module, renderer, device='cuda:0', ): self.config = config self.flame = flame_module self.diff_renderer = renderer self.config = config self.device = device self.actor_name = self.config.video_name DATA_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.actor_name}' self.MAX_STEPS = min(len([f for f in os.listdir(f'{DATA_FOLDER}/cropped/') if f.endswith('.jpg') or f.endswith('.png')]) - self.config.start_frame, 1000) self.FRAME_SKIP = 1 self.BATCH_SIZE = self.config.batch_size print(f''' <<<<<<<< INITIALIZING TRACKER INSTANCE FOR {self.actor_name} >>>>>>>> ''') self.mirror_order = torch.from_numpy(np.load(f'{env_paths.MIRROR_INDEX}')).long().cuda() self.uv_loss_fn = UVLoss(stricter_mask=self.config.uv_loss.stricter_uv_mask, delta_uv= self.config.uv_loss.delta_uv, dist_uv=self.config.uv_loss.dist_uv) if COMPILE: self.uv_loss_fn.compute_loss = torch.compile(self.uv_loss_fn.compute_loss) self.actor_name = self.actor_name + f'_nV{config.num_views}' if config.no_lm: self.actor_name = self.actor_name + '_noLM' if config.no_pho: self.actor_name = self.actor_name + '_noPho' if self.config.ignore_mica: self.actor_name = self.actor_name + '_noMICA' if self.config.flame2023: self.actor_name = self.actor_name + '_FLAME23' if self.config.uv_map_super > 0: self.actor_name = self.actor_name + f'_uv{self.config.uv_map_super}' if self.config.normal_super > 0: self.actor_name = self.actor_name + f'_n{self.config.normal_super}' if self.config.normal_super_can > 0: self.actor_name = self.actor_name + f'_nc{self.config.normal_super_can}' self.global_step = 0 self.no_sh = config.no_sh self.no_lm = config.no_lm self.no_pho = config.no_pho # Latter will be set up self.frame = 0 self.is_initializing = False self.image_size = torch.tensor([[config.image_size[0], config.image_size[1]]]).cuda() if hasattr(self.config, 'output_folder'): self.save_folder = self.config.output_folder else: self.save_folder = env_paths.TRACKING_OUTPUT self.output_folder = os.path.join(self.save_folder, self.actor_name) self.checkpoint_folder = os.path.join(self.save_folder, self.actor_name, "checkpoint") self.mesh_folder = os.path.join(self.save_folder, self.config.video_name, "mesh") self.create_output_folders() self.writer = SummaryWriter(log_dir=self.save_folder + self.actor_name + '/logs') self.cam_pose_nvd = {} self.R_base = {} self.t_base = {} flame_mesh_mask = np.load(f'{env_paths.FLAME_MASK_ASSET}/FLAME2020/FLAME_masks/FLAME_masks.pkl', allow_pickle=True, encoding='latin1') self.vertex_face_mask = torch.from_numpy(flame_mesh_mask['face']).cuda().long() self.setup_renderer() self.intermediate_exprs = [] self.intermediate_Rs = [] self.intermediate_ts = [] self.intermediate_eyes = [] self.intermediate_eyelids = [] self.intermediate_jaws = [] self.intermediate_necks = [] self.intermediate_fls = [] self.intermediate_pps = [] self.cached_data = {} def get_image_size(self): return self.image_size[0][0].item(), self.image_size[0][1].item() def create_output_folders(self): Path(self.save_folder).mkdir(parents=True, exist_ok=True) Path(self.checkpoint_folder).mkdir(parents=True, exist_ok=True) Path(self.mesh_folder).mkdir(parents=True, exist_ok=True) def setup_renderer(self): mesh_file = f'{env_paths.head_template}' self.config.image_size = self.get_image_size() self.flame.vertex_face_mask = self.vertex_face_mask if COMPILE: self.flame = torch.compile(self.flame) self.opt_pre = torch.compile(self.opt_pre) self.opt_post = torch.compile(self.opt_post) self.actual_smooth = torch.compile(self.actual_smooth) self.renderer = self.diff_renderer # already global self.faces = load_obj(mesh_file)[1] def save_checkpoint(self, frame_id, selected_frames = None): if selected_frames is None: exp = self.exp eyes = self.eyes eyelids = self.eyelids R = self.R t = self.t jaw = self.jaw neck = self.neck focal_length = self.focal_length principal_point = self.principal_point else: exp = self.exp(selected_frames) eyes = self.eyes(selected_frames) eyelids = self.eyelids(selected_frames) R = self.R(selected_frames) t = self.t(selected_frames) jaw = self.jaw(selected_frames) neck = self.neck(selected_frames) if self.config.global_camera: focal_length = self.focal_length principal_point = self.principal_point else: focal_length = self.focal_length(selected_frames) principal_point = self.principal_point(selected_frames) frame = { 'flame': { 'exp': exp.clone().detach().cpu().numpy(), 'shape': self.shape.clone().detach().cpu().numpy(), 'eyes': eyes.clone().detach().cpu().numpy(), 'eyelids': eyelids.clone().detach().cpu().numpy(), 'jaw': jaw.clone().detach().cpu().numpy(), 'neck': neck.clone().detach().cpu().numpy(), 'R': R.clone().detach().cpu().numpy(), 'R_rotation_matrix': rotation_6d_to_matrix(R).detach().cpu().numpy(), 't': t.clone().detach().cpu().numpy(), }, 'img_size': self.image_size.clone().detach().cpu().numpy()[0], 'frame_id': frame_id, 'global_step': self.global_step } cam_params = { f'R_base_{serial}': self.R_base[serial].clone().detach().cpu().numpy() for serial in self.R_base.keys() } cam_pos = { f't_base_{serial}': self.t_base[serial].clone().detach().cpu().numpy() for serial in self.R_base.keys() } intr = { 'fl': focal_length.clone().detach().cpu().numpy(), 'pp': principal_point.clone().detach().cpu().numpy(), } cam_params.update(cam_pos) cam_params.update(intr) frame.update( { f'camera': cam_params } ) bs = exp.shape[0] vertices, lmks, joint_transforms, vertices_can, vertices_noneck = self.flame(cameras=torch.inverse(self.R_base[0])[:1, ...].repeat(bs, 1, 1), shape_params=self.shape[:1, ...].repeat(bs, 1), expression_params=exp, eye_pose_params=eyes, jaw_pose_params=jaw, neck_pose_params=neck, rot_params_lmk_shift=R, eyelid_params=eyelids, ) frame.update( { f'joint_transforms': joint_transforms.detach().cpu().numpy(), } ) f = self.diff_renderer.faces[0].cpu().numpy() for b_i in range(bs): v = vertices[b_i].cpu().numpy() if self.config.save_meshes: trimesh.Trimesh(faces=f, vertices=v, process=False).export(f'{self.mesh_folder}/{frame_id:05d}.glb') torch.save(frame, f'{self.checkpoint_folder}/{frame_id:05d}.frame') selction_indx = np.array([36, 39, 42, 45, 33, 48, 54]) _lmks = lmks[b_i].detach().squeeze().cpu().numpy() if self.config.save_landmarks: np.save(f'{self.mesh_folder}/landmarks_{frame_id}_{b_i}.npy', _lmks[selction_indx]) if frame_id == self.config.start_frame and self.config.save_meshes: faces = self.diff_renderer.faces[0].cpu().numpy() trimesh.Trimesh(faces=faces, vertices=vertices_can[0].detach().cpu().numpy(), process=False).export(f'{self.mesh_folder}/canonical.glb') if self.config.save_landmarks: lmks = lmks.detach().squeeze().cpu().numpy() np.save(f'{self.mesh_folder}/ibug68_{frame_id}.glb', lmks) selction_indx = np.array([36, 39, 42, 45, 33, 48, 54]) np.save(f'{self.mesh_folder}/now_{frame_id}.glb', lmks[selction_indx]) def get_heatmap(self, values): l2 = tensor2im(values) l2 = cv2.cvtColor(l2, cv2.COLOR_RGB2BGR) #l2[l2 > 125] = 125 #l2 = cv2.normalize(l2, None, 0, 255, cv2.NORM_MINMAX) #l2[l2 > 35] = 35 #l2 = cv2.normalize(l2, None, 0, 255, cv2.NORM_MINMAX) l2 = l2 - 127 max_err = 25 l2[l2>max_err] = max_err l2 = ((l2 / max_err)*255).astype(np.uint8) heatmap = cv2.applyColorMap(l2, cv2.COLORMAP_JET) #/ 255. heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255. #heatmap = heatmap ** (1/3) #Image.fromarray((heatmap*255).astype(np.uint8)).show() #exit() #heatmap = cv2.cvtColor(cv2.addWeighted(heatmap, 0.75, l2, 0.25, 0).astype(np.uint8), cv2.COLOR_BGR2RGB) / 255. heatmap = torch.from_numpy(heatmap).permute(2, 0, 1) return heatmap def to_cuda(self, batch, unsqueeze=False): for key in batch.keys(): if torch.is_tensor(batch[key]): batch[key] = batch[key].to(self.device) if unsqueeze: batch[key] = batch[key][None] return batch def create_parameters(self, timestep, mica_shape): bz = 1 pose_mat = np.eye(4) pose_mat[2, 3] = -1 opencv_w2c_pose = Pose(pose_mat, camera_coordinate_convention=dreifus.matrix.CameraCoordinateConvention.OPEN_CV) opencv_w2c_pose = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.CAM_2_WORLD) opencv_w2c_pose.look_at(np.zeros(3), np.array([0, 1, 0])) opencv_w2c_pose = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM) self.debug_pose_init = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM).copy() self.shape = mica_shape.detach().clone() self.mica_shape = mica_shape.detach().clone() if self.config.ignore_mica: self.shape = torch.zeros_like(self.shape) self.mica_shape = torch.zeros_like(self.mica_shape) cam_pose = opencv_w2c_pose cam_pose = cam_pose.change_pose_type(dreifus.matrix.PoseType.CAM_2_WORLD) cam_pose_nvd = cam_pose.copy() cam_pose_nvd = cam_pose_nvd.change_camera_coordinate_convention(new_camera_coordinate_convention=dreifus.matrix.CameraCoordinateConvention.OPEN_GL) cam_pose_nvd = cam_pose_nvd.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM) self.cam_pose_nvd[timestep] = torch.from_numpy(cam_pose_nvd.copy()).float().cuda() R = torch.from_numpy(cam_pose_nvd.get_rotation_matrix()).unsqueeze(0).cuda() T = torch.from_numpy(cam_pose_nvd.get_translation()).unsqueeze(0).cuda() R.requires_grad = True T.requires_grad = True self.R_base[timestep] = R self.t_base[timestep] = T init_f = 2000 * self.config.size/512 self.focal_length = torch.tensor([[init_f/self.config.size]]).float().to(self.device) self.principal_point = torch.tensor([[0, 0]]).float().to(self.device) self.focal_length.requires_grad = True self.principal_point.requires_grad = True intrinsics = torch.tensor([[init_f, 0, self.config.size//2], [0, init_f, self.config.size//2], [0, 0, 1]]).float().cuda() proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics, znear=0.1, zfar=10, width=self.config.size, height=self.config.size) self.r_mvps = {} for serial in self.cam_pose_nvd.keys(): self.r_mvps[serial] = ( proj_512 @ self.cam_pose_nvd[serial] )[None, ...] n_timesteps = 1 expression_params = np.zeros([n_timesteps, 100]) jaw_params = np.zeros([n_timesteps, 3]) neck_params = np.zeros([n_timesteps, 3]) flame_R = torch.from_numpy(np.stack([np.eye(3) for _ in range(n_timesteps)], axis=0)) flame_t = torch.from_numpy(np.stack([np.zeros([3]) for _ in range(n_timesteps)], axis=0)) self.R = nn.Parameter(matrix_to_rotation_6d(flame_R.float().to(self.device))) self.t = nn.Parameter(flame_t.float().to(self.device)) self.expression_params = expression_params self.jaw_params = jaw_params.astype(np.float32) self.neck_params = neck_params.astype(np.float32) self.shape = nn.Parameter(self.mica_shape.detach().clone()) self.texture_observation_mask = None self.exp = nn.Parameter(torch.from_numpy(self.expression_params[[0] + self.config.keyframes,..., :]).float().to(self.device)) self.jaw = nn.Parameter(matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[[0]+ self.config.keyframes,..., :]).cuda(), 'XYZ'))) self.neck = nn.Parameter(matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.neck_params[[0]+ self.config.keyframes,..., :]).cuda(), 'XYZ'))) self.eyes = nn.Parameter(torch.cat([matrix_to_rotation_6d(I), matrix_to_rotation_6d(I)], dim=1).repeat(1+len(self.config.keyframes), 1) ) self.eyelids = nn.Parameter(torch.zeros(1+len(self.config.keyframes), 2).float().to(self.device)) def parse_mask(self, ops, batch, visualization=False): result = ops['mask_images_rendering'] if visualization: result = ops['mask_images'] return result.detach() def clone_params_keyframes_all(self, freeze_id : bool = False, is_joint : bool = False, freeze_cam : bool = False, include_neck : bool = False): lr_scale = 1.0 lr_scale_id_related = 1.0 if freeze_id: lr_scale_id_related = 0.1 params = [ {'params': [self.exp], 'lr': self.config.lr_exp * lr_scale, 'name': ['exp']}, # 0.025 {'params': [self.eyes], 'lr': 0.005 * lr_scale, 'name': ['eyes']}, # {'params': [self.eyelids.clone())], 'lr': 0.001, 'name': ['eyelids']}, {'params': [self.eyelids], 'lr': 0.002 * lr_scale, 'name': ['eyelids']}, # {'params': [self.sh.clone())], 'lr': 0.01, 'name': ['sh']}, {'params': [self.t], 'lr': self.config.lr_t * lr_scale, 'name': ['t']}, #{'params': [self.t.clone())], 'lr': 0.005 * lr_scale, 'name': ['t']}, {'params': [self.R], 'lr': self.config.lr_R * lr_scale, 'name': ['R']}, #{'params': [self.R.clone())], 'lr': 0.003 * lr_scale, 'name': ['R']}, # {'params': [self.tex.clone())], 'lr': 0.001, 'name': ['tex']}, # {'params': [self.principal_point.clone())], 'lr': 0.001, 'name': ['principal_point']}, # {'params': [self.focal_length.clone())], 'lr': 0.001, 'name': ['focal_length']} ] #params.append({'params': [self.shape.clone())], 'lr': self.config.lr_id * lr_scale, 'name': ['shape']}) if not freeze_id: if is_joint: params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale * 1, 'name': ['shape']}) else: params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale, 'name': ['shape']}) #params.append({'params': [self.shape], 'lr': 0.0, 'name': ['shape']}) params.append({'params': [self.jaw], 'lr': self.config.lr_jaw * lr_scale, 'name': ['jaw']}) if include_neck: params.append({'params': [self.neck], 'lr': self.config.lr_neck, 'name': ['neck']}) # params.append({'params': [self.t], 'lr': 0.001, 'name': ['translation']}) # params.append({'params': [self.R], 'lr': 0.005, 'name': ['rotation']}) # params.append({'params': [self.focal_length, self.principal_point], 'lr': 0.01*lr_scale, 'name': ['camera_params']}) #if not self.config.load_intr: if not freeze_cam: params.append({'params': [self.focal_length], 'lr': self.config.lr_f * lr_scale_id_related, 'name': ['camera_params']}) params.append({'params': [self.principal_point], 'lr': self.config.lr_pp * lr_scale_id_related, 'name': ['camera_params']}) return params def clone_params_keyframes_all_joint(self, freeze_id : bool = False, is_joint : bool = False, include_neck : bool = False): lr_scale = 1.0 lr_scale_id_related = 1.0 if freeze_id: lr_scale_id_related = 0.1 params = [ {'params': self.exp.parameters(), 'lr': self.config.lr_exp * lr_scale, 'name': ['exp']}, # 0.025 {'params': self.eyes.parameters(), 'lr': 0.005 * lr_scale, 'name': ['eyes']}, {'params': self.eyelids.parameters(), 'lr': 0.002 * lr_scale, 'name': ['eyelids']}, {'params': self.t.parameters(), 'lr': self.config.lr_t * lr_scale, 'name': ['t']}, {'params': self.R.parameters(), 'lr': self.config.lr_R * lr_scale, 'name': ['R']}, ] params.append({'params': self.jaw.parameters(), 'lr': self.config.lr_jaw * lr_scale, 'name': ['jaw']}) if include_neck: params.append({'params': self.neck.parameters(), 'lr': self.config.lr_neck, 'name': ['jaw']}) if not self.config.global_camera: params.append({'params': self.focal_length.parameters(), 'lr': self.config.lr_f * lr_scale_id_related, 'name': ['camera_params']}) params.append({'params': self.principal_point.parameters(), 'lr': self.config.lr_pp * lr_scale_id_related, 'name': ['camera_params']}) #params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale * 1, 'name': ['shape']}) return params def reduce_loss(self, losses): all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss return all_loss def optimize_camera(self, batch, steps=2000, is_first_frame : bool = False ): batch = self.to_cuda(batch) images, landmarks, lmk_mask = self.parse_landmarks(batch) h, w = images.shape[2:4] num_keyframes = 1 uv_mask = batch["uv_mask"] uv_map = batch["uv_map"] if "uv_map" in batch else None if uv_map is not None: uv_map[(1 - uv_mask[:, :, :, :]).bool()] = 0 self.focal_length.requires_grad = True self.principal_point.requires_grad = True lr_mult = 1.0 params = [ {'params': [self.t], 'lr': lr_mult*0.001}, ##0.05}, {'params': [self.R], 'lr': lr_mult*0.005}, #0.05}, ] if is_first_frame: params.append({'params': [self.focal_length], 'lr': 0.02}) params.append({'params': [self.principal_point], 'lr': 0.0001}) optimizer = torch.optim.Adam(params) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(steps*0.75), gamma=0.1) #self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]], # frame_dst='/camera', save=False, dump_directly=True) t = tqdm(range(steps), desc='', leave=True, miniters=100) num_views = 1 #len(self.R_base.keys()) bs = 1 #len(self.cam_serials) * num_keyframes for k in t: vertices_can, lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame(cameras=torch.inverse(self.R_base[0]), shape_params=self.shape if self.shape.shape[0] == bs else self.shape.repeat(bs, 1), expression_params=self.exp.repeat_interleave(num_views, dim=0), eye_pose_params=self.eyes.repeat_interleave(num_views, dim=0), jaw_pose_params=self.jaw.repeat_interleave(num_views, dim=0), neck_pose_params=self.neck.repeat_interleave(num_views, dim=0), rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(self.R)))).repeat_interleave(num_views, dim=0), ) lmk68 = torch.einsum('bny,bxy->bnx', lmk68, rotation_6d_to_matrix(self.R.repeat_interleave(num_views, dim=0))) + self.t.repeat_interleave(num_views, dim=0).unsqueeze(1) verts = torch.einsum('bny,bxy->bnx', vertices_can, rotation_6d_to_matrix( self.R.repeat_interleave(num_views, dim=0))) + self.t.repeat_interleave(num_views, dim=0).unsqueeze( 1) lmk68_screen_space = project_points_screen_space(lmk68, self.focal_length, self.principal_point, self.R_base, self.t_base, size=self.config.size) verts_screen_space = project_points_screen_space(verts, self.focal_length, self.principal_point, self.R_base, self.t_base, size=self.config.size) losses = {} losses['pp_reg'] = torch.sum(self.principal_point ** 2) if k <= steps // 2: losses['lmk68'] = util.lmk_loss(lmk68_screen_space[..., :2], landmarks[..., :2], [h, w], lmk_mask) * 3000 if k == 0: self.uv_loss_fn.compute_corresp(uv_map) if k > steps // 2: uv_loss = self.uv_loss_fn.compute_loss(verts_screen_space) losses['uv_loss'] = uv_loss * 1000 all_loss = 0. for key in losses.keys(): all_loss = all_loss + losses[key] losses['all_loss'] = all_loss optimizer.zero_grad() all_loss.backward() optimizer.step() scheduler.step() optimizer.zero_grad() intrinsics = get_intrinsics(self.focal_length, self.principal_point, use_hack=False, size=self.config.size) proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics[0], znear=0.1, zfar=5, width=self.config.size, height=self.config.size) for serial in self.cam_pose_nvd.keys(): extr = get_extrinsics(self.R_base[serial], self.t_base[serial]) r_mvps = proj_512 @ extr self.r_mvps[serial] = r_mvps loss = all_loss.item() t.set_description(f'Loss for camera {loss:.4f}') self.frame += 1 #if k % 100 == 0: # self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY, View.COLOR_OVERLAY]], frame_dst='/camera', save=False, dump_directly=True, is_camera=True) self.frame = 0 @torch.compiler.disable def get_vars(self, is_joint, selected_frames): if not is_joint: exp = self.exp eyes = self.eyes eyelids = self.eyelids _R = self.R _t = self.t jaw = self.jaw neck = self.neck focal_length = self.focal_length principal_point = self.principal_point else: selected_frames = torch.from_numpy(selected_frames).long().cuda() exp = self.exp(selected_frames) eyes = self.eyes(selected_frames) eyelids = self.eyelids(selected_frames) _R = self.R(selected_frames) _t = self.t(selected_frames) jaw = self.jaw(selected_frames) neck = self.neck(selected_frames) if not self.config.global_camera: focal_length = self.focal_length(selected_frames) principal_point = self.principal_point(selected_frames) else: focal_length = self.focal_length principal_point = self.principal_point return exp, eyes, eyelids, _R, _t, jaw, neck, focal_length, principal_point @torch.compiler.disable def data_stuff(self, is_joint, iters, p, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris): if is_joint: with torch.no_grad(): if (p < int(iters * 0.15) and (p % 2 == 0)) or not self.config.smooth: all_frames = np.array( range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP)) selected_frames = np.sort(np.random.choice(np.arange(len(all_frames)), size=self.BATCH_SIZE, replace=False)) # np.random.choice( else: all_frames = np.array( range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP)) start = np.min(all_frames) end = np.max(all_frames) rnd_start = np.random.randint(start, end) assert (end - start) >= self.BATCH_SIZE + 1 assert self.BATCH_SIZE % 2 == 0 if rnd_start - self.BATCH_SIZE // 2 < 0: rnd_start = self.BATCH_SIZE // 2 if rnd_start + self.BATCH_SIZE // 2 + 1 > end: rnd_start = end - self.BATCH_SIZE // 2 + 1 selected_frames = np.array( list(range(rnd_start - self.BATCH_SIZE // 2, rnd_start + self.BATCH_SIZE // 2))) selected_frames_th = torch.from_numpy(selected_frames).long() batch = {k: self.cached_data[k][selected_frames_th, ...] for k in self.cached_data.keys()} images, landmarks, lmk_mask = self.parse_landmarks(batch) uv_mask = batch["uv_mask"] normal_mask = batch["normal_mask"] normal_map = batch["normals"] if "normals" in batch else None uv_map = batch["uv_map"] if "uv_map" in batch else None #TODO check if this was important in any way if uv_map is not None: uv_map[(1 - uv_mask[:, :, :, :]).bool()] = 0 num_views = len(self.R_base.keys()) bs = batch['normals'].shape[0] * num_views image_lmks68 = landmarks if landmarks is not None: left_iris = batch['left_iris'] right_iris = batch['right_iris'] mask_left_iris = batch['mask_left_iris'] mask_right_iris = batch['mask_right_iris'] else: selected_frames = None bs = 1 num_views = 1 batch = None return selected_frames, batch, bs, num_views, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris #TODO: could be improved by compiling all the actuall smooth loss stuff #@torch.compile def actual_smooth(self, variables, losses): reg_smooth_exp = (variables['exp'][:-1, :] - variables['exp'][1:, :]).square().mean() reg_smooth_eyes = (variables['eyes'][:-1, :] - variables['eyes'][1:, :]).square().mean() reg_smooth_eyelids = (variables['eyelids'][:-1, :] - variables['eyelids'][1:, :]).square().mean() reg_smooth_R = (variables['R'][:-1, :] - variables['R'][1:, :]).square().mean() reg_smooth_t = (variables['t'][:-1, :] - variables['t'][1:, :]).square().mean() reg_smooth_jaw = (variables['jaw'][:-1, :] - variables['jaw'][1:, :]).square().mean() reg_smooth_neck = (variables['neck'][:-1, :] - variables['neck'][1:, :]).square().mean() if not self.config.global_camera: reg_smooth_principal_point = ( variables['principal_point'][:-1, :] - variables['principal_point'][1:, :]).square().mean() reg_smooth_focal_length = ( variables['focal_length'][:-1, :] - variables['focal_length'][1:, :]).square().mean() else: reg_smooth_principal_point = torch.zeros_like(reg_smooth_jaw) reg_smooth_focal_length = torch.zeros_like(reg_smooth_jaw) losses['smooth/exp'] = reg_smooth_exp * self.config.reg_smooth_exp * self.config.reg_smooth_mult losses['smooth/eyes'] = reg_smooth_eyes * self.config.reg_smooth_eyes * self.config.reg_smooth_mult losses['smooth/eyelids'] = reg_smooth_eyelids * self.config.reg_smooth_eyelids * self.config.reg_smooth_mult losses['smooth/jaw'] = reg_smooth_jaw * self.config.reg_smooth_jaw * self.config.reg_smooth_mult losses['smooth/neck'] = reg_smooth_neck * self.config.reg_smooth_neck * self.config.reg_smooth_mult losses['smooth/R'] = reg_smooth_R * self.config.reg_smooth_R * self.config.reg_smooth_mult losses['smooth/t'] = reg_smooth_t * self.config.reg_smooth_t * self.config.reg_smooth_mult losses['smooth/principal_point'] = reg_smooth_principal_point * self.config.reg_smooth_pp * self.config.reg_smooth_mult losses['smooth/focal_length'] = reg_smooth_focal_length * self.config.reg_smooth_fl * self.config.reg_smooth_mult return losses @torch.compiler.disable def add_smooth_loss(self, losses, is_joint, p, iters, variables): if is_joint and self.config.smooth and ((p >= int(iters * 0.15) and (p % 2 == 1)) ): # and p % 2 != 0 and False: losses = self.actual_smooth(variables, losses) return losses def opt_pre(self, is_joint, iters, p, no_lm, image_lmks68, lmk_mask, normal_mask, normal_map, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris): image_size = [self.config.size, self.config.size] selected_frames, batch, bs, num_views, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris = self.data_stuff(is_joint, iters, p, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris) self.diff_renderer.reset() losses = {} exp, eyes, eyelids, _R, _t, jaw, neck, focal_length, principal_point = self.get_vars(is_joint, selected_frames) variables = { 'exp': exp, 'eyes': eyes, 'eyelids': eyelids, 'R': _R, 't': _t, 'jaw': jaw, 'neck': neck, 'principal_point': principal_point, 'focal_lenght': focal_length, } intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False, size=self.config.size) proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics, znear=0.1, zfar=5, width=self.config.size, height=self.config.size) for serial in self.cam_pose_nvd.keys(): extr = get_extrinsics(self.R_base[serial], self.t_base[serial]) r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1)) self.r_mvps[serial] = r_mvps vertices_can, lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame( cameras=torch.inverse(self.R_base[0]).repeat(bs, 1, 1), shape_params=self.shape if self.shape.shape[0] == bs else self.shape.repeat(bs, 1).cuda(), expression_params=exp.repeat_interleave(num_views, dim=0), # .repeat(bs, 1), eye_pose_params=eyes.repeat_interleave(num_views, dim=0), # .repeat(bs, 1), jaw_pose_params=jaw.repeat_interleave(num_views, dim=0), # .repeat(bs, 1), neck_pose_params=neck.repeat_interleave(num_views, dim=0), # .repeat(bs, 1), eyelid_params=eyelids.repeat_interleave(num_views, dim=0), # .repeat(bs, 1), rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(_R)))).repeat_interleave( num_views, dim=0), # .repeat(bs, 1) ) verts_can_can_mirrored = vertices_can_can[:, self.mirror_order, :] vertices_can_can_mirrored = torch.zeros_like(verts_can_can_mirrored) vertices_can_can_mirrored[:, :, 0] = -verts_can_can_mirrored[:, :, 0] vertices_can_can_mirrored[:, :, 1:] = verts_can_can_mirrored[:, :, 1:] mirror_loss = (vertices_can_can_mirrored - vertices_can_can).square().sum(-1) mirror_loss = mirror_loss.mean() lmk68 = torch.einsum('bny,bxy->bnx', lmk68, rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave( num_views, dim=0).unsqueeze(1) vertices = torch.einsum('bny,bxy->bnx', vertices_can, rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave( num_views, dim=0).unsqueeze(1) vertices_noneck = torch.einsum('bny,bxy->bnx', vertices_noneck, rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave( num_views, dim=0).unsqueeze(1) proj_lmks68 = project_points_screen_space(lmk68, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size) proj_vertices = project_points_screen_space(vertices, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size) right_eye, left_eye = eyes[:, :6], eyes[:, 6:] # landmark loss if not no_lm: lmk_scale = 1.0 # 0.0001 # Landmarks sparse term # losses[('loss/lmk_oval')] = util.oval_lmk_loss(proj_lmks68[..., :2], image_lmks68, image_size, lmk_mask) * self.config.w_lmks_oval * lmk_scale # losses['loss/lmk_68'] = util.lmk_loss(proj_lmks68[:, 17:, :2], image_lmks68[:, 17:, :], image_size, lmk_mask[:, 17:, :]) * self.config.w_lmks * lmk_scale # if self.config.use_eyebrows: # losses['loss/lmk_eyebrows'] = util.lmk_loss(proj_lmks68[:, 17:27, :2], image_lmks68[:, 17:27, :], image_size, lmk_mask[:, 17:27, :]) * self.config.w_lmks * lmk_scale * 5.0 losses['loss/lmk_eye2'] = util.lmk_loss(proj_lmks68[:, 36:48, :2], image_lmks68[:, 36:48, :], image_size, lmk_mask[:, 36:48, :]) * self.config.w_lmks * lmk_scale * 5 #10 # 0 #2.0 #0.5 #0.0 #100 if self.config.use_mouth_lmk: losses['loss/lmk_mouth'] = util.lmk_loss(proj_lmks68[:, 48:68, :2], image_lmks68[:, 48:68, :], image_size, lmk_mask[:, 48:68, :]) * self.config.w_lmks_mouth * lmk_scale * 0.25 losses['loss/lmk_mouth_closure'] = util.mouth_closure_lmk_loss(proj_lmks68[..., :2], image_lmks68, image_size, lmk_mask) * self.config.w_lmks_mouth * lmk_scale * 2.5 losses['loss/lmk_eye'] = util.eye_closure_lmk_loss(proj_lmks68[..., :2], image_lmks68, image_size, lmk_mask) * self.config.w_lmks_lid * lmk_scale * 500 # 0 #500 #0.0 #10 losses['loss/lmk_iris_left'] = util.lmk_loss(proj_vertices[:, left_iris_flame[:1], ..., :2], left_iris, image_size, mask_left_iris) * self.config.w_lmks_iris * lmk_scale * 50.00 losses['loss/lmk_iris_right'] = util.lmk_loss(proj_vertices[:, right_iris_flame[:1], ..., :2], right_iris, image_size, mask_right_iris) * self.config.w_lmks_iris * lmk_scale * 50.0 # Reguralizers losses['reg/exp'] = torch.sum(exp ** 2, dim=-1).mean() * self.config.w_exp losses['reg/sym'] = torch.sum((right_eye - left_eye) ** 2, dim=-1).mean() * 0.1 # 8.0 #*5.0 losses['reg/jaw'] = torch.sum((I6D - jaw) ** 2, dim=-1).mean() * self.config.w_jaw losses['reg/neck'] = torch.sum((I6D - neck) ** 2, dim=-1).mean() * self.config.w_neck # losses['reg/eye_lids'] = torch.sum((eyelids[:, 0] - eyelids[:, 1]) ** 2, dim=-1).mean() * 0.1 losses['reg/eye_left'] = torch.sum((I6D - left_eye) ** 2, dim=-1).mean() * 0.01 losses['reg/eye_right'] = torch.sum((I6D - right_eye) ** 2, dim=-1).mean() * 0.01 losses['reg/shape'] = torch.sum((self.shape - self.mica_shape) ** 2, dim=-1).mean() * self.config.w_shape losses['reg/shape_general'] = torch.sum((self.shape) ** 2, dim=-1).mean() * self.config.w_shape_general losses['reg/mirror'] = mirror_loss * 5000 if not (self.config.n_fine and p >= iters // 2): losses['reg/pp'] = torch.sum(principal_point ** 2, dim=-1).mean() return batch, losses, vertices, vertices_noneck, vertices_can, vertices_can_can, proj_vertices, proj_lmks68, selected_frames, variables, num_views, normal_mask, normal_map, uv_map, uv_mask def opt_post(self, variables, ops, proj_vertices, proj_lmks68, batch, is_joint, is_first_step, losses, uv_map, selected_frames, p, iters, num_views, normal_mask, normal_map): grabbed_depth = ops['actual_rendered_depth'][:, 0, torch.clamp(proj_vertices[:, :, 1].long(), 0, self.config.size - 1), torch.clamp(proj_vertices[:, :, 0].long(), 0, self.config.size - 1), ][:, 0, :] is_visible_verts_idx = grabbed_depth < (proj_vertices[:, :, 2] + 1e-2) if not self.config.occ_filter: is_visible_verts_idx = torch.ones_like(is_visible_verts_idx) valid_bg_classes = batch['valid_bg'] # bg-class or neck-class if self.config.sil_super > 0: if is_joint or (not is_first_step): # and p > 50 and p < int(iters*0.85): # 100 # losses['loss/sil'] =((1-upper_forehead[:, None, :, :]) * (batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super#0 losses['loss/sil'] = ((valid_bg_classes[:, None, :, :]) * ( batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super # 0 else: losses['loss/sil'] = ((valid_bg_classes[:, None, :, :]) * ( batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super / 10 # 0 if self.config.uv_map_super: # and p > iters // 2: gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1) if self.config.uv_l2: uv_loss = ((gt_uv - ops['uv_images']) * batch["uv_mask"][:, 0, ...].unsqueeze(-1)).square().mean() * 100 else: uv_loss = ((gt_uv - ops['uv_images']) * batch["uv_mask"][:, 0, ...].unsqueeze(-1)).abs().mean() # TODO: outlier filtering!!! losses['loss/uv_pixel'] = uv_loss * self.config.uv_map_super if self.config.uv_map_super > 0.0: # and (p < iters // 2 or self.config.keep_uv) and not self.config.no2d_verts: # uv_loss = get_uv_loss(uv_map, proj_vertices) if self.uv_loss_fn.gt_2_verts is None: self.uv_loss_fn.compute_corresp(uv_map, selected_frames=selected_frames) uv_loss = self.uv_loss_fn.compute_loss(proj_vertices, selected_frames=selected_frames, uv_map=uv_map, l2_loss=self.config.uv_l2, is_visible_verts_idx=is_visible_verts_idx) losses['loss/uv'] = uv_loss * self.config.uv_map_super # 000 skip_normals = False if self.config.n_fine and p < iters // 2: skip_normals = True if (self.config.normal_super > 0.0 or self.config.normal_super_can > 0.0) and not skip_normals: # normal_loss_map = normal_loss_map * dilated_eye_mask[:, 0, ...] * (1 - ops['mask_images_eyes_region'][:, 0, ...]) # use dilated eye mask only # maybe also applie eyemask in image not rendering dilated_eye_mask = 1 - (gaussian_blur(ops['mask_images_eyes'], [self.config.normal_mask_ksize, self.config.normal_mask_ksize], sigma=[self.config.normal_mask_ksize, self.config.normal_mask_ksize]) > 0).float() pred_normals = ops['normal_images'] # 1 3 512 512 normals in world space rot_mat = rotation_6d_to_matrix(variables["R"].repeat_interleave(num_views, dim=0)) # 1 3 3 pred_normals_flame_space = torch.einsum('bxy,bxhw->byhw', rot_mat, pred_normals) if normal_map is not None: l_map = (normal_map - pred_normals_flame_space) valid = ((l_map.abs().sum(dim=1) / 3) < self.config.delta_n).unsqueeze(1) normal_loss_map = l_map * valid.float() * normal_mask * dilated_eye_mask if self.config.normal_l2: losses['loss/normal'] = normal_loss_map.square().mean() * self.config.normal_super else: losses['loss/normal'] = normal_loss_map.abs().mean() * self.config.normal_super else: losses['loss/normal'] = 0.0 # smoothness loss losses = self.add_smooth_loss(losses, is_joint, p, iters, variables) all_loss = self.reduce_loss(losses) return all_loss def optimize_color(self, batch, params_func, no_lm : bool = False, save_timestep=0, is_joint : bool = False, is_first_step : bool = False, ): iters = self.config.iters if not is_joint: images, landmarks, lmk_mask = self.parse_landmarks(batch) uv_mask = batch["uv_mask"] normal_mask = batch["normal_mask"] normal_map = batch["normals"] if "normals" in batch else None uv_map = batch["uv_map"] if "uv_map" in batch else None if uv_map is not None: uv_map[(1-uv_mask[:, :, :, :]).bool()] = 0 # Optimizer per step if is_joint: optimizer = torch.optim.SparseAdam(params_func()) params_global = [ {'params': [self.shape], 'lr': self.config.lr_id * 1.0, 'name': ['shape']} ] if self.config.global_camera: params_global.append({'params': [self.focal_length], 'lr': self.config.lr_f * 1.0, 'name': ['camera_params']}) params_global.append({'params': [self.principal_point], 'lr': self.config.lr_pp * 1.0, 'name': ['camera_params']}) optimizer_id = torch.optim.Adam(params_global) optimizer_id.zero_grad() else: optimizer = torch.optim.Adam(params_func()) optimizer.zero_grad() if not is_joint: num_views = len(self.R_base.keys()) bs = batch['normals'].shape[0] * num_views image_lmks68 = landmarks if landmarks is not None: left_iris = batch['left_iris'] right_iris = batch['right_iris'] mask_left_iris = batch['mask_left_iris'] mask_right_iris = batch['mask_right_iris'] else: image_lmks68 = None lmk_mask, normal_mask, normal_map, uv_map, uv_mask = None, None, None, None, None left_iris, right_iris, mask_left_iris, mask_right_iris = None, None, None, None self.diff_renderer.reset() best_loss = np.inf n_steps_stagnant = 0 stagnant_window_size = 10 past_k_steps = np.array([100.0 for _ in range(stagnant_window_size)]) iterator = tqdm(range(iters), desc='', leave=True, miniters=100) for p in iterator: if is_joint and p == int(iters*0.5): for pgroup in optimizer.param_groups: if pgroup['name'] in ['t', 'R', 'jaw']: pgroup['lr'] = pgroup['lr'] / 10 print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}') else: pgroup['lr'] = pgroup['lr'] / 2 if is_joint and p == int(iters *0.75): for pgroup in optimizer.param_groups: if pgroup['name'] in ['t', 'R', 'jaw']: pgroup['lr'] = pgroup['lr'] / 5 print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}') else: pgroup['lr'] = pgroup['lr'] / 2 if is_joint and p == int(iters *0.9): for pgroup in optimizer.param_groups: if pgroup['name'] in ['t', 'R', 'jaw']: pgroup['lr'] = pgroup['lr'] / 2 print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}') else: pgroup['lr'] = pgroup['lr'] / 5 batch_joint, losses, vertices, vertices_noneck, vertices_can, vertices_can_can, proj_vertices, proj_lmks68, selected_frames, variables, num_views, normal_mask, normal_map, uv_map, uv_mask = self.opt_pre(is_joint, iters, p, no_lm, image_lmks68, lmk_mask, normal_mask, normal_map, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris) if is_joint: batch = batch_joint timestep = 0 ops = self.diff_renderer(vertices, None, None, self.r_mvps[timestep], self.R_base[timestep], self.t_base[timestep], texture_observation_mask=self.texture_observation_mask, verts_can=vertices_can, verts_noneck=vertices_noneck, verts_can_can=vertices_can_can, verts_depth=proj_vertices[:, :, 2:3], ) all_loss = self.opt_post(variables, ops, proj_vertices, proj_lmks68, batch, is_joint, is_first_step, losses, uv_map, selected_frames, p, iters, num_views, normal_mask, normal_map) #vertices.retain_grad() #if not self.init_done: all_loss.backward()#retain_graph=True) optimizer.step() optimizer.zero_grad() if is_joint: optimizer_id.step() optimizer_id.zero_grad() #if p == 0 or p == iters-1: #if p == iters-1:# and not self.config.low_overhead and False: #wandb.log(losses) self.global_step += 1 loss_color = all_loss.item() if loss_color < best_loss - 1.0: best_loss = loss_color n_steps_stagnant = 0 elif p > 25: # only start counting after n steps n_steps_stagnant += 1 if p > 0: past_k_steps[p%stagnant_window_size] = np.abs(all_loss.item() - prev_loss) prev_loss = all_loss.item() if (self.frame % 99 == 0 or p < 10) and is_joint: pass #with torch.no_grad(): # intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False) #proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics, # znear=0.1, zfar=5, # width=512, # height=512) #for serial in self.cam_pose_nvd.keys(): # extr = get_extrinsics(self.R_base[serial], self.t_base[serial]) # r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1)) # self.r_mvps[serial] = r_mvps #self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]], # frame_dst='/debug_joint', save=False, dump_directly=True, timestep=p, selected_frames=selected_frames, is_final=True) self.frame += 1 iterator.set_description(f'Timestep {save_timestep}; Loss {all_loss.item():.4f}') #if n_steps_stagnant > 35 and not is_joint: # print('Early Stopping, go to next frame!') # #break if not is_joint and not is_first_step: if p > stagnant_window_size and np.mean(past_k_steps) < self.config.early_stopping_delta: #3.0: #3.0: print('Early Stopping, go to next frame!') #losses['early_stopping'] = past_k_steps #wandb.log(losses) #wandb.log({'early_stopping': wandb.Histogram(past_k_steps)}) break #print('rate of change', np.mean(past_k_steps)) def render_and_save(self, batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.HEATMAP], [View.COLOR_OVERLAY, View.SHAPE_OVERLAY, View.SHAPE]], frame_dst='/video', save=True, dump_directly=False, outer_iter = None, is_camera : bool = False, all_keyframes : bool = False, timestep : int = 0, is_final : bool = False, selected_frames = None, ): batch = self.to_cuda(batch) images, landmarks, _ = self.parse_landmarks(batch) if 'uv_map' in batch: uv_map = batch['uv_map'] uv_mask = batch['uv_mask'] uv_map[(1-uv_mask).bool()] = 0 else: uv_map = None uv_mask = None if 'normals' in batch: normal_map = batch['normals'] else: normal_map = None if 'normal_map_can' in batch: normal_map_can = batch['normal_map_can'] else: normal_map_can = None savefolder = self.save_folder + self.actor_name + frame_dst num_keyframes = 1#1 + len(self.config.keyframes) with torch.no_grad(): self.diff_renderer.reset() num_views = len(self.R_base.keys()) bs = batch['normals'].shape[0] * num_keyframes #self.shape.shape[0] if selected_frames is None: exp = self.exp eyes = self.eyes eyelids = self.eyelids R = self.R t = self.t jaw = self.jaw neck = self.neck focal_length = self.focal_length principal_point = self.principal_point else: exp = self.exp(selected_frames) eyes = self.eyes(selected_frames) eyelids = self.eyelids(selected_frames) R = self.R(selected_frames) t = self.t(selected_frames) jaw = self.jaw(selected_frames) neck = self.neck(selected_frames) if not self.config.global_camera: focal_length = self.focal_length(selected_frames) principal_point = self.principal_point(selected_frames) else: focal_length = self.focal_length principal_point = self.principal_point with torch.no_grad(): intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False, size=self.config.size) proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics, znear=0.1, zfar=5, width=self.config.size, height=self.config.size) for serial in self.cam_pose_nvd.keys(): extr = get_extrinsics(self.R_base[serial], self.t_base[serial]) r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1)) self.r_mvps[serial] = r_mvps vertices_can, _lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame( #cameras=torch.inverse(self.R_base[0]), cameras=torch.inverse(self.R_base[0]).repeat(bs, 1, 1), shape_params=self.shape.repeat(bs, 1), expression_params=exp.repeat_interleave(num_views, dim=0), #torch.from_numpy(self.expression_params[:1, :]).cuda().repeat(bs, 1), #self.exp, eye_pose_params=eyes.repeat_interleave(num_views, dim=0), #euler_angles_to_matrix(x_opts['rotation'][i], 'XYZ') jaw_pose_params=jaw.repeat_interleave(num_views, dim=0), #matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[:1, :]).cuda(), 'XYZ')).repeat(bs, 1), #self.jaw, neck_pose_params=neck.repeat_interleave(num_views, dim=0), #matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[:1, :]).cuda(), 'XYZ')).repeat(bs, 1), #self.jaw, eyelid_params=eyelids.repeat_interleave(num_views, dim=0), rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(R)))).repeat_interleave(num_views, dim=0), ) lmk68 = torch.einsum('bny,bxy->bnx', _lmk68, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1) vertices = torch.einsum('bny,bxy->bnx', vertices_can, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1) vertices_noneck = torch.einsum('bny,bxy->bnx', vertices_noneck, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1) lmk68 = project_points_screen_space(lmk68, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size) proj_vertices = project_points_screen_space(vertices, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size) _timestep = 0 ops = self.diff_renderer(vertices, None, None, self.r_mvps[_timestep], self.R_base[_timestep], self.t_base[_timestep], verts_can=vertices_can, verts_noneck=vertices_noneck, verts_depth=proj_vertices[:, :, 2:3], is_viz=True ) # if they asked *only* for the pure shape mask: if visualizations == [[View.SHAPE]]: # build your normal‐map preview as before normals = ops['normal_images'][0].cpu().numpy() # [3,H,W] normals = (normals + 1.0) / 2.0 # → [0,1] normals = np.transpose(normals, (1,2,0)) # H×W×3 arr = (normals * 255).clip(0,255).astype(np.uint8) # --- export the posed mesh, using the correct face indices field --- os.makedirs(self.mesh_folder, exist_ok=True) frame_id = str(0).zfill(5) ply_path = os.path.join(self.mesh_folder, f"{frame_id}.glb") # pull out the face index tensor faces_np = self.faces.verts_idx.cpu().numpy() # `vertices` is your posed mesh: shape (1, V, 3) verts_np = vertices[0].detach().cpu().numpy() # 1) build your mesh (this will compute smooth normals automatically) mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_np) # 2) fetch those normals: shape is (V,3), each component in [-1,1] normals = mesh.vertex_normals # (V,3) numpy array # 3) convert them to RGB in [0,255]: # (n+1)/2 maps [-1,1]→[0,1], then *255→[0,255] colors = ((normals + 1.0) * 0.5 * 255.0).astype(np.uint8) # (V,3) # 4) you need RGBA for many formats—just set alpha=255 alpha = np.full((colors.shape[0],1), 255, dtype=np.uint8) vertex_colors = np.hstack([colors, alpha]) # (V,4) # 5) assign those as your mesh’s visual colors mesh.visual.vertex_colors = vertex_colors # 6) export—PLY or GLB both support vertex colors out_path = os.path.join(self.mesh_folder, f"{frame_id}.glb") mesh.export(out_path) return arr mask = (self.parse_mask(ops, batch, visualization=True) > 0).float() grabbed_depth = ops['actual_rendered_depth'][0, 0, torch.clamp(proj_vertices[0, :, 1].long(), 0, self.config.size-1), torch.clamp(proj_vertices[0, :, 0].long(), 0, self.config.size-1), ] is_visible_verts_idx = grabbed_depth < proj_vertices[0, :, 2] + 1e-2 if not self.config.occ_filter: is_visible_verts_idx = torch.ones_like(is_visible_verts_idx) if outer_iter is None: frame_id = str(self.frame).zfill(5) else: frame_id = str(self.frame + 10*outer_iter).zfill(5) if uv_map is not None and is_final: # uv losses visualizations proj_vertices = proj_vertices[:, self.uv_loss_fn.valid_vertex_index, :] can_uv = torch.from_numpy(np.load(env_paths.FLAME_UV_COORDS)).cuda().unsqueeze(0).float()[:, self.uv_loss_fn.valid_vertex_index, :] valid_verts_visibility = is_visible_verts_idx[self.uv_loss_fn.valid_vertex_index] #can_uv[..., 0] = (can_uv[..., 0] * -1) + 1 can_uv[..., 1] = (can_uv[..., 1] * -1) + 1 #can_uv = can_uv[:, ::50, :] gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1) gt_uv = gt_uv.reshape(gt_uv.shape[0], -1, 2) # B x n_pixel x 2 can_uv = can_uv.repeat(gt_uv.shape[0], 1, 1) knn_result = knn_points(can_uv, gt_uv) pixel_position_width = knn_result.idx % uv_map.shape[-1] pixel_position_height = knn_result.idx // uv_map.shape[-2] dists = knn_result.dists.clone() gt_2_verts = torch.cat([pixel_position_width, pixel_position_height], dim=-1) pred_normals = ops['normal_images'] # 1 3 512 512 normals in world space rot_mat = rotation_6d_to_matrix(R.detach().repeat_interleave(num_views, dim=0)) # 1 3 3 pred_normals_flame_space = torch.einsum('bxy,bxhw->byhw', rot_mat, pred_normals) delta = self.config.uv_loss.delta_uv catted_uv_rows = [] for b_i in range(images.shape[0]): empty = images[b_i].detach().cpu().numpy().copy().transpose(1, 2, 0) is_valid_uv_corresp = (dists[b_i, :, 0] < delta) & valid_verts_visibility valid_pred_2d = proj_vertices[b_i, is_valid_uv_corresp, :] valid_gt_2d = gt_2_verts[b_i, is_valid_uv_corresp, :] pixels_pred = torch.stack( [ torch.clamp(valid_pred_2d[:, 0], 0, images.shape[-1] - 1), torch.clamp(valid_pred_2d[:, 1], 0, images.shape[-2] - 1), ], dim=-1 ).int() pixels_gt = torch.stack( [ torch.clamp(valid_gt_2d[:, 0], 0, images.shape[-1] - 1), torch.clamp(valid_gt_2d[:, 1], 0, images.shape[-2] - 1), ], dim=-1 ).int() if self.config.draw_uv_corresp: empty = plot_points(empty, pts=pixels_pred.detach().cpu().numpy(), pts2=pixels_gt.detach().cpu().numpy()) gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1) upper_forehead = ((uv_map[:, 0, :, :].abs() < 0.85) & (uv_map[:, 0, :, :].abs() > (1 - 0.85)) & (uv_map[:, 1, :, :] < 0.35) & (uv_map[:, 1, :, :] > 0.)).float() upper_forehead = (gaussian_blur(upper_forehead, [self.config.normal_mask_ksize, self.config.normal_mask_ksize], sigma=[self.config.normal_mask_ksize, self.config.normal_mask_ksize]) > 0).float() losses_sil = ((1 - upper_forehead[:, None, :, :]) * (batch['fg_mask'] - ops['fg_images'])).abs().permute(0, 2, 3, 1) uv_loss = ((gt_uv - ops['uv_images']) * ops['mask_images'][:, 0, ...].unsqueeze(-1)).abs() #catted_uv = torch.cat([gt_uv[b_i], ops['uv_images'][b_i], uv_loss[b_i]], dim=1).detach().cpu().numpy() catted_uv = torch.cat([losses_sil[b_i][..., :2], uv_loss[b_i]], dim=1).detach().cpu().numpy() catted_uv_I = np.zeros([catted_uv.shape[0], catted_uv.shape[1], 3]) catted_uv_I[:, :, :2] = catted_uv catted_uv_I = (catted_uv_I * 255).astype(np.uint8) shape_mask = ((ops['alpha_images'] * ops['mask_images_mesh']) > 0.).int()[b_i] shape = (pred_normals_flame_space[b_i]+1)/2 * shape_mask blend = images[b_i] * (1 - shape_mask) + images[b_i] * shape_mask * 0.3 + shape * 0.7 * shape_mask to_be_catted = [(images[b_i].cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8), (blend.permute(1, 2, 0).detach().cpu().numpy()*255).astype(np.uint8), ] if self.config.draw_uv_corresp: to_be_catted.append(catted_uv_I) to_be_catted.append(empty) catted_uv_I = np.concatenate(to_be_catted, axis=1) catted_uv_rows.append(catted_uv_I) if normal_map is None: catted_uv_I = Image.fromarray(np.concatenate(catted_uv_rows, axis=0)) #pl = pv.Plotter() #pl.add_mesh(trim) #pl.add_points(visible_verts) #pl.show() else: catted_uv_I = None catted_uv_rows = [] if normal_map is not None: dilated_eye_mask = 1 - (gaussian_blur(ops['mask_images_eyes'], [self.config.normal_mask_ksize, self.config.normal_mask_ksize], sigma=[1, 1]) > 0).float() l_map = (normal_map - pred_normals_flame_space) valid = ((l_map.abs().sum(dim=1)/3) < self.config.delta_n).unsqueeze(1) predicted_normal = ((pred_normals_flame_space.permute(0, 2, 3, 1)[..., :3] + 1) / 2 * 255).detach().cpu().numpy().astype(np.uint8) if self.config.draw_uv_corresp: normal_loss_map = l_map * valid.float() * batch["normal_mask"] * dilated_eye_mask pseudo_normal = ((normal_map.permute(0, 2, 3, 1) + 1) / 2 * 255).detach().cpu().numpy().astype( np.uint8) normal_loss_map = ( (normal_loss_map.abs().permute(0, 2, 3, 1)) / 2 * 255).detach().cpu().numpy().astype( np.uint8) catted = np.concatenate([pseudo_normal, predicted_normal, normal_loss_map], axis=2) else: catted = predicted_normal # Image.fromarray(catted).show() # print('hi') for b_i in range(catted.shape[0]): if len(catted_uv_rows) > 0: catted_uv_rows[b_i] = np.concatenate([catted_uv_rows[b_i], catted[b_i]], axis=1) else: catted_uv_rows.append(catted[b_i]) catted_uv_I = Image.fromarray(np.concatenate(catted_uv_rows, axis=0)) #if catted_uv_I is not None: # save_fodler_uv = f'{savefolder}' # os.makedirs(save_fodler_uv, exist_ok=True) # if is_final: # catted_uv_I.save(f'{save_fodler_uv}/{timestep}.png') # else: # catted_uv_I.save(f'{save_fodler_uv}/{self.frame}.png') if not save: return # CHECKPOINT self.save_checkpoint(timestep, selected_frames=selected_frames) return catted_uv_I def parse_landmarks(self, batch): images = batch['rgb'] if 'lmk' in batch: landmarks = batch['lmk'] lmk68 = landmarks[:, WFLW_2_iBUG68, :] lmk_mask = ~(lmk68.sum(2, keepdim=True) == 0) batch['left_iris'] = landmarks[:, 96:97, :] batch['right_iris'] = landmarks[:, 97:98, :] batch['mask_left_iris'] = ~(landmarks.sum(2, keepdim=True) == 0)[:, 96:97, :] batch['mask_right_iris'] = ~(landmarks.sum(2, keepdim=True) == 0)[:, 97:98, :] landmarks = lmk68 else: landmarks = lmk_mask = None return images, landmarks, lmk_mask, def read_data(self, timestep): DATA_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}' P3DMM_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm/' try: rgb = np.array(Image.open(f'{DATA_FOLDER}/cropped/{timestep:05d}.jpg').resize((self.config.size, self.config.size))) / 255 except Exception as ex: rgb = np.array(Image.open(f'{DATA_FOLDER}/cropped/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255 mica_folder = f'{DATA_FOLDER}/mica' mica_files = os.listdir(mica_folder) mica_shapes = [] for mica_file in mica_files: mica_shape = np.load(f'{mica_folder}/{mica_file}/identity.npy') mica_shapes.append(np.squeeze(mica_shape)) mica_shapes = np.stack(mica_shapes, axis=0) if self.config.early_exit: mica_shape = mica_shapes[0, :] else: mica_shape = np.mean(mica_shapes, axis=0) seg = np.array(Image.open(f'{DATA_FOLDER}/seg_og/{timestep:05d}.png').resize((self.config.size, self.config.size), Image.NEAREST)) if len(seg.shape) == 3: seg = seg[..., 0] uv_mask = ((seg == 2) | (seg == 6) | (seg == 7) | (seg == 10) | (seg == 12) | (seg == 13) | (seg==1) | # neck (seg == 4) | (seg==5) # ears ) normal_mask = ((seg == 2) | (seg == 6) | (seg == 7) | (seg == 10) | (seg == 12) | (seg == 13) ) | (seg == 11) # mouth interior if self.config.big_normal_mask: normal_mask = normal_mask | (seg==1) | (seg == 4) | (seg==5) # add neck and ears fg_mask = ((seg == 2) | (seg == 6) | (seg == 7) | (seg == 8) | (seg == 9) | #(seg == 4) | (seg == 5) | (seg == 10) | (seg == 12) | (seg == 13) ) valid_bg = seg <= 1 try: normals = ((np.array(Image.open(f'{P3DMM_FOLDER}/normals/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype(np.float32) - 0.5 )*2 uv_map = (np.array(Image.open(f'{P3DMM_FOLDER}/uv_map/{timestep:05d}png').resize((self.config.size, self.config.size))) / 255).astype(np.float32) except Exception as ex: normals = ((np.array(Image.open(f'{P3DMM_FOLDER}/normals/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype( np.float32) - 0.5) * 2 uv_map = (np.array(Image.open(f'{P3DMM_FOLDER}/uv_map/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype(np.float32) try: lms = np.load(f'{DATA_FOLDER}/PIPnet_landmarks/{timestep:05d}.npy') * self.config.size except Exception as ex: lms = np.zeros([98, 2]) ret_dict = { 'rgb': rgb, 'mica_shape': mica_shape, 'normals': normals, 'uv_map': uv_map, 'uv_mask': uv_mask, 'normal_mask': normal_mask, 'fg_mask': fg_mask, 'valid_bg': valid_bg, } if lms is not None: ret_dict['lmk'] = lms ret_dict = {k: torch.from_numpy(v).float().unsqueeze(0).cuda() for k,v in ret_dict.items()} ret_dict['uv_mask'] = ret_dict['uv_mask'][:, :, :, None].repeat(1, 1, 1, 3) ret_dict['normal_mask'] = ret_dict['normal_mask'][:, :, :, None].repeat(1, 1, 1, 3) ret_dict['fg_mask'] = ret_dict['fg_mask'][:, :, :, None].repeat(1, 1, 1, 3) channels_first =['rgb', 'uv_mask', 'normal_mask', 'normals', 'uv_map', 'fg_mask'] for k in channels_first: ret_dict[k] = ret_dict[k].permute(0, 3, 1, 2) return ret_dict def prepare_global_optimization(self, N_FRAMES): is_sparse=True self.exp = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=100, sparse=is_sparse, ).cuda() self.R = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda() self.t = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=3, sparse=is_sparse).cuda() self.eyes = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=12, sparse=is_sparse).cuda() self.eyelids = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=12, sparse=is_sparse).cuda() self.jaw = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda() self.neck = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda() if not self.config.global_camera: self.focal_length = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=1, sparse=is_sparse).cuda() self.principal_point = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=2, sparse=is_sparse).cuda() exp = torch.cat(self.intermediate_exprs, dim=0) R = torch.cat(self.intermediate_Rs, dim=0) t = torch.cat(self.intermediate_ts, dim=0) eyes = torch.cat(self.intermediate_eyes, dim=0) eyelids = torch.cat(self.intermediate_eyelids, dim=0) jaw = torch.cat(self.intermediate_jaws, dim=0) neck = torch.cat(self.intermediate_necks, dim=0) if not self.config.global_camera: focal_length = torch.cat(self.intermediate_fls, dim=0) principal_point = torch.cat(self.intermediate_pps, dim=0) with torch.no_grad(): self.exp.weight = torch.nn.Parameter(exp) self.R.weight = torch.nn.Parameter(R) self.t.weight = torch.nn.Parameter(t) self.eyes.weight = torch.nn.Parameter(eyes) self.eyelids.weight = torch.nn.Parameter(eyelids) self.jaw.weight = torch.nn.Parameter(jaw) self.neck.weight = torch.nn.Parameter(neck) if not self.config.global_camera: self.focal_length.weight = torch.nn.Parameter(focal_length) self.principal_point.weight = torch.nn.Parameter(principal_point) def run(self): timestep = self.config.start_frame batch = self.read_data(timestep=timestep) # Important to initialize self.create_parameters(0, batch['mica_shape']) self.frame = 0 print(''' <<<<<<<< STARTING ONLINE TRACKING PHASE >>>>>>>> ''') for timestep in range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP): batch = self.read_data(timestep=timestep) for k in batch.keys(): if k not in self.cached_data: self.cached_data[k] = [batch[k]] else: self.cached_data[k].append(batch[k]) if timestep == self.config.start_frame: self.optimize_camera(batch, steps=500, is_first_frame=True) params = lambda: self.clone_params_keyframes_all(freeze_id=False, freeze_cam=self.config.global_camera, include_neck=self.config.include_neck) is_first_step = True else: if self.config.extra_cam_steps: self.optimize_camera(batch, steps=10, is_first_frame=False) params = lambda: self.clone_params_keyframes_all(freeze_id=True, freeze_cam=self.config.global_camera, include_neck=self.config.include_neck) is_first_step = False self.optimize_color(batch, params, no_lm=self.no_lm, save_timestep=timestep, is_first_step=is_first_step ) self.uv_loss_fn.is_next() #self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.COLOR_OVERLAY, View.LANDMARKS, View.SHAPE]], frame_dst='/initialization', outer_iter=0, timestep=timestep, is_final=True, save=True) self.frame += 1 # save results for global optimization later self.intermediate_exprs.append(self.exp.detach().clone()) self.intermediate_Rs.append(self.R.detach().clone()) self.intermediate_ts.append(self.t.detach().clone()) self.intermediate_eyes.append(self.eyes.detach().clone()) self.intermediate_eyelids.append(self.eyelids.detach().clone()) self.intermediate_jaws.append(self.jaw.detach().clone()) self.intermediate_necks.append(self.neck.detach().clone()) if not self.config.global_camera: self.intermediate_fls.append(self.focal_length.detach().clone()) self.intermediate_pps.append(self.principal_point.detach().clone()) if self.config.early_exit: exit() for k in self.cached_data.keys(): self.cached_data[k] = torch.cat(self.cached_data[k], dim=0) params = lambda: self.clone_params_keyframes_all_joint(freeze_id=False, is_joint=True, include_neck=self.config.include_neck) if self.config.uv_map_super > 0.0: self.uv_loss_fn.finish_stage1() self.config.iters = self.config.global_iters #self.config.iters * 10 N_FRAMES = len(self.intermediate_exprs) #build optimization targets for global optimization, implement as sparse torch.Embedding self.prepare_global_optimization(N_FRAMES=N_FRAMES) if COMPILE: self.flame = torch.compile(self.flame) self.opt_pre = torch.compile(self.opt_pre) self.opt_post = torch.compile(self.opt_post) print(''' <<<<<<<< STARTING GLOBAL TRACKING PHASE >>>>>>>> ''') if N_FRAMES > 1: self.optimize_color(None, params, no_lm=self.no_lm, save_timestep=1000, #timestep, is_joint=True, ) # render result and save it as a video to get some viusal feedback video_frames = [] for it, timestep in enumerate(range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP)): selected_frames = [] selected_frames_loading = [] batches = [] batch = self.read_data(timestep=timestep) batches.append(batch) selected_frames.append(it) selected_frames_loading.append(timestep) batches = {k: torch.cat([x[k] for x in batches], dim=0) for k in batch.keys()} selected_frames = torch.from_numpy(np.array(selected_frames)).long().cuda() result_rendering = self.render_and_save(batch, visualizations=[[View.SHAPE]], # ← only mesh by default frame_dst='/video', save=True, dump_directly=False, outer_iter=0, timestep=timestep, is_final=True, selected_frames=selected_frames) video_frames.append(np.array(result_rendering)) self.frame += 1 out_dir = f"{self.save_folder}/{self.config.video_name}/frames" os.makedirs(out_dir, exist_ok=True) for i, frame in enumerate(video_frames): # If float in [0,1], convert: if frame.dtype != np.uint8: frame_uint8 = (frame * 255).astype(np.uint8) else: frame_uint8 = frame # OpenCV expects BGR ordering: bgr = cv2.cvtColor(frame_uint8, cv2.COLOR_RGB2BGR) cv2.imwrite(os.path.join(out_dir, f"{i:05d}.jpg"), bgr) print(f"✅ Saved {len(video_frames)} frames to `{out_dir}`") # Optionally delete all preoprocessing artifacts, once tracking is done (only keep cropped images) if self.config.delete_preprocessing: shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/mica') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm_wGT') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm_extraViz') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/pipnet') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/PIPnet_annotated_images') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/PIPnet_landmarks') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/rgb') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/seg_non_crop_annotations') shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/seg_og') print(f''' <<<<<<<< DONE WITH TRACKING {self.actor_name} >>>>>>>> ''')