Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| from ...utils.geometry import rot6d_to_rotmat, aa_to_rotmat | |
| from typing import Optional | |
| def make_linear_layers(feat_dims, relu_final=True, use_bn=False): | |
| layers = [] | |
| for i in range(len(feat_dims)-1): | |
| layers.append(nn.Linear(feat_dims[i], feat_dims[i+1])) | |
| # Do not use ReLU for final estimation | |
| if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and relu_final): | |
| if use_bn: | |
| layers.append(nn.BatchNorm1d(feat_dims[i+1])) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def make_conv_layers(feat_dims, kernel=3, stride=1, padding=1, bnrelu_final=True): | |
| layers = [] | |
| for i in range(len(feat_dims)-1): | |
| layers.append( | |
| nn.Conv2d( | |
| in_channels=feat_dims[i], | |
| out_channels=feat_dims[i+1], | |
| kernel_size=kernel, | |
| stride=stride, | |
| padding=padding | |
| )) | |
| # Do not use BN and ReLU for final estimation | |
| if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): | |
| layers.append(nn.BatchNorm2d(feat_dims[i+1])) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def make_deconv_layers(feat_dims, bnrelu_final=True): | |
| layers = [] | |
| for i in range(len(feat_dims)-1): | |
| layers.append( | |
| nn.ConvTranspose2d( | |
| in_channels=feat_dims[i], | |
| out_channels=feat_dims[i+1], | |
| kernel_size=4, | |
| stride=2, | |
| padding=1, | |
| output_padding=0, | |
| bias=False)) | |
| # Do not use BN and ReLU for final estimation | |
| if i < len(feat_dims)-2 or (i == len(feat_dims)-2 and bnrelu_final): | |
| layers.append(nn.BatchNorm2d(feat_dims[i+1])) | |
| layers.append(nn.ReLU(inplace=True)) | |
| return nn.Sequential(*layers) | |
| def sample_joint_features(img_feat, joint_xy): | |
| height, width = img_feat.shape[2:] | |
| x = joint_xy[:, :, 0] / (width - 1) * 2 - 1 | |
| y = joint_xy[:, :, 1] / (height - 1) * 2 - 1 | |
| grid = torch.stack((x, y), 2)[:, :, None, :] | |
| img_feat = F.grid_sample(img_feat, grid, align_corners=True)[:, :, :, 0] # batch_size, channel_dim, joint_num | |
| img_feat = img_feat.permute(0, 2, 1).contiguous() # batch_size, joint_num, channel_dim | |
| return img_feat | |
| def perspective_projection(points: torch.Tensor, | |
| translation: torch.Tensor, | |
| focal_length: torch.Tensor, | |
| camera_center: Optional[torch.Tensor] = None, | |
| rotation: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| """ | |
| Computes the perspective projection of a set of 3D points. | |
| Args: | |
| points (torch.Tensor): Tensor of shape (B, N, 3) containing the input 3D points. | |
| translation (torch.Tensor): Tensor of shape (B, 3) containing the 3D camera translation. | |
| focal_length (torch.Tensor): Tensor of shape (B, 2) containing the focal length in pixels. | |
| camera_center (torch.Tensor): Tensor of shape (B, 2) containing the camera center in pixels. | |
| rotation (torch.Tensor): Tensor of shape (B, 3, 3) containing the camera rotation. | |
| Returns: | |
| torch.Tensor: Tensor of shape (B, N, 2) containing the projection of the input points. | |
| """ | |
| batch_size = points.shape[0] | |
| if rotation is None: | |
| rotation = torch.eye(3, device=points.device, dtype=points.dtype).unsqueeze(0).expand(batch_size, -1, -1) | |
| if camera_center is None: | |
| camera_center = torch.zeros(batch_size, 2, device=points.device, dtype=points.dtype) | |
| # Populate intrinsic camera matrix K. | |
| K = torch.zeros([batch_size, 3, 3], device=points.device, dtype=points.dtype) | |
| K[:,0,0] = focal_length[:,0] | |
| K[:,1,1] = focal_length[:,1] | |
| K[:,2,2] = 1. | |
| K[:,:-1, -1] = camera_center | |
| # Transform points | |
| points = torch.einsum('bij,bkj->bki', rotation, points) | |
| points = points + translation.unsqueeze(1) | |
| # Apply perspective distortion | |
| projected_points = points / points[:,:,-1].unsqueeze(-1) | |
| # Apply camera intrinsics | |
| projected_points = torch.einsum('bij,bkj->bki', K, projected_points) | |
| return projected_points[:, :, :-1] | |
| class DeConvNet(nn.Module): | |
| def __init__(self, feat_dim=768, upscale=4): | |
| super(DeConvNet, self).__init__() | |
| self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False) | |
| self.deconv = nn.ModuleList([]) | |
| for i in range(int(math.log2(upscale))+1): | |
| if i==0: | |
| self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4])) | |
| elif i==1: | |
| self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8])) | |
| elif i==2: | |
| self.deconv.append(make_deconv_layers([feat_dim//2, feat_dim//4, feat_dim//8, feat_dim//8])) | |
| def forward(self, img_feat): | |
| face_img_feats = [] | |
| img_feat = self.first_conv(img_feat) | |
| face_img_feats.append(img_feat) | |
| for i, deconv in enumerate(self.deconv): | |
| scale = 2**i | |
| img_feat_i = deconv(img_feat) | |
| face_img_feat = img_feat_i | |
| face_img_feats.append(face_img_feat) | |
| return face_img_feats[::-1] # high resolution -> low resolution | |
| class DeConvNet_v2(nn.Module): | |
| def __init__(self, feat_dim=768): | |
| super(DeConvNet_v2, self).__init__() | |
| self.first_conv = make_conv_layers([feat_dim, feat_dim//2], kernel=1, stride=1, padding=0, bnrelu_final=False) | |
| self.deconv = nn.Sequential(*[nn.ConvTranspose2d(in_channels=feat_dim//2, out_channels=feat_dim//4, kernel_size=4, stride=4, padding=0, output_padding=0, bias=False), | |
| nn.BatchNorm2d(feat_dim//4), | |
| nn.ReLU(inplace=True)]) | |
| def forward(self, img_feat): | |
| face_img_feats = [] | |
| img_feat = self.first_conv(img_feat) | |
| img_feat = self.deconv(img_feat) | |
| return [img_feat] | |
| class RefineNet(nn.Module): | |
| def __init__(self, cfg, feat_dim=1280, upscale=3): | |
| super(RefineNet, self).__init__() | |
| #self.deconv = DeConvNet_v2(feat_dim=feat_dim) | |
| #self.out_dim = feat_dim//4 | |
| self.deconv = DeConvNet(feat_dim=feat_dim, upscale=upscale) | |
| self.out_dim = feat_dim//8 + feat_dim//4 + feat_dim//2 | |
| self.dec_pose = nn.Linear(self.out_dim, 96) | |
| self.dec_cam = nn.Linear(self.out_dim, 3) | |
| self.dec_shape = nn.Linear(self.out_dim, 10) | |
| self.cfg = cfg | |
| self.joint_rep_type = cfg.MODEL.MANO_HEAD.get('JOINT_REP', '6d') | |
| self.joint_rep_dim = {'6d': 6, 'aa': 3}[self.joint_rep_type] | |
| def forward(self, img_feat, verts_3d, pred_cam, pred_mano_feats, focal_length): | |
| B = img_feat.shape[0] | |
| img_feats = self.deconv(img_feat) | |
| img_feat_sizes = [img_feat.shape[2] for img_feat in img_feats] | |
| temp_cams = [torch.stack([pred_cam[:, 1], pred_cam[:, 2], | |
| 2*focal_length[:, 0]/(img_feat_size * pred_cam[:, 0] +1e-9)],dim=-1) for img_feat_size in img_feat_sizes] | |
| verts_2d = [perspective_projection(verts_3d, | |
| translation=temp_cams[i], | |
| focal_length=focal_length / img_feat_sizes[i]) for i in range(len(img_feat_sizes))] | |
| vert_feats = [sample_joint_features(img_feats[i], verts_2d[i]).max(1).values for i in range(len(img_feat_sizes))] | |
| vert_feats = torch.cat(vert_feats, dim=-1) | |
| delta_pose = self.dec_pose(vert_feats) | |
| delta_betas = self.dec_shape(vert_feats) | |
| delta_cam = self.dec_cam(vert_feats) | |
| pred_hand_pose = pred_mano_feats['hand_pose'] + delta_pose | |
| pred_betas = pred_mano_feats['betas'] + delta_betas | |
| pred_cam = pred_mano_feats['cam'] + delta_cam | |
| joint_conversion_fn = { | |
| '6d': rot6d_to_rotmat, | |
| 'aa': lambda x: aa_to_rotmat(x.view(-1, 3).contiguous()) | |
| }[self.joint_rep_type] | |
| pred_hand_pose = joint_conversion_fn(pred_hand_pose).view(B, self.cfg.MANO.NUM_HAND_JOINTS+1, 3, 3) | |
| pred_mano_params = {'global_orient': pred_hand_pose[:, [0]], | |
| 'hand_pose': pred_hand_pose[:, 1:], | |
| 'betas': pred_betas} | |
| return pred_mano_params, pred_cam | |