alexnasa's picture
Update src/pixel3dmm/tracking/tracker.py
219e22c verified
import shutil
import mediapy
from PIL import Image, ImageDraw
import os.path
from enum import Enum
from pathlib import Path
import wandb
import time
import cv2
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import trimesh
from pytorch3d.io import load_obj
from pytorch3d.ops import knn_points, knn_gather
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from torchvision.transforms.functional import gaussian_blur
from time import time
import pyvista as pv
import dreifus
from dreifus.matrix import Pose, Intrinsics, CameraCoordinateConvention, PoseType
from dreifus.pyvista import add_camera_frustum, render_from_camera
from pixel3dmm import env_paths
from pixel3dmm.tracking import util
from pixel3dmm.tracking.losses import UVLoss
from pixel3dmm.tracking import nvdiffrast_util
from pixel3dmm.tracking.renderer_nvdiffrast import NVDRenderer
from pixel3dmm import env_paths
from pixel3dmm.tracking.flame.FLAME import FLAME
from pixel3dmm.utils.misc import tensor2im
from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix, matrix_to_rotation_6d, euler_angles_to_matrix
from pixel3dmm.utils.drawing import plot_points
def timeit(t0, tag):
t1 = time()
#print(f'[PROFILER]: {tag} took {t1-t0} seconds')
return t1
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
rank = 42
torch.manual_seed(rank)
torch.cuda.manual_seed(rank)
cudnn.benchmark = True
np.random.seed(rank)
I = torch.eye(3)[None].cuda().detach()
I6D = matrix_to_rotation_6d(I)
left_iris_flame = [4597, 4542, 4510, 4603, 4570]
right_iris_flame = [4051, 3996, 3964, 3932, 4028]
left_iris_mp = [468, 469, 470, 471, 472]
right_iris_mp = [473, 474, 475, 476, 477]
torch.set_float32_matmul_precision('high')
class View(Enum):
GROUND_TRUTH = 1
COLOR_OVERLAY = 2
SHAPE_OVERLAY = 4
SHAPE = 8
LANDMARKS = 16
HEATMAP = 32
DEPTH = 64
def get_intrinsics(focal_length, principal_point, use_hack : bool = True, size : int = 512):
intrinsics = torch.eye(3)[None, ...].float().cuda().repeat(focal_length.shape[0], 1,1 )
intrinsics[:, 0, 0] = focal_length.squeeze() * size
intrinsics[:, 1, 1] = focal_length.squeeze() * size
intrinsics[:, :2, 2] = size/2+0.5 + principal_point * (size/2+0.5)
if use_hack:
intrinsics[:, 0:1, 2:3] = size - intrinsics[:, 0:1, 2:3] # TODO fix this hack
return intrinsics
def get_extrinsics(R_base, t_base):
timestep = 0
w2c_openGL = torch.eye(4)[None, ...].float().cuda()
w2c_openGL[:, :3, :3] = R_base[timestep]
w2c_openGL[:, :3, 3] = t_base[timestep]
return w2c_openGL
def project_points_screen_space(points3d, focal_length, principal_point, R_base, t_base, size : int = 512):
# construct camera matrices
intrinsics = get_intrinsics(focal_length, principal_point, size=size)
w2c_openGL = get_extrinsics(R_base, t_base).repeat(focal_length.shape[0], 1, 1)
B = points3d.shape[0]
reps_extr = B if w2c_openGL.shape[0] == 1 else 1
reps_intr = B if intrinsics.shape[0] == 1 else 1
# apply w2c transformation
lmk68_cam_space = torch.bmm(
torch.cat([points3d, torch.ones_like(points3d[..., :1])], dim=-1),
w2c_openGL.permute(0, 2, 1).repeat(reps_extr, 1, 1))
# project from cam_space to screen_space
lmk68_cam_space_prime = lmk68_cam_space[..., :3] / -lmk68_cam_space[..., [2]]
lmk68_screen_space = (-1) * torch.bmm(lmk68_cam_space_prime, intrinsics.permute(0, 2, 1).repeat(reps_intr, 1, 1))[..., :2]
lmk68_screen_space = torch.stack([size - 1 - lmk68_screen_space[..., 0], lmk68_screen_space[..., 1], lmk68_cam_space[..., 2]], dim=-1)
return lmk68_screen_space
WFLW_2_iBUG68 = np.array(
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 33, 34, 35, 36, 37, 42, 43, 44, 45, 46, 51,
52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 63, 64, 65, 67, 68, 69, 71, 72, 73, 75, 76, 77, 78, 79, 80, 81, 82,
83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95])
WFLW_2_iBUG68 = torch.from_numpy(WFLW_2_iBUG68).cuda()
COMPILE = True
if COMPILE:
project_points_screen_space = torch.compile(project_points_screen_space)
class Tracker(object):
def __init__(self, config, flame_module, renderer,
device='cuda:0',
):
self.config = config
self.flame = flame_module
self.diff_renderer = renderer
self.config = config
self.device = device
self.actor_name = self.config.video_name
DATA_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.actor_name}'
self.MAX_STEPS = min(len([f for f in os.listdir(f'{DATA_FOLDER}/cropped/') if f.endswith('.jpg') or f.endswith('.png')]) - self.config.start_frame, 1000)
self.FRAME_SKIP = 1
self.BATCH_SIZE = self.config.batch_size
print(f'''
<<<<<<<< INITIALIZING TRACKER INSTANCE FOR {self.actor_name} >>>>>>>>
''')
self.mirror_order = torch.from_numpy(np.load(f'{env_paths.MIRROR_INDEX}')).long().cuda()
self.uv_loss_fn = UVLoss(stricter_mask=self.config.uv_loss.stricter_uv_mask,
delta_uv= self.config.uv_loss.delta_uv,
dist_uv=self.config.uv_loss.dist_uv)
if COMPILE:
self.uv_loss_fn.compute_loss = torch.compile(self.uv_loss_fn.compute_loss)
self.actor_name = self.actor_name + f'_nV{config.num_views}'
if config.no_lm:
self.actor_name = self.actor_name + '_noLM'
if config.no_pho:
self.actor_name = self.actor_name + '_noPho'
if self.config.ignore_mica:
self.actor_name = self.actor_name + '_noMICA'
if self.config.flame2023:
self.actor_name = self.actor_name + '_FLAME23'
if self.config.uv_map_super > 0:
self.actor_name = self.actor_name + f'_uv{self.config.uv_map_super}'
if self.config.normal_super > 0:
self.actor_name = self.actor_name + f'_n{self.config.normal_super}'
if self.config.normal_super_can > 0:
self.actor_name = self.actor_name + f'_nc{self.config.normal_super_can}'
self.global_step = 0
self.no_sh = config.no_sh
self.no_lm = config.no_lm
self.no_pho = config.no_pho
# Latter will be set up
self.frame = 0
self.is_initializing = False
self.image_size = torch.tensor([[config.image_size[0], config.image_size[1]]]).cuda()
if hasattr(self.config, 'output_folder'):
self.save_folder = self.config.output_folder
else:
self.save_folder = env_paths.TRACKING_OUTPUT
self.output_folder = os.path.join(self.save_folder, self.actor_name)
self.checkpoint_folder = os.path.join(self.save_folder, self.actor_name, "checkpoint")
self.mesh_folder = os.path.join(self.save_folder, self.config.video_name, "mesh")
self.create_output_folders()
self.writer = SummaryWriter(log_dir=self.save_folder + self.actor_name + '/logs')
self.cam_pose_nvd = {}
self.R_base = {}
self.t_base = {}
flame_mesh_mask = np.load(f'{env_paths.FLAME_MASK_ASSET}/FLAME2020/FLAME_masks/FLAME_masks.pkl', allow_pickle=True, encoding='latin1')
self.vertex_face_mask = torch.from_numpy(flame_mesh_mask['face']).cuda().long()
self.setup_renderer()
self.intermediate_exprs = []
self.intermediate_Rs = []
self.intermediate_ts = []
self.intermediate_eyes = []
self.intermediate_eyelids = []
self.intermediate_jaws = []
self.intermediate_necks = []
self.intermediate_fls = []
self.intermediate_pps = []
self.cached_data = {}
def get_image_size(self):
return self.image_size[0][0].item(), self.image_size[0][1].item()
def create_output_folders(self):
Path(self.save_folder).mkdir(parents=True, exist_ok=True)
Path(self.checkpoint_folder).mkdir(parents=True, exist_ok=True)
Path(self.mesh_folder).mkdir(parents=True, exist_ok=True)
def setup_renderer(self):
mesh_file = f'{env_paths.head_template}'
self.config.image_size = self.get_image_size()
self.flame.vertex_face_mask = self.vertex_face_mask
if COMPILE:
self.flame = torch.compile(self.flame)
self.opt_pre = torch.compile(self.opt_pre)
self.opt_post = torch.compile(self.opt_post)
self.actual_smooth = torch.compile(self.actual_smooth)
self.renderer = self.diff_renderer # already global
self.faces = load_obj(mesh_file)[1]
def save_checkpoint(self, frame_id, selected_frames = None):
if selected_frames is None:
exp = self.exp
eyes = self.eyes
eyelids = self.eyelids
R = self.R
t = self.t
jaw = self.jaw
neck = self.neck
focal_length = self.focal_length
principal_point = self.principal_point
else:
exp = self.exp(selected_frames)
eyes = self.eyes(selected_frames)
eyelids = self.eyelids(selected_frames)
R = self.R(selected_frames)
t = self.t(selected_frames)
jaw = self.jaw(selected_frames)
neck = self.neck(selected_frames)
if self.config.global_camera:
focal_length = self.focal_length
principal_point = self.principal_point
else:
focal_length = self.focal_length(selected_frames)
principal_point = self.principal_point(selected_frames)
frame = {
'flame': {
'exp': exp.clone().detach().cpu().numpy(),
'shape': self.shape.clone().detach().cpu().numpy(),
'eyes': eyes.clone().detach().cpu().numpy(),
'eyelids': eyelids.clone().detach().cpu().numpy(),
'jaw': jaw.clone().detach().cpu().numpy(),
'neck': neck.clone().detach().cpu().numpy(),
'R': R.clone().detach().cpu().numpy(),
'R_rotation_matrix': rotation_6d_to_matrix(R).detach().cpu().numpy(),
't': t.clone().detach().cpu().numpy(),
},
'img_size': self.image_size.clone().detach().cpu().numpy()[0],
'frame_id': frame_id,
'global_step': self.global_step
}
cam_params = {
f'R_base_{serial}': self.R_base[serial].clone().detach().cpu().numpy() for serial in self.R_base.keys()
}
cam_pos = {
f't_base_{serial}': self.t_base[serial].clone().detach().cpu().numpy() for serial in self.R_base.keys()
}
intr = {
'fl': focal_length.clone().detach().cpu().numpy(),
'pp': principal_point.clone().detach().cpu().numpy(),
}
cam_params.update(cam_pos)
cam_params.update(intr)
frame.update(
{
f'camera': cam_params
}
)
bs = exp.shape[0]
vertices, lmks, joint_transforms, vertices_can, vertices_noneck = self.flame(cameras=torch.inverse(self.R_base[0])[:1, ...].repeat(bs, 1, 1),
shape_params=self.shape[:1, ...].repeat(bs, 1),
expression_params=exp,
eye_pose_params=eyes,
jaw_pose_params=jaw,
neck_pose_params=neck,
rot_params_lmk_shift=R,
eyelid_params=eyelids,
)
frame.update(
{
f'joint_transforms': joint_transforms.detach().cpu().numpy(),
}
)
f = self.diff_renderer.faces[0].cpu().numpy()
for b_i in range(bs):
v = vertices[b_i].cpu().numpy()
if self.config.save_meshes:
trimesh.Trimesh(faces=f, vertices=v, process=False).export(f'{self.mesh_folder}/{frame_id:05d}.glb')
torch.save(frame, f'{self.checkpoint_folder}/{frame_id:05d}.frame')
selction_indx = np.array([36, 39, 42, 45, 33, 48, 54])
_lmks = lmks[b_i].detach().squeeze().cpu().numpy()
if self.config.save_landmarks:
np.save(f'{self.mesh_folder}/landmarks_{frame_id}_{b_i}.npy', _lmks[selction_indx])
if frame_id == self.config.start_frame and self.config.save_meshes:
faces = self.diff_renderer.faces[0].cpu().numpy()
trimesh.Trimesh(faces=faces, vertices=vertices_can[0].detach().cpu().numpy(), process=False).export(f'{self.mesh_folder}/canonical.glb')
if self.config.save_landmarks:
lmks = lmks.detach().squeeze().cpu().numpy()
np.save(f'{self.mesh_folder}/ibug68_{frame_id}.glb', lmks)
selction_indx = np.array([36, 39, 42, 45, 33, 48, 54])
np.save(f'{self.mesh_folder}/now_{frame_id}.glb', lmks[selction_indx])
def get_heatmap(self, values):
l2 = tensor2im(values)
l2 = cv2.cvtColor(l2, cv2.COLOR_RGB2BGR)
#l2[l2 > 125] = 125
#l2 = cv2.normalize(l2, None, 0, 255, cv2.NORM_MINMAX)
#l2[l2 > 35] = 35
#l2 = cv2.normalize(l2, None, 0, 255, cv2.NORM_MINMAX)
l2 = l2 - 127
max_err = 25
l2[l2>max_err] = max_err
l2 = ((l2 / max_err)*255).astype(np.uint8)
heatmap = cv2.applyColorMap(l2, cv2.COLORMAP_JET) #/ 255.
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB) / 255.
#heatmap = heatmap ** (1/3)
#Image.fromarray((heatmap*255).astype(np.uint8)).show()
#exit()
#heatmap = cv2.cvtColor(cv2.addWeighted(heatmap, 0.75, l2, 0.25, 0).astype(np.uint8), cv2.COLOR_BGR2RGB) / 255.
heatmap = torch.from_numpy(heatmap).permute(2, 0, 1)
return heatmap
def to_cuda(self, batch, unsqueeze=False):
for key in batch.keys():
if torch.is_tensor(batch[key]):
batch[key] = batch[key].to(self.device)
if unsqueeze:
batch[key] = batch[key][None]
return batch
def create_parameters(self, timestep, mica_shape):
bz = 1
pose_mat = np.eye(4)
pose_mat[2, 3] = -1
opencv_w2c_pose = Pose(pose_mat, camera_coordinate_convention=dreifus.matrix.CameraCoordinateConvention.OPEN_CV)
opencv_w2c_pose = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.CAM_2_WORLD)
opencv_w2c_pose.look_at(np.zeros(3), np.array([0, 1, 0]))
opencv_w2c_pose = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM)
self.debug_pose_init = opencv_w2c_pose.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM).copy()
self.shape = mica_shape.detach().clone()
self.mica_shape = mica_shape.detach().clone()
if self.config.ignore_mica:
self.shape = torch.zeros_like(self.shape)
self.mica_shape = torch.zeros_like(self.mica_shape)
cam_pose = opencv_w2c_pose
cam_pose = cam_pose.change_pose_type(dreifus.matrix.PoseType.CAM_2_WORLD)
cam_pose_nvd = cam_pose.copy()
cam_pose_nvd = cam_pose_nvd.change_camera_coordinate_convention(new_camera_coordinate_convention=dreifus.matrix.CameraCoordinateConvention.OPEN_GL)
cam_pose_nvd = cam_pose_nvd.change_pose_type(dreifus.matrix.PoseType.WORLD_2_CAM)
self.cam_pose_nvd[timestep] = torch.from_numpy(cam_pose_nvd.copy()).float().cuda()
R = torch.from_numpy(cam_pose_nvd.get_rotation_matrix()).unsqueeze(0).cuda()
T = torch.from_numpy(cam_pose_nvd.get_translation()).unsqueeze(0).cuda()
R.requires_grad = True
T.requires_grad = True
self.R_base[timestep] = R
self.t_base[timestep] = T
init_f = 2000 * self.config.size/512
self.focal_length = torch.tensor([[init_f/self.config.size]]).float().to(self.device)
self.principal_point = torch.tensor([[0, 0]]).float().to(self.device)
self.focal_length.requires_grad = True
self.principal_point.requires_grad = True
intrinsics = torch.tensor([[init_f, 0, self.config.size//2],
[0, init_f, self.config.size//2],
[0, 0, 1]]).float().cuda()
proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics,
znear=0.1, zfar=10,
width=self.config.size,
height=self.config.size)
self.r_mvps = {}
for serial in self.cam_pose_nvd.keys():
self.r_mvps[serial] = ( proj_512 @ self.cam_pose_nvd[serial] )[None, ...]
n_timesteps = 1
expression_params = np.zeros([n_timesteps, 100])
jaw_params = np.zeros([n_timesteps, 3])
neck_params = np.zeros([n_timesteps, 3])
flame_R = torch.from_numpy(np.stack([np.eye(3) for _ in range(n_timesteps)], axis=0))
flame_t = torch.from_numpy(np.stack([np.zeros([3]) for _ in range(n_timesteps)], axis=0))
self.R = nn.Parameter(matrix_to_rotation_6d(flame_R.float().to(self.device)))
self.t = nn.Parameter(flame_t.float().to(self.device))
self.expression_params = expression_params
self.jaw_params = jaw_params.astype(np.float32)
self.neck_params = neck_params.astype(np.float32)
self.shape = nn.Parameter(self.mica_shape.detach().clone())
self.texture_observation_mask = None
self.exp = nn.Parameter(torch.from_numpy(self.expression_params[[0] + self.config.keyframes,..., :]).float().to(self.device))
self.jaw = nn.Parameter(matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[[0]+ self.config.keyframes,..., :]).cuda(), 'XYZ')))
self.neck = nn.Parameter(matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.neck_params[[0]+ self.config.keyframes,..., :]).cuda(), 'XYZ')))
self.eyes = nn.Parameter(torch.cat([matrix_to_rotation_6d(I), matrix_to_rotation_6d(I)], dim=1).repeat(1+len(self.config.keyframes), 1) )
self.eyelids = nn.Parameter(torch.zeros(1+len(self.config.keyframes), 2).float().to(self.device))
def parse_mask(self, ops, batch, visualization=False):
result = ops['mask_images_rendering']
if visualization:
result = ops['mask_images']
return result.detach()
def clone_params_keyframes_all(self, freeze_id : bool = False, is_joint : bool = False, freeze_cam : bool = False,
include_neck : bool = False):
lr_scale = 1.0
lr_scale_id_related = 1.0
if freeze_id:
lr_scale_id_related = 0.1
params = [
{'params': [self.exp], 'lr': self.config.lr_exp * lr_scale, 'name': ['exp']}, # 0.025
{'params': [self.eyes], 'lr': 0.005 * lr_scale, 'name': ['eyes']},
# {'params': [self.eyelids.clone())], 'lr': 0.001, 'name': ['eyelids']},
{'params': [self.eyelids], 'lr': 0.002 * lr_scale, 'name': ['eyelids']},
# {'params': [self.sh.clone())], 'lr': 0.01, 'name': ['sh']},
{'params': [self.t], 'lr': self.config.lr_t * lr_scale, 'name': ['t']},
#{'params': [self.t.clone())], 'lr': 0.005 * lr_scale, 'name': ['t']},
{'params': [self.R], 'lr': self.config.lr_R * lr_scale, 'name': ['R']},
#{'params': [self.R.clone())], 'lr': 0.003 * lr_scale, 'name': ['R']},
# {'params': [self.tex.clone())], 'lr': 0.001, 'name': ['tex']},
# {'params': [self.principal_point.clone())], 'lr': 0.001, 'name': ['principal_point']},
# {'params': [self.focal_length.clone())], 'lr': 0.001, 'name': ['focal_length']}
]
#params.append({'params': [self.shape.clone())], 'lr': self.config.lr_id * lr_scale, 'name': ['shape']})
if not freeze_id:
if is_joint:
params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale * 1, 'name': ['shape']})
else:
params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale, 'name': ['shape']})
#params.append({'params': [self.shape], 'lr': 0.0, 'name': ['shape']})
params.append({'params': [self.jaw], 'lr': self.config.lr_jaw * lr_scale, 'name': ['jaw']})
if include_neck:
params.append({'params': [self.neck], 'lr': self.config.lr_neck, 'name': ['neck']})
# params.append({'params': [self.t], 'lr': 0.001, 'name': ['translation']})
# params.append({'params': [self.R], 'lr': 0.005, 'name': ['rotation']})
# params.append({'params': [self.focal_length, self.principal_point], 'lr': 0.01*lr_scale, 'name': ['camera_params']})
#if not self.config.load_intr:
if not freeze_cam:
params.append({'params': [self.focal_length], 'lr': self.config.lr_f * lr_scale_id_related, 'name': ['camera_params']})
params.append({'params': [self.principal_point], 'lr': self.config.lr_pp * lr_scale_id_related, 'name': ['camera_params']})
return params
def clone_params_keyframes_all_joint(self, freeze_id : bool = False, is_joint : bool = False,
include_neck : bool = False):
lr_scale = 1.0
lr_scale_id_related = 1.0
if freeze_id:
lr_scale_id_related = 0.1
params = [
{'params': self.exp.parameters(), 'lr': self.config.lr_exp * lr_scale, 'name': ['exp']}, # 0.025
{'params': self.eyes.parameters(), 'lr': 0.005 * lr_scale, 'name': ['eyes']},
{'params': self.eyelids.parameters(), 'lr': 0.002 * lr_scale, 'name': ['eyelids']},
{'params': self.t.parameters(), 'lr': self.config.lr_t * lr_scale, 'name': ['t']},
{'params': self.R.parameters(), 'lr': self.config.lr_R * lr_scale, 'name': ['R']},
]
params.append({'params': self.jaw.parameters(), 'lr': self.config.lr_jaw * lr_scale, 'name': ['jaw']})
if include_neck:
params.append({'params': self.neck.parameters(), 'lr': self.config.lr_neck, 'name': ['jaw']})
if not self.config.global_camera:
params.append({'params': self.focal_length.parameters(), 'lr': self.config.lr_f * lr_scale_id_related,
'name': ['camera_params']})
params.append({'params': self.principal_point.parameters(), 'lr': self.config.lr_pp * lr_scale_id_related,
'name': ['camera_params']})
#params.append({'params': [self.shape], 'lr': self.config.lr_id * lr_scale * 1, 'name': ['shape']})
return params
def reduce_loss(self, losses):
all_loss = 0.
for key in losses.keys():
all_loss = all_loss + losses[key]
losses['all_loss'] = all_loss
return all_loss
def optimize_camera(self, batch, steps=2000, is_first_frame : bool = False
):
batch = self.to_cuda(batch)
images, landmarks, lmk_mask = self.parse_landmarks(batch)
h, w = images.shape[2:4]
num_keyframes = 1
uv_mask = batch["uv_mask"]
uv_map = batch["uv_map"] if "uv_map" in batch else None
if uv_map is not None:
uv_map[(1 - uv_mask[:, :, :, :]).bool()] = 0
self.focal_length.requires_grad = True
self.principal_point.requires_grad = True
lr_mult = 1.0
params = [
{'params': [self.t], 'lr': lr_mult*0.001}, ##0.05},
{'params': [self.R], 'lr': lr_mult*0.005}, #0.05},
]
if is_first_frame:
params.append({'params': [self.focal_length], 'lr': 0.02})
params.append({'params': [self.principal_point], 'lr': 0.0001})
optimizer = torch.optim.Adam(params)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(steps*0.75),
gamma=0.1)
#self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]],
# frame_dst='/camera', save=False, dump_directly=True)
t = tqdm(range(steps), desc='', leave=True, miniters=100)
num_views = 1 #len(self.R_base.keys())
bs = 1 #len(self.cam_serials) * num_keyframes
for k in t:
vertices_can, lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame(cameras=torch.inverse(self.R_base[0]),
shape_params=self.shape if self.shape.shape[0] == bs else self.shape.repeat(bs, 1),
expression_params=self.exp.repeat_interleave(num_views, dim=0),
eye_pose_params=self.eyes.repeat_interleave(num_views, dim=0),
jaw_pose_params=self.jaw.repeat_interleave(num_views, dim=0),
neck_pose_params=self.neck.repeat_interleave(num_views, dim=0),
rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(self.R)))).repeat_interleave(num_views, dim=0),
)
lmk68 = torch.einsum('bny,bxy->bnx', lmk68,
rotation_6d_to_matrix(self.R.repeat_interleave(num_views, dim=0))) + self.t.repeat_interleave(num_views, dim=0).unsqueeze(1)
verts = torch.einsum('bny,bxy->bnx', vertices_can,
rotation_6d_to_matrix(
self.R.repeat_interleave(num_views, dim=0))) + self.t.repeat_interleave(num_views,
dim=0).unsqueeze(
1)
lmk68_screen_space = project_points_screen_space(lmk68, self.focal_length, self.principal_point, self.R_base, self.t_base, size=self.config.size)
verts_screen_space = project_points_screen_space(verts, self.focal_length, self.principal_point, self.R_base, self.t_base, size=self.config.size)
losses = {}
losses['pp_reg'] = torch.sum(self.principal_point ** 2)
if k <= steps // 2:
losses['lmk68'] = util.lmk_loss(lmk68_screen_space[..., :2], landmarks[..., :2], [h, w], lmk_mask) * 3000
if k == 0:
self.uv_loss_fn.compute_corresp(uv_map)
if k > steps // 2:
uv_loss = self.uv_loss_fn.compute_loss(verts_screen_space)
losses['uv_loss'] = uv_loss * 1000
all_loss = 0.
for key in losses.keys():
all_loss = all_loss + losses[key]
losses['all_loss'] = all_loss
optimizer.zero_grad()
all_loss.backward()
optimizer.step()
scheduler.step()
optimizer.zero_grad()
intrinsics = get_intrinsics(self.focal_length, self.principal_point, use_hack=False, size=self.config.size)
proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics[0],
znear=0.1, zfar=5,
width=self.config.size,
height=self.config.size)
for serial in self.cam_pose_nvd.keys():
extr = get_extrinsics(self.R_base[serial], self.t_base[serial])
r_mvps = proj_512 @ extr
self.r_mvps[serial] = r_mvps
loss = all_loss.item()
t.set_description(f'Loss for camera {loss:.4f}')
self.frame += 1
#if k % 100 == 0:
# self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY, View.COLOR_OVERLAY]], frame_dst='/camera', save=False, dump_directly=True, is_camera=True)
self.frame = 0
@torch.compiler.disable
def get_vars(self, is_joint, selected_frames):
if not is_joint:
exp = self.exp
eyes = self.eyes
eyelids = self.eyelids
_R = self.R
_t = self.t
jaw = self.jaw
neck = self.neck
focal_length = self.focal_length
principal_point = self.principal_point
else:
selected_frames = torch.from_numpy(selected_frames).long().cuda()
exp = self.exp(selected_frames)
eyes = self.eyes(selected_frames)
eyelids = self.eyelids(selected_frames)
_R = self.R(selected_frames)
_t = self.t(selected_frames)
jaw = self.jaw(selected_frames)
neck = self.neck(selected_frames)
if not self.config.global_camera:
focal_length = self.focal_length(selected_frames)
principal_point = self.principal_point(selected_frames)
else:
focal_length = self.focal_length
principal_point = self.principal_point
return exp, eyes, eyelids, _R, _t, jaw, neck, focal_length, principal_point
@torch.compiler.disable
def data_stuff(self, is_joint, iters, p, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris):
if is_joint:
with torch.no_grad():
if (p < int(iters * 0.15) and (p % 2 == 0)) or not self.config.smooth:
all_frames = np.array(
range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP))
selected_frames = np.sort(np.random.choice(np.arange(len(all_frames)), size=self.BATCH_SIZE,
replace=False)) # np.random.choice(
else:
all_frames = np.array(
range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP))
start = np.min(all_frames)
end = np.max(all_frames)
rnd_start = np.random.randint(start, end)
assert (end - start) >= self.BATCH_SIZE + 1
assert self.BATCH_SIZE % 2 == 0
if rnd_start - self.BATCH_SIZE // 2 < 0:
rnd_start = self.BATCH_SIZE // 2
if rnd_start + self.BATCH_SIZE // 2 + 1 > end:
rnd_start = end - self.BATCH_SIZE // 2 + 1
selected_frames = np.array(
list(range(rnd_start - self.BATCH_SIZE // 2, rnd_start + self.BATCH_SIZE // 2)))
selected_frames_th = torch.from_numpy(selected_frames).long()
batch = {k: self.cached_data[k][selected_frames_th, ...] for k in self.cached_data.keys()}
images, landmarks, lmk_mask = self.parse_landmarks(batch)
uv_mask = batch["uv_mask"]
normal_mask = batch["normal_mask"]
normal_map = batch["normals"] if "normals" in batch else None
uv_map = batch["uv_map"] if "uv_map" in batch else None
#TODO check if this was important in any way
if uv_map is not None:
uv_map[(1 - uv_mask[:, :, :, :]).bool()] = 0
num_views = len(self.R_base.keys())
bs = batch['normals'].shape[0] * num_views
image_lmks68 = landmarks
if landmarks is not None:
left_iris = batch['left_iris']
right_iris = batch['right_iris']
mask_left_iris = batch['mask_left_iris']
mask_right_iris = batch['mask_right_iris']
else:
selected_frames = None
bs = 1
num_views = 1
batch = None
return selected_frames, batch, bs, num_views, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris
#TODO: could be improved by compiling all the actuall smooth loss stuff
#@torch.compile
def actual_smooth(self, variables, losses):
reg_smooth_exp = (variables['exp'][:-1, :] - variables['exp'][1:, :]).square().mean()
reg_smooth_eyes = (variables['eyes'][:-1, :] - variables['eyes'][1:, :]).square().mean()
reg_smooth_eyelids = (variables['eyelids'][:-1, :] - variables['eyelids'][1:, :]).square().mean()
reg_smooth_R = (variables['R'][:-1, :] - variables['R'][1:, :]).square().mean()
reg_smooth_t = (variables['t'][:-1, :] - variables['t'][1:, :]).square().mean()
reg_smooth_jaw = (variables['jaw'][:-1, :] - variables['jaw'][1:, :]).square().mean()
reg_smooth_neck = (variables['neck'][:-1, :] - variables['neck'][1:, :]).square().mean()
if not self.config.global_camera:
reg_smooth_principal_point = (
variables['principal_point'][:-1, :] - variables['principal_point'][1:, :]).square().mean()
reg_smooth_focal_length = (
variables['focal_length'][:-1, :] - variables['focal_length'][1:, :]).square().mean()
else:
reg_smooth_principal_point = torch.zeros_like(reg_smooth_jaw)
reg_smooth_focal_length = torch.zeros_like(reg_smooth_jaw)
losses['smooth/exp'] = reg_smooth_exp * self.config.reg_smooth_exp * self.config.reg_smooth_mult
losses['smooth/eyes'] = reg_smooth_eyes * self.config.reg_smooth_eyes * self.config.reg_smooth_mult
losses['smooth/eyelids'] = reg_smooth_eyelids * self.config.reg_smooth_eyelids * self.config.reg_smooth_mult
losses['smooth/jaw'] = reg_smooth_jaw * self.config.reg_smooth_jaw * self.config.reg_smooth_mult
losses['smooth/neck'] = reg_smooth_neck * self.config.reg_smooth_neck * self.config.reg_smooth_mult
losses['smooth/R'] = reg_smooth_R * self.config.reg_smooth_R * self.config.reg_smooth_mult
losses['smooth/t'] = reg_smooth_t * self.config.reg_smooth_t * self.config.reg_smooth_mult
losses['smooth/principal_point'] = reg_smooth_principal_point * self.config.reg_smooth_pp * self.config.reg_smooth_mult
losses['smooth/focal_length'] = reg_smooth_focal_length * self.config.reg_smooth_fl * self.config.reg_smooth_mult
return losses
@torch.compiler.disable
def add_smooth_loss(self, losses, is_joint, p, iters, variables):
if is_joint and self.config.smooth and ((p >= int(iters * 0.15) and (p % 2 == 1)) ): # and p % 2 != 0 and False:
losses = self.actual_smooth(variables, losses)
return losses
def opt_pre(self, is_joint, iters, p, no_lm, image_lmks68, lmk_mask, normal_mask, normal_map, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris):
image_size = [self.config.size, self.config.size]
selected_frames, batch, bs, num_views, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris = self.data_stuff(is_joint, iters, p, image_lmks68, lmk_mask, normal_map, normal_mask, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris)
self.diff_renderer.reset()
losses = {}
exp, eyes, eyelids, _R, _t, jaw, neck, focal_length, principal_point = self.get_vars(is_joint, selected_frames)
variables = {
'exp': exp,
'eyes': eyes,
'eyelids': eyelids,
'R': _R,
't': _t,
'jaw': jaw,
'neck': neck,
'principal_point': principal_point,
'focal_lenght': focal_length,
}
intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False, size=self.config.size)
proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics,
znear=0.1, zfar=5,
width=self.config.size,
height=self.config.size)
for serial in self.cam_pose_nvd.keys():
extr = get_extrinsics(self.R_base[serial], self.t_base[serial])
r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1))
self.r_mvps[serial] = r_mvps
vertices_can, lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame(
cameras=torch.inverse(self.R_base[0]).repeat(bs, 1, 1),
shape_params=self.shape if self.shape.shape[0] == bs else self.shape.repeat(bs, 1).cuda(),
expression_params=exp.repeat_interleave(num_views, dim=0), # .repeat(bs, 1),
eye_pose_params=eyes.repeat_interleave(num_views, dim=0), # .repeat(bs, 1),
jaw_pose_params=jaw.repeat_interleave(num_views, dim=0), # .repeat(bs, 1),
neck_pose_params=neck.repeat_interleave(num_views, dim=0), # .repeat(bs, 1),
eyelid_params=eyelids.repeat_interleave(num_views, dim=0), # .repeat(bs, 1),
rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(_R)))).repeat_interleave(
num_views, dim=0), # .repeat(bs, 1)
)
verts_can_can_mirrored = vertices_can_can[:, self.mirror_order, :]
vertices_can_can_mirrored = torch.zeros_like(verts_can_can_mirrored)
vertices_can_can_mirrored[:, :, 0] = -verts_can_can_mirrored[:, :, 0]
vertices_can_can_mirrored[:, :, 1:] = verts_can_can_mirrored[:, :, 1:]
mirror_loss = (vertices_can_can_mirrored - vertices_can_can).square().sum(-1)
mirror_loss = mirror_loss.mean()
lmk68 = torch.einsum('bny,bxy->bnx', lmk68,
rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave(
num_views, dim=0).unsqueeze(1)
vertices = torch.einsum('bny,bxy->bnx', vertices_can,
rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave(
num_views, dim=0).unsqueeze(1)
vertices_noneck = torch.einsum('bny,bxy->bnx', vertices_noneck,
rotation_6d_to_matrix(_R.repeat_interleave(num_views, dim=0))) + _t.repeat_interleave(
num_views, dim=0).unsqueeze(1)
proj_lmks68 = project_points_screen_space(lmk68, focal_length, principal_point, self.R_base, self.t_base,
size=self.config.size)
proj_vertices = project_points_screen_space(vertices, focal_length, principal_point, self.R_base, self.t_base,
size=self.config.size)
right_eye, left_eye = eyes[:, :6], eyes[:, 6:]
# landmark loss
if not no_lm:
lmk_scale = 1.0 # 0.0001
# Landmarks sparse term
# losses[('loss/lmk_oval')] = util.oval_lmk_loss(proj_lmks68[..., :2], image_lmks68, image_size, lmk_mask) * self.config.w_lmks_oval * lmk_scale
# losses['loss/lmk_68'] = util.lmk_loss(proj_lmks68[:, 17:, :2], image_lmks68[:, 17:, :], image_size, lmk_mask[:, 17:, :]) * self.config.w_lmks * lmk_scale
# if self.config.use_eyebrows:
# losses['loss/lmk_eyebrows'] = util.lmk_loss(proj_lmks68[:, 17:27, :2], image_lmks68[:, 17:27, :], image_size, lmk_mask[:, 17:27, :]) * self.config.w_lmks * lmk_scale * 5.0
losses['loss/lmk_eye2'] = util.lmk_loss(proj_lmks68[:, 36:48, :2], image_lmks68[:, 36:48, :], image_size,
lmk_mask[:, 36:48,
:]) * self.config.w_lmks * lmk_scale * 5 #10 # 0 #2.0 #0.5 #0.0 #100
if self.config.use_mouth_lmk:
losses['loss/lmk_mouth'] = util.lmk_loss(proj_lmks68[:, 48:68, :2], image_lmks68[:, 48:68, :],
image_size,
lmk_mask[:, 48:68, :]) * self.config.w_lmks_mouth * lmk_scale * 0.25
losses['loss/lmk_mouth_closure'] = util.mouth_closure_lmk_loss(proj_lmks68[..., :2], image_lmks68,
image_size,
lmk_mask) * self.config.w_lmks_mouth * lmk_scale * 2.5
losses['loss/lmk_eye'] = util.eye_closure_lmk_loss(proj_lmks68[..., :2], image_lmks68, image_size,
lmk_mask) * self.config.w_lmks_lid * lmk_scale * 500 # 0 #500 #0.0 #10
losses['loss/lmk_iris_left'] = util.lmk_loss(proj_vertices[:, left_iris_flame[:1], ..., :2], left_iris,
image_size,
mask_left_iris) * self.config.w_lmks_iris * lmk_scale * 50.00
losses['loss/lmk_iris_right'] = util.lmk_loss(proj_vertices[:, right_iris_flame[:1], ..., :2], right_iris,
image_size,
mask_right_iris) * self.config.w_lmks_iris * lmk_scale * 50.0
# Reguralizers
losses['reg/exp'] = torch.sum(exp ** 2, dim=-1).mean() * self.config.w_exp
losses['reg/sym'] = torch.sum((right_eye - left_eye) ** 2, dim=-1).mean() * 0.1 # 8.0 #*5.0
losses['reg/jaw'] = torch.sum((I6D - jaw) ** 2, dim=-1).mean() * self.config.w_jaw
losses['reg/neck'] = torch.sum((I6D - neck) ** 2, dim=-1).mean() * self.config.w_neck
# losses['reg/eye_lids'] = torch.sum((eyelids[:, 0] - eyelids[:, 1]) ** 2, dim=-1).mean() * 0.1
losses['reg/eye_left'] = torch.sum((I6D - left_eye) ** 2, dim=-1).mean() * 0.01
losses['reg/eye_right'] = torch.sum((I6D - right_eye) ** 2, dim=-1).mean() * 0.01
losses['reg/shape'] = torch.sum((self.shape - self.mica_shape) ** 2, dim=-1).mean() * self.config.w_shape
losses['reg/shape_general'] = torch.sum((self.shape) ** 2, dim=-1).mean() * self.config.w_shape_general
losses['reg/mirror'] = mirror_loss * 5000
if not (self.config.n_fine and p >= iters // 2):
losses['reg/pp'] = torch.sum(principal_point ** 2, dim=-1).mean()
return batch, losses, vertices, vertices_noneck, vertices_can, vertices_can_can, proj_vertices, proj_lmks68, selected_frames, variables, num_views, normal_mask, normal_map, uv_map, uv_mask
def opt_post(self, variables, ops, proj_vertices, proj_lmks68, batch, is_joint, is_first_step, losses, uv_map, selected_frames, p, iters, num_views, normal_mask, normal_map):
grabbed_depth = ops['actual_rendered_depth'][:, 0,
torch.clamp(proj_vertices[:, :, 1].long(), 0,
self.config.size - 1),
torch.clamp(proj_vertices[:, :, 0].long(), 0,
self.config.size - 1),
][:, 0, :]
is_visible_verts_idx = grabbed_depth < (proj_vertices[:, :, 2] + 1e-2)
if not self.config.occ_filter:
is_visible_verts_idx = torch.ones_like(is_visible_verts_idx)
valid_bg_classes = batch['valid_bg'] # bg-class or neck-class
if self.config.sil_super > 0:
if is_joint or (not is_first_step): # and p > 50 and p < int(iters*0.85): # 100
# losses['loss/sil'] =((1-upper_forehead[:, None, :, :]) * (batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super#0
losses['loss/sil'] = ((valid_bg_classes[:, None, :, :]) * (
batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super # 0
else:
losses['loss/sil'] = ((valid_bg_classes[:, None, :, :]) * (
batch['fg_mask'] - ops['fg_images'])).abs().mean() * self.config.sil_super / 10 # 0
if self.config.uv_map_super: # and p > iters // 2:
gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1)
if self.config.uv_l2:
uv_loss = ((gt_uv - ops['uv_images']) * batch["uv_mask"][:, 0, ...].unsqueeze(-1)).square().mean() * 100
else:
uv_loss = ((gt_uv - ops['uv_images']) * batch["uv_mask"][:, 0, ...].unsqueeze(-1)).abs().mean()
# TODO: outlier filtering!!!
losses['loss/uv_pixel'] = uv_loss * self.config.uv_map_super
if self.config.uv_map_super > 0.0: # and (p < iters // 2 or self.config.keep_uv) and not self.config.no2d_verts:
# uv_loss = get_uv_loss(uv_map, proj_vertices)
if self.uv_loss_fn.gt_2_verts is None:
self.uv_loss_fn.compute_corresp(uv_map, selected_frames=selected_frames)
uv_loss = self.uv_loss_fn.compute_loss(proj_vertices, selected_frames=selected_frames, uv_map=uv_map,
l2_loss=self.config.uv_l2, is_visible_verts_idx=is_visible_verts_idx)
losses['loss/uv'] = uv_loss * self.config.uv_map_super # 000
skip_normals = False
if self.config.n_fine and p < iters // 2:
skip_normals = True
if (self.config.normal_super > 0.0 or self.config.normal_super_can > 0.0) and not skip_normals:
# normal_loss_map = normal_loss_map * dilated_eye_mask[:, 0, ...] * (1 - ops['mask_images_eyes_region'][:, 0, ...])
# use dilated eye mask only
# maybe also applie eyemask in image not rendering
dilated_eye_mask = 1 - (gaussian_blur(ops['mask_images_eyes'],
[self.config.normal_mask_ksize, self.config.normal_mask_ksize],
sigma=[self.config.normal_mask_ksize,
self.config.normal_mask_ksize]) > 0).float()
pred_normals = ops['normal_images'] # 1 3 512 512 normals in world space
rot_mat = rotation_6d_to_matrix(variables["R"].repeat_interleave(num_views, dim=0)) # 1 3 3
pred_normals_flame_space = torch.einsum('bxy,bxhw->byhw', rot_mat, pred_normals)
if normal_map is not None:
l_map = (normal_map - pred_normals_flame_space)
valid = ((l_map.abs().sum(dim=1) / 3) < self.config.delta_n).unsqueeze(1)
normal_loss_map = l_map * valid.float() * normal_mask * dilated_eye_mask
if self.config.normal_l2:
losses['loss/normal'] = normal_loss_map.square().mean() * self.config.normal_super
else:
losses['loss/normal'] = normal_loss_map.abs().mean() * self.config.normal_super
else:
losses['loss/normal'] = 0.0
# smoothness loss
losses = self.add_smooth_loss(losses, is_joint, p, iters, variables)
all_loss = self.reduce_loss(losses)
return all_loss
def optimize_color(self, batch, params_func,
no_lm : bool = False,
save_timestep=0,
is_joint : bool = False,
is_first_step : bool = False,
):
iters = self.config.iters
if not is_joint:
images, landmarks, lmk_mask = self.parse_landmarks(batch)
uv_mask = batch["uv_mask"]
normal_mask = batch["normal_mask"]
normal_map = batch["normals"] if "normals" in batch else None
uv_map = batch["uv_map"] if "uv_map" in batch else None
if uv_map is not None:
uv_map[(1-uv_mask[:, :, :, :]).bool()] = 0
# Optimizer per step
if is_joint:
optimizer = torch.optim.SparseAdam(params_func())
params_global = [
{'params': [self.shape], 'lr': self.config.lr_id * 1.0, 'name': ['shape']}
]
if self.config.global_camera:
params_global.append({'params': [self.focal_length], 'lr': self.config.lr_f * 1.0,
'name': ['camera_params']})
params_global.append({'params': [self.principal_point], 'lr': self.config.lr_pp * 1.0,
'name': ['camera_params']})
optimizer_id = torch.optim.Adam(params_global)
optimizer_id.zero_grad()
else:
optimizer = torch.optim.Adam(params_func())
optimizer.zero_grad()
if not is_joint:
num_views = len(self.R_base.keys())
bs = batch['normals'].shape[0] * num_views
image_lmks68 = landmarks
if landmarks is not None:
left_iris = batch['left_iris']
right_iris = batch['right_iris']
mask_left_iris = batch['mask_left_iris']
mask_right_iris = batch['mask_right_iris']
else:
image_lmks68 = None
lmk_mask, normal_mask, normal_map, uv_map, uv_mask = None, None, None, None, None
left_iris, right_iris, mask_left_iris, mask_right_iris = None, None, None, None
self.diff_renderer.reset()
best_loss = np.inf
n_steps_stagnant = 0
stagnant_window_size = 10
past_k_steps = np.array([100.0 for _ in range(stagnant_window_size)])
iterator = tqdm(range(iters), desc='', leave=True, miniters=100)
for p in iterator:
if is_joint and p == int(iters*0.5):
for pgroup in optimizer.param_groups:
if pgroup['name'] in ['t', 'R', 'jaw']:
pgroup['lr'] = pgroup['lr'] / 10
print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}')
else:
pgroup['lr'] = pgroup['lr'] / 2
if is_joint and p == int(iters *0.75):
for pgroup in optimizer.param_groups:
if pgroup['name'] in ['t', 'R', 'jaw']:
pgroup['lr'] = pgroup['lr'] / 5
print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}')
else:
pgroup['lr'] = pgroup['lr'] / 2
if is_joint and p == int(iters *0.9):
for pgroup in optimizer.param_groups:
if pgroup['name'] in ['t', 'R', 'jaw']:
pgroup['lr'] = pgroup['lr'] / 2
print(f'LR Reduce at iter {p}, for pgroup {pgroup["name"]}')
else:
pgroup['lr'] = pgroup['lr'] / 5
batch_joint, losses, vertices, vertices_noneck, vertices_can, vertices_can_can, proj_vertices, proj_lmks68, selected_frames, variables, num_views, normal_mask, normal_map, uv_map, uv_mask = self.opt_pre(is_joint, iters, p, no_lm, image_lmks68, lmk_mask, normal_mask, normal_map, uv_map, uv_mask, left_iris, right_iris, mask_left_iris, mask_right_iris)
if is_joint:
batch = batch_joint
timestep = 0
ops = self.diff_renderer(vertices, None, None,
self.r_mvps[timestep], self.R_base[timestep], self.t_base[timestep],
texture_observation_mask=self.texture_observation_mask,
verts_can=vertices_can,
verts_noneck=vertices_noneck,
verts_can_can=vertices_can_can,
verts_depth=proj_vertices[:, :, 2:3],
)
all_loss = self.opt_post(variables, ops, proj_vertices, proj_lmks68, batch, is_joint, is_first_step, losses, uv_map, selected_frames, p, iters, num_views, normal_mask, normal_map)
#vertices.retain_grad()
#if not self.init_done:
all_loss.backward()#retain_graph=True)
optimizer.step()
optimizer.zero_grad()
if is_joint:
optimizer_id.step()
optimizer_id.zero_grad()
#if p == 0 or p == iters-1:
#if p == iters-1:# and not self.config.low_overhead and False:
#wandb.log(losses)
self.global_step += 1
loss_color = all_loss.item()
if loss_color < best_loss - 1.0:
best_loss = loss_color
n_steps_stagnant = 0
elif p > 25: # only start counting after n steps
n_steps_stagnant += 1
if p > 0:
past_k_steps[p%stagnant_window_size] = np.abs(all_loss.item() - prev_loss)
prev_loss = all_loss.item()
if (self.frame % 99 == 0 or p < 10) and is_joint:
pass
#with torch.no_grad():
# intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False)
#proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics,
# znear=0.1, zfar=5,
# width=512,
# height=512)
#for serial in self.cam_pose_nvd.keys():
# extr = get_extrinsics(self.R_base[serial], self.t_base[serial])
# r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1))
# self.r_mvps[serial] = r_mvps
#self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.SHAPE_OVERLAY]],
# frame_dst='/debug_joint', save=False, dump_directly=True, timestep=p, selected_frames=selected_frames, is_final=True)
self.frame += 1
iterator.set_description(f'Timestep {save_timestep}; Loss {all_loss.item():.4f}')
#if n_steps_stagnant > 35 and not is_joint:
# print('Early Stopping, go to next frame!')
# #break
if not is_joint and not is_first_step:
if p > stagnant_window_size and np.mean(past_k_steps) < self.config.early_stopping_delta: #3.0: #3.0:
print('Early Stopping, go to next frame!')
#losses['early_stopping'] = past_k_steps
#wandb.log(losses)
#wandb.log({'early_stopping': wandb.Histogram(past_k_steps)})
break
#print('rate of change', np.mean(past_k_steps))
def render_and_save(self, batch,
visualizations=[[View.GROUND_TRUTH, View.LANDMARKS, View.HEATMAP], [View.COLOR_OVERLAY, View.SHAPE_OVERLAY, View.SHAPE]],
frame_dst='/video', save=True, dump_directly=False,
outer_iter = None,
is_camera : bool = False,
all_keyframes : bool = False,
timestep : int = 0,
is_final : bool = False,
selected_frames = None,
):
batch = self.to_cuda(batch)
images, landmarks, _ = self.parse_landmarks(batch)
if 'uv_map' in batch:
uv_map = batch['uv_map']
uv_mask = batch['uv_mask']
uv_map[(1-uv_mask).bool()] = 0
else:
uv_map = None
uv_mask = None
if 'normals' in batch:
normal_map = batch['normals']
else:
normal_map = None
if 'normal_map_can' in batch:
normal_map_can = batch['normal_map_can']
else:
normal_map_can = None
savefolder = self.save_folder + self.actor_name + frame_dst
num_keyframes = 1#1 + len(self.config.keyframes)
with torch.no_grad():
self.diff_renderer.reset()
num_views = len(self.R_base.keys())
bs = batch['normals'].shape[0] * num_keyframes #self.shape.shape[0]
if selected_frames is None:
exp = self.exp
eyes = self.eyes
eyelids = self.eyelids
R = self.R
t = self.t
jaw = self.jaw
neck = self.neck
focal_length = self.focal_length
principal_point = self.principal_point
else:
exp = self.exp(selected_frames)
eyes = self.eyes(selected_frames)
eyelids = self.eyelids(selected_frames)
R = self.R(selected_frames)
t = self.t(selected_frames)
jaw = self.jaw(selected_frames)
neck = self.neck(selected_frames)
if not self.config.global_camera:
focal_length = self.focal_length(selected_frames)
principal_point = self.principal_point(selected_frames)
else:
focal_length = self.focal_length
principal_point = self.principal_point
with torch.no_grad():
intrinsics = get_intrinsics(focal_length, principal_point, use_hack=False, size=self.config.size)
proj_512 = nvdiffrast_util.intrinsics2projection(intrinsics,
znear=0.1, zfar=5,
width=self.config.size,
height=self.config.size)
for serial in self.cam_pose_nvd.keys():
extr = get_extrinsics(self.R_base[serial], self.t_base[serial])
r_mvps = torch.matmul(proj_512, extr.repeat(bs, 1, 1))
self.r_mvps[serial] = r_mvps
vertices_can, _lmk68, lmkMP, vertices_can_can, vertices_noneck = self.flame(
#cameras=torch.inverse(self.R_base[0]),
cameras=torch.inverse(self.R_base[0]).repeat(bs, 1, 1),
shape_params=self.shape.repeat(bs, 1),
expression_params=exp.repeat_interleave(num_views, dim=0), #torch.from_numpy(self.expression_params[:1, :]).cuda().repeat(bs, 1), #self.exp,
eye_pose_params=eyes.repeat_interleave(num_views, dim=0),
#euler_angles_to_matrix(x_opts['rotation'][i], 'XYZ')
jaw_pose_params=jaw.repeat_interleave(num_views, dim=0), #matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[:1, :]).cuda(), 'XYZ')).repeat(bs, 1), #self.jaw,
neck_pose_params=neck.repeat_interleave(num_views, dim=0), #matrix_to_rotation_6d(euler_angles_to_matrix(torch.from_numpy(self.jaw_params[:1, :]).cuda(), 'XYZ')).repeat(bs, 1), #self.jaw,
eyelid_params=eyelids.repeat_interleave(num_views, dim=0),
rot_params_lmk_shift=(matrix_to_rotation_6d(torch.inverse(rotation_6d_to_matrix(R)))).repeat_interleave(num_views, dim=0),
)
lmk68 = torch.einsum('bny,bxy->bnx', _lmk68, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1)
vertices = torch.einsum('bny,bxy->bnx', vertices_can, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1)
vertices_noneck = torch.einsum('bny,bxy->bnx', vertices_noneck, rotation_6d_to_matrix(R.repeat_interleave(num_views, dim=0))) + t.repeat_interleave(num_views, dim=0).unsqueeze(1)
lmk68 = project_points_screen_space(lmk68, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size)
proj_vertices = project_points_screen_space(vertices, focal_length, principal_point, self.R_base, self.t_base, size=self.config.size)
_timestep = 0
ops = self.diff_renderer(vertices, None, None,
self.r_mvps[_timestep], self.R_base[_timestep], self.t_base[_timestep],
verts_can=vertices_can,
verts_noneck=vertices_noneck,
verts_depth=proj_vertices[:, :, 2:3],
is_viz=True
)
# if they asked *only* for the pure shape mask:
if visualizations == [[View.SHAPE]]:
# build your normal‐map preview as before
normals = ops['normal_images'][0].cpu().numpy() # [3,H,W]
normals = (normals + 1.0) / 2.0 # → [0,1]
normals = np.transpose(normals, (1,2,0)) # H×W×3
arr = (normals * 255).clip(0,255).astype(np.uint8)
# --- export the posed mesh, using the correct face indices field ---
os.makedirs(self.mesh_folder, exist_ok=True)
frame_id = str(0).zfill(5)
ply_path = os.path.join(self.mesh_folder, f"{frame_id}.glb")
# pull out the face index tensor
faces_np = self.faces.verts_idx.cpu().numpy()
# `vertices` is your posed mesh: shape (1, V, 3)
verts_np = vertices[0].detach().cpu().numpy()
# 1) build your mesh (this will compute smooth normals automatically)
mesh = trimesh.Trimesh(vertices=verts_np, faces=faces_np)
# 2) fetch those normals: shape is (V,3), each component in [-1,1]
normals = mesh.vertex_normals # (V,3) numpy array
# 3) convert them to RGB in [0,255]:
# (n+1)/2 maps [-1,1]→[0,1], then *255→[0,255]
colors = ((normals + 1.0) * 0.5 * 255.0).astype(np.uint8) # (V,3)
# 4) you need RGBA for many formats—just set alpha=255
alpha = np.full((colors.shape[0],1), 255, dtype=np.uint8)
vertex_colors = np.hstack([colors, alpha]) # (V,4)
# 5) assign those as your mesh’s visual colors
mesh.visual.vertex_colors = vertex_colors
# 6) export—PLY or GLB both support vertex colors
out_path = os.path.join(self.mesh_folder, f"{frame_id}.glb")
mesh.export(out_path)
return arr
mask = (self.parse_mask(ops, batch, visualization=True) > 0).float()
grabbed_depth = ops['actual_rendered_depth'][0, 0,
torch.clamp(proj_vertices[0, :, 1].long(), 0, self.config.size-1),
torch.clamp(proj_vertices[0, :, 0].long(), 0, self.config.size-1),
]
is_visible_verts_idx = grabbed_depth < proj_vertices[0, :, 2] + 1e-2
if not self.config.occ_filter:
is_visible_verts_idx = torch.ones_like(is_visible_verts_idx)
if outer_iter is None:
frame_id = str(self.frame).zfill(5)
else:
frame_id = str(self.frame + 10*outer_iter).zfill(5)
if uv_map is not None and is_final:
# uv losses visualizations
proj_vertices = proj_vertices[:, self.uv_loss_fn.valid_vertex_index, :]
can_uv = torch.from_numpy(np.load(env_paths.FLAME_UV_COORDS)).cuda().unsqueeze(0).float()[:, self.uv_loss_fn.valid_vertex_index, :]
valid_verts_visibility = is_visible_verts_idx[self.uv_loss_fn.valid_vertex_index]
#can_uv[..., 0] = (can_uv[..., 0] * -1) + 1
can_uv[..., 1] = (can_uv[..., 1] * -1) + 1
#can_uv = can_uv[:, ::50, :]
gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1)
gt_uv = gt_uv.reshape(gt_uv.shape[0], -1, 2) # B x n_pixel x 2
can_uv = can_uv.repeat(gt_uv.shape[0], 1, 1)
knn_result = knn_points(can_uv, gt_uv)
pixel_position_width = knn_result.idx % uv_map.shape[-1]
pixel_position_height = knn_result.idx // uv_map.shape[-2]
dists = knn_result.dists.clone()
gt_2_verts = torch.cat([pixel_position_width, pixel_position_height], dim=-1)
pred_normals = ops['normal_images'] # 1 3 512 512 normals in world space
rot_mat = rotation_6d_to_matrix(R.detach().repeat_interleave(num_views, dim=0)) # 1 3 3
pred_normals_flame_space = torch.einsum('bxy,bxhw->byhw', rot_mat, pred_normals)
delta = self.config.uv_loss.delta_uv
catted_uv_rows = []
for b_i in range(images.shape[0]):
empty = images[b_i].detach().cpu().numpy().copy().transpose(1, 2, 0)
is_valid_uv_corresp = (dists[b_i, :, 0] < delta) & valid_verts_visibility
valid_pred_2d = proj_vertices[b_i, is_valid_uv_corresp, :]
valid_gt_2d = gt_2_verts[b_i, is_valid_uv_corresp, :]
pixels_pred = torch.stack(
[
torch.clamp(valid_pred_2d[:, 0], 0, images.shape[-1] - 1),
torch.clamp(valid_pred_2d[:, 1], 0, images.shape[-2] - 1),
], dim=-1
).int()
pixels_gt = torch.stack(
[
torch.clamp(valid_gt_2d[:, 0], 0, images.shape[-1] - 1),
torch.clamp(valid_gt_2d[:, 1], 0, images.shape[-2] - 1),
], dim=-1
).int()
if self.config.draw_uv_corresp:
empty = plot_points(empty, pts=pixels_pred.detach().cpu().numpy(), pts2=pixels_gt.detach().cpu().numpy())
gt_uv = uv_map[:, :2, :, :].permute(0, 2, 3, 1)
upper_forehead = ((uv_map[:, 0, :, :].abs() < 0.85) &
(uv_map[:, 0, :, :].abs() > (1 - 0.85)) &
(uv_map[:, 1, :, :] < 0.35) &
(uv_map[:, 1, :, :] > 0.)).float()
upper_forehead = (gaussian_blur(upper_forehead, [self.config.normal_mask_ksize, self.config.normal_mask_ksize], sigma=[self.config.normal_mask_ksize, self.config.normal_mask_ksize]) > 0).float()
losses_sil = ((1 - upper_forehead[:, None, :, :]) * (batch['fg_mask'] - ops['fg_images'])).abs().permute(0, 2, 3, 1)
uv_loss = ((gt_uv - ops['uv_images']) * ops['mask_images'][:, 0, ...].unsqueeze(-1)).abs()
#catted_uv = torch.cat([gt_uv[b_i], ops['uv_images'][b_i], uv_loss[b_i]], dim=1).detach().cpu().numpy()
catted_uv = torch.cat([losses_sil[b_i][..., :2], uv_loss[b_i]], dim=1).detach().cpu().numpy()
catted_uv_I = np.zeros([catted_uv.shape[0], catted_uv.shape[1], 3])
catted_uv_I[:, :, :2] = catted_uv
catted_uv_I = (catted_uv_I * 255).astype(np.uint8)
shape_mask = ((ops['alpha_images'] * ops['mask_images_mesh']) > 0.).int()[b_i]
shape = (pred_normals_flame_space[b_i]+1)/2 * shape_mask
blend = images[b_i] * (1 - shape_mask) + images[b_i] * shape_mask * 0.3 + shape * 0.7 * shape_mask
to_be_catted = [(images[b_i].cpu().permute(1, 2, 0).numpy()*255).astype(np.uint8),
(blend.permute(1, 2, 0).detach().cpu().numpy()*255).astype(np.uint8),
]
if self.config.draw_uv_corresp:
to_be_catted.append(catted_uv_I)
to_be_catted.append(empty)
catted_uv_I = np.concatenate(to_be_catted, axis=1)
catted_uv_rows.append(catted_uv_I)
if normal_map is None:
catted_uv_I = Image.fromarray(np.concatenate(catted_uv_rows, axis=0))
#pl = pv.Plotter()
#pl.add_mesh(trim)
#pl.add_points(visible_verts)
#pl.show()
else:
catted_uv_I = None
catted_uv_rows = []
if normal_map is not None:
dilated_eye_mask = 1 - (gaussian_blur(ops['mask_images_eyes'], [self.config.normal_mask_ksize, self.config.normal_mask_ksize], sigma=[1, 1]) > 0).float()
l_map = (normal_map - pred_normals_flame_space)
valid = ((l_map.abs().sum(dim=1)/3) < self.config.delta_n).unsqueeze(1)
predicted_normal = ((pred_normals_flame_space.permute(0, 2, 3, 1)[...,
:3] + 1) / 2 * 255).detach().cpu().numpy().astype(np.uint8)
if self.config.draw_uv_corresp:
normal_loss_map = l_map * valid.float() * batch["normal_mask"] * dilated_eye_mask
pseudo_normal = ((normal_map.permute(0, 2, 3, 1) + 1) / 2 * 255).detach().cpu().numpy().astype(
np.uint8)
normal_loss_map = (
(normal_loss_map.abs().permute(0, 2, 3, 1)) / 2 * 255).detach().cpu().numpy().astype(
np.uint8)
catted = np.concatenate([pseudo_normal, predicted_normal, normal_loss_map], axis=2)
else:
catted = predicted_normal
# Image.fromarray(catted).show()
# print('hi')
for b_i in range(catted.shape[0]):
if len(catted_uv_rows) > 0:
catted_uv_rows[b_i] = np.concatenate([catted_uv_rows[b_i], catted[b_i]], axis=1)
else:
catted_uv_rows.append(catted[b_i])
catted_uv_I = Image.fromarray(np.concatenate(catted_uv_rows, axis=0))
#if catted_uv_I is not None:
# save_fodler_uv = f'{savefolder}'
# os.makedirs(save_fodler_uv, exist_ok=True)
# if is_final:
# catted_uv_I.save(f'{save_fodler_uv}/{timestep}.png')
# else:
# catted_uv_I.save(f'{save_fodler_uv}/{self.frame}.png')
if not save:
return
# CHECKPOINT
self.save_checkpoint(timestep, selected_frames=selected_frames)
return catted_uv_I
def parse_landmarks(self, batch):
images = batch['rgb']
if 'lmk' in batch:
landmarks = batch['lmk']
lmk68 = landmarks[:, WFLW_2_iBUG68, :]
lmk_mask = ~(lmk68.sum(2, keepdim=True) == 0)
batch['left_iris'] = landmarks[:, 96:97, :]
batch['right_iris'] = landmarks[:, 97:98, :]
batch['mask_left_iris'] = ~(landmarks.sum(2, keepdim=True) == 0)[:, 96:97, :]
batch['mask_right_iris'] = ~(landmarks.sum(2, keepdim=True) == 0)[:, 97:98, :]
landmarks = lmk68
else:
landmarks = lmk_mask = None
return images, landmarks, lmk_mask,
def read_data(self, timestep):
DATA_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}'
P3DMM_FOLDER = f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm/'
try:
rgb = np.array(Image.open(f'{DATA_FOLDER}/cropped/{timestep:05d}.jpg').resize((self.config.size, self.config.size))) / 255
except Exception as ex:
rgb = np.array(Image.open(f'{DATA_FOLDER}/cropped/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255
mica_folder = f'{DATA_FOLDER}/mica'
mica_files = os.listdir(mica_folder)
mica_shapes = []
for mica_file in mica_files:
mica_shape = np.load(f'{mica_folder}/{mica_file}/identity.npy')
mica_shapes.append(np.squeeze(mica_shape))
mica_shapes = np.stack(mica_shapes, axis=0)
if self.config.early_exit:
mica_shape = mica_shapes[0, :]
else:
mica_shape = np.mean(mica_shapes, axis=0)
seg = np.array(Image.open(f'{DATA_FOLDER}/seg_og/{timestep:05d}.png').resize((self.config.size, self.config.size), Image.NEAREST))
if len(seg.shape) == 3:
seg = seg[..., 0]
uv_mask = ((seg == 2) | (seg == 6) | (seg == 7) |
(seg == 10) | (seg == 12) | (seg == 13) |
(seg==1) | # neck
(seg == 4) | (seg==5) # ears
)
normal_mask = ((seg == 2) | (seg == 6) | (seg == 7) |
(seg == 10) | (seg == 12) | (seg == 13)
) | (seg == 11) # mouth interior
if self.config.big_normal_mask:
normal_mask = normal_mask | (seg==1) | (seg == 4) | (seg==5) # add neck and ears
fg_mask = ((seg == 2) | (seg == 6) | (seg == 7) | (seg == 8) | (seg == 9) | #(seg == 4) | (seg == 5) |
(seg == 10) | (seg == 12) | (seg == 13)
)
valid_bg = seg <= 1
try:
normals = ((np.array(Image.open(f'{P3DMM_FOLDER}/normals/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype(np.float32) - 0.5 )*2
uv_map = (np.array(Image.open(f'{P3DMM_FOLDER}/uv_map/{timestep:05d}png').resize((self.config.size, self.config.size))) / 255).astype(np.float32)
except Exception as ex:
normals = ((np.array(Image.open(f'{P3DMM_FOLDER}/normals/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype(
np.float32) - 0.5) * 2
uv_map = (np.array(Image.open(f'{P3DMM_FOLDER}/uv_map/{timestep:05d}.png').resize((self.config.size, self.config.size))) / 255).astype(np.float32)
try:
lms = np.load(f'{DATA_FOLDER}/PIPnet_landmarks/{timestep:05d}.npy') * self.config.size
except Exception as ex:
lms = np.zeros([98, 2])
ret_dict = {
'rgb': rgb,
'mica_shape': mica_shape,
'normals': normals,
'uv_map': uv_map,
'uv_mask': uv_mask,
'normal_mask': normal_mask,
'fg_mask': fg_mask,
'valid_bg': valid_bg,
}
if lms is not None:
ret_dict['lmk'] = lms
ret_dict = {k: torch.from_numpy(v).float().unsqueeze(0).cuda() for k,v in ret_dict.items()}
ret_dict['uv_mask'] = ret_dict['uv_mask'][:, :, :, None].repeat(1, 1, 1, 3)
ret_dict['normal_mask'] = ret_dict['normal_mask'][:, :, :, None].repeat(1, 1, 1, 3)
ret_dict['fg_mask'] = ret_dict['fg_mask'][:, :, :, None].repeat(1, 1, 1, 3)
channels_first =['rgb', 'uv_mask', 'normal_mask', 'normals', 'uv_map', 'fg_mask']
for k in channels_first:
ret_dict[k] = ret_dict[k].permute(0, 3, 1, 2)
return ret_dict
def prepare_global_optimization(self, N_FRAMES):
is_sparse=True
self.exp = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=100, sparse=is_sparse, ).cuda()
self.R = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda()
self.t = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=3, sparse=is_sparse).cuda()
self.eyes = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=12, sparse=is_sparse).cuda()
self.eyelids = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=12, sparse=is_sparse).cuda()
self.jaw = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda()
self.neck = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=6, sparse=is_sparse).cuda()
if not self.config.global_camera:
self.focal_length = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=1, sparse=is_sparse).cuda()
self.principal_point = nn.Embedding(num_embeddings=N_FRAMES, embedding_dim=2, sparse=is_sparse).cuda()
exp = torch.cat(self.intermediate_exprs, dim=0)
R = torch.cat(self.intermediate_Rs, dim=0)
t = torch.cat(self.intermediate_ts, dim=0)
eyes = torch.cat(self.intermediate_eyes, dim=0)
eyelids = torch.cat(self.intermediate_eyelids, dim=0)
jaw = torch.cat(self.intermediate_jaws, dim=0)
neck = torch.cat(self.intermediate_necks, dim=0)
if not self.config.global_camera:
focal_length = torch.cat(self.intermediate_fls, dim=0)
principal_point = torch.cat(self.intermediate_pps, dim=0)
with torch.no_grad():
self.exp.weight = torch.nn.Parameter(exp)
self.R.weight = torch.nn.Parameter(R)
self.t.weight = torch.nn.Parameter(t)
self.eyes.weight = torch.nn.Parameter(eyes)
self.eyelids.weight = torch.nn.Parameter(eyelids)
self.jaw.weight = torch.nn.Parameter(jaw)
self.neck.weight = torch.nn.Parameter(neck)
if not self.config.global_camera:
self.focal_length.weight = torch.nn.Parameter(focal_length)
self.principal_point.weight = torch.nn.Parameter(principal_point)
def run(self):
timestep = self.config.start_frame
batch = self.read_data(timestep=timestep)
# Important to initialize
self.create_parameters(0, batch['mica_shape'])
self.frame = 0
print('''
<<<<<<<< STARTING ONLINE TRACKING PHASE >>>>>>>>
''')
for timestep in range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP):
batch = self.read_data(timestep=timestep)
for k in batch.keys():
if k not in self.cached_data:
self.cached_data[k] = [batch[k]]
else:
self.cached_data[k].append(batch[k])
if timestep == self.config.start_frame:
self.optimize_camera(batch, steps=500, is_first_frame=True)
params = lambda: self.clone_params_keyframes_all(freeze_id=False, freeze_cam=self.config.global_camera, include_neck=self.config.include_neck)
is_first_step = True
else:
if self.config.extra_cam_steps:
self.optimize_camera(batch, steps=10, is_first_frame=False)
params = lambda: self.clone_params_keyframes_all(freeze_id=True, freeze_cam=self.config.global_camera, include_neck=self.config.include_neck)
is_first_step = False
self.optimize_color(batch, params,
no_lm=self.no_lm,
save_timestep=timestep,
is_first_step=is_first_step
)
self.uv_loss_fn.is_next()
#self.checkpoint(batch, visualizations=[[View.GROUND_TRUTH, View.COLOR_OVERLAY, View.LANDMARKS, View.SHAPE]], frame_dst='/initialization', outer_iter=0, timestep=timestep, is_final=True, save=True)
self.frame += 1
# save results for global optimization later
self.intermediate_exprs.append(self.exp.detach().clone())
self.intermediate_Rs.append(self.R.detach().clone())
self.intermediate_ts.append(self.t.detach().clone())
self.intermediate_eyes.append(self.eyes.detach().clone())
self.intermediate_eyelids.append(self.eyelids.detach().clone())
self.intermediate_jaws.append(self.jaw.detach().clone())
self.intermediate_necks.append(self.neck.detach().clone())
if not self.config.global_camera:
self.intermediate_fls.append(self.focal_length.detach().clone())
self.intermediate_pps.append(self.principal_point.detach().clone())
if self.config.early_exit:
exit()
for k in self.cached_data.keys():
self.cached_data[k] = torch.cat(self.cached_data[k], dim=0)
params = lambda: self.clone_params_keyframes_all_joint(freeze_id=False, is_joint=True, include_neck=self.config.include_neck)
if self.config.uv_map_super > 0.0:
self.uv_loss_fn.finish_stage1()
self.config.iters = self.config.global_iters #self.config.iters * 10
N_FRAMES = len(self.intermediate_exprs)
#build optimization targets for global optimization, implement as sparse torch.Embedding
self.prepare_global_optimization(N_FRAMES=N_FRAMES)
if COMPILE:
self.flame = torch.compile(self.flame)
self.opt_pre = torch.compile(self.opt_pre)
self.opt_post = torch.compile(self.opt_post)
print('''
<<<<<<<< STARTING GLOBAL TRACKING PHASE >>>>>>>>
''')
if N_FRAMES > 1:
self.optimize_color(None, params,
no_lm=self.no_lm,
save_timestep=1000, #timestep,
is_joint=True,
)
# render result and save it as a video to get some viusal feedback
video_frames = []
for it, timestep in enumerate(range(self.config.start_frame, self.MAX_STEPS + self.config.start_frame, self.FRAME_SKIP)):
selected_frames = []
selected_frames_loading = []
batches = []
batch = self.read_data(timestep=timestep)
batches.append(batch)
selected_frames.append(it)
selected_frames_loading.append(timestep)
batches = {k: torch.cat([x[k] for x in batches], dim=0) for k in batch.keys()}
selected_frames = torch.from_numpy(np.array(selected_frames)).long().cuda()
result_rendering = self.render_and_save(batch,
visualizations=[[View.SHAPE]], # ← only mesh by default
frame_dst='/video', save=True, dump_directly=False, outer_iter=0, timestep=timestep, is_final=True, selected_frames=selected_frames)
video_frames.append(np.array(result_rendering))
self.frame += 1
out_dir = f"{self.save_folder}/{self.config.video_name}/frames"
os.makedirs(out_dir, exist_ok=True)
for i, frame in enumerate(video_frames):
# If float in [0,1], convert:
if frame.dtype != np.uint8:
frame_uint8 = (frame * 255).astype(np.uint8)
else:
frame_uint8 = frame
# OpenCV expects BGR ordering:
bgr = cv2.cvtColor(frame_uint8, cv2.COLOR_RGB2BGR)
cv2.imwrite(os.path.join(out_dir, f"{i:05d}.jpg"), bgr)
print(f"✅ Saved {len(video_frames)} frames to `{out_dir}`")
# Optionally delete all preoprocessing artifacts, once tracking is done (only keep cropped images)
if self.config.delete_preprocessing:
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/mica')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm_wGT')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/p3dmm_extraViz')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/pipnet')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/PIPnet_annotated_images')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/PIPnet_landmarks')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/rgb')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/seg_non_crop_annotations')
shutil.rmtree(f'{env_paths.PREPROCESSED_DATA}/{self.config.video_name}/seg_og')
print(f'''
<<<<<<<< DONE WITH TRACKING {self.actor_name} >>>>>>>>
''')