| ''' |
| This script is borrowed and extended from https://github.com/nkolot/SPIN/blob/master/train/fits_dict.py |
| ''' |
| import os |
| import cv2 |
| import torch |
| import numpy as np |
| from torchgeometry import angle_axis_to_rotation_matrix, rotation_matrix_to_angle_axis |
|
|
| from core import path_config, constants |
|
|
| import logging |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class FitsDict(): |
| """ Dictionary keeping track of the best fit per image in the training set """ |
|
|
| def __init__(self, options, train_dataset): |
| self.options = options |
| self.train_dataset = train_dataset |
| self.fits_dict = {} |
| self.valid_fit_state = {} |
| |
| self.flipped_parts = torch.tensor(constants.SMPL_POSE_FLIP_PERM, |
| dtype=torch.int64) |
| |
| for ds_name, ds in train_dataset.dataset_dict.items(): |
| if ds_name in ['h36m']: |
| dict_file = os.path.join(path_config.FINAL_FITS_DIR, |
| ds_name + '.npy') |
| self.fits_dict[ds_name] = torch.from_numpy(np.load(dict_file)) |
| self.valid_fit_state[ds_name] = torch.ones(len( |
| self.fits_dict[ds_name]), |
| dtype=torch.uint8) |
| else: |
| dict_file = os.path.join(path_config.FINAL_FITS_DIR, |
| ds_name + '.npz') |
| fits_dict = np.load(dict_file) |
| opt_pose = torch.from_numpy(fits_dict['pose']) |
| opt_betas = torch.from_numpy(fits_dict['betas']) |
| opt_valid_fit = torch.from_numpy(fits_dict['valid_fit']).to( |
| torch.uint8) |
| self.fits_dict[ds_name] = torch.cat([opt_pose, opt_betas], |
| dim=1) |
| self.valid_fit_state[ds_name] = opt_valid_fit |
|
|
| if not options.single_dataset: |
| for ds in train_dataset.datasets: |
| if ds.dataset not in ['h36m']: |
| ds.pose = self.fits_dict[ds.dataset][:, :72].numpy() |
| ds.betas = self.fits_dict[ds.dataset][:, 72:].numpy() |
| ds.has_smpl = self.valid_fit_state[ds.dataset].numpy() |
|
|
| def save(self): |
| """ Save dictionary state to disk """ |
| for ds_name in self.train_dataset.dataset_dict.keys(): |
| dict_file = os.path.join(self.options.checkpoint_dir, |
| ds_name + '_fits.npy') |
| np.save(dict_file, self.fits_dict[ds_name].cpu().numpy()) |
|
|
| def __getitem__(self, x): |
| """ Retrieve dictionary entries """ |
| dataset_name, ind, rot, is_flipped = x |
| batch_size = len(dataset_name) |
| pose = torch.zeros((batch_size, 72)) |
| betas = torch.zeros((batch_size, 10)) |
| for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
| params = self.fits_dict[ds][i] |
| pose[n, :] = params[:72] |
| betas[n, :] = params[72:] |
| pose = pose.clone() |
| |
| pose = self.flip_pose(self.rotate_pose(pose, rot), is_flipped) |
| betas = betas.clone() |
| return pose, betas |
|
|
| def get_vaild_state(self, dataset_name, ind): |
| batch_size = len(dataset_name) |
| valid_fit = torch.zeros(batch_size, dtype=torch.uint8) |
| for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
| valid_fit[n] = self.valid_fit_state[ds][i] |
| valid_fit = valid_fit.clone() |
| return valid_fit |
|
|
| def __setitem__(self, x, val): |
| """ Update dictionary entries """ |
| dataset_name, ind, rot, is_flipped, update = x |
| pose, betas = val |
| batch_size = len(dataset_name) |
| |
| pose = self.rotate_pose(self.flip_pose(pose, is_flipped), -rot) |
| params = torch.cat((pose, betas), dim=-1).cpu() |
| for ds, i, n in zip(dataset_name, ind, range(batch_size)): |
| if update[n]: |
| self.fits_dict[ds][i] = params[n] |
|
|
| def flip_pose(self, pose, is_flipped): |
| """flip SMPL pose parameters""" |
| is_flipped = is_flipped.byte() |
| pose_f = pose.clone() |
| pose_f[is_flipped, :] = pose[is_flipped][:, self.flipped_parts] |
| |
| pose_f[is_flipped, 1::3] *= -1 |
| pose_f[is_flipped, 2::3] *= -1 |
| return pose_f |
|
|
| def rotate_pose(self, pose, rot): |
| """Rotate SMPL pose parameters by rot degrees""" |
| pose = pose.clone() |
| cos = torch.cos(-np.pi * rot / 180.) |
| sin = torch.sin(-np.pi * rot / 180.) |
| zeros = torch.zeros_like(cos) |
| r3 = torch.zeros(cos.shape[0], 1, 3, device=cos.device) |
| r3[:, 0, -1] = 1 |
| R = torch.cat([ |
| torch.stack([cos, -sin, zeros], dim=-1).unsqueeze(1), |
| torch.stack([sin, cos, zeros], dim=-1).unsqueeze(1), r3 |
| ], |
| dim=1) |
| global_pose = pose[:, :3] |
| global_pose_rotmat = angle_axis_to_rotation_matrix(global_pose) |
| global_pose_rotmat_3b3 = global_pose_rotmat[:, :3, :3] |
| global_pose_rotmat_3b3 = torch.matmul(R, global_pose_rotmat_3b3) |
| global_pose_rotmat[:, :3, :3] = global_pose_rotmat_3b3 |
| global_pose_rotmat = global_pose_rotmat[:, :-1, :-1].cpu().numpy() |
| global_pose_np = np.zeros((global_pose.shape[0], 3)) |
| for i in range(global_pose.shape[0]): |
| aa, _ = cv2.Rodrigues(global_pose_rotmat[i]) |
| global_pose_np[i, :] = aa.squeeze() |
| pose[:, :3] = torch.from_numpy(global_pose_np).to(pose.device) |
| return pose |
|
|