| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from pathlib import Path |
| | import numpy as np |
| | from tqdm import tqdm |
| | import mmcv |
| | import numpy as np |
| | import os |
| | from tqdm import tqdm |
| | import cv2 |
| | import os.path |
| | from functools import reduce |
| | from pathlib import Path |
| | from loguru import logger |
| | import face_alignment |
| | import mmcv |
| | from pathlib import Path |
| | import numpy as np |
| | from tqdm import tqdm |
| | import mmcv |
| | import numpy as np |
| | import os |
| | import os.path as osp |
| | from tqdm import tqdm |
| | import cv2 |
| | import glob |
| | import os.path |
| | from functools import reduce |
| | from pathlib import Path |
| | from loguru import logger |
| | import face_alignment |
| | import mmcv |
| |
|
| | from SHOW.utils.video import images_to_video |
| | from torchvision.transforms.functional import gaussian_blur |
| | from pytorch3d.transforms import axis_angle_to_matrix |
| | from pytorch3d.renderer import RasterizationSettings, PointLights, MeshRenderer, MeshRasterizer, TexturesVertex, SoftPhongShader, look_at_view_transform, PerspectiveCameras |
| | from pytorch3d.transforms import axis_angle_to_matrix |
| | from pytorch3d.io import load_obj |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from torch.utils.data import DataLoader |
| | import torch.backends.cudnn |
| | import torch.nn.functional as F |
| |
|
| | import SHOW |
| | from SHOW.utils import default_timers |
| | from SHOW.datasets import op_base |
| | from SHOW.detector.face_detector import FaceDetector |
| | from SHOW.load_models import load_smplx_model, load_vposer_model |
| | from SHOW.save_results import save_one_results |
| | from SHOW.load_models import load_save_pkl |
| | from SHOW.flame.FLAME import FLAMETex |
| | from SHOW.smplx_dataset import ImagesDataset |
| | from SHOW.renderer import Renderer |
| | from SHOW.load_assets import load_assets |
| | from SHOW.loggers.logger import setup_logger |
| | from SHOW.save_tracker import save_tracker |
| | from SHOW.utils import is_valid_json |
| | from configs.cfg_ins import condor_cfg |
| |
|
| |
|
| | @logger.catch |
| | def SHOW_stage2(*args, **kwargs): |
| |
|
| | machine_info = SHOW.get_machine_info() |
| | import pprint |
| | pprint.pprint(f'machine_info: {machine_info}') |
| |
|
| | loggers = kwargs.get('loggers', None) |
| |
|
| | tracker_cfg = SHOW.from_rela_path(__file__, |
| | './configs/mmcv_tracker_config.py') |
| | tracker_cfg.update(**kwargs) |
| | tracker_cfg.merge_from_dict(condor_cfg) |
| |
|
| | if tracker_cfg.get('over_write_cfg', None): |
| | tracker_cfg.update(tracker_cfg.over_write_cfg) |
| | |
| |
|
| | mmcv.dump(tracker_cfg, tracker_cfg.tracker_cfg_path) |
| | |
| | |
| | try: |
| | gpu_mem = machine_info['gpu_info']['gpu_Total'] |
| |
|
| | import platform |
| | if platform.system() == 'Linux': |
| | |
| | tracker_cfg.bs_at_a_time = int(50.0 * gpu_mem / (80.0 * 1024)) |
| | logger.warning(f'bs_at_a_time: {tracker_cfg.bs_at_a_time}') |
| | except: |
| | import traceback |
| | traceback.print_exc() |
| |
|
| | Path(tracker_cfg.mica_save_path).mkdir(exist_ok=True, parents=True) |
| | Path(tracker_cfg.mica_org_out_path).mkdir(exist_ok=True, parents=True) |
| |
|
| | iters = tracker_cfg.iters |
| | sampling = tracker_cfg.sampling |
| | device = tracker_cfg.device |
| | tracker_cfg.dtype = dtype = SHOW.str_to_torch_dtype(tracker_cfg.dtype) |
| |
|
| | face_ider = SHOW.build_ider(tracker_cfg.ider_cfg) |
| | img_folder = tracker_cfg.img_folder |
| | template_im = os.listdir(img_folder)[0] |
| | template_im = os.path.join(img_folder, template_im) |
| |
|
| | assets = load_assets( |
| | tracker_cfg, |
| | face_ider=face_ider, |
| | template_im=template_im, |
| | ) |
| | if assets is None: |
| | return |
| |
|
| | setup_logger(tracker_cfg.mica_all_dir, filename='mica.log', mode='o') |
| |
|
| | |
| | if not Path(tracker_cfg.ours_pkl_file_path).exists(): |
| | logger.warning( |
| | f'ours_pkl_file_path not exists: {tracker_cfg.ours_pkl_file_path}') |
| | return False |
| |
|
| | if not is_valid_json(tracker_cfg.final_losses_json_path): |
| | logger.warning( |
| | f'final_losses_json_path not valid: {tracker_cfg.final_losses_json_path}' |
| | ) |
| | return False |
| |
|
| | |
| | with default_timers['build_vars_stage']: |
| | face_ider = SHOW.build_ider(tracker_cfg.ider_cfg) |
| | person_face_emb = assets.person_face_emb |
| | face_detector_mediapipe = FaceDetector('google', device=device) |
| | face_detector = face_alignment.FaceAlignment( |
| | face_alignment.LandmarksType._2D, device=device) |
| |
|
| | body_model = load_smplx_model(dtype=dtype, **tracker_cfg.smplx_cfg) |
| | body_params_dict = load_save_pkl(tracker_cfg.ours_pkl_file_path, |
| | device) |
| |
|
| | width = body_params_dict['width'] |
| | height = body_params_dict['height'] |
| | center = body_params_dict['center'] |
| | camera_transl = body_params_dict['camera_transl'] |
| | focal_length = body_params_dict['focal_length'] |
| | total_batch_size = body_params_dict['batch_size'] |
| |
|
| | opt_bs = tracker_cfg.bs_at_a_time |
| | opt_iters = total_batch_size // opt_bs |
| | st_et_list = [] |
| | for i in range(opt_iters): |
| | st = i * opt_bs |
| | et = (i + 1) * opt_bs |
| | if et > total_batch_size - 1: |
| | et = total_batch_size - 1 |
| | st_et_list.append((st, et)) |
| |
|
| | op = op_base() |
| | smplx2flame_idx = assets.smplx2flame_idx |
| |
|
| | mesh_file = Path(__file__).parent.joinpath( |
| | '../data/head_template_mesh.obj') |
| |
|
| | diff_renderer = Renderer(torch.Tensor([[512, 512]]), |
| | obj_filename=mesh_file) |
| |
|
| | flame_faces = load_obj(mesh_file)[1] |
| | flametex = FLAMETex(tracker_cfg.flame_cfg).to(device) |
| |
|
| | mesh_rasterizer = MeshRasterizer( |
| | raster_settings=RasterizationSettings(image_size=[512, 512], |
| | faces_per_pixel=1, |
| | cull_backfaces=True, |
| | perspective_correct=True)) |
| |
|
| | debug_renderer = MeshRenderer( |
| | rasterizer=mesh_rasterizer, |
| | shader=SoftPhongShader(device=device, |
| | lights=PointLights( |
| | device=device, |
| | location=((0.0, 0.0, -5.0), ), |
| | ambient_color=((0.5, 0.5, 0.5), ), |
| | diffuse_color=((0.5, 0.5, 0.5), )))) |
| |
|
| | pre_frame_exp = None |
| | for opt_idx, (start_frame, end_frame) in enumerate(st_et_list): |
| | if assets.person_face_emb is not None: |
| | mica_part_file_path = f'w_mica_part_{start_frame}_{end_frame}_{opt_idx}_{opt_iters}.pkl' |
| | mica_part_pkl_path = os.path.join(tracker_cfg.mica_all_dir, |
| | mica_part_file_path) |
| |
|
| | if Path(mica_part_pkl_path).exists(): |
| | logger.info( |
| | f'mica_part_pkl_path exists,skipping: {mica_part_pkl_path}' |
| | ) |
| | pre_con = mmcv.load(mica_part_pkl_path) |
| | pre_frame_exp = pre_con['expression'][-1] |
| | pre_frame_exp = torch.Tensor(pre_frame_exp).to(device) |
| | continue |
| |
|
| | opt_bs = end_frame - start_frame |
| |
|
| | com_tex = torch.zeros(1, 150).to('cuda') |
| | com_sh = torch.zeros(1, 9, 3).to('cuda') |
| |
|
| | use_shared_tex = 1 |
| | if not use_shared_tex: |
| | opt_bs_tex = nn.Parameter(com_tex).expand(opt_bs, -1).detach() |
| | else: |
| | opt_bs_tex = nn.Parameter(com_tex).expand(1, -1).detach() |
| |
|
| | opt_bs_sh = nn.Parameter(com_sh).expand(opt_bs, -1, -1).detach() |
| |
|
| | logger.info(f'origin input data frame batchsize:{opt_bs}') |
| | |
| | with default_timers['load_dataset_stage']: |
| |
|
| | debug = 0 |
| | if debug: opt_bs = 30 |
| |
|
| | dataset = ImagesDataset( |
| | tracker_cfg, |
| | start_frame=start_frame, |
| | face_ider=face_ider, |
| | person_face_emb=person_face_emb, |
| | face_detector_mediapipe=face_detector_mediapipe, |
| | face_detector=face_detector) |
| | dataloader = DataLoader(dataset, |
| | batch_size=opt_bs, |
| | num_workers=0, |
| | shuffle=False, |
| | pin_memory=True, |
| | drop_last=False) |
| | iterator = iter(dataloader) |
| | batch = next(iterator) |
| | if not debug: |
| | batch = SHOW.utils.to_cuda(batch) |
| | valid_bool = batch['is_person_deted'].bool() |
| | valid_bs = valid_bool.count_nonzero() |
| | logger.info(f'valid input data frame batchsize:{valid_bs}') |
| | logger.info(f'valid_bool: {valid_bool}') |
| |
|
| | if valid_bs == 0: |
| | logger.warning('valid bs == 0, skipping') |
| | open(mica_part_pkl_path + '.empty', 'a').close() |
| | continue |
| |
|
| | bbox = batch['bbox'] |
| | images = batch['cropped_image'] |
| | landmarks = batch['cropped_lmk'] |
| | h = batch['h'] |
| | w = batch['w'] |
| | py = batch['py'] |
| | px = batch['px'] |
| |
|
| | diff_renderer.masking.set_bs(valid_bs) |
| | diff_renderer = diff_renderer.to(device) |
| |
|
| | debug = 0 |
| | report_wandb = 0 |
| | use_opt_pose = 1 |
| | save_traing_img = 0 |
| | observe_idx_list = [4, 8] |
| | |
| | with default_timers['optimize_stage']: |
| | model_output = None |
| |
|
| | def get_pose_opt(start_frame, end_frame): |
| | tmp = body_params_dict['body_pose_axis'][ |
| | start_frame:end_frame, ...].clone().detach() |
| | tmp = tmp.reshape(tmp.shape[0], -1, 3) |
| | return torch.stack([tmp[:, 12 - 1, :], |
| | tmp[:, 15 - 1, :]], |
| | dim=1) |
| |
|
| | def clone_params_color(start_frame, end_frame): |
| | opt_var_clone_detach = [ |
| | { |
| | 'params': [ |
| | nn.Parameter( |
| | body_params_dict['expression'] |
| | [start_frame:end_frame].clone().detach()) |
| | ], |
| | 'lr': |
| | 0.025, |
| | 'name': ['exp'] |
| | }, |
| | { |
| | 'params': [ |
| | nn.Parameter(body_params_dict['leye_pose'] |
| | [start_frame:end_frame].clone( |
| | ).clone().detach()) |
| | ], |
| | 'lr': |
| | 0.001, |
| | 'name': ['leyes'] |
| | }, |
| | { |
| | 'params': [ |
| | nn.Parameter( |
| | body_params_dict['reye_pose'] |
| | [start_frame:end_frame].clone().detach()) |
| | ], |
| | 'lr': |
| | 0.001, |
| | 'name': ['reyes'] |
| | }, |
| | { |
| | 'params': [ |
| | nn.Parameter( |
| | body_params_dict['jaw_pose'] |
| | [start_frame:end_frame].clone().detach()) |
| | ], |
| | 'lr': |
| | 0.001, |
| | 'name': ['jaw'] |
| | }, |
| | { |
| | 'params': |
| | [nn.Parameter(opt_bs_sh.clone().detach())], |
| | 'lr': 0.01, |
| | 'name': ['sh'] |
| | }, |
| | { |
| | 'params': |
| | [nn.Parameter(opt_bs_tex.clone().detach())], |
| | 'lr': 0.005, |
| | 'name': ['tex'] |
| | }, |
| | ] |
| | if use_opt_pose: |
| | opt_var_clone_detach.append({ |
| | 'params': [ |
| | nn.Parameter( |
| | get_pose_opt(start_frame, end_frame)) |
| | ], |
| | 'lr': |
| | 0.005, |
| | 'name': ['body_pose'] |
| | }) |
| | return opt_var_clone_detach |
| |
|
| | save_traing_img_dir = tracker_cfg.mica_process_path + f'_{start_frame}_{end_frame}' |
| |
|
| | if save_traing_img: |
| | Path(save_traing_img_dir).mkdir(parents=True, |
| | exist_ok=True) |
| |
|
| | with tqdm(total=iters * 3, |
| | position=0, |
| | leave=True, |
| | bar_format="{percentage:3.0f}%|{bar}{r_bar}{desc}" |
| | ) as pbar: |
| | for k, scale in enumerate(sampling): |
| |
|
| | size = [int(512 * scale), int(512 * scale)] |
| | img = F.interpolate(images.float().clone(), |
| | size, |
| | mode='bilinear', |
| | align_corners=False) |
| |
|
| | if k > 0: |
| | img = gaussian_blur(img, [9, 9]).detach() |
| |
|
| | flipped = torch.flip(img, [2, 3]) |
| | flipped = flipped[valid_bool.bool(), ...] |
| |
|
| | best_loss = np.inf |
| | prev_loss = np.inf |
| |
|
| | xb_min, xb_max, yb_min, yb_max = bbox.values() |
| | box_w = xb_max - xb_min |
| | box_h = yb_max - yb_min |
| | box_w = box_w.int() |
| | box_h = box_h.int() |
| |
|
| | image_size = size |
| |
|
| | diff_renderer.rasterizer.reset() |
| | diff_renderer.set_size(image_size) |
| | debug_renderer.rasterizer.raster_settings.image_size = size |
| |
|
| | image_lmks = landmarks * size[0] / 512 |
| | image_lmks = image_lmks[valid_bool.bool(), ...] |
| |
|
| | optimizer = torch.optim.Adam( |
| | clone_params_color(start_frame, end_frame)) |
| | params = optimizer.param_groups |
| | get_param = SHOW.utils.get_param |
| |
|
| | cur_tex = get_param('tex', params) |
| | cur_sh = get_param('sh', params) |
| |
|
| | cur_exp = get_param('exp', params) |
| | cur_leyes = get_param('leyes', params) |
| | cur_reyes = get_param('reyes', params) |
| | cur_jaw = get_param('jaw', params) |
| |
|
| | if use_opt_pose: |
| | two_opt = get_param('body_pose', params) |
| | frame_pose = body_params_dict['body_pose_axis'][ |
| | start_frame:end_frame] |
| | bs = frame_pose.shape[0] |
| | frame_pose = frame_pose.reshape(bs, -1, 3) |
| | cur_pose = torch.cat( |
| | [ |
| | frame_pose[:, :11, :], |
| | two_opt[:, 0:1], |
| | frame_pose[:, 12:14, :], |
| | two_opt[:, 1:2], |
| | frame_pose[:, 15:, :] |
| | ], |
| | dim=1).reshape(bs, 1, -1) |
| | else: |
| | frame_pose = body_params_dict['body_pose_axis'][ |
| | start_frame:end_frame] |
| | bs = frame_pose.shape[0] |
| | cur_pose = frame_pose.reshape(bs, 1, -1) |
| |
|
| | cur_transl = body_params_dict['transl'][ |
| | start_frame:end_frame] |
| | cur_global_orient = body_params_dict['global_orient'][ |
| | start_frame:end_frame] |
| | cur_left_hand_pose = body_params_dict[ |
| | 'left_hand_pose'][start_frame:end_frame] |
| | cur_right_hand_pose = body_params_dict[ |
| | 'right_hand_pose'][start_frame:end_frame] |
| |
|
| | R = torch.Tensor([[[-1, 0, 0], [0, -1, 0], [0, 0, 1]]]) |
| | bs_image_size = torch.Tensor(image_size).repeat( |
| | opt_bs, 1).to(device) |
| | bs_camera_transl = (camera_transl).repeat(opt_bs, |
| | 1).to(device) |
| |
|
| | |
| | bs_center = torch.Tensor(center).repeat(opt_bs, |
| | 1).to(device) |
| | bs_box_min = torch.stack([xb_min, yb_min], |
| | dim=-1).to(device) |
| | bs_R = R.repeat(opt_bs, 1, 1).to(device) |
| |
|
| | s_w = size[0] / box_w |
| | s_h = size[1] / box_h |
| | s_1to2 = torch.stack([s_w, s_h], dim=1) |
| | s_1to2 = s_1to2.to(device) |
| | bs_pp = (bs_center - bs_box_min) * (s_1to2) |
| |
|
| | bs_pp[:, 0] = (bs_pp[:, 0] + px) * 512 / (w + 2 * px) |
| | bs_pp[:, 1] = (bs_pp[:, 1] + py) * 512 / (h + 2 * py) |
| |
|
| | s_1to2[:, 0] = s_1to2[:, 0] * 512 / (w + 2 * px) |
| | s_1to2[:, 1] = s_1to2[:, 1] * 512 / (h + 2 * py) |
| |
|
| | cam_cfg = dict( |
| | principal_point=bs_pp, |
| | focal_length=focal_length * s_1to2, |
| | R=bs_R, |
| | T=bs_camera_transl, |
| | image_size=bs_image_size, |
| | ) |
| |
|
| | for key, val in cam_cfg.items(): |
| | cam_cfg[key] = val[valid_bool.bool(), ...] |
| |
|
| | cameras = PerspectiveCameras(**cam_cfg, |
| | device=device, |
| | in_ndc=False) |
| |
|
| | if True: |
| | for p in range(iters): |
| | if (p + 1) % 32 == 0: |
| | diff_renderer.rasterizer.reset() |
| | losses = {} |
| | model_output = body_model( |
| | return_verts=True, |
| | jaw_pose=cur_jaw, |
| | leye_pose=cur_leyes, |
| | reye_pose=cur_reyes, |
| | expression=cur_exp, |
| | betas=body_params_dict['betas'], |
| | transl=cur_transl, |
| | body_pose=cur_pose, |
| | global_orient=cur_global_orient, |
| | left_hand_pose=cur_left_hand_pose, |
| | right_hand_pose=cur_right_hand_pose, |
| | ) |
| |
|
| | vertices = model_output.vertices[:, |
| | smplx2flame_idx |
| | .long(), :] |
| | vertices = vertices[valid_bool.bool(), ...] |
| |
|
| | lmk68_all = model_output.joints[:, 67:67 + 51 + |
| | 17, :] |
| | lmk68 = lmk68_all[valid_bool.bool(), ...] |
| |
|
| | proj_lmks = cameras.transform_points_screen( |
| | lmk68)[:, :, :2] |
| | proj_lmks = torch.cat([ |
| | proj_lmks[:, -17:, :], |
| | proj_lmks[:, :-17, :] |
| | ], |
| | dim=1) |
| |
|
| | I = torch.eye(3)[None].to(device) |
| |
|
| | if pre_frame_exp is not None and start_frame != 0: |
| | losses['pre_exp'] = 0.001 * torch.sum( |
| | (pre_frame_exp - cur_exp[0])**2) |
| |
|
| | if False: |
| | linear_rot_left = (axis_angle_to_matrix( |
| | cur_leyes[valid_bool.bool(), ...])) |
| | linear_rot_right = (axis_angle_to_matrix( |
| | cur_reyes[valid_bool.bool(), ...])) |
| | losses['eyes_sym_reg'] = torch.sum( |
| | (linear_rot_right - linear_rot_left)** |
| | 2) / opt_bs |
| | losses['eyes_left_reg'] = torch.sum( |
| | (I - linear_rot_left)**2) / opt_bs |
| | losses['eyes_right_reg'] = torch.sum( |
| | (I - linear_rot_right)**2) / opt_bs |
| |
|
| | w_lmks = tracker_cfg.w_lmks |
| | losses['lmk'] = SHOW.utils.lmk_loss( |
| | proj_lmks, image_lmks, |
| | image_size) * w_lmks * 8.0 |
| | losses[ |
| | 'lmk_mount'] = SHOW.utils.mouth_loss( |
| | proj_lmks, image_lmks, |
| | image_size) * w_lmks * 4.0 * 4 |
| | losses['lmk_oval'] = SHOW.utils.lmk_loss( |
| | proj_lmks[:, :17, ...], image_lmks[:, :17, |
| | ...], |
| | image_size) * w_lmks |
| |
|
| | losses['jaw_reg'] = torch.sum( |
| | (I - axis_angle_to_matrix( |
| | cur_jaw[valid_bool.bool(), ...]))** |
| | 2) * 16.0 / opt_bs |
| | losses['exp_reg'] = torch.sum( |
| | cur_exp[valid_bool.bool(), |
| | ...]**2) * 0.01 / opt_bs |
| |
|
| | if use_shared_tex: |
| | losses['tex_reg'] = torch.sum(cur_tex** |
| | 2) * 0.02 |
| | else: |
| | losses['tex_reg'] = torch.sum( |
| | cur_tex[valid_bool.bool(), |
| | ...]**2) * 0.02 / opt_bs |
| |
|
| | def temporary_loss(o_w, i_w, gmof, param): |
| | assert param.shape[ |
| | 0] > 2, f'optimize batchsize must > 2 to enable temporary smooth' |
| | return (o_w**2) * (gmof( |
| | i_w * |
| | (param[2:, ...] + param[:-2, ...] - |
| | 2 * param[1:-1, ...]))).mean() |
| |
|
| | def pow(x): |
| | return x.pow(2) |
| |
|
| | if cur_exp.shape[0] > 2: |
| | losses['loss_sexp'] = temporary_loss( |
| | 1.0, 2.0, pow, cur_exp) |
| | losses['loss_sjaw'] = temporary_loss( |
| | 1.0, 2.0, pow, cur_jaw) |
| | |
| | def k_fun(k): |
| | return tracker_cfg.w_pho * 32.0 if k > 0 else tracker_cfg.w_pho |
| |
|
| | albedos = flametex(cur_tex) / 255. |
| |
|
| | if use_shared_tex: |
| | albedos = albedos.expand( |
| | valid_bs, -1, -1, -1) |
| | else: |
| | albedos = albedos[valid_bool.bool(), ...] |
| |
|
| | ops = diff_renderer( |
| | vertices, albedos, |
| | cur_sh[valid_bool.bool(), ...], cameras) |
| |
|
| | grid = ops['position_images'].permute( |
| | 0, 2, 3, 1)[:, :, :, :2] |
| | sampled_image = F.grid_sample( |
| | flipped, grid, align_corners=False) |
| | ops_mask = SHOW.utils.parse_mask(ops) |
| | tmp_img = ops['images'] |
| |
|
| | losses['pho'] = SHOW.utils.pixel_loss( |
| | tmp_img, sampled_image, |
| | ops_mask) * k_fun(k) |
| |
|
| | all_loss = 0. |
| | for key in losses.keys(): |
| | all_loss = all_loss + losses[key] |
| | losses['all_loss'] = all_loss |
| |
|
| | log_str = SHOW.print_dict_losses(losses) |
| |
|
| | if report_wandb: |
| | if globals().get('wandb', None) is None: |
| | os.environ[ |
| | 'WANDB_API_KEY'] = 'xxx' |
| | os.environ['WANDB_NAME'] = 'tracker' |
| | import wandb |
| | wandb.init( |
| | reinit=True, |
| | resume='allow', |
| | project='tracker', |
| | ) |
| | globals()['wandb'] = wandb |
| |
|
| | if globals().get('wandb', |
| | None) is not None: |
| | globals()['wandb'].log(losses) |
| |
|
| | if save_traing_img: |
| |
|
| | def save_callback(frame, final_views): |
| | cur_idx = (frame + opt_bs * opt_idx) |
| | if cur_idx in observe_idx_list: |
| |
|
| | observe_idx_frame_dir = os.path.join( |
| | save_traing_img_dir, |
| | f'{cur_idx:03d}') |
| | Path(observe_idx_frame_dir).mkdir( |
| | parents=True, exist_ok=True) |
| |
|
| | cv2.imwrite( |
| | os.path.join( |
| | observe_idx_frame_dir, |
| | f'{k}_{p}.jpg'), |
| | final_views) |
| |
|
| | save_tracker( |
| | img, |
| | valid_bool, |
| | valid_bs, |
| | ops, |
| | vertices, |
| | cameras, |
| | image_lmks, |
| | proj_lmks, |
| | flame_faces, |
| | mesh_rasterizer, |
| | debug_renderer, |
| | save_callback, |
| | ) |
| |
|
| | if loggers is not None: |
| | loggers.log_bs(losses) |
| | if torch.isnan(all_loss).sum(): |
| | loggers.alert( |
| | title='Nan error', |
| | msg= |
| | f'tracker nan in: {tracker_cfg.ours_output_folder}' |
| | ) |
| | open( |
| | tracker_cfg.ours_output_folder + |
| | '/mica_opt_nan.info', 'a').close() |
| | break |
| |
|
| | else: |
| | pbar.set_description(log_str) |
| | pbar.update(1) |
| |
|
| | optimizer.zero_grad() |
| | all_loss.backward() |
| | optimizer.step() |
| |
|
| | if all_loss.item() < best_loss: |
| | best_loss = all_loss.item() |
| | opt_bs_tex = cur_tex.clone().detach() |
| | opt_bs_sh = cur_sh.clone().detach() |
| | body_params_dict['expression'][ |
| | start_frame:end_frame] = cur_exp.clone( |
| | ).detach() |
| | body_params_dict['leye_pose'][ |
| | start_frame: |
| | end_frame] = cur_leyes.clone().detach( |
| | ) |
| | body_params_dict['reye_pose'][ |
| | start_frame: |
| | end_frame] = cur_reyes.clone().detach( |
| | ) |
| | body_params_dict['jaw_pose'][ |
| | start_frame:end_frame] = cur_jaw.clone( |
| | ).detach() |
| | body_params_dict['body_pose_axis'][ |
| | start_frame: |
| | end_frame] = cur_pose.clone().detach( |
| | ).squeeze() |
| |
|
| | |
| | with default_timers['saving_stage']: |
| | if save_traing_img: |
| | for idx in observe_idx_list: |
| | observe_idx_frame_dir = os.path.join( |
| | save_traing_img_dir, f'{idx:03d}') |
| | Path(observe_idx_frame_dir).mkdir(parents=True, |
| | exist_ok=True) |
| |
|
| | if not SHOW.is_empty_dir(observe_idx_frame_dir): |
| | images_to_video( |
| | input_folder=observe_idx_frame_dir, |
| | output_path=observe_idx_frame_dir + '.mp4', |
| | img_format=None, |
| | fps=30, |
| | ) |
| |
|
| | dict_to_save = dict( |
| | expression=body_params_dict['expression'] |
| | [start_frame:end_frame].clone().detach().cpu().numpy(), |
| | leye_pose=body_params_dict['leye_pose'] |
| | [start_frame:end_frame].clone().detach().cpu().numpy(), |
| | reye_pose=body_params_dict['reye_pose'] |
| | [start_frame:end_frame].clone().detach().cpu().numpy(), |
| | jaw_pose=body_params_dict['jaw_pose'] |
| | [start_frame:end_frame].clone().detach().cpu().numpy(), |
| | body_pose_axis=body_params_dict['body_pose_axis'] |
| | [start_frame:end_frame].clone().detach().cpu().numpy(), |
| | tex=opt_bs_tex.clone().detach().cpu().numpy(), |
| | sh=opt_bs_sh.clone().detach().cpu().numpy(), |
| | ) |
| | mmcv.dump(dict_to_save, mica_part_pkl_path) |
| | logger.info(f'mica pkl part path: {mica_part_pkl_path}') |
| | pre_frame_exp = dict_to_save['expression'][-1] |
| | pre_frame_exp = torch.Tensor(pre_frame_exp).to(device) |
| | vertices_ = model_output.vertices.clone().detach().cpu() |
| | logger.info( |
| | f'mica render to origin path: {tracker_cfg.mica_org_out_path}') |
| |
|
| | import platform |
| | if platform.system() == "Linux": |
| | os.environ['PYOPENGL_PLATFORM'] = 'egl' |
| | else: |
| | if 'PYOPENGL_PLATFORM' in os.environ: |
| | os.environ.__delitem__('PYOPENGL_PLATFORM') |
| |
|
| | import pyrender |
| | input_renderer = pyrender.OffscreenRenderer(viewport_width=width, |
| | viewport_height=height, |
| | point_size=1.0) |
| |
|
| | for idx in tqdm(range(vertices_.shape[0]), |
| | desc='saving ours final pyrender images'): |
| |
|
| | cur_idx = idx + start_frame + 1 |
| |
|
| | input_img = SHOW.find_full_impath_by_name( |
| | root=tracker_cfg.img_folder, name=f'{cur_idx:06d}') |
| | output_name = os.path.join( |
| | tracker_cfg.mica_org_out_path, |
| | f"{cur_idx:06}.{tracker_cfg.output_img_ext}") |
| |
|
| | camera_pose = op.get_smplx_to_pyrender_K(camera_transl) |
| |
|
| | meta_data = dict( |
| | input_img=input_img, |
| | output_name=output_name, |
| | ) |
| |
|
| | save_one_results( |
| | vertices_[idx], |
| | body_model.faces, |
| | img_size=(height, width), |
| | center=center, |
| | focal_length=[focal_length, focal_length], |
| | camera_pose=camera_pose, |
| | meta_data=meta_data, |
| | input_renderer=input_renderer, |
| | ) |
| | input_renderer.delete() |
| |
|
| | |
| | if tracker_cfg.save_final_vis: |
| |
|
| | def save_callback(frame, final_views): |
| | cur_idx = (frame + opt_bs * opt_idx) |
| |
|
| | if loggers is not None: |
| | loggers.log_image(f"final_mica_img/{cur_idx:03d}", |
| | final_views / 255.0) |
| |
|
| | cv2.imwrite( |
| | os.path.join(tracker_cfg.mica_save_path, |
| | f'{cur_idx:03d}.jpg'), final_views) |
| |
|
| | if True: |
| | save_tracker( |
| | img, |
| | valid_bool, |
| | valid_bs, |
| | ops, |
| | vertices, |
| | cameras, |
| | image_lmks, |
| | proj_lmks, |
| | flame_faces, |
| | mesh_rasterizer, |
| | debug_renderer, |
| | save_callback, |
| | ) |
| |
|
| | load_data = mmcv.load(tracker_cfg.ours_pkl_file_path)[0] |
| | load_data = SHOW.replace_mica_exp(tracker_cfg.mica_all_dir, load_data) |
| | mmcv.dump([load_data], tracker_cfg.mica_merge_pkl) |
| |
|
| | if not Path(tracker_cfg.mica_org_out_video).exists(): |
| | if not SHOW.is_empty_dir(tracker_cfg.mica_org_out_path): |
| | images_to_video( |
| | input_folder=tracker_cfg.mica_org_out_path, |
| | output_path=tracker_cfg.mica_org_out_video, |
| | img_format=None, |
| | fps=30, |
| | ) |
| | if not Path(tracker_cfg.mica_grid_video).exists(): |
| | if not SHOW.is_empty_dir(tracker_cfg.mica_save_path): |
| | images_to_video( |
| | input_folder=tracker_cfg.mica_save_path, |
| | output_path=tracker_cfg.mica_grid_video, |
| | img_format=None, |
| | fps=30, |
| | ) |
| |
|