| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import gc |
|
|
| import logging |
| from lib.common.config import cfg |
| from lib.dataset.mesh_util import ( |
| load_checkpoint, |
| update_mesh_shape_prior_losses, |
| blend_rgb_norm, |
| unwrap, |
| remesh, |
| tensor2variable, |
| rot6d_to_rotmat |
| ) |
|
|
| from lib.dataset.TestDataset import TestDataset |
| from lib.common.render import query_color |
| from lib.net.local_affine import LocalAffine |
| from pytorch3d.structures import Meshes |
| from apps.ICON import ICON |
|
|
| from termcolor import colored |
| import numpy as np |
| from PIL import Image |
| import trimesh |
| import numpy as np |
| from tqdm import tqdm |
|
|
| import torch |
| torch.backends.cudnn.benchmark = True |
|
|
| logging.getLogger("trimesh").setLevel(logging.ERROR) |
|
|
|
|
| def generate_model(in_path, model_type): |
|
|
| torch.cuda.empty_cache() |
| |
| if model_type == 'ICON': |
| model_type = 'icon-filter' |
| else: |
| model_type = model_type.lower() |
|
|
| config_dict = {'loop_smpl': 100, |
| 'loop_cloth': 200, |
| 'patience': 5, |
| 'out_dir': './results', |
| 'hps_type': 'pymaf', |
| 'config': f"./configs/{model_type}.yaml"} |
|
|
| |
| cfg.merge_from_file(config_dict['config']) |
| cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml") |
|
|
| os.makedirs(config_dict['out_dir'], exist_ok=True) |
|
|
| cfg_show_list = [ |
| "test_gpus", |
| [0], |
| "mcube_res", |
| 256, |
| "clean_mesh", |
| True, |
| ] |
|
|
| cfg.merge_from_list(cfg_show_list) |
| cfg.freeze() |
|
|
| os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| device = torch.device(f"cuda:0") |
|
|
| |
| model = ICON(cfg) |
| model = load_checkpoint(model, cfg) |
|
|
| dataset_param = { |
| 'image_path': in_path, |
| 'seg_dir': None, |
| 'has_det': True, |
| 'hps_type': 'pymaf' |
| } |
|
|
| if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']: |
| print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red")) |
| dataset_param["hps_type"] = "pymaf" |
|
|
| dataset = TestDataset(dataset_param, device) |
|
|
| print(colored(f"Dataset Size: {len(dataset)}", "green")) |
|
|
| pbar = tqdm(dataset) |
|
|
| for data in pbar: |
|
|
| pbar.set_description(f"{data['name']}") |
|
|
| in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]} |
|
|
| |
| optimed_pose = torch.tensor( |
| data["body_pose"], device=device, requires_grad=True |
| ) |
| optimed_trans = torch.tensor( |
| data["trans"], device=device, requires_grad=True |
| ) |
| optimed_betas = torch.tensor( |
| data["betas"], device=device, requires_grad=True |
| ) |
| optimed_orient = torch.tensor( |
| data["global_orient"], device=device, requires_grad=True |
| ) |
|
|
| optimizer_smpl = torch.optim.Adam( |
| [optimed_pose, optimed_trans, optimed_betas, optimed_orient], |
| lr=1e-3, |
| amsgrad=True, |
| ) |
| scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer_smpl, |
| mode="min", |
| factor=0.5, |
| verbose=0, |
| min_lr=1e-5, |
| patience=config_dict['patience'], |
| ) |
|
|
| losses = { |
| |
| "cloth": {"weight": 1e1, "value": 0.0}, |
| |
| "stiffness": {"weight": 1e5, "value": 0.0}, |
| |
| "rigid": {"weight": 1e5, "value": 0.0}, |
| |
| "edge": {"weight": 0, "value": 0.0}, |
| |
| "nc": {"weight": 0, "value": 0.0}, |
| |
| "laplacian": {"weight": 1e2, "value": 0.0}, |
| |
| "normal": {"weight": 1e0, "value": 0.0}, |
| |
| "silhouette": {"weight": 1e0, "value": 0.0}, |
| } |
|
|
| |
|
|
| loop_smpl = tqdm(range(config_dict['loop_smpl'])) |
|
|
| for _ in loop_smpl: |
|
|
| optimizer_smpl.zero_grad() |
| |
| |
| optimed_orient_mat = rot6d_to_rotmat(optimed_orient.view(-1,6)).unsqueeze(0) |
| optimed_pose_mat = rot6d_to_rotmat(optimed_pose.view(-1,6)).unsqueeze(0) |
|
|
| if dataset_param["hps_type"] != "pixie": |
| smpl_out = dataset.smpl_model( |
| betas=optimed_betas, |
| body_pose=optimed_pose_mat, |
| global_orient=optimed_orient_mat, |
| pose2rot=False, |
| ) |
|
|
| smpl_verts = ((smpl_out.vertices) + |
| optimed_trans) * data["scale"] |
| else: |
| smpl_verts, _, _ = dataset.smpl_model( |
| shape_params=optimed_betas, |
| expression_params=tensor2variable(data["exp"], device), |
| body_pose=optimed_pose_mat, |
| global_pose=optimed_orient_mat, |
| jaw_pose=tensor2variable(data["jaw_pose"], device), |
| left_hand_pose=tensor2variable( |
| data["left_hand_pose"], device), |
| right_hand_pose=tensor2variable( |
| data["right_hand_pose"], device), |
| ) |
|
|
| smpl_verts = (smpl_verts + optimed_trans) * data["scale"] |
|
|
| |
| in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( |
| smpl_verts * |
| torch.tensor([1.0, -1.0, -1.0] |
| ).to(device), in_tensor["smpl_faces"] |
| ) |
| T_mask_F, T_mask_B = dataset.render.get_silhouette_image() |
|
|
| with torch.no_grad(): |
| in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter( |
| in_tensor |
| ) |
|
|
| diff_F_smpl = torch.abs( |
| in_tensor["T_normal_F"] - in_tensor["normal_F"]) |
| diff_B_smpl = torch.abs( |
| in_tensor["T_normal_B"] - in_tensor["normal_B"]) |
|
|
| losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean() |
|
|
| |
| smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0] |
| gt_arr = torch.cat( |
| [in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2 |
| ).permute(1, 2, 0) |
| gt_arr = ((gt_arr + 1.0) * 0.5).to(device) |
| bg_color = ( |
| torch.Tensor([0.5, 0.5, 0.5]).unsqueeze( |
| 0).unsqueeze(0).to(device) |
| ) |
| gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float() |
| diff_S = torch.abs(smpl_arr - gt_arr) |
| losses["silhouette"]["value"] = diff_S.mean() |
|
|
| |
| smpl_loss = 0.0 |
| pbar_desc = "Body Fitting --- " |
| for k in ["normal", "silhouette"]: |
| pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | " |
| smpl_loss += losses[k]["value"] * losses[k]["weight"] |
| pbar_desc += f"Total: {smpl_loss:.3f}" |
| loop_smpl.set_description(pbar_desc) |
|
|
| smpl_loss.backward() |
| optimizer_smpl.step() |
| scheduler_smpl.step(smpl_loss) |
| in_tensor["smpl_verts"] = smpl_verts * \ |
| torch.tensor([1.0, 1.0, -1.0]).to(device) |
|
|
| |
| |
| |
|
|
| os.makedirs(os.path.join(config_dict['out_dir'], cfg.name, |
| "refinement"), exist_ok=True) |
|
|
| |
| os.makedirs(os.path.join(config_dict['out_dir'], |
| cfg.name, "vid"), exist_ok=True) |
|
|
| |
| |
| |
| |
|
|
| os.makedirs(os.path.join(config_dict['out_dir'], |
| cfg.name, "png"), exist_ok=True) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| os.makedirs(os.path.join(config_dict['out_dir'], |
| cfg.name, "obj"), exist_ok=True) |
|
|
| norm_pred_F = ( |
| ((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) |
| .detach() |
| .cpu() |
| .numpy() |
| .astype(np.uint8) |
| ) |
| |
| norm_pred_B = ( |
| ((in_tensor["normal_B"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) |
| .detach() |
| .cpu() |
| .numpy() |
| .astype(np.uint8) |
| ) |
|
|
| norm_orig_F = unwrap(norm_pred_F, data) |
| norm_orig_B = unwrap(norm_pred_B, data) |
| |
| mask_orig = unwrap( |
| np.repeat( |
| data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2 |
| ).astype(np.uint8), |
| data, |
| ) |
| rgb_norm_F = blend_rgb_norm(data["ori_image"], norm_orig_F, mask_orig) |
| rgb_norm_B = blend_rgb_norm(data["ori_image"], norm_orig_B, mask_orig) |
|
|
| Image.fromarray( |
| np.concatenate( |
| [data["ori_image"].astype(np.uint8), rgb_norm_F, rgb_norm_B], axis=1) |
| ).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")) |
|
|
| smpl_obj = trimesh.Trimesh( |
| in_tensor["smpl_verts"].detach().cpu()[0] * |
| torch.tensor([1.0, -1.0, 1.0]), |
| in_tensor['smpl_faces'].detach().cpu()[0], |
| process=False, |
| maintains_order=True |
| ) |
| smpl_obj.visual.vertex_colors = (smpl_obj.vertex_normals+1.0)*255.0*0.5 |
| smpl_obj.export( |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj") |
| smpl_obj.export( |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb") |
|
|
| smpl_info = {'betas': optimed_betas, |
| 'pose': optimed_pose_mat, |
| 'orient': optimed_orient_mat, |
| 'trans': optimed_trans} |
|
|
| np.save( |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True) |
|
|
| |
|
|
| |
|
|
| |
| in_tensor.update( |
| dataset.compute_vis_cmap( |
| in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0] |
| ) |
| ) |
|
|
| if cfg.net.prior_type == "pamir": |
| in_tensor.update( |
| dataset.compute_voxel_verts( |
| optimed_pose, |
| optimed_orient, |
| optimed_betas, |
| optimed_trans, |
| data["scale"], |
| ) |
| ) |
|
|
| with torch.no_grad(): |
| verts_pr, faces_pr, _ = model.test_single(in_tensor) |
|
|
| recon_obj = trimesh.Trimesh( |
| verts_pr, faces_pr, process=False, maintains_order=True |
| ) |
| recon_obj.visual.vertex_colors = ( |
| recon_obj.vertex_normals+1.0)*255.0*0.5 |
| recon_obj.export( |
| os.path.join(config_dict['out_dir'], cfg.name, |
| f"obj/{data['name']}_recon.obj") |
| ) |
|
|
| |
| verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name, |
| f"obj/{data['name']}_recon.obj"), 0.5, device) |
|
|
| |
| mesh_pr = Meshes(verts_refine, faces_refine).to(device) |
| local_affine_model = LocalAffine( |
| mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device) |
| optimizer_cloth = torch.optim.Adam( |
| [{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True) |
|
|
| scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( |
| optimizer_cloth, |
| mode="min", |
| factor=0.1, |
| verbose=0, |
| min_lr=1e-5, |
| patience=config_dict['patience'], |
| ) |
|
|
| final = None |
|
|
| if config_dict['loop_cloth'] > 0: |
|
|
| loop_cloth = tqdm(range(config_dict['loop_cloth'])) |
|
|
| for _ in loop_cloth: |
|
|
| optimizer_cloth.zero_grad() |
|
|
| deformed_verts, stiffness, rigid = local_affine_model( |
| verts_refine.to(device), return_stiff=True) |
| mesh_pr = mesh_pr.update_padded(deformed_verts) |
|
|
| |
| update_mesh_shape_prior_losses(mesh_pr, losses) |
|
|
| in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal( |
| mesh_pr.verts_padded(), mesh_pr.faces_padded()) |
|
|
| diff_F_cloth = torch.abs( |
| in_tensor["P_normal_F"] - in_tensor["normal_F"]) |
| diff_B_cloth = torch.abs( |
| in_tensor["P_normal_B"] - in_tensor["normal_B"]) |
|
|
| losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean() |
| losses["stiffness"]["value"] = torch.mean(stiffness) |
| losses["rigid"]["value"] = torch.mean(rigid) |
|
|
| |
| cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) |
| pbar_desc = "Cloth Refinement --- " |
|
|
| for k in losses.keys(): |
| if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0: |
| cloth_loss = cloth_loss + \ |
| losses[k]["value"] * losses[k]["weight"] |
| pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | " |
|
|
| pbar_desc += f"Total: {cloth_loss:.5f}" |
| loop_cloth.set_description(pbar_desc) |
|
|
| |
| cloth_loss.backward() |
| optimizer_cloth.step() |
| scheduler_cloth.step(cloth_loss) |
|
|
| final = trimesh.Trimesh( |
| mesh_pr.verts_packed().detach().squeeze(0).cpu(), |
| mesh_pr.faces_packed().detach().squeeze(0).cpu(), |
| process=False, maintains_order=True |
| ) |
| |
| |
| tex_colors = query_color( |
| mesh_pr.verts_packed().detach().squeeze(0).cpu(), |
| mesh_pr.faces_packed().detach().squeeze(0).cpu(), |
| in_tensor["image"], |
| device=device, |
| ) |
|
|
| |
| norm_colors = (mesh_pr.verts_normals_padded().squeeze( |
| 0).detach().cpu() + 1.0) * 0.5 * 255.0 |
| |
| final.visual.vertex_colors = tex_colors |
| final.export( |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj") |
| |
| final.visual.vertex_colors = norm_colors |
| final.export( |
| f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb") |
|
|
| |
| verts_lst = [smpl_obj.vertices, final.vertices] |
| faces_lst = [smpl_obj.faces, final.faces] |
|
|
| |
| dataset.render.load_meshes( |
| verts_lst, faces_lst) |
| dataset.render.get_rendered_video( |
| [data["ori_image"], rgb_norm_F, rgb_norm_B], |
| os.path.join(config_dict['out_dir'], cfg.name, |
| f"vid/{data['name']}_cloth.mp4"), |
| ) |
|
|
| smpl_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj" |
| smpl_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb" |
| smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy" |
| refine_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj" |
| refine_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb" |
|
|
| video_path = os.path.join( |
| config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4") |
| overlap_path = os.path.join( |
| config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png") |
|
|
| |
| for element in dir(): |
| if 'path' not in element: |
| del locals()[element] |
| gc.collect() |
| torch.cuda.empty_cache() |
| |
| return [smpl_glb_path, smpl_obj_path,smpl_npy_path, |
| refine_glb_path, refine_obj_path, |
| video_path, video_path, overlap_path] |
|
|