Spaces:
Runtime error
Runtime error
| # | |
| # Toyota Motor Europe NV/SA and its affiliated companies retain all intellectual | |
| # property and proprietary rights in and to this software and related documentation. | |
| # Any commercial use, reproduction, disclosure or distribution of this software and | |
| # related documentation without an express license agreement from Toyota Motor Europe NV/SA | |
| # is strictly prohibited. | |
| # | |
| from vhap.config.base import import_module, PhotometricStageConfig, BaseTrackingConfig | |
| from vhap.model.flame import FlameHead, FlameTexPCA, FlameTexPainted, FlameUvMask | |
| from vhap.model.lbs import batch_rodrigues | |
| from vhap.util.mesh import ( | |
| get_mtl_content, | |
| get_obj_content, | |
| normalize_image_points, | |
| ) | |
| from vhap.util.log import get_logger | |
| from vhap.util.visualization import plot_landmarks_2d | |
| from torch.utils.tensorboard import SummaryWriter | |
| import torch | |
| import torchvision | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import numpy as np | |
| from matplotlib import cm | |
| from typing import Literal | |
| from functools import partial | |
| import tyro | |
| import yaml | |
| from datetime import datetime | |
| import threading | |
| from typing import Optional | |
| from collections import defaultdict | |
| from copy import deepcopy | |
| import time | |
| import os | |
| class FlameTracker: | |
| def __init__(self, cfg: BaseTrackingConfig): | |
| self.cfg = cfg | |
| self.device = cfg.device | |
| self.tb_writer = None | |
| # model | |
| self.flame = FlameHead( | |
| cfg.model.n_shape, | |
| cfg.model.n_expr, | |
| add_teeth=cfg.model.add_teeth, | |
| remove_lip_inside=cfg.model.remove_lip_inside, | |
| face_clusters=cfg.model.tex_clusters, | |
| ).to(self.device) | |
| if cfg.model.tex_painted: | |
| self.flame_tex_painted = FlameTexPainted(tex_size=cfg.model.tex_resolution).to(self.device) | |
| else: | |
| self.flame_tex_pca = FlameTexPCA(cfg.model.n_tex, tex_size=cfg.model.tex_resolution).to(self.device) | |
| self.flame_uvmask = FlameUvMask().to(self.device) | |
| # renderer for visualization, dense photometric energy | |
| if self.cfg.render.backend == 'nvdiffrast': | |
| from vhap.util.render_nvdiffrast import NVDiffRenderer | |
| self.render = NVDiffRenderer( | |
| use_opengl=self.cfg.render.use_opengl, | |
| lighting_type=self.cfg.render.lighting_type, | |
| lighting_space=self.cfg.render.lighting_space, | |
| disturb_rate_fg=self.cfg.render.disturb_rate_fg, | |
| disturb_rate_bg=self.cfg.render.disturb_rate_bg, | |
| fid2cid=self.flame.mask.fid2cid, | |
| ) | |
| elif self.cfg.render.backend == 'pytorch3d': | |
| from vhap.util.render_pytorch3d import PyTorch3DRenderer | |
| self.render = PyTorch3DRenderer() | |
| else: | |
| raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}") | |
| def load_from_tracked_flame_params(self, fp): | |
| """ | |
| loads checkpoint from tracked_flame_params file. Counterpart to save_result() | |
| :param fp: | |
| :return: | |
| """ | |
| report = np.load(fp) | |
| # LOADING PARAMETERS | |
| def load_param(param, ckpt_array): | |
| param.data[:] = torch.from_numpy(ckpt_array).to(param.device) | |
| def load_param_list(param_list, ckpt_array): | |
| for i in range(min(len(param_list), len(ckpt_array))): | |
| load_param(param_list[i], ckpt_array[i]) | |
| load_param_list(self.rotation, report["rotation"]) | |
| load_param_list(self.translation, report["translation"]) | |
| load_param_list(self.neck_pose, report["neck_pose"]) | |
| load_param_list(self.jaw_pose, report["jaw_pose"]) | |
| load_param_list(self.eyes_pose, report["eyes_pose"]) | |
| load_param(self.shape, report["shape"]) | |
| load_param_list(self.expr, report["expr"]) | |
| load_param(self.lights, report["lights"]) | |
| # self.frame_idx = report["n_processed_frames"] | |
| if not self.calibrated: | |
| load_param(self.focal_length, report["focal_length"]) | |
| if not self.cfg.model.tex_painted: | |
| if "tex" in report: | |
| load_param(self.tex_pca, report["tex"]) | |
| else: | |
| self.logger.warn("No tex_extra found in flame_params!") | |
| if self.cfg.model.tex_extra: | |
| if "tex_extra" in report: | |
| load_param(self.tex_extra, report["tex_extra"]) | |
| else: | |
| self.logger.warn("No tex_extra found in flame_params!") | |
| if self.cfg.model.use_static_offset: | |
| if "static_offset" in report: | |
| load_param(self.static_offset, report["static_offset"]) | |
| else: | |
| self.logger.warn("No static_offset found in flame_params!") | |
| if self.cfg.model.use_dynamic_offset: | |
| if "dynamic_offset" in report: | |
| load_param_list(self.dynamic_offset, report["dynamic_offset"]) | |
| else: | |
| self.logger.warn("No dynamic_offset found in flame_params!") | |
| def trimmed_decays(self, is_init): | |
| decays = {} | |
| for k, v in self.decays.items(): | |
| if is_init and "init" in k or not is_init and "init" not in k: | |
| decays[k.replace("_init", "")] = v | |
| return decays | |
| def clear_cache(self): | |
| self.render.clear_cache() | |
| def get_current_frame(self, frame_idx, include_keyframes=False): | |
| """ | |
| Creates a single item batch from the frame data at index frame_idx in the dataset. | |
| If include_keyframes option is set, keyframe data will be appended to the batch. However, | |
| it is guaranteed that the frame data belonging to frame_idx is at position 0 | |
| :param frame_idx: | |
| :return: | |
| """ | |
| indices = [frame_idx] | |
| if include_keyframes: | |
| indices += self.cfg.exp.keyframes | |
| samples = [] | |
| for idx in indices: | |
| sample = self.dataset.getitem_by_timestep(idx) | |
| # sample["timestep_index"] = idx | |
| # for k, v in sample.items(): | |
| # if isinstance(v, torch.Tensor): | |
| # sample[k] = v[None, ...].to(self.device) | |
| samples.append(sample) | |
| # if also keyframes have been loaded, stack all data | |
| sample = {} | |
| for k, v in samples[0].items(): | |
| values = [s[k] for s in samples] | |
| if isinstance(v, torch.Tensor): | |
| values = torch.cat(values, dim=0) | |
| sample[k] = values | |
| if "lmk2d_iris" in sample: | |
| sample["lmk2d"] = torch.cat([sample["lmk2d"], sample["lmk2d_iris"]], dim=1) | |
| return sample | |
| def fill_cam_params_into_sample(self, sample): | |
| """ | |
| Adds intrinsics and extrinics to sample, if data is not calibrated | |
| """ | |
| if self.calibrated: | |
| assert "intrinsic" in sample | |
| assert "extrinsic" in sample | |
| else: | |
| b, _, h, w = sample["rgb"].shape | |
| # K = torch.eye(3, 3).to(self.device) | |
| # denormalize cam params | |
| f = self.focal_length * max(h, w) | |
| cx, cy = torch.tensor([[0.5*w], [0.5*h]]).to(f) | |
| sample["intrinsic"] = torch.stack([f, f, cx, cy], dim=1) | |
| sample["extrinsic"] = self.RT[None, ...].expand(b, -1, -1) | |
| def configure_optimizer(self, params, lr_scale=1.0): | |
| """ | |
| Creates optimizer for the given set of parameters | |
| :param params: | |
| :return: | |
| """ | |
| # copy dict because we will call 'pop' | |
| params = params.copy() | |
| param_groups = [] | |
| default_lr = self.cfg.lr.base | |
| # dict map group name to param dict keys | |
| group_def = { | |
| "translation": ["translation"], | |
| "expr": ["expr"], | |
| "light": ["lights"], | |
| } | |
| if not self.calibrated: | |
| group_def ["cam"] = ["cam"] | |
| if self.cfg.model.use_static_offset: | |
| group_def ["static_offset"] = ["static_offset"] | |
| if self.cfg.model.use_dynamic_offset: | |
| group_def ["dynamic_offset"] = ["dynamic_offset"] | |
| # dict map group name to lr | |
| group_lr = { | |
| "translation": self.cfg.lr.translation, | |
| "expr": self.cfg.lr.expr, | |
| "light": self.cfg.lr.light, | |
| } | |
| if not self.calibrated: | |
| group_lr["cam"] = self.cfg.lr.camera | |
| if self.cfg.model.use_static_offset: | |
| group_lr["static_offset"] = self.cfg.lr.static_offset | |
| if self.cfg.model.use_dynamic_offset: | |
| group_lr["dynamic_offset"] = self.cfg.lr.dynamic_offset | |
| for group_name, param_keys in group_def.items(): | |
| selected = [] | |
| for p in param_keys: | |
| if p in params: | |
| selected += params.pop(p) | |
| if len(selected) > 0: | |
| param_groups.append({"params": selected, "lr": group_lr[group_name] * lr_scale}) | |
| # create default group with remaining params | |
| selected = [] | |
| for _, v in params.items(): | |
| selected += v | |
| param_groups.append({"params": selected}) | |
| optim = torch.optim.Adam(param_groups, lr=default_lr * lr_scale) | |
| return optim | |
| def initialize_frame(self, frame_idx): | |
| """ | |
| Initializes parameters of frame frame_idx | |
| :param frame_idx: | |
| :return: | |
| """ | |
| if frame_idx > 0: | |
| self.initialize_from_previous(frame_idx) | |
| def initialize_from_previous(self, frame_idx): | |
| """ | |
| Initializes the flame parameters with the optimized ones from the previous frame | |
| :param frame_idx: | |
| :return: | |
| """ | |
| if frame_idx == 0: | |
| return | |
| param_list = [ | |
| self.expr, | |
| self.neck_pose, | |
| self.jaw_pose, | |
| self.translation, | |
| self.rotation, | |
| self.eyes_pose, | |
| ] | |
| for param in param_list: | |
| param[frame_idx].data = param[frame_idx - 1].detach().clone().data | |
| def select_frame_indices(self, frame_idx, include_keyframes): | |
| indices = [frame_idx] | |
| if include_keyframes: | |
| indices += self.cfg.exp.keyframes | |
| return indices | |
| def forward_flame(self, frame_idx, include_keyframes): | |
| """ | |
| Evaluates the flame model using the given parameters | |
| :param flame_params: | |
| :return: | |
| """ | |
| indices = self.select_frame_indices(frame_idx, include_keyframes) | |
| dynamic_offset = self.to_batch(self.dynamic_offset, indices) if self.cfg.model.use_dynamic_offset else None | |
| ret = self.flame( | |
| self.shape[None, ...].expand(len(indices), -1), | |
| self.to_batch(self.expr, indices), | |
| self.to_batch(self.rotation, indices), | |
| self.to_batch(self.neck_pose, indices), | |
| self.to_batch(self.jaw_pose, indices), | |
| self.to_batch(self.eyes_pose, indices), | |
| self.to_batch(self.translation, indices), | |
| return_verts_cano=True, | |
| static_offset=self.static_offset, | |
| dynamic_offset=dynamic_offset, | |
| ) | |
| verts, verts_cano, lmks = ret[0], ret[1], ret[2] | |
| albedos = self.get_albedo().expand(len(indices), -1, -1, -1) | |
| return verts, verts_cano, lmks, albedos | |
| def get_base_texture(self): | |
| if self.cfg.model.tex_extra and not self.cfg.model.residual_tex: | |
| albedos_base = self.tex_extra[None, ...] | |
| else: | |
| if self.cfg.model.tex_painted: | |
| albedos_base = self.flame_tex_painted() | |
| else: | |
| albedos_base = self.flame_tex_pca(self.tex_pca[None, :]) | |
| return albedos_base | |
| def get_albedo(self): | |
| albedos_base = self.get_base_texture() | |
| if self.cfg.model.tex_extra and self.cfg.model.residual_tex: | |
| albedos_res = self.tex_extra[None, :] | |
| if albedos_base.shape[-1] != albedos_res.shape[-1] or albedos_base.shape[-2] != albedos_res.shape[-2]: | |
| albedos_base = F.interpolate(albedos_base, albedos_res.shape[-2:], mode='bilinear') | |
| albedos = albedos_base + albedos_res | |
| else: | |
| albedos = albedos_base | |
| return albedos | |
| def rasterize_flame( | |
| self, sample, verts, faces, camera_index=None, train_mode=False | |
| ): | |
| """ | |
| Rasterizes the flame head mesh | |
| :param verts: | |
| :param albedos: | |
| :param K: | |
| :param RT: | |
| :param resolution: | |
| :param use_cache: | |
| :return: | |
| """ | |
| # cameras parameters | |
| K = sample["intrinsic"].clone().to(self.device) | |
| RT = sample["extrinsic"].to(self.device) | |
| if camera_index is not None: | |
| K = K[[camera_index]] | |
| RT = RT[[camera_index]] | |
| H, W = self.image_size | |
| image_size = H, W | |
| # rasterize fragments | |
| rast_dict = self.render.rasterize(verts, faces, RT, K, image_size, False, train_mode) | |
| return rast_dict | |
| def get_background_color(self, gt_rgb, gt_alpha, stage): | |
| if stage is None: # when stage is None, it means we are in the evaluation mode | |
| background = self.cfg.render.background_eval | |
| else: | |
| background = self.cfg.render.background_train | |
| if background == 'target': | |
| """use gt_rgb as background""" | |
| color = gt_rgb.permute(0, 2, 3, 1) | |
| elif background == 'white': | |
| color = [1, 1, 1] | |
| elif background == 'black': | |
| color = [0, 0, 0] | |
| else: | |
| raise NotImplementedError(f"Unknown background mode: {background}") | |
| return color | |
| def render_rgba( | |
| self, rast_dict, verts, faces, albedos, lights, background_color=[1, 1, 1], | |
| align_texture_except_fid=None, align_boundary_except_vid=None, enable_disturbance=False, | |
| ): | |
| """ | |
| Renders the rgba image from the rasterization result and | |
| the optimized texture + lights | |
| """ | |
| faces_uv = self.flame.textures_idx | |
| if self.cfg.render.backend == 'nvdiffrast': | |
| verts_uv = self.flame.verts_uvs.clone() | |
| verts_uv[:, 1] = 1 - verts_uv[:, 1] | |
| tex = albedos | |
| render_out = self.render.render_rgba( | |
| rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color, | |
| align_texture_except_fid, align_boundary_except_vid, enable_disturbance | |
| ) | |
| render_out = {k: v.permute(0, 3, 1, 2) for k, v in render_out.items()} | |
| elif self.cfg.render.backend == 'pytorch3d': | |
| B = verts.shape[0] # TODO: double check | |
| verts_uv = self.flame.face_uvcoords.repeat(B, 1, 1) | |
| tex = albedos.expand(B, -1, -1, -1) | |
| rgba = self.render.render_rgba( | |
| rast_dict, verts, faces, verts_uv, faces_uv, tex, lights, background_color | |
| ) | |
| render_out = {'rgba': rgba.permute(0, 3, 1, 2)} | |
| else: | |
| raise NotImplementedError(f"Unknown renderer backend: {self.cfg.render.backend}") | |
| return render_out | |
| def render_normal(self, rast_dict, verts, faces): | |
| """ | |
| Renders the rgba image from the rasterization result and | |
| the optimized texture + lights | |
| """ | |
| uv_coords = self.flame.face_uvcoords | |
| uv_coords = uv_coords.repeat(verts.shape[0], 1, 1) | |
| return self.render.render_normal(rast_dict, verts, faces, uv_coords) | |
| def compute_lmk_energy(self, sample, pred_lmks, disable_jawline_landmarks=False): | |
| """ | |
| Computes the landmark energy loss term between groundtruth landmarks and flame landmarks | |
| :param sample: | |
| :param pred_lmks: | |
| :return: the lmk loss for all 68 facial landmarks, a separate 2 pupil landmark loss and | |
| a relative eye close term | |
| """ | |
| img_size = sample["rgb"].shape[-2:] | |
| # ground-truth landmark | |
| lmk2d = sample["lmk2d"].clone().to(pred_lmks) | |
| lmk2d, confidence = lmk2d[:, :, :2], lmk2d[:, :, 2] | |
| lmk2d[:, :, 0], lmk2d[:, :, 1] = normalize_image_points( | |
| lmk2d[:, :, 0], lmk2d[:, :, 1], img_size | |
| ) | |
| # predicted landmark | |
| K = sample["intrinsic"].to(self.device) | |
| RT = sample["extrinsic"].to(self.device) | |
| pred_lmk_ndc = self.render.world_to_ndc(pred_lmks, RT, K, img_size, flip_y=True) | |
| pred_lmk2d = pred_lmk_ndc[:, :, :2] | |
| if (lmk2d.shape[1] == 70): | |
| diff = lmk2d - pred_lmk2d | |
| confidence = confidence[:, :70] | |
| # eyes weighting | |
| confidence[:, 68:] = confidence[:, 68:] * 2 | |
| else: | |
| diff = lmk2d[:, :68] - pred_lmk2d[:, :68] | |
| confidence = confidence[:, :68] | |
| # compute general landmark term | |
| lmk_loss = torch.norm(diff, dim=2, p=1) * confidence | |
| result_dict = { | |
| "gt_lmk2d": lmk2d, | |
| "pred_lmk2d": pred_lmk2d, | |
| } | |
| return lmk_loss.mean(), result_dict | |
| def compute_photometric_energy( | |
| self, | |
| sample, | |
| verts, | |
| faces, | |
| albedos, | |
| rast_dict, | |
| step_i=None, | |
| stage=None, | |
| include_keyframes=False, | |
| ): | |
| """ | |
| Computes the dense photometric energy | |
| :param sample: | |
| :param vertices: | |
| :param albedos: | |
| :return: | |
| """ | |
| gt_rgb = sample["rgb"].to(verts) | |
| if "alpha" in sample: | |
| gt_alpha = sample["alpha_map"].to(verts) | |
| else: | |
| gt_alpha = None | |
| lights = self.lights[None] if self.lights is not None else None | |
| bg_color = self.get_background_color(gt_rgb, gt_alpha, stage) | |
| align_texture_except_fid = self.flame.mask.get_fid_by_region( | |
| self.cfg.pipeline[stage].align_texture_except | |
| ) if stage is not None else None | |
| align_boundary_except_vid = self.flame.mask.get_vid_by_region( | |
| self.cfg.pipeline[stage].align_boundary_except | |
| ) if stage is not None else None | |
| render_out = self.render_rgba( | |
| rast_dict, verts, faces, albedos, lights, bg_color, | |
| align_texture_except_fid, align_boundary_except_vid, | |
| enable_disturbance=stage!=None, | |
| ) | |
| pred_rgb = render_out['rgba'][:, :3] | |
| pred_alpha = render_out['rgba'][:, 3:] | |
| pred_mask = render_out['rgba'][:, [3]].detach() > 0 | |
| pred_mask = pred_mask.expand(-1, 3, -1, -1) | |
| results_dict = render_out | |
| # ---- rgb loss ---- | |
| error_rgb = gt_rgb - pred_rgb | |
| color_loss = error_rgb.abs().sum() / pred_mask.detach().sum() | |
| results_dict.update( | |
| { | |
| "gt_rgb": gt_rgb, | |
| "pred_rgb": pred_rgb, | |
| "error_rgb": error_rgb, | |
| "pred_alpha": pred_alpha, | |
| } | |
| ) | |
| # ---- silhouette loss ---- | |
| # error_alpha = gt_alpha - pred_alpha | |
| # mask_loss = error_alpha.abs().sum() | |
| # results_dict.update( | |
| # { | |
| # "gt_alpha": gt_alpha, | |
| # "error_alpha": error_alpha, | |
| # } | |
| # ) | |
| # ---- background loss ---- | |
| # bg_mask = gt_alpha < 0.5 | |
| # error_alpha = gt_alpha - pred_alpha | |
| # error_alpha = torch.where(bg_mask, error_alpha, torch.zeros_like(error_alpha)) | |
| # mask_loss = error_alpha.abs().sum() / bg_mask.sum() | |
| # results_dict.update( | |
| # { | |
| # "gt_alpha": gt_alpha, | |
| # "error_alpha": error_alpha, | |
| # } | |
| # ) | |
| # -------- | |
| # photo_loss = color_loss + mask_loss | |
| photo_loss = color_loss | |
| # photo_loss = mask_loss | |
| return photo_loss, results_dict | |
| def compute_regularization_energy(self, result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage): | |
| """ | |
| Computes the energy term that penalizes strong deviations from the flame base model | |
| """ | |
| log_dict = {} | |
| std_tex = 1 | |
| std_expr = 1 | |
| std_shape = 1 | |
| indices = self.select_frame_indices(frame_idx, include_keyframes) | |
| # pose smoothness term | |
| if self.opt_dict['pose'] and 'tracking' in stage: | |
| E_pose_smooth = self.compute_pose_smooth_energy(frame_idx, stage=='global_tracking') | |
| log_dict["pose_smooth"] = E_pose_smooth | |
| # joint regularization term | |
| if self.opt_dict['joints']: | |
| if 'tracking' in stage: | |
| joint_smooth = self.compute_joint_smooth_energy(frame_idx, stage=='global_tracking') | |
| log_dict["joint_smooth"] = joint_smooth | |
| joint_prior = self.compute_joint_prior_energy(frame_idx) | |
| log_dict["joint_prior"] = joint_prior | |
| # expression regularization | |
| if self.opt_dict['expr']: | |
| expr = self.to_batch(self.expr, indices) | |
| reg_expr = (expr / std_expr) ** 2 | |
| log_dict["reg_expr"] = self.cfg.w.reg_expr * reg_expr.mean() | |
| # shape regularization | |
| if self.opt_dict['shape']: | |
| reg_shape = (self.shape / std_shape) ** 2 | |
| log_dict["reg_shape"] = self.cfg.w.reg_shape * reg_shape.mean() | |
| # texture regularization | |
| if self.opt_dict['texture']: | |
| # texture space | |
| if not self.cfg.model.tex_painted: | |
| reg_tex_pca = (self.tex_pca / std_tex) ** 2 | |
| log_dict["reg_tex_pca"] = self.cfg.w.reg_tex_pca * reg_tex_pca.mean() | |
| # texture map | |
| if self.cfg.model.tex_extra: | |
| if self.cfg.model.residual_tex: | |
| if self.cfg.w.reg_tex_res is not None: | |
| reg_tex_res = self.tex_extra ** 2 | |
| # reg_tex_res = self.tex_extra.abs() # L1 loss can create noise textures | |
| # if len(self.cfg.model.occluded) > 0: | |
| # mask = (~self.flame_uvmask.get_uvmask_by_region(self.cfg.model.occluded)).float()[None, ...] | |
| # reg_tex_res *= mask | |
| log_dict["reg_tex_res"] = self.cfg.w.reg_tex_res * reg_tex_res.mean() | |
| if self.cfg.w.reg_tex_tv is not None: | |
| tex = self.get_albedo()[0] # (3, H, W) | |
| tv_y = (tex[..., :-1, :] - tex[..., 1:, :]) ** 2 | |
| tv_x = (tex[..., :, :-1] - tex[..., :, 1:]) ** 2 | |
| tv = tv_y.reshape(tv_y.shape[0], -1) + tv_x.reshape(tv_x.shape[0], -1) | |
| w_reg_tex_tv = self.cfg.w.reg_tex_tv * self.cfg.data.scale_factor ** 2 | |
| if self.cfg.data.n_downsample_rgb is not None: | |
| w_reg_tex_tv /= (self.cfg.data.n_downsample_rgb ** 2) | |
| log_dict["reg_tex_tv"] = w_reg_tex_tv * tv.mean() | |
| if self.cfg.w.reg_tex_res_clusters is not None: | |
| mask_sclerae = self.flame_uvmask.get_uvmask_by_region(self.cfg.w.reg_tex_res_for)[None, :, :] | |
| reg_tex_res_clusters = self.tex_extra ** 2 * mask_sclerae | |
| log_dict["reg_tex_res_clusters"] = self.cfg.w.reg_tex_res_clusters * reg_tex_res_clusters.mean() | |
| # lighting parameters regularization | |
| if self.opt_dict['lights']: | |
| if self.cfg.w.reg_light is not None and self.lights is not None: | |
| reg_light = (self.lights - self.lights_uniform) ** 2 | |
| log_dict["reg_light"] = self.cfg.w.reg_light * reg_light.mean() | |
| if self.cfg.w.reg_diffuse is not None and self.lights is not None: | |
| diffuse = result_dict['diffuse_detach_normal'] | |
| reg_diffuse = F.relu(diffuse.max() - 1) + diffuse.var(dim=1).mean() | |
| log_dict["reg_diffuse"] = self.cfg.w.reg_diffuse * reg_diffuse | |
| # offset regularization | |
| if self.opt_dict['static_offset'] or self.opt_dict['dynamic_offset']: | |
| if self.static_offset is not None or self.dynamic_offset is not None: | |
| offset = 0 | |
| if self.static_offset is not None: | |
| offset += self.static_offset | |
| if self.dynamic_offset is not None: | |
| offset += self.to_batch(self.dynamic_offset, indices) | |
| if self.cfg.w.reg_offset_lap is not None: | |
| # laplacian loss | |
| vert_wo_offset = (verts_cano - offset).detach() | |
| reg_offset_lap = self.compute_laplacian_smoothing_loss( | |
| vert_wo_offset, vert_wo_offset + offset | |
| ) | |
| if len(self.cfg.w.reg_offset_lap_relax_for) > 0: | |
| w = self.scale_vertex_weights_by_region( | |
| weights=torch.ones_like(verts[:, :, :1]), | |
| scale_factor=self.cfg.w.reg_offset_lap_relax_coef, | |
| region=self.cfg.w.reg_offset_lap_relax_for, | |
| ) | |
| reg_offset_lap *= w | |
| log_dict["reg_offset_lap"] = self.cfg.w.reg_offset_lap * reg_offset_lap.mean() | |
| if self.cfg.w.reg_offset is not None: | |
| # norm loss | |
| # reg_offset = offset.norm(dim=-1, keepdim=True) | |
| reg_offset = offset.abs() | |
| if len(self.cfg.w.reg_offset_relax_for) > 0: | |
| w = self.scale_vertex_weights_by_region( | |
| weights=torch.ones_like(verts[:, :, :1]), | |
| scale_factor=self.cfg.w.reg_offset_relax_coef, | |
| region=self.cfg.w.reg_offset_relax_for, | |
| ) | |
| reg_offset *= w | |
| log_dict["reg_offset"] = self.cfg.w.reg_offset * reg_offset.mean() | |
| if self.cfg.w.reg_offset_rigid is not None: | |
| reg_offset_rigid = 0 | |
| for region in self.cfg.w.reg_offset_rigid_for: | |
| vids = self.flame.mask.get_vid_by_region([region]) | |
| reg_offset_rigid += offset[:, vids, :].var(dim=-2).mean() | |
| log_dict["reg_offset_rigid"] = self.cfg.w.reg_offset_rigid * reg_offset_rigid | |
| if self.cfg.w.reg_offset_dynamic is not None and self.dynamic_offset is not None and self.opt_dict['dynamic_offset']: | |
| # The dynamic offset is regularized to be temporally smooth | |
| if frame_idx == 0: | |
| reg_offset_d = torch.zeros_like(self.dynamic_offset[0]) | |
| offset_d = self.dynamic_offset[0] | |
| else: | |
| reg_offset_d = torch.stack([self.dynamic_offset[0], self.dynamic_offset[frame_idx - 1]]) | |
| offset_d = self.dynamic_offset[frame_idx] | |
| reg_offset_dynamic = ((offset_d - reg_offset_d) ** 2).mean() | |
| log_dict["reg_offset_dynamic"] = self.cfg.w.reg_offset_dynamic * reg_offset_dynamic | |
| return log_dict | |
| def scale_vertex_weights_by_region(self, weights, scale_factor, region): | |
| indices = self.flame.mask.get_vid_by_region(region) | |
| weights[:, indices] *= scale_factor | |
| for _ in range(self.cfg.w.blur_iter): | |
| M = self.flame.laplacian_matrix_negate_diag[None, ...] | |
| weights = M.bmm(weights) / 2 | |
| return weights | |
| def compute_pose_smooth_energy(self, frame_idx, use_next_frame=False): | |
| """ | |
| Regularizes the global pose of the flame head model to be temporally smooth | |
| """ | |
| idx = frame_idx | |
| idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1) | |
| if use_next_frame: | |
| idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1) | |
| ref_indices = [idx_prev, idx_next] | |
| else: | |
| ref_indices = [idx_prev] | |
| E_trans = ((self.translation[[idx]] - self.translation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_trans | |
| E_rot = ((self.rotation[[idx]] - self.rotation[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_rot | |
| return E_trans + E_rot | |
| def compute_joint_smooth_energy(self, frame_idx, use_next_frame=False): | |
| """ | |
| Regularizes the joints of the flame head model to be temporally smooth | |
| """ | |
| idx = frame_idx | |
| idx_prev = np.clip(idx - 1, 0, self.n_timesteps - 1) | |
| if use_next_frame: | |
| idx_next = np.clip(idx + 1, 0, self.n_timesteps - 1) | |
| ref_indices = [idx_prev, idx_next] | |
| else: | |
| ref_indices = [idx_prev] | |
| E_joint_smooth = 0 | |
| E_joint_smooth += ((self.neck_pose[[idx]] - self.neck_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_neck | |
| E_joint_smooth += ((self.jaw_pose[[idx]] - self.jaw_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_jaw | |
| E_joint_smooth += ((self.eyes_pose[[idx]] - self.eyes_pose[ref_indices].detach()) ** 2).mean() * self.cfg.w.smooth_eyes | |
| return E_joint_smooth | |
| def compute_joint_prior_energy(self, frame_idx): | |
| """ | |
| Regularizes the joints of the flame head model towards neutral joint locations | |
| """ | |
| poses = [ | |
| ("neck", self.neck_pose[[frame_idx], :]), | |
| ("jaw", self.jaw_pose[[frame_idx], :]), | |
| ("eyes", self.eyes_pose[[frame_idx], :3]), | |
| ("eyes", self.eyes_pose[[frame_idx], 3:]), | |
| ] | |
| # Joints should are regularized towards neural | |
| E_joint_prior = 0 | |
| for name, pose in poses: | |
| # L2 regularization for each joint | |
| rotmats = batch_rodrigues(torch.cat([torch.zeros_like(pose), pose], dim=0)) | |
| diff = ((rotmats[[0]] - rotmats[1:]) ** 2).mean() | |
| # Additional regularization for physical plausibility | |
| if name == 'jaw': | |
| # penalize negative rotation along x axis of jaw | |
| diff += F.relu(-pose[:, 0]).mean() * 10 | |
| # penalize rotation along y and z axis of jaw | |
| diff += (pose[:, 1:] ** 2).mean() * 3 | |
| elif name == 'eyes': | |
| # penalize the difference between the two eyes | |
| diff += ((self.eyes_pose[[frame_idx], :3] - self.eyes_pose[[frame_idx], 3:]) ** 2).mean() | |
| E_joint_prior += diff * self.cfg.w[f"prior_{name}"] | |
| return E_joint_prior | |
| def compute_laplacian_smoothing_loss(self, verts, offset_verts): | |
| L = self.flame.laplacian_matrix[None, ...].detach() # (1, V, V) | |
| basis_lap = L.bmm(verts).detach() #.norm(dim=-1) * weights | |
| offset_lap = L.bmm(offset_verts) #.norm(dim=-1) # * weights | |
| diff = (offset_lap - basis_lap) ** 2 | |
| diff = diff.sum(dim=-1, keepdim=True) | |
| return diff | |
| def compute_energy( | |
| self, | |
| sample, | |
| frame_idx, | |
| include_keyframes=False, | |
| step_i=None, | |
| stage=None, | |
| ): | |
| """ | |
| Compute total energy for frame frame_idx | |
| :param sample: | |
| :param frame_idx: | |
| :param include_keyframes: if key frames shall be included when predicting the per | |
| frame energy | |
| :return: loss, log dict, predicted vertices and landmarks | |
| """ | |
| log_dict = {} | |
| gt_rgb = sample["rgb"] | |
| result_dict = {"gt_rgb": gt_rgb} | |
| verts, verts_cano, lmks, albedos = self.forward_flame(frame_idx, include_keyframes) | |
| faces = self.flame.faces | |
| if isinstance(sample["num_cameras"], list): | |
| num_cameras = sample["num_cameras"][0] | |
| else: | |
| num_cameras = sample["num_cameras"] | |
| # albedos = self.repeat_n_times(albedos, num_cameras) # only needed for pytorch3d renderer | |
| if self.cfg.w.landmark is not None: | |
| lmks_n = self.repeat_n_times(lmks, num_cameras) | |
| if not self.cfg.w.always_enable_jawline_landmarks and stage is not None: | |
| disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks'] | |
| else: | |
| disable_jawline_landmarks = False | |
| E_lmk, _result_dict = self.compute_lmk_energy(sample, lmks_n, disable_jawline_landmarks) | |
| log_dict["lmk"] = self.cfg.w.landmark * E_lmk | |
| result_dict.update(_result_dict) | |
| if stage is None or isinstance(self.cfg.pipeline[stage], PhotometricStageConfig): | |
| if self.cfg.w.photo is not None: | |
| verts_n = self.repeat_n_times(verts, num_cameras) | |
| rast_dict = self.rasterize_flame( | |
| sample, verts_n, self.flame.faces, train_mode=True | |
| ) | |
| photo_energy_func = self.compute_photometric_energy | |
| E_photo, _result_dict = photo_energy_func( | |
| sample, | |
| verts, | |
| faces, | |
| albedos, | |
| rast_dict, | |
| step_i, | |
| stage, | |
| include_keyframes, | |
| ) | |
| result_dict.update(_result_dict) | |
| log_dict["photo"] = self.cfg.w.photo * E_photo | |
| if stage is not None: | |
| _log_dict = self.compute_regularization_energy( | |
| result_dict, verts, verts_cano, lmks, albedos, frame_idx, include_keyframes, stage | |
| ) | |
| log_dict.update(_log_dict) | |
| E_total = torch.stack([v for k, v in log_dict.items()]).sum() | |
| log_dict["total"] = E_total | |
| return E_total, log_dict, verts, faces, lmks, albedos, result_dict | |
| def to_batch(x, indices): | |
| return torch.stack([x[i] for i in indices]) | |
| def repeat_n_times(x: torch.Tensor, n: int): | |
| """Expand a tensor from shape [F, ...] to [F*n, ...]""" | |
| return x.unsqueeze(1).repeat_interleave(n, dim=1).reshape(-1, *x.shape[1:]) | |
| def log_scalars( | |
| self, | |
| log_dict, | |
| frame_idx, | |
| session: Literal["train", "eval"] = "train", | |
| stage=None, | |
| frame_step=None, | |
| # step_in_stage=None, | |
| ): | |
| """ | |
| Logs scalars in log_dict to tensorboard and self.logger | |
| :param log_dict: | |
| :param frame_idx: | |
| :param step_i: | |
| :return: | |
| """ | |
| if not self.calibrated and stage is not None and 'cam' in self.cfg.pipeline[stage].optimizable_params: | |
| log_dict["focal_length"] = self.focal_length.squeeze(0) | |
| log_msg = "" | |
| if session == "train": | |
| global_step = self.global_step | |
| else: | |
| global_step = frame_idx | |
| for k, v in log_dict.items(): | |
| if not k.startswith("decay"): | |
| log_msg += "{}: {:.4f} ".format(k, v) | |
| if self.tb_writer is not None: | |
| self.tb_writer.add_scalar(f"{session}/{k}", v, global_step) | |
| if session == "train": | |
| assert stage is not None | |
| if frame_step is not None: | |
| msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {frame_step}: " | |
| else: | |
| msg_prefix = f"[{session}-{stage}] frame {frame_idx} step {self.global_step}: " | |
| elif session == "eval": | |
| msg_prefix = f"[{session}] frame {frame_idx}: " | |
| self.logger.info(msg_prefix + log_msg) | |
| def save_obj_with_texture(self, vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path): | |
| # Save the texture image | |
| torchvision.utils.save_image(albedos.squeeze(0), texture_path) | |
| # Create the MTL file | |
| with open(mtl_path, 'w') as f: | |
| f.write(get_mtl_content(texture_path.name)) | |
| # Create the obj file | |
| with open(obj_path, 'w') as f: | |
| f.write(get_obj_content(vertices, faces, uv_coordinates, uv_indices, mtl_path.name)) | |
| def async_func(func): | |
| """Decorator to run a function asynchronously""" | |
| def wrapper(*args, **kwargs): | |
| self = args[0] | |
| if self.cfg.async_func: | |
| thread = threading.Thread(target=func, args=args, kwargs=kwargs) | |
| thread.start() | |
| else: | |
| func(*args, **kwargs) | |
| return wrapper | |
| def log_media( | |
| self, | |
| verts: torch.tensor, | |
| faces: torch.tensor, | |
| lmks: torch.tensor, | |
| albedos: torch.tensor, | |
| output_dict: dict, | |
| sample: dict, | |
| frame_idx: int, | |
| session: str, | |
| stage: Optional[str]=None, | |
| frame_step: int=None, | |
| epoch=None, | |
| ): | |
| """ | |
| Logs current tracking visualization to tensorboard | |
| :param verts: | |
| :param lmks: | |
| :param sample: | |
| :param frame_idx: | |
| :param frame_step: | |
| :param show_lmks: | |
| :param show_overlay: | |
| :return: | |
| """ | |
| tic = time.time() | |
| prepare_output_path = partial( | |
| self.prepare_output_path, | |
| session=session, | |
| frame_idx=frame_idx, | |
| stage=stage, | |
| step=frame_step, | |
| epoch=epoch, | |
| ) | |
| """images""" | |
| if not self.cfg.w.always_enable_jawline_landmarks and stage is not None: | |
| disable_jawline_landmarks = self.cfg.pipeline[stage]['disable_jawline_landmarks'] | |
| else: | |
| disable_jawline_landmarks = False | |
| img = self.visualize_tracking(verts, lmks, albedos, output_dict, sample, disable_jawline_landmarks=disable_jawline_landmarks) | |
| img_path = prepare_output_path(folder_name="image_grid", file_type=self.cfg.log.image_format) | |
| torchvision.utils.save_image(img, img_path) | |
| """meshes""" | |
| texture_path = prepare_output_path(folder_name="mesh", file_type=self.cfg.log.image_format) | |
| mtl_path = prepare_output_path(folder_name="mesh", file_type="mtl") | |
| obj_path = prepare_output_path(folder_name="mesh", file_type="obj") | |
| vertices = verts.squeeze(0).detach().cpu().numpy() | |
| faces = faces.detach().cpu().numpy() | |
| uv_coordinates = self.flame.verts_uvs.cpu().numpy() | |
| uv_indices = self.flame.textures_idx.cpu().numpy() | |
| self.save_obj_with_texture(vertices, faces, uv_coordinates, uv_indices, albedos, obj_path, mtl_path, texture_path) | |
| """""" | |
| toc = time.time() - tic | |
| if stage is not None: | |
| msg_prefix = f"[{session}-{stage}] frame {frame_idx}" | |
| else: | |
| msg_prefix = f"[{session}] frame {frame_idx}" | |
| if frame_step is not None: | |
| msg_prefix += f" step {frame_step}" | |
| self.logger.info(f"{msg_prefix}: Logging media took {toc:.2f}s") | |
| def visualize_tracking( | |
| self, | |
| verts, | |
| lmks, | |
| albedos, | |
| output_dict, | |
| sample, | |
| return_imgs_seperately=False, | |
| disable_jawline_landmarks=False, | |
| ): | |
| """ | |
| Visualizes the tracking result | |
| """ | |
| if len(self.cfg.log.view_indices) > 0: | |
| view_indices = torch.tensor(self.cfg.log.view_indices) | |
| else: | |
| num_views = sample["rgb"].shape[0] | |
| if num_views > 1: | |
| step = (num_views - 1) // (self.cfg.log.max_num_views - 1) | |
| view_indices = torch.arange(0, num_views, step=step) | |
| else: | |
| view_indices = torch.tensor([0]) | |
| num_views_log = len(view_indices) | |
| imgs = [] | |
| # rgb | |
| gt_rgb = output_dict["gt_rgb"][view_indices].cpu() | |
| transfm = torchvision.transforms.Resize(gt_rgb.shape[-2:]) | |
| imgs += [img[None] for img in gt_rgb] | |
| if "pred_rgb" in output_dict: | |
| pred_rgb = transfm(output_dict["pred_rgb"][view_indices].cpu()) | |
| pred_rgb = torch.clip(pred_rgb, min=0, max=1) | |
| imgs += [img[None] for img in pred_rgb] | |
| if "error_rgb" in output_dict: | |
| error_rgb = transfm(output_dict["error_rgb"][view_indices].cpu()) | |
| error_rgb = error_rgb.mean(dim=1) / 2 + 0.5 | |
| cmap = cm.get_cmap("seismic") | |
| error_rgb = cmap(error_rgb.cpu()) | |
| error_rgb = torch.from_numpy(error_rgb[..., :3]).to(gt_rgb).permute(0, 3, 1, 2) | |
| imgs += [img[None] for img in error_rgb] | |
| # cluster id | |
| if "cid" in output_dict: | |
| cid = transfm(output_dict["cid"][view_indices].cpu()) | |
| cid = cid / cid.max() | |
| cid = cid.expand(-1, 3, -1, -1).clone() | |
| pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) | |
| bg = pred_alpha == 0 | |
| cid[bg] = 1 | |
| imgs += [img[None] for img in cid] | |
| # albedo | |
| if "albedo" in output_dict: | |
| albedo = transfm(output_dict["albedo"][view_indices].cpu()) | |
| albedo = torch.clip(albedo, min=0, max=1) | |
| pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) | |
| bg = pred_alpha == 0 | |
| albedo[bg] = 1 | |
| imgs += [img[None] for img in albedo] | |
| # normal | |
| if "normal" in output_dict: | |
| normal = transfm(output_dict["normal"][view_indices].cpu()) | |
| normal = torch.clip(normal/2+0.5, min=0, max=1) | |
| imgs += [img[None] for img in normal] | |
| # diffuse | |
| diffuse = None | |
| if self.cfg.render.lighting_type != 'constant' and "diffuse" in output_dict: | |
| diffuse = transfm(output_dict["diffuse"][view_indices].cpu()) | |
| diffuse = torch.clip(diffuse, min=0, max=1) | |
| imgs += [img[None] for img in diffuse] | |
| # aa | |
| if "aa" in output_dict: | |
| aa = transfm(output_dict["aa"][view_indices].cpu()) | |
| aa = torch.clip(aa, min=0, max=1) | |
| imgs += [img[None] for img in aa] | |
| # alpha | |
| if "gt_alpha" in output_dict: | |
| gt_alpha = transfm(output_dict["gt_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) | |
| imgs += [img[None] for img in gt_alpha] | |
| if "pred_alpha" in output_dict: | |
| pred_alpha = transfm(output_dict["pred_alpha"][view_indices].cpu()).expand(-1, 3, -1, -1) | |
| color_alpha = torch.tensor([0.2, 0.5, 1])[None, :, None, None] | |
| fg_mask = (pred_alpha > 0).float() | |
| if diffuse is not None: | |
| fg_mask *= diffuse | |
| w = 0.7 | |
| overlay_alpha = fg_mask * (w * color_alpha * pred_alpha + (1-w) * gt_rgb) \ | |
| + (1 - fg_mask) * gt_rgb | |
| imgs += [img[None] for img in overlay_alpha] | |
| if "error_alpha" in output_dict: | |
| error_alpha = transfm(output_dict["error_alpha"][view_indices].cpu()) | |
| error_alpha = error_alpha.mean(dim=1) / 2 + 0.5 | |
| cmap = cm.get_cmap("seismic") | |
| error_alpha = cmap(error_alpha.cpu()) | |
| error_alpha = ( | |
| torch.from_numpy(error_alpha[..., :3]).to(gt_rgb).permute(0, 3, 1, 2) | |
| ) | |
| imgs += [img[None] for img in error_alpha] | |
| else: | |
| error_alpha = None | |
| # landmark | |
| vis_lmk = self.visualize_landmarks(gt_rgb, output_dict, view_indices, disable_jawline_landmarks) | |
| if vis_lmk is not None: | |
| imgs += [img[None] for img in vis_lmk] | |
| # ---------------- | |
| num_types = len(imgs) // len(view_indices) | |
| if return_imgs_seperately: | |
| return imgs | |
| else: | |
| if self.cfg.log.stack_views_in_rows: | |
| imgs = [imgs[j * num_views_log + i] for i in range(num_views_log) for j in range(num_types)] | |
| imgs = torch.cat(imgs, dim=0).cpu() | |
| return torchvision.utils.make_grid(imgs, nrow=num_types) | |
| else: | |
| imgs = torch.cat(imgs, dim=0).cpu() | |
| return torchvision.utils.make_grid(imgs, nrow=num_views_log) | |
| def visualize_landmarks(self, gt_rgb, output_dict, view_indices=torch.tensor([0]), disable_jawline_landmarks=False): | |
| h, w = gt_rgb.shape[-2:] | |
| unit = h / 750 | |
| wh = torch.tensor([[[w, h]]]) | |
| vis_lmk = None | |
| if "gt_lmk2d" in output_dict: | |
| gt_lmk2d = (output_dict['gt_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh | |
| if disable_jawline_landmarks: | |
| gt_lmk2d = gt_lmk2d[:, 17:68] | |
| else: | |
| gt_lmk2d = gt_lmk2d[:, :68] | |
| vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk | |
| for i in range(len(view_indices)): | |
| vis_lmk[i] = plot_landmarks_2d( | |
| vis_lmk[i].clone(), | |
| gt_lmk2d[[i]], | |
| colors="green", | |
| unit=unit, | |
| input_float=True, | |
| ).to(vis_lmk[i]) | |
| if "pred_lmk2d" in output_dict: | |
| pred_lmk2d = (output_dict['pred_lmk2d'][view_indices].cpu() * 0.5 + 0.5) * wh | |
| if disable_jawline_landmarks: | |
| pred_lmk2d = pred_lmk2d[:, 17:68] | |
| else: | |
| pred_lmk2d = pred_lmk2d[:, :68] | |
| vis_lmk = gt_rgb.clone() if vis_lmk is None else vis_lmk | |
| for i in range(len(view_indices)): | |
| vis_lmk[i] = plot_landmarks_2d( | |
| vis_lmk[i].clone(), | |
| pred_lmk2d[[i]], | |
| colors="red", | |
| unit=unit, | |
| input_float=True, | |
| ).to(vis_lmk[i]) | |
| return vis_lmk | |
| def evaluate(self, make_visualization=True, epoch=0): | |
| # always save parameters before evaluation | |
| self.save_result(epoch=epoch) | |
| self.logger.info("Started Evaluation") | |
| # vid_frames = [] | |
| photo_loss = [] | |
| for frame_idx in range(self.n_timesteps): | |
| sample = self.get_current_frame(frame_idx, include_keyframes=False) | |
| self.clear_cache() | |
| self.fill_cam_params_into_sample(sample) | |
| ( | |
| E_total, | |
| log_dict, | |
| verts, | |
| faces, | |
| lmks, | |
| albedos, | |
| output_dict, | |
| ) = self.compute_energy(sample, frame_idx) | |
| self.log_scalars(log_dict, frame_idx, session="eval") | |
| photo_loss.append(log_dict["photo"].item()) | |
| if make_visualization: | |
| self.log_media( | |
| verts, | |
| faces, | |
| lmks, | |
| albedos, | |
| output_dict, | |
| sample, | |
| frame_idx, | |
| session="eval", | |
| epoch=epoch, | |
| ) | |
| self.tb_writer.add_scalar(f"eval_mean/photo", np.mean(photo_loss), epoch) | |
| def prepare_output_path(self, session, frame_idx, folder_name, file_type, stage=None, step=None, epoch=None): | |
| if epoch is not None: | |
| output_folder = self.out_dir / f'{session}_{epoch}' / folder_name | |
| else: | |
| output_folder = self.out_dir / session / folder_name | |
| os.makedirs(output_folder, exist_ok=True) | |
| if stage is not None: | |
| assert step is not None | |
| fname = "frame_{:05d}_{:03d}_{}.{}".format(frame_idx, step, stage, file_type) | |
| else: | |
| fname = "frame_{:05d}.{}".format(frame_idx, file_type) | |
| return output_folder / fname | |
| def save_result(self, fname=None, epoch=None): | |
| """ | |
| Saves tracked/optimized flame parameters. | |
| :return: | |
| """ | |
| # save parameters | |
| keys = [ | |
| "rotation", | |
| "translation", | |
| "neck_pose", | |
| "jaw_pose", | |
| "eyes_pose", | |
| "shape", | |
| "expr", | |
| "timestep_id", | |
| "n_processed_frames", | |
| ] | |
| values = [ | |
| self.rotation, | |
| self.translation, | |
| self.neck_pose, | |
| self.jaw_pose, | |
| self.eyes_pose, | |
| self.shape, | |
| self.expr, | |
| np.array(self.dataset.timestep_ids), | |
| self.frame_idx, | |
| ] | |
| if not self.calibrated: | |
| keys += ["focal_length"] | |
| values += [self.focal_length] | |
| if not self.cfg.model.tex_painted: | |
| keys += ["tex"] | |
| values += [self.tex_pca] | |
| if self.cfg.model.tex_extra: | |
| keys += ["tex_extra"] | |
| values += [self.tex_extra] | |
| if self.lights is not None: | |
| keys += ["lights"] | |
| values += [self.lights] | |
| if self.cfg.model.use_static_offset: | |
| keys += ["static_offset"] | |
| values += [self.static_offset] | |
| if self.cfg.model.use_dynamic_offset: | |
| keys += ["dynamic_offset"] | |
| values += [self.dynamic_offset] | |
| export_dict = {} | |
| for k, v in zip(keys, values): | |
| if not isinstance(v, np.ndarray): | |
| if isinstance(v, list): | |
| v = torch.stack(v) | |
| if isinstance(v, torch.Tensor): | |
| v = v.detach().cpu().numpy() | |
| export_dict[k] = v | |
| export_dict["image_size"] = np.array(self.image_size) | |
| fname = fname if fname is not None else "tracked_flame_params" | |
| if epoch is not None: | |
| fname = f"{fname}_{epoch}" | |
| np.savez(self.out_dir / f'{fname}.npz', **export_dict) | |
| class GlobalTracker(FlameTracker): | |
| def __init__(self, cfg: BaseTrackingConfig): | |
| super().__init__(cfg) | |
| self.calibrated = cfg.data.calibrated | |
| # logging | |
| out_dir = cfg.exp.output_folder / datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| out_dir.mkdir(parents=True,exist_ok=True) | |
| self.frame_idx = self.cfg.begin_frame_idx | |
| self.out_dir = out_dir | |
| self.tb_writer = SummaryWriter(self.out_dir) | |
| self.log_interval_scalar = self.cfg.log.interval_scalar | |
| self.log_interval_media = self.cfg.log.interval_media | |
| config_yaml_path = out_dir / 'config.yml' | |
| config_yaml_path.write_text(yaml.dump(cfg), "utf8") | |
| print(tyro.to_yaml(cfg)) | |
| self.logger = get_logger(__name__, root=True, log_dir=out_dir) | |
| # data | |
| self.dataset = import_module(cfg.data._target)( | |
| cfg=cfg.data, | |
| img_to_tensor=True, | |
| batchify_all_views=True, # important to optimized all views together | |
| ) | |
| # FlameTracker expects all views of a frame in a batch, which is undertaken by the | |
| # dataset. Therefore batching is disabled for the dataloader | |
| self.image_size = self.dataset[0]["rgb"].shape[-2:] | |
| self.n_timesteps = len(self.dataset) | |
| # parameters | |
| self.init_params() | |
| if self.cfg.model.flame_params_path is not None: | |
| self.load_from_tracked_flame_params(self.cfg.model.flame_params_path) | |
| def init_params(self): | |
| train_tensors = [] | |
| # flame model params | |
| self.shape = torch.zeros(self.cfg.model.n_shape).to(self.device) | |
| self.expr = torch.zeros(self.n_timesteps, self.cfg.model.n_expr).to(self.device) | |
| # joint axis angles | |
| self.neck_pose = torch.zeros(self.n_timesteps, 3).to(self.device) | |
| self.jaw_pose = torch.zeros(self.n_timesteps, 3).to(self.device) | |
| self.eyes_pose = torch.zeros(self.n_timesteps, 6).to(self.device) | |
| # rigid pose | |
| self.translation = torch.zeros(self.n_timesteps, 3).to(self.device) | |
| self.rotation = torch.zeros(self.n_timesteps, 3).to(self.device) | |
| # texture and lighting params | |
| self.tex_pca = torch.zeros(self.cfg.model.n_tex).to(self.device) | |
| if self.cfg.model.tex_extra: | |
| res = self.cfg.model.tex_resolution | |
| self.tex_extra = torch.zeros(3, res, res).to(self.device) | |
| if self.cfg.render.lighting_type == 'SH': | |
| self.lights_uniform = torch.zeros(9, 3).to(self.device) | |
| self.lights_uniform[0] = torch.tensor([np.sqrt(4 * np.pi)]).expand(3).float().to(self.device) | |
| self.lights = self.lights_uniform.clone() | |
| else: | |
| self.lights = None | |
| train_tensors += ( | |
| [self.shape, self.translation, self.rotation, self.neck_pose, self.jaw_pose, self.eyes_pose, self.expr,] | |
| ) | |
| if not self.cfg.model.tex_painted: | |
| train_tensors += [self.tex_pca] | |
| if self.cfg.model.tex_extra: | |
| train_tensors += [self.tex_extra] | |
| if self.lights is not None: | |
| train_tensors += [self.lights] | |
| if self.cfg.model.use_static_offset: | |
| self.static_offset = torch.zeros(1, self.flame.v_template.shape[0], 3).to(self.device) | |
| train_tensors += [self.static_offset] | |
| else: | |
| self.static_offset = None | |
| if self.cfg.model.use_dynamic_offset: | |
| self.dynamic_offset = torch.zeros(self.n_timesteps, self.flame.v_template.shape[0], 3).to(self.device) | |
| train_tensors += self.dynamic_offset | |
| else: | |
| self.dynamic_offset = None | |
| # camera definition | |
| if not self.calibrated: | |
| # K contains focal length and principle point | |
| self.focal_length = torch.tensor([1.5]).to(self.device) | |
| self.RT = torch.eye(3, 4).to(self.device) | |
| self.RT[2, 3] = -1 # (0, 0, -1) in w2c corresponds to (0, 0, 1) in c2w | |
| train_tensors += [self.focal_length] | |
| for t in train_tensors: | |
| t.requires_grad = True | |
| def optimize(self): | |
| """ | |
| Optimizes flame parameters on all frames of the dataset with random rampling | |
| :return: | |
| """ | |
| self.global_step = 0 | |
| # first initialize frame either from calibration or previous frame | |
| # with torch.no_grad(): | |
| # self.initialize_frame(frame_idx) | |
| # sequential optimization of timesteps | |
| self.logger.info(f"Start sequential tracking FLAME in {self.n_timesteps} frames") | |
| dataloader = DataLoader(self.dataset, batch_size=None, shuffle=False, num_workers=0) | |
| for sample in dataloader: | |
| timestep = sample["timestep_index"][0].item() | |
| if timestep == 0: | |
| self.optimize_stage('lmk_init_rigid', sample) | |
| self.optimize_stage('lmk_init_all', sample) | |
| if self.cfg.exp.photometric: | |
| self.optimize_stage('rgb_init_texture', sample) | |
| self.optimize_stage('rgb_init_all', sample) | |
| if self.cfg.model.use_static_offset: | |
| self.optimize_stage('rgb_init_offset', sample) | |
| if self.cfg.exp.photometric: | |
| self.optimize_stage('rgb_sequential_tracking', sample) | |
| else: | |
| self.optimize_stage('lmk_sequential_tracking', sample) | |
| self.initialize_next_timtestep(timestep) | |
| self.evaluate(make_visualization=False, epoch=0) | |
| self.logger.info(f"Start global optimization of all frames") | |
| # global optimization with random sampling | |
| dataloader = DataLoader(self.dataset, batch_size=None, shuffle=True, num_workers=0) | |
| if self.cfg.exp.photometric: | |
| self.optimize_stage(stage='rgb_global_tracking', dataloader=dataloader, lr_scale=0.1) | |
| else: | |
| self.optimize_stage(stage='lmk_global_tracking', dataloader=dataloader, lr_scale=0.1) | |
| self.logger.info("All done.") | |
| def optimize_stage( | |
| self, | |
| stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_texture', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'], | |
| sample = None, | |
| dataloader = None, | |
| lr_scale = 1.0, | |
| ): | |
| params = self.get_train_parameters(stage) | |
| optimizer = self.configure_optimizer(params, lr_scale=lr_scale) | |
| if sample is not None: | |
| num_steps = self.cfg.pipeline[stage].num_steps | |
| for step_i in range(num_steps): | |
| self.optimize_iter(sample, optimizer, stage) | |
| else: | |
| assert dataloader is not None | |
| num_epochs = self.cfg.pipeline[stage].num_epochs | |
| scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9) | |
| for epoch_i in range(num_epochs): | |
| self.logger.info(f"EPOCH {epoch_i+1} / {num_epochs}") | |
| for step_i, sample in enumerate(dataloader): | |
| self.optimize_iter(sample, optimizer, stage) | |
| scheduler.step() | |
| if (epoch_i + 1) % 10 == 0: | |
| self.evaluate(make_visualization=True, epoch=epoch_i+1) | |
| def optimize_iter(self, sample, optimizer, stage): | |
| # compute loss and update parameters | |
| self.clear_cache() | |
| timestep_index = sample["timestep_index"][0] | |
| self.fill_cam_params_into_sample(sample) | |
| ( | |
| E_total, | |
| log_dict, | |
| verts, | |
| faces, | |
| lmks, | |
| albedos, | |
| output_dict, | |
| ) = self.compute_energy( | |
| sample, frame_idx=timestep_index, stage=stage, | |
| ) | |
| optimizer.zero_grad() | |
| E_total.backward() | |
| optimizer.step() | |
| # log energy terms and visualize | |
| if (self.global_step+1) % self.log_interval_scalar == 0: | |
| self.log_scalars( | |
| log_dict, | |
| timestep_index, | |
| session="train", | |
| stage=stage, | |
| frame_step=self.global_step, | |
| ) | |
| if (self.global_step+1) % self.log_interval_media == 0: | |
| self.log_media( | |
| verts, | |
| faces, | |
| lmks, | |
| albedos, | |
| output_dict, | |
| sample, | |
| timestep_index, | |
| session="train", | |
| stage=stage, | |
| frame_step=self.global_step, | |
| ) | |
| del verts, faces, lmks, albedos, output_dict | |
| self.global_step += 1 | |
| def get_train_parameters( | |
| self, stage: Literal['lmk_init_rigid', 'lmk_init_all', 'rgb_init_all', 'rgb_init_offset', 'rgb_sequential_tracking', 'rgb_global_tracking'], | |
| ): | |
| """ | |
| Collects the parameters to be optimized for the current frame | |
| :return: dict of parameters | |
| """ | |
| self.opt_dict = defaultdict(bool) # dict to keep track of which parameters are optimized | |
| for p in self.cfg.pipeline[stage].optimizable_params: | |
| self.opt_dict[p] = True | |
| params = defaultdict(list) # dict to collect parameters to be optimized | |
| # shared properties | |
| if self.opt_dict["cam"] and not self.calibrated: | |
| params["cam"] = [self.focal_length] | |
| if self.opt_dict["shape"]: | |
| params["shape"] = [self.shape] | |
| if self.opt_dict["texture"]: | |
| if not self.cfg.model.tex_painted: | |
| params["tex"] = [self.tex_pca] | |
| if self.cfg.model.tex_extra: | |
| params["tex_extra"] = [self.tex_extra] | |
| if self.opt_dict["static_offset"] and self.cfg.model.use_static_offset: | |
| params["static_offset"] = [self.static_offset] | |
| if self.opt_dict["lights"] and self.lights is not None: | |
| params["lights"] = [self.lights] | |
| # per-frame properties | |
| if self.opt_dict["pose"]: | |
| params["translation"].append(self.translation) | |
| params["rotation"].append(self.rotation) | |
| if self.opt_dict["joints"]: | |
| params["eyes"].append(self.eyes_pose) | |
| params["neck"].append(self.neck_pose) | |
| params["jaw"].append(self.jaw_pose) | |
| if self.opt_dict["expr"]: | |
| params["expr"].append(self.expr) | |
| if self.opt_dict["dynamic_offset"] and self.cfg.model.use_dynamic_offset: | |
| params["dynamic_offset"].append(self.dynamic_offset) | |
| return params | |
| def initialize_next_timtestep(self, timestep): | |
| if timestep < self.n_timesteps - 1: | |
| self.translation[timestep + 1].data.copy_(self.translation[timestep]) | |
| self.rotation[timestep + 1].data.copy_(self.rotation[timestep]) | |
| self.neck_pose[timestep + 1].data.copy_(self.neck_pose[timestep]) | |
| self.jaw_pose[timestep + 1].data.copy_(self.jaw_pose[timestep]) | |
| self.eyes_pose[timestep + 1].data.copy_(self.eyes_pose[timestep]) | |
| self.expr[timestep + 1].data.copy_(self.expr[timestep]) | |
| if self.cfg.model.use_dynamic_offset: | |
| self.dynamic_offset[timestep + 1].data.copy_(self.dynamic_offset[timestep]) | |