|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| from lib.dataset.mesh_util import cal_sdf_batch, feat_select, read_smpl_constants |
| from lib.net.NormalNet import NormalNet |
| from lib.net.MLP import MLP |
| from lib.dataset.mesh_util import SMPLX |
| from lib.net.VE import VolumeEncoder |
| from lib.net.HGFilters import * |
| from termcolor import colored |
| from lib.net.BasePIFuNet import BasePIFuNet |
| import torch.nn as nn |
| import torch |
|
|
|
|
| maskout = False |
|
|
|
|
| class HGPIFuNet(BasePIFuNet): |
| ''' |
| HG PIFu network uses Hourglass stacks as the image filter. |
| It does the following: |
| 1. Compute image feature stacks and store it in self.im_feat_list |
| self.im_feat_list[-1] is the last stack (output stack) |
| 2. Calculate calibration |
| 3. If training, it index on every intermediate stacks, |
| If testing, it index on the last stack. |
| 4. Classification. |
| 5. During training, error is calculated on all stacks. |
| ''' |
|
|
| def __init__(self, |
| cfg, |
| projection_mode='orthogonal', |
| error_term=nn.MSELoss()): |
|
|
| super(HGPIFuNet, self).__init__(projection_mode=projection_mode, |
| error_term=error_term) |
|
|
| self.l1_loss = nn.SmoothL1Loss() |
| self.opt = cfg.net |
| self.root = cfg.root |
| self.overfit = cfg.overfit |
|
|
| channels_IF = self.opt.mlp_dim |
|
|
| self.use_filter = self.opt.use_filter |
| self.prior_type = self.opt.prior_type |
| self.smpl_feats = self.opt.smpl_feats |
|
|
| self.smpl_dim = self.opt.smpl_dim |
| self.voxel_dim = self.opt.voxel_dim |
| self.hourglass_dim = self.opt.hourglass_dim |
| self.sdf_clip = cfg.sdf_clip / 100.0 |
|
|
| self.in_geo = [item[0] for item in self.opt.in_geo] |
| self.in_nml = [item[0] for item in self.opt.in_nml] |
|
|
| self.in_geo_dim = sum([item[1] for item in self.opt.in_geo]) |
| self.in_nml_dim = sum([item[1] for item in self.opt.in_nml]) |
|
|
| self.in_total = self.in_geo + self.in_nml |
| self.smpl_feat_dict = None |
| self.smplx_data = SMPLX() |
|
|
| if self.prior_type == 'icon': |
| if 'image' in self.in_geo: |
| self.channels_filter = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 6, 7, 8]] |
| else: |
| self.channels_filter = [[0, 1, 2], [3, 4, 5]] |
|
|
| else: |
| if 'image' in self.in_geo: |
| self.channels_filter = [[0, 1, 2, 3, 4, 5, 6, 7, 8]] |
| else: |
| self.channels_filter = [[0, 1, 2, 3, 4, 5]] |
|
|
| channels_IF[0] = self.hourglass_dim if self.use_filter else len( |
| self.channels_filter[0]) |
|
|
| if self.prior_type == 'icon' and 'vis' not in self.smpl_feats: |
| if self.use_filter: |
| channels_IF[0] += self.hourglass_dim |
| else: |
| channels_IF[0] += len(self.channels_filter[0]) |
|
|
| if self.prior_type == 'icon': |
| channels_IF[0] += self.smpl_dim |
| elif self.prior_type == 'pamir': |
| channels_IF[0] += self.voxel_dim |
| smpl_vertex_code, smpl_face_code, smpl_faces, smpl_tetras = read_smpl_constants( |
| self.smplx_data.tedra_dir) |
| self.voxelization = Voxelization( |
| smpl_vertex_code, |
| smpl_face_code, |
| smpl_faces, |
| smpl_tetras, |
| volume_res=128, |
| sigma=0.05, |
| smooth_kernel_size=7, |
| batch_size=cfg.batch_size, |
| device=torch.device(f"cuda:{cfg.gpus[0]}")) |
| self.ve = VolumeEncoder(3, self.voxel_dim, self.opt.num_stack) |
| else: |
| channels_IF[0] += 1 |
|
|
| self.icon_keys = ["smpl_verts", "smpl_faces", "smpl_vis", "smpl_cmap"] |
| self.pamir_keys = [ |
| "voxel_verts", "voxel_faces", "pad_v_num", "pad_f_num" |
| ] |
|
|
| self.if_regressor = MLP( |
| filter_channels=channels_IF, |
| name='if', |
| res_layers=self.opt.res_layers, |
| norm=self.opt.norm_mlp, |
| last_op=nn.Sigmoid() if not cfg.test_mode else None) |
|
|
| |
| if self.use_filter: |
| if self.opt.gtype == "HGPIFuNet": |
| self.F_filter = HGFilter(self.opt, self.opt.num_stack, |
| len(self.channels_filter[0])) |
| else: |
| print( |
| colored(f"Backbone {self.opt.gtype} is unimplemented", |
| 'green')) |
|
|
| summary_log = f"{self.prior_type.upper()}:\n" + \ |
| f"w/ Global Image Encoder: {self.use_filter}\n" + \ |
| f"Image Features used by MLP: {self.in_geo}\n" |
|
|
| if self.prior_type == "icon": |
| summary_log += f"Geometry Features used by MLP: {self.smpl_feats}\n" |
| summary_log += f"Dim of Image Features (local): 6\n" |
| summary_log += f"Dim of Geometry Features (ICON): {self.smpl_dim}\n" |
| elif self.prior_type == "pamir": |
| summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
| summary_log += f"Dim of Geometry Features (PaMIR): {self.voxel_dim}\n" |
| else: |
| summary_log += f"Dim of Image Features (global): {self.hourglass_dim}\n" |
| summary_log += f"Dim of Geometry Features (PIFu): 1 (z-value)\n" |
|
|
| summary_log += f"Dim of MLP's first layer: {channels_IF[0]}\n" |
|
|
| print(colored(summary_log, "yellow")) |
|
|
| self.normal_filter = NormalNet(cfg) |
| init_net(self) |
|
|
| def get_normal(self, in_tensor_dict): |
|
|
| |
| if (not self.training) and (not self.overfit): |
| |
| with torch.no_grad(): |
| feat_lst = [] |
| if "image" in self.in_geo: |
| feat_lst.append( |
| in_tensor_dict['image']) |
| if 'normal_F' in self.in_geo and 'normal_B' in self.in_geo: |
| if 'normal_F' not in in_tensor_dict.keys( |
| ) or 'normal_B' not in in_tensor_dict.keys(): |
| (nmlF, nmlB) = self.normal_filter(in_tensor_dict) |
| else: |
| nmlF = in_tensor_dict['normal_F'] |
| nmlB = in_tensor_dict['normal_B'] |
| feat_lst.append(nmlF) |
| feat_lst.append(nmlB) |
| in_filter = torch.cat(feat_lst, dim=1) |
|
|
| else: |
| in_filter = torch.cat([in_tensor_dict[key] for key in self.in_geo], |
| dim=1) |
|
|
| return in_filter |
|
|
| def get_mask(self, in_filter, size=128): |
|
|
| mask = F.interpolate(in_filter[:, self.channels_filter[0]], |
| size=(size, size), |
| mode="bilinear", |
| align_corners=True).abs().sum(dim=1, |
| keepdim=True) != 0.0 |
|
|
| return mask |
|
|
| def filter(self, in_tensor_dict, return_inter=False): |
| ''' |
| Filter the input images |
| store all intermediate features. |
| :param images: [B, C, H, W] input images |
| ''' |
|
|
| in_filter = self.get_normal(in_tensor_dict) |
|
|
| features_G = [] |
|
|
| if self.prior_type == 'icon': |
| if self.use_filter: |
| features_F = self.F_filter(in_filter[:, |
| self.channels_filter[0]] |
| ) |
| features_B = self.F_filter(in_filter[:, |
| self.channels_filter[1]] |
| ) |
| else: |
| features_F = [in_filter[:, self.channels_filter[0]]] |
| features_B = [in_filter[:, self.channels_filter[1]]] |
| for idx in range(len(features_F)): |
| features_G.append( |
| torch.cat([features_F[idx], features_B[idx]], dim=1)) |
| else: |
| if self.use_filter: |
| features_G = self.F_filter(in_filter[:, |
| self.channels_filter[0]]) |
| else: |
| features_G = [in_filter[:, self.channels_filter[0]]] |
|
|
| if self.prior_type == 'icon': |
| self.smpl_feat_dict = { |
| k: in_tensor_dict[k] |
| for k in self.icon_keys |
| } |
| elif self.prior_type == "pamir": |
| self.smpl_feat_dict = { |
| k: in_tensor_dict[k] |
| for k in self.pamir_keys |
| } |
| else: |
| pass |
| |
|
|
| |
| if not self.training: |
| features_out = [features_G[-1]] |
| else: |
| features_out = features_G |
|
|
| if maskout: |
| features_out_mask = [] |
| for feat in features_out: |
| features_out_mask.append( |
| feat * self.get_mask(in_filter, size=feat.shape[2])) |
| features_out = features_out_mask |
|
|
| if return_inter: |
| return features_out, in_filter |
| else: |
| return features_out |
|
|
| def query(self, features, points, calibs, transforms=None, regressor=None): |
|
|
| xyz = self.projection(points, calibs, transforms) |
|
|
| (xy, z) = xyz.split([2, 1], dim=1) |
|
|
| in_cube = (xyz > -1.0) & (xyz < 1.0) |
| in_cube = in_cube.all(dim=1, keepdim=True).detach().float() |
|
|
| preds_list = [] |
|
|
| if self.prior_type == 'icon': |
|
|
| |
| |
| |
|
|
| smpl_sdf, smpl_norm, smpl_cmap, smpl_vis = cal_sdf_batch( |
| self.smpl_feat_dict['smpl_verts'], |
| self.smpl_feat_dict['smpl_faces'], |
| self.smpl_feat_dict['smpl_cmap'], |
| self.smpl_feat_dict['smpl_vis'], |
| xyz.permute(0, 2, 1).contiguous()) |
|
|
| |
| |
| |
| |
|
|
| feat_lst = [smpl_sdf] |
| if 'cmap' in self.smpl_feats: |
| feat_lst.append(smpl_cmap) |
| if 'norm' in self.smpl_feats: |
| feat_lst.append(smpl_norm) |
| if 'vis' in self.smpl_feats: |
| feat_lst.append(smpl_vis) |
|
|
| smpl_feat = torch.cat(feat_lst, dim=2).permute(0, 2, 1) |
| vol_feats = features |
|
|
| elif self.prior_type == "pamir": |
|
|
| voxel_verts = self.smpl_feat_dict[ |
| 'voxel_verts'][:, :-self.smpl_feat_dict['pad_v_num'][0], :] |
| voxel_faces = self.smpl_feat_dict[ |
| 'voxel_faces'][:, :-self.smpl_feat_dict['pad_f_num'][0], :] |
|
|
| self.voxelization.update_param( |
| batch_size=voxel_faces.shape[0], |
| smpl_tetra=voxel_faces[0].detach().cpu().numpy()) |
| vol = self.voxelization(voxel_verts) |
| vol_feats = self.ve(vol, intermediate_output=self.training) |
| else: |
| vol_feats = features |
|
|
| for im_feat, vol_feat in zip(features, vol_feats): |
|
|
| |
| |
| if self.prior_type == 'icon': |
| if 'vis' in self.smpl_feats: |
| point_local_feat = feat_select(self.index(im_feat, xy), |
| smpl_feat[:, [-1], :]) |
| if maskout: |
| normal_mask = torch.tile( |
| point_local_feat.sum(dim=1, keepdims=True) == 0.0, |
| (1, smpl_feat.shape[1], 1)) |
| normal_mask[:, 1:, :] = False |
| smpl_feat[normal_mask] = -1.0 |
| point_feat_list = [point_local_feat, smpl_feat[:, :-1, :]] |
| else: |
| point_local_feat = self.index(im_feat, xy) |
| point_feat_list = [point_local_feat, smpl_feat[:, :, :]] |
|
|
| elif self.prior_type == 'pamir': |
| |
| |
| point_feat_list = [ |
| self.index(im_feat, xy), |
| self.index(vol_feat, xyz) |
| ] |
|
|
| else: |
| point_feat_list = [self.index(im_feat, xy), z] |
|
|
| point_feat = torch.cat(point_feat_list, 1) |
|
|
| |
| preds = regressor(point_feat) |
| preds = in_cube * preds |
|
|
| preds_list.append(preds) |
|
|
| return preds_list |
|
|
| def get_error(self, preds_if_list, labels): |
| """calcaulate error |
| |
| Args: |
| preds_list (list): list of torch.tensor(B, 3, N) |
| labels (torch.tensor): (B, N_knn, N) |
| |
| Returns: |
| torch.tensor: error |
| """ |
| error_if = 0 |
|
|
| for pred_id in range(len(preds_if_list)): |
| pred_if = preds_if_list[pred_id] |
| error_if += self.error_term(pred_if, labels) |
|
|
| error_if /= len(preds_if_list) |
|
|
| return error_if |
|
|
| def forward(self, in_tensor_dict): |
| """ |
| sample_tensor [B, 3, N] |
| calib_tensor [B, 4, 4] |
| label_tensor [B, 1, N] |
| smpl_feat_tensor [B, 59, N] |
| """ |
|
|
| sample_tensor = in_tensor_dict['sample'] |
| calib_tensor = in_tensor_dict['calib'] |
| label_tensor = in_tensor_dict['label'] |
|
|
| in_feat = self.filter(in_tensor_dict) |
|
|
| preds_if_list = self.query(in_feat, |
| sample_tensor, |
| calib_tensor, |
| regressor=self.if_regressor) |
|
|
| error = self.get_error(preds_if_list, label_tensor) |
|
|
| return preds_if_list[-1], error |
|
|