| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from lib.hybrik.models.simple3dpose import HybrIKBaseSMPLCam |
| | from lib.pixielib.utils.config import cfg as pixie_cfg |
| | from lib.pixielib.pixie import PIXIE |
| | import lib.smplx as smplx |
| | |
| | from lib.pymaf.utils.geometry import rot6d_to_rotmat, batch_rodrigues, rotation_matrix_to_angle_axis |
| | from lib.pymaf.utils.imutils import process_image |
| | from lib.common.imutils import econ_process_image |
| | from lib.pymaf.core import path_config |
| | from lib.pymaf.models import pymaf_net |
| | from lib.common.config import cfg |
| | from lib.common.render import Render |
| | from lib.dataset.body_model import TetraSMPLModel |
| | from lib.dataset.mesh_util import get_visibility |
| | from utils.smpl_util import SMPLX |
| | import os.path as osp |
| | import os |
| | import torch |
| | import numpy as np |
| | import random |
| | from termcolor import colored |
| | from PIL import ImageFile |
| | from torchvision.models import detection |
| |
|
| |
|
| | ImageFile.LOAD_TRUNCATED_IMAGES = True |
| |
|
| |
|
| | class SMPLDataset(): |
| |
|
| | def __init__(self, cfg, device): |
| |
|
| | random.seed(1993) |
| |
|
| | self.image_dir = cfg['image_dir'] |
| | self.seg_dir = cfg['seg_dir'] |
| | self.hps_type = cfg['hps_type'] |
| | self.smpl_type = 'smpl' if cfg['hps_type'] != 'pixie' else 'smplx' |
| | self.smpl_gender = 'neutral' |
| | self.colab = cfg['colab'] |
| |
|
| | self.device = device |
| |
|
| | keep_lst = [f"{self.image_dir}/{i}" for i in sorted(os.listdir(self.image_dir))] |
| | img_fmts = ['jpg', 'png', 'jpeg', "JPG", 'bmp'] |
| | keep_lst = [item for item in keep_lst if item.split(".")[-1] in img_fmts] |
| |
|
| | self.subject_list = [item for item in keep_lst if item.split(".")[-1] in img_fmts] |
| |
|
| | if self.colab: |
| | self.subject_list = [self.subject_list[0]] |
| |
|
| | |
| | self.smpl_data = SMPLX() |
| |
|
| | |
| | self.smpl_joint_ids_24 = np.arange(22).tolist() + [68, 73] |
| | self.smpl_joint_ids_24_pixie = np.arange(22).tolist() + [68 + 61, 72 + 68] |
| | self.get_smpl_model = lambda smpl_type, smpl_gender: smplx.create(model_path=self.smpl_data. |
| | model_dir, |
| | gender=smpl_gender, |
| | model_type=smpl_type, |
| | ext='npz') |
| |
|
| | |
| | self.smpl_model = self.get_smpl_model(self.smpl_type, self.smpl_gender).to(self.device) |
| | self.faces = self.smpl_model.faces |
| |
|
| | if self.hps_type == 'pymaf': |
| | self.hps = pymaf_net(path_config.SMPL_MEAN_PARAMS, pretrained=True).to(self.device) |
| | self.hps.load_state_dict(torch.load(path_config.CHECKPOINT_FILE)['model'], strict=True) |
| | self.hps.eval() |
| |
|
| | elif self.hps_type == 'pare': |
| | self.hps = PARETester(path_config.CFG, path_config.CKPT).model |
| | elif self.hps_type == 'pixie': |
| | self.hps = PIXIE(config=pixie_cfg, device=self.device) |
| | self.smpl_model = self.hps.smplx |
| | elif self.hps_type == 'hybrik': |
| | smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") |
| | self.hps = HybrIKBaseSMPLCam(cfg_file=path_config.HYBRIK_CFG, |
| | smpl_path=smpl_path, |
| | data_path=path_config.hybrik_data_dir) |
| | self.hps.load_state_dict(torch.load(path_config.HYBRIK_CKPT, map_location='cpu'), |
| | strict=False) |
| | self.hps.to(self.device) |
| | elif self.hps_type == 'bev': |
| | try: |
| | import bev |
| | except: |
| | print('Could not find bev, installing via pip install --upgrade simple-romp') |
| | os.system('pip install simple-romp==1.0.3') |
| | import bev |
| | settings = bev.main.default_settings |
| | |
| | settings.mode = 'image' |
| | settings.GPU = int(str(self.device).split(':')[1]) |
| | settings.show_largest = True |
| | |
| | self.hps = bev.BEV(settings) |
| |
|
| | self.detector=detection.maskrcnn_resnet50_fpn(pretrained=True) |
| | self.detector.eval() |
| | print(colored(f"Using {self.hps_type} as HPS Estimator\n", "green")) |
| |
|
| | self.render = Render(size=512, device=device) |
| |
|
| | def __len__(self): |
| | return len(self.subject_list) |
| |
|
| | def compute_vis_cmap(self, smpl_verts, smpl_faces): |
| |
|
| | (xy, z) = torch.as_tensor(smpl_verts).split([2, 1], dim=1) |
| | smpl_vis = get_visibility(xy, -z, torch.as_tensor(smpl_faces).long()) |
| | smpl_cmap = self.smpl_data.cmap_smpl_vids(self.smpl_type) |
| |
|
| | return { |
| | 'smpl_vis': smpl_vis.unsqueeze(0).to(self.device), |
| | 'smpl_cmap': smpl_cmap.unsqueeze(0).to(self.device), |
| | 'smpl_verts': smpl_verts.unsqueeze(0) |
| | } |
| |
|
| | def compute_voxel_verts(self, body_pose, global_orient, betas, trans, scale): |
| |
|
| | smpl_path = osp.join(self.smpl_data.model_dir, "smpl/SMPL_NEUTRAL.pkl") |
| | tetra_path = osp.join(self.smpl_data.tedra_dir, 'tetra_neutral_adult_smpl.npz') |
| | smpl_model = TetraSMPLModel(smpl_path, tetra_path, 'adult') |
| |
|
| | pose = torch.cat([global_orient[0], body_pose[0]], dim=0) |
| | smpl_model.set_params(rotation_matrix_to_angle_axis(rot6d_to_rotmat(pose)), beta=betas[0]) |
| |
|
| | verts = np.concatenate([smpl_model.verts, smpl_model.verts_added], |
| | axis=0) * scale.item() + trans.detach().cpu().numpy() |
| | faces = np.loadtxt(osp.join(self.smpl_data.tedra_dir, 'tetrahedrons_neutral_adult.txt'), |
| | dtype=np.int32) - 1 |
| |
|
| | pad_v_num = int(8000 - verts.shape[0]) |
| | pad_f_num = int(25100 - faces.shape[0]) |
| |
|
| | verts = np.pad(verts, |
| | ((0, pad_v_num), |
| | (0, 0)), mode='constant', constant_values=0.0).astype(np.float32) * 0.5 |
| | faces = np.pad(faces, ((0, pad_f_num), (0, 0)), mode='constant', |
| | constant_values=0.0).astype(np.int32) |
| |
|
| | verts[:, 2] *= -1.0 |
| |
|
| | voxel_dict = { |
| | 'voxel_verts': torch.from_numpy(verts).to(self.device).unsqueeze(0).float(), |
| | 'voxel_faces': torch.from_numpy(faces).to(self.device).unsqueeze(0).long(), |
| | 'pad_v_num': torch.tensor(pad_v_num).to(self.device).unsqueeze(0).long(), |
| | 'pad_f_num': torch.tensor(pad_f_num).to(self.device).unsqueeze(0).long() |
| | } |
| |
|
| | return voxel_dict |
| |
|
| | def __getitem__(self, index): |
| |
|
| | img_path = self.subject_list[index] |
| | img_name = img_path.split("/")[-1].rsplit(".", 1)[0] |
| | print(img_name) |
| | |
| | |
| |
|
| | if self.seg_dir is None: |
| | img_icon, img_hps, img_ori, img_mask, uncrop_param = process_image( |
| | img_path, self.hps_type, 512, self.device) |
| |
|
| | data_dict = { |
| | 'name': img_name, |
| | 'image': img_icon.to(self.device).unsqueeze(0), |
| | 'ori_image': img_ori, |
| | 'mask': img_mask, |
| | 'uncrop_param': uncrop_param |
| | } |
| |
|
| | else: |
| | img_icon, img_hps, img_ori, img_mask, uncrop_param, segmentations = process_image( |
| | img_path, |
| | self.hps_type, |
| | 512, |
| | self.device, |
| | seg_path=os.path.join(self.seg_dir, f'{img_name}.json')) |
| | data_dict = { |
| | 'name': img_name, |
| | 'image': img_icon.to(self.device).unsqueeze(0), |
| | 'ori_image': img_ori, |
| | 'mask': img_mask, |
| | 'uncrop_param': uncrop_param, |
| | 'segmentations': segmentations |
| | } |
| |
|
| | arr_dict=econ_process_image(img_path,self.hps_type,True,512,self.detector) |
| | data_dict['hands_visibility']=arr_dict['hands_visibility'] |
| |
|
| | with torch.no_grad(): |
| | |
| | preds_dict = self.hps.forward(img_hps) |
| |
|
| | data_dict['smpl_faces'] = torch.Tensor(self.faces.astype(np.int64)).long().unsqueeze(0).to( |
| | self.device) |
| |
|
| | if self.hps_type == 'pymaf': |
| | output = preds_dict['smpl_out'][-1] |
| | scale, tranX, tranY = output['theta'][0, :3] |
| | data_dict['betas'] = output['pred_shape'] |
| | data_dict['body_pose'] = output['rotmat'][:, 1:] |
| | data_dict['global_orient'] = output['rotmat'][:, 0:1] |
| | data_dict['smpl_verts'] = output['verts'] |
| | data_dict["type"] = "smpl" |
| |
|
| | elif self.hps_type == 'pare': |
| | data_dict['body_pose'] = preds_dict['pred_pose'][:, 1:] |
| | data_dict['global_orient'] = preds_dict['pred_pose'][:, 0:1] |
| | data_dict['betas'] = preds_dict['pred_shape'] |
| | data_dict['smpl_verts'] = preds_dict['smpl_vertices'] |
| | scale, tranX, tranY = preds_dict['pred_cam'][0, :3] |
| | data_dict["type"] = "smpl" |
| |
|
| | elif self.hps_type == 'pixie': |
| | data_dict.update(preds_dict) |
| | data_dict['body_pose'] = preds_dict['body_pose'] |
| | data_dict['global_orient'] = preds_dict['global_pose'] |
| | data_dict['betas'] = preds_dict['shape'] |
| | data_dict['smpl_verts'] = preds_dict['vertices'] |
| | scale, tranX, tranY = preds_dict['cam'][0, :3] |
| | data_dict["type"] = "smplx" |
| |
|
| | elif self.hps_type == 'hybrik': |
| | data_dict['body_pose'] = preds_dict['pred_theta_mats'][:, 1:] |
| | data_dict['global_orient'] = preds_dict['pred_theta_mats'][:, [0]] |
| | data_dict['betas'] = preds_dict['pred_shape'] |
| | data_dict['smpl_verts'] = preds_dict['pred_vertices'] |
| | scale, tranX, tranY = preds_dict['pred_camera'][0, :3] |
| | scale = scale * 2 |
| | data_dict["type"] = "smpl" |
| |
|
| | elif self.hps_type == 'bev': |
| | data_dict['betas'] = torch.from_numpy(preds_dict['smpl_betas'])[[0], :10].to( |
| | self.device).float() |
| | pred_thetas = batch_rodrigues( |
| | torch.from_numpy(preds_dict['smpl_thetas'][0]).reshape(-1, 3)).float() |
| | data_dict['body_pose'] = pred_thetas[1:][None].to(self.device) |
| | data_dict['global_orient'] = pred_thetas[[0]][None].to(self.device) |
| | data_dict['smpl_verts'] = torch.from_numpy(preds_dict['verts'][[0]]).to( |
| | self.device).float() |
| | tranX = preds_dict['cam_trans'][0, 0] |
| | tranY = preds_dict['cam'][0, 1] + 0.28 |
| | scale = preds_dict['cam'][0, 0] * 1.1 |
| | data_dict["type"] = "smpl" |
| |
|
| | data_dict['scale'] = scale |
| | data_dict['trans'] = torch.tensor([tranX, tranY, 0.0]).unsqueeze(0).to(self.device).float() |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | N_body = data_dict["body_pose"].shape[1] |
| | data_dict["body_pose"] = data_dict["body_pose"][:, :, :, :2].reshape(1, N_body, -1) |
| | data_dict["global_orient"] = data_dict["global_orient"][:, :, :, :2].reshape(1, 1, -1) |
| |
|
| | return data_dict |
| |
|
| | def render_normal(self, verts, faces): |
| |
|
| | |
| | self.render.load_meshes(verts, faces) |
| | return self.render.get_rgb_image() |
| |
|
| | def render_depth(self, verts, faces): |
| |
|
| | |
| | self.render.load_meshes(verts, faces) |
| | return self.render.get_depth_map(cam_ids=[0, 2]) |
| |
|
| | def visualize_alignment(self, data): |
| |
|
| | import vedo |
| | import trimesh |
| |
|
| | if self.hps_type != 'pixie': |
| | smpl_out = self.smpl_model(betas=data['betas'], |
| | body_pose=data['body_pose'], |
| | global_orient=data['global_orient'], |
| | pose2rot=False) |
| | smpl_verts = ((smpl_out.vertices + data['trans']) * |
| | data['scale']).detach().cpu().numpy()[0] |
| | else: |
| | smpl_verts, _, _ = self.smpl_model(shape_params=data['betas'], |
| | expression_params=data['exp'], |
| | body_pose=data['body_pose'], |
| | global_pose=data['global_orient'], |
| | jaw_pose=data['jaw_pose'], |
| | left_hand_pose=data['left_hand_pose'], |
| | right_hand_pose=data['right_hand_pose']) |
| |
|
| | smpl_verts = ((smpl_verts + data['trans']) * data['scale']).detach().cpu().numpy()[0] |
| |
|
| | smpl_verts *= np.array([1.0, -1.0, -1.0]) |
| | faces = data['smpl_faces'][0].detach().cpu().numpy() |
| |
|
| | image_P = data['image'] |
| | image_F, image_B = self.render_normal(smpl_verts, faces) |
| |
|
| | |
| | vp = vedo.Plotter(title="", size=(1500, 1500)) |
| | vis_list = [] |
| |
|
| | image_F = (0.5 * (1.0 + image_F[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) |
| | image_B = (0.5 * (1.0 + image_B[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) |
| | image_P = (0.5 * (1.0 + image_P[0].permute(1, 2, 0).detach().cpu().numpy()) * 255.0) |
| |
|
| | vis_list.append( |
| | vedo.Picture(image_P * 0.5 + image_F * 0.5).scale(2.0 / image_P.shape[0]).pos( |
| | -1.0, -1.0, 1.0)) |
| | vis_list.append(vedo.Picture(image_F).scale(2.0 / image_F.shape[0]).pos(-1.0, -1.0, -0.5)) |
| | vis_list.append(vedo.Picture(image_B).scale(2.0 / image_B.shape[0]).pos(-1.0, -1.0, -1.0)) |
| |
|
| | |
| | mesh = trimesh.Trimesh(smpl_verts, faces, process=False) |
| | mesh.visual.vertex_colors = [200, 200, 0] |
| | vis_list.append(mesh) |
| |
|
| | vp.show(*vis_list, bg="white", axes=1, interactive=True) |
| |
|
| |
|
| | if __name__ == '__main__': |
| |
|
| | cfg.merge_from_file("./configs/icon-filter.yaml") |
| | cfg.merge_from_file('./lib/pymaf/configs/pymaf_config.yaml') |
| |
|
| | cfg_show_list = ['test_gpus', ['0'], 'mcube_res', 512, 'clean_mesh', False] |
| |
|
| | cfg.merge_from_list(cfg_show_list) |
| | cfg.freeze() |
| |
|
| | |
| | device = torch.device('cuda:0') |
| |
|
| | dataset = SMPLDataset( |
| | { |
| | 'image_dir': "./examples", |
| | 'has_det': True, |
| | 'hps_type': 'bev' |
| | }, |
| | device) |
| |
|
| | for i in range(len(dataset)): |
| | dataset.visualize_alignment(dataset[i]) |
| |
|