Spaces:
Sleeping
Sleeping
| # 2022.07.19 - Changed for CLIFF | |
| # Huawei Technologies Co., Ltd. | |
| # Copyright (C) 2022. Huawei Technologies Co., Ltd. All rights reserved. | |
| # Copyright (c) 2019, University of Pennsylvania, Max Planck Institute for Intelligent Systems | |
| # This program is free software; you can redistribute it and/or modify it | |
| # under the terms of the MIT license. | |
| # This program is distributed in the hope that it will be useful, but WITHOUT ANY | |
| # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A | |
| # PARTICULAR PURPOSE. See the MIT License for more details. | |
| # This script is borrowed and extended from SPIN | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import math | |
| from common.imutils import rot6d_to_rotmat | |
| from models.backbones.resnet import ResNet | |
| class CLIFF(nn.Module): | |
| """ SMPL Iterative Regressor with ResNet50 backbone""" | |
| def __init__(self, smpl_mean_params, img_feat_num=2048): | |
| super(CLIFF, self).__init__() | |
| self.encoder = ResNet(layers=[3, 4, 6, 3]) | |
| npose = 24 * 6 | |
| nshape = 10 | |
| ncam = 3 | |
| nbbox = 3 | |
| fc1_feat_num = 1024 | |
| fc2_feat_num = 1024 | |
| final_feat_num = fc2_feat_num | |
| reg_in_feat_num = img_feat_num + nbbox + npose + nshape + ncam | |
| self.fc1 = nn.Linear(reg_in_feat_num, fc1_feat_num) | |
| self.drop1 = nn.Dropout() | |
| self.fc2 = nn.Linear(fc1_feat_num, fc2_feat_num) | |
| self.drop2 = nn.Dropout() | |
| self.decpose = nn.Linear(final_feat_num, npose) | |
| self.decshape = nn.Linear(final_feat_num, nshape) | |
| self.deccam = nn.Linear(final_feat_num, ncam) | |
| nn.init.xavier_uniform_(self.decpose.weight, gain=0.01) | |
| nn.init.xavier_uniform_(self.decshape.weight, gain=0.01) | |
| nn.init.xavier_uniform_(self.deccam.weight, gain=0.01) | |
| for m in self.modules(): | |
| if isinstance(m, nn.Conv2d): | |
| n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels | |
| m.weight.data.normal_(0, math.sqrt(2. / n)) | |
| elif isinstance(m, nn.BatchNorm2d): | |
| m.weight.data.fill_(1) | |
| m.bias.data.zero_() | |
| mean_params = np.load(smpl_mean_params) | |
| init_pose = torch.from_numpy(mean_params['pose'][:]).unsqueeze(0) | |
| init_shape = torch.from_numpy(mean_params['shape'][:].astype('float32')).unsqueeze(0) | |
| init_cam = torch.from_numpy(mean_params['cam']).unsqueeze(0) | |
| self.register_buffer('init_pose', init_pose) | |
| self.register_buffer('init_shape', init_shape) | |
| self.register_buffer('init_cam', init_cam) | |
| def forward(self, x, bbox, init_pose=None, init_shape=None, init_cam=None, n_iter=3): | |
| batch_size = x.shape[0] | |
| if init_pose is None: | |
| init_pose = self.init_pose.expand(batch_size, -1) | |
| if init_shape is None: | |
| init_shape = self.init_shape.expand(batch_size, -1) | |
| if init_cam is None: | |
| init_cam = self.init_cam.expand(batch_size, -1) | |
| xf = self.encoder(x) | |
| pred_pose = init_pose | |
| pred_shape = init_shape | |
| pred_cam = init_cam | |
| for i in range(n_iter): | |
| xc = torch.cat([xf, bbox, pred_pose, pred_shape, pred_cam], 1) | |
| xc = self.fc1(xc) | |
| xc = self.drop1(xc) | |
| xc = self.fc2(xc) | |
| xc = self.drop2(xc) | |
| pred_pose = self.decpose(xc) + pred_pose | |
| pred_shape = self.decshape(xc) + pred_shape | |
| pred_cam = self.deccam(xc) + pred_cam | |
| pred_rotmat = rot6d_to_rotmat(pred_pose).view(batch_size, 24, 3, 3) | |
| return pred_rotmat, pred_shape, pred_cam | |