| | from collections import namedtuple |
| | import os |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import yaml |
| | from torch.nn import functional as F |
| |
|
| | from .layers.Resnet import ResNet |
| | from .layers.smpl.SMPL import SMPL_layer |
| |
|
| | ModelOutput = namedtuple( |
| | typename='ModelOutput', |
| | field_names=[ |
| | 'pred_shape', 'pred_theta_mats', 'pred_phi', 'pred_delta_shape', |
| | 'pred_leaf', 'pred_uvd_jts', 'pred_xyz_jts_29', 'pred_xyz_jts_24', |
| | 'pred_xyz_jts_24_struct', 'pred_xyz_jts_17', 'pred_vertices', |
| | 'maxvals', 'cam_scale', 'cam_trans', 'cam_root', 'uvd_heatmap', |
| | 'transl', 'img_feat', 'pred_camera', 'pred_aa' |
| | ]) |
| | ModelOutput.__new__.__defaults__ = (None, ) * len(ModelOutput._fields) |
| |
|
| |
|
| | def update_config(config_file): |
| | with open(config_file) as f: |
| | config = yaml.load(f, Loader=yaml.FullLoader) |
| | return config |
| |
|
| |
|
| | def norm_heatmap(norm_type, heatmap): |
| | |
| | shape = heatmap.shape |
| | if norm_type == 'softmax': |
| | heatmap = heatmap.reshape(*shape[:2], -1) |
| | |
| | heatmap = F.softmax(heatmap, 2) |
| | return heatmap.reshape(*shape) |
| | else: |
| | raise NotImplementedError |
| |
|
| |
|
| | class HybrIKBaseSMPLCam(nn.Module): |
| |
|
| | def __init__(self, |
| | cfg_file, |
| | smpl_path, |
| | data_path, |
| | norm_layer=nn.BatchNorm2d): |
| | super(HybrIKBaseSMPLCam, self).__init__() |
| |
|
| | cfg = update_config(cfg_file)['MODEL'] |
| |
|
| | self.deconv_dim = cfg['NUM_DECONV_FILTERS'] |
| | self._norm_layer = norm_layer |
| | self.num_joints = cfg['NUM_JOINTS'] |
| | self.norm_type = cfg['POST']['NORM_TYPE'] |
| | self.depth_dim = cfg['EXTRA']['DEPTH_DIM'] |
| | self.height_dim = cfg['HEATMAP_SIZE'][0] |
| | self.width_dim = cfg['HEATMAP_SIZE'][1] |
| | self.smpl_dtype = torch.float32 |
| |
|
| | backbone = ResNet |
| |
|
| | self.preact = backbone(f"resnet{cfg['NUM_LAYERS']}") |
| |
|
| | |
| | import torchvision.models as tm |
| | if cfg['NUM_LAYERS'] == 101: |
| | ''' Load pretrained model ''' |
| | x = tm.resnet101(pretrained=True) |
| | self.feature_channel = 2048 |
| | elif cfg['NUM_LAYERS'] == 50: |
| | x = tm.resnet50(pretrained=True) |
| | self.feature_channel = 2048 |
| | elif cfg['NUM_LAYERS'] == 34: |
| | x = tm.resnet34(pretrained=True) |
| | self.feature_channel = 512 |
| | elif cfg['NUM_LAYERS'] == 18: |
| | x = tm.resnet18(pretrained=True) |
| | self.feature_channel = 512 |
| | else: |
| | raise NotImplementedError |
| | model_state = self.preact.state_dict() |
| | state = { |
| | k: v |
| | for k, v in x.state_dict().items() |
| | if k in self.preact.state_dict() |
| | and v.size() == self.preact.state_dict()[k].size() |
| | } |
| | model_state.update(state) |
| | self.preact.load_state_dict(model_state) |
| |
|
| | self.deconv_layers = self._make_deconv_layer() |
| | self.final_layer = nn.Conv2d(self.deconv_dim[2], |
| | self.num_joints * self.depth_dim, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0) |
| |
|
| | h36m_jregressor = np.load( |
| | os.path.join(data_path, 'J_regressor_h36m.npy')) |
| | self.smpl = SMPL_layer(smpl_path, |
| | h36m_jregressor=h36m_jregressor, |
| | dtype=self.smpl_dtype) |
| |
|
| | self.joint_pairs_24 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), |
| | (16, 17), (18, 19), (20, 21), (22, 23)) |
| |
|
| | self.joint_pairs_29 = ((1, 2), (4, 5), (7, 8), (10, 11), (13, 14), |
| | (16, 17), (18, 19), (20, 21), (22, 23), |
| | (25, 26), (27, 28)) |
| |
|
| | self.leaf_pairs = ((0, 1), (3, 4)) |
| | self.root_idx_smpl = 0 |
| |
|
| | |
| | init_shape = np.load(os.path.join(data_path, 'h36m_mean_beta.npy')) |
| | self.register_buffer('init_shape', torch.Tensor(init_shape).float()) |
| |
|
| | init_cam = torch.tensor([0.9, 0, 0]) |
| | self.register_buffer('init_cam', torch.Tensor(init_cam).float()) |
| |
|
| | self.avg_pool = nn.AdaptiveAvgPool2d(1) |
| | self.fc1 = nn.Linear(self.feature_channel, 1024) |
| | self.drop1 = nn.Dropout(p=0.5) |
| | self.fc2 = nn.Linear(1024, 1024) |
| | self.drop2 = nn.Dropout(p=0.5) |
| | self.decshape = nn.Linear(1024, 10) |
| | self.decphi = nn.Linear(1024, 23 * 2) |
| | self.deccam = nn.Linear(1024, 3) |
| |
|
| | self.focal_length = cfg['FOCAL_LENGTH'] |
| | self.input_size = 256.0 |
| |
|
| | def _make_deconv_layer(self): |
| | deconv_layers = [] |
| | deconv1 = nn.ConvTranspose2d(self.feature_channel, |
| | self.deconv_dim[0], |
| | kernel_size=4, |
| | stride=2, |
| | padding=int(4 / 2) - 1, |
| | bias=False) |
| | bn1 = self._norm_layer(self.deconv_dim[0]) |
| | deconv2 = nn.ConvTranspose2d(self.deconv_dim[0], |
| | self.deconv_dim[1], |
| | kernel_size=4, |
| | stride=2, |
| | padding=int(4 / 2) - 1, |
| | bias=False) |
| | bn2 = self._norm_layer(self.deconv_dim[1]) |
| | deconv3 = nn.ConvTranspose2d(self.deconv_dim[1], |
| | self.deconv_dim[2], |
| | kernel_size=4, |
| | stride=2, |
| | padding=int(4 / 2) - 1, |
| | bias=False) |
| | bn3 = self._norm_layer(self.deconv_dim[2]) |
| |
|
| | deconv_layers.append(deconv1) |
| | deconv_layers.append(bn1) |
| | deconv_layers.append(nn.ReLU(inplace=True)) |
| | deconv_layers.append(deconv2) |
| | deconv_layers.append(bn2) |
| | deconv_layers.append(nn.ReLU(inplace=True)) |
| | deconv_layers.append(deconv3) |
| | deconv_layers.append(bn3) |
| | deconv_layers.append(nn.ReLU(inplace=True)) |
| |
|
| | return nn.Sequential(*deconv_layers) |
| |
|
| | def _initialize(self): |
| | for name, m in self.deconv_layers.named_modules(): |
| | if isinstance(m, nn.ConvTranspose2d): |
| | nn.init.normal_(m.weight, std=0.001) |
| | elif isinstance(m, nn.BatchNorm2d): |
| | nn.init.constant_(m.weight, 1) |
| | nn.init.constant_(m.bias, 0) |
| | for m in self.final_layer.modules(): |
| | if isinstance(m, nn.Conv2d): |
| | nn.init.normal_(m.weight, std=0.001) |
| | nn.init.constant_(m.bias, 0) |
| |
|
| | def flip_uvd_coord(self, pred_jts, shift=False, flatten=True): |
| | if flatten: |
| | assert pred_jts.dim() == 2 |
| | num_batches = pred_jts.shape[0] |
| | pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) |
| | else: |
| | assert pred_jts.dim() == 3 |
| | num_batches = pred_jts.shape[0] |
| |
|
| | |
| | if shift: |
| | pred_jts[:, :, 0] = -pred_jts[:, :, 0] |
| | else: |
| | pred_jts[:, :, 0] = -1 / self.width_dim - pred_jts[:, :, 0] |
| |
|
| | for pair in self.joint_pairs_29: |
| | dim0, dim1 = pair |
| | idx = torch.Tensor((dim0, dim1)).long() |
| | inv_idx = torch.Tensor((dim1, dim0)).long() |
| | pred_jts[:, idx] = pred_jts[:, inv_idx] |
| |
|
| | if flatten: |
| | pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) |
| |
|
| | return pred_jts |
| |
|
| | def flip_xyz_coord(self, pred_jts, flatten=True): |
| | if flatten: |
| | assert pred_jts.dim() == 2 |
| | num_batches = pred_jts.shape[0] |
| | pred_jts = pred_jts.reshape(num_batches, self.num_joints, 3) |
| | else: |
| | assert pred_jts.dim() == 3 |
| | num_batches = pred_jts.shape[0] |
| |
|
| | pred_jts[:, :, 0] = -pred_jts[:, :, 0] |
| |
|
| | for pair in self.joint_pairs_29: |
| | dim0, dim1 = pair |
| | idx = torch.Tensor((dim0, dim1)).long() |
| | inv_idx = torch.Tensor((dim1, dim0)).long() |
| | pred_jts[:, idx] = pred_jts[:, inv_idx] |
| |
|
| | if flatten: |
| | pred_jts = pred_jts.reshape(num_batches, self.num_joints * 3) |
| |
|
| | return pred_jts |
| |
|
| | def flip_phi(self, pred_phi): |
| | pred_phi[:, :, 1] = -1 * pred_phi[:, :, 1] |
| |
|
| | for pair in self.joint_pairs_24: |
| | dim0, dim1 = pair |
| | idx = torch.Tensor((dim0 - 1, dim1 - 1)).long() |
| | inv_idx = torch.Tensor((dim1 - 1, dim0 - 1)).long() |
| | pred_phi[:, idx] = pred_phi[:, inv_idx] |
| |
|
| | return pred_phi |
| |
|
| | def forward(self, |
| | x, |
| | flip_item=None, |
| | flip_output=False, |
| | gt_uvd=None, |
| | gt_uvd_weight=None, |
| | **kwargs): |
| |
|
| | batch_size = x.shape[0] |
| |
|
| | |
| | |
| |
|
| | x0 = self.preact(x) |
| | out = self.deconv_layers(x0) |
| | out = self.final_layer(out) |
| |
|
| | |
| | |
| |
|
| | out = out.reshape((out.shape[0], self.num_joints, -1)) |
| |
|
| | maxvals, _ = torch.max(out, dim=2, keepdim=True) |
| |
|
| | out = norm_heatmap(self.norm_type, out) |
| | assert out.dim() == 3, out.shape |
| |
|
| | heatmaps = out / out.sum(dim=2, keepdim=True) |
| |
|
| | heatmaps = heatmaps.reshape( |
| | (heatmaps.shape[0], self.num_joints, self.depth_dim, |
| | self.height_dim, self.width_dim)) |
| |
|
| | hm_x0 = heatmaps.sum((2, 3)) |
| | hm_y0 = heatmaps.sum((2, 4)) |
| | hm_z0 = heatmaps.sum((3, 4)) |
| |
|
| | range_tensor = torch.arange(hm_x0.shape[-1], |
| | dtype=torch.float32, |
| | device=hm_x0.device) |
| | hm_x = hm_x0 * range_tensor |
| | hm_y = hm_y0 * range_tensor |
| | hm_z = hm_z0 * range_tensor |
| |
|
| | coord_x = hm_x.sum(dim=2, keepdim=True) |
| | coord_y = hm_y.sum(dim=2, keepdim=True) |
| | coord_z = hm_z.sum(dim=2, keepdim=True) |
| |
|
| | coord_x = coord_x / float(self.width_dim) - 0.5 |
| | coord_y = coord_y / float(self.height_dim) - 0.5 |
| | coord_z = coord_z / float(self.depth_dim) - 0.5 |
| |
|
| | |
| | pred_uvd_jts_29 = torch.cat((coord_x, coord_y, coord_z), dim=2) |
| |
|
| | x0 = self.avg_pool(x0) |
| | x0 = x0.view(x0.size(0), -1) |
| | init_shape = self.init_shape.expand(batch_size, -1) |
| | init_cam = self.init_cam.expand(batch_size, -1) |
| |
|
| | xc = x0 |
| |
|
| | xc = self.fc1(xc) |
| | xc = self.drop1(xc) |
| | xc = self.fc2(xc) |
| | xc = self.drop2(xc) |
| |
|
| | delta_shape = self.decshape(xc) |
| | pred_shape = delta_shape + init_shape |
| | pred_phi = self.decphi(xc) |
| | pred_camera = self.deccam(xc).reshape(batch_size, -1) + init_cam |
| |
|
| | camScale = pred_camera[:, :1].unsqueeze(1) |
| | camTrans = pred_camera[:, 1:].unsqueeze(1) |
| |
|
| | camDepth = self.focal_length / (self.input_size * camScale + 1e-9) |
| |
|
| | pred_xyz_jts_29 = torch.zeros_like(pred_uvd_jts_29) |
| | pred_xyz_jts_29[:, :, 2:] = pred_uvd_jts_29[:, :, |
| | 2:].clone() |
| | pred_xyz_jts_29_meter = (pred_uvd_jts_29[:, :, :2] * self.input_size / self.focal_length) \ |
| | * (pred_xyz_jts_29[:, :, 2:]*2.2 + camDepth) - camTrans |
| |
|
| | pred_xyz_jts_29[:, :, :2] = pred_xyz_jts_29_meter / 2.2 |
| |
|
| | camera_root = pred_xyz_jts_29[:, [0], ] * 2.2 |
| | camera_root[:, :, :2] += camTrans |
| | camera_root[:, :, [2]] += camDepth |
| |
|
| | if not self.training: |
| | pred_xyz_jts_29 = pred_xyz_jts_29 - pred_xyz_jts_29[:, [0]] |
| |
|
| | if flip_item is not None: |
| | assert flip_output is not None |
| | pred_xyz_jts_29_orig, pred_phi_orig, pred_leaf_orig, pred_shape_orig = flip_item |
| |
|
| | if flip_output: |
| | pred_xyz_jts_29 = self.flip_xyz_coord(pred_xyz_jts_29, |
| | flatten=False) |
| | if flip_output and flip_item is not None: |
| | pred_xyz_jts_29 = (pred_xyz_jts_29 + pred_xyz_jts_29_orig.reshape( |
| | batch_size, 29, 3)) / 2 |
| |
|
| | pred_xyz_jts_29_flat = pred_xyz_jts_29.reshape(batch_size, -1) |
| |
|
| | pred_phi = pred_phi.reshape(batch_size, 23, 2) |
| |
|
| | if flip_output: |
| | pred_phi = self.flip_phi(pred_phi) |
| |
|
| | if flip_output and flip_item is not None: |
| | pred_phi = (pred_phi + pred_phi_orig) / 2 |
| | pred_shape = (pred_shape + pred_shape_orig) / 2 |
| |
|
| | output = self.smpl.hybrik( |
| | pose_skeleton=pred_xyz_jts_29.type(self.smpl_dtype) * |
| | 2.2, |
| | betas=pred_shape.type(self.smpl_dtype), |
| | phis=pred_phi.type(self.smpl_dtype), |
| | global_orient=None, |
| | return_verts=True) |
| | pred_vertices = output.vertices.float() |
| | |
| | |
| | pred_xyz_jts_24_struct = output.joints.float() / 2 |
| | |
| | |
| | pred_xyz_jts_17 = output.joints_from_verts.float() / 2 |
| | pred_theta_mats = output.rot_mats.float().reshape(batch_size, 24, 3, 3) |
| | pred_xyz_jts_24 = pred_xyz_jts_29[:, :24, :].reshape(batch_size, |
| | 72) / 2 |
| | pred_xyz_jts_24_struct = pred_xyz_jts_24_struct.reshape(batch_size, 72) |
| | pred_xyz_jts_17_flat = pred_xyz_jts_17.reshape(batch_size, 17 * 3) |
| |
|
| | transl = pred_xyz_jts_29[:, 0, :] * \ |
| | 2.2 - pred_xyz_jts_17[:, 0, :] * 2.2 |
| | transl[:, :2] += camTrans[:, 0] |
| | transl[:, 2] += camDepth[:, 0, 0] |
| |
|
| | new_cam = torch.zeros_like(transl) |
| | new_cam[:, 1:] = transl[:, :2] |
| | new_cam[:, 0] = self.focal_length / \ |
| | (self.input_size * transl[:, 2] + 1e-9) |
| |
|
| | |
| |
|
| | output = dict( |
| | pred_phi=pred_phi, |
| | pred_delta_shape=delta_shape, |
| | pred_shape=pred_shape, |
| | |
| | pred_theta_mats=pred_theta_mats, |
| | pred_uvd_jts=pred_uvd_jts_29.reshape(batch_size, -1), |
| | pred_xyz_jts_29=pred_xyz_jts_29_flat, |
| | pred_xyz_jts_24=pred_xyz_jts_24, |
| | pred_xyz_jts_24_struct=pred_xyz_jts_24_struct, |
| | pred_xyz_jts_17=pred_xyz_jts_17_flat, |
| | pred_vertices=pred_vertices, |
| | maxvals=maxvals, |
| | cam_scale=camScale[:, 0], |
| | cam_trans=camTrans[:, 0], |
| | cam_root=camera_root, |
| | pred_camera=new_cam, |
| | transl=transl, |
| | |
| | |
| | |
| | ) |
| | return output |
| |
|
| | def forward_gt_theta(self, gt_theta, gt_beta): |
| |
|
| | output = self.smpl(pose_axis_angle=gt_theta, |
| | betas=gt_beta, |
| | global_orient=None, |
| | return_verts=True) |
| |
|
| | return output |
| |
|