Spaces:
Sleeping
Sleeping
added missing files
Browse files- utils/__init__.py +0 -0
- utils/cluster.py +99 -0
- utils/colorwheel.py +22 -0
- utils/config.py +196 -0
- utils/default_hparams.py +45 -0
- utils/diff_renderer.py +287 -0
- utils/get_cfg.py +17 -0
- utils/hrnet.py +625 -0
- utils/image_utils.py +444 -0
- utils/kp_utils.py +1114 -0
- utils/loss.py +207 -0
- utils/mesh_utils.py +6 -0
- utils/metrics.py +106 -0
- utils/smpl_uv.py +167 -0
utils/__init__.py
ADDED
|
File without changes
|
utils/cluster.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import stat
|
| 4 |
+
import shutil
|
| 5 |
+
import subprocess
|
| 6 |
+
|
| 7 |
+
from loguru import logger
|
| 8 |
+
|
| 9 |
+
GPUS = {
|
| 10 |
+
'v100-v16': ('\"Tesla V100-PCIE-16GB\"', 'tesla', 16000),
|
| 11 |
+
'v100-p32': ('\"Tesla V100-PCIE-32GB\"', 'tesla', 32000),
|
| 12 |
+
'v100-s32': ('\"Tesla V100-SXM2-32GB\"', 'tesla', 32000),
|
| 13 |
+
'v100-p16': ('\"Tesla P100-PCIE-16GB\"', 'tesla', 16000),
|
| 14 |
+
}
|
| 15 |
+
|
| 16 |
+
def get_gpus(min_mem=10000, arch=('tesla', 'quadro', 'rtx')):
|
| 17 |
+
gpu_names = []
|
| 18 |
+
for k, (gpu_name, gpu_arch, gpu_mem) in GPUS.items():
|
| 19 |
+
if gpu_mem >= min_mem and gpu_arch in arch:
|
| 20 |
+
gpu_names.append(gpu_name)
|
| 21 |
+
|
| 22 |
+
assert len(gpu_names) > 0, 'Suitable GPU model could not be found'
|
| 23 |
+
|
| 24 |
+
return gpu_names
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def execute_task_on_cluster(
|
| 28 |
+
script,
|
| 29 |
+
exp_name,
|
| 30 |
+
output_dir,
|
| 31 |
+
condor_dir,
|
| 32 |
+
cfg_file,
|
| 33 |
+
num_exp=1,
|
| 34 |
+
exp_opts=None,
|
| 35 |
+
bid_amount=10,
|
| 36 |
+
num_workers=2,
|
| 37 |
+
memory=64000,
|
| 38 |
+
gpu_min_mem=10000,
|
| 39 |
+
gpu_arch=('tesla', 'quadro', 'rtx'),
|
| 40 |
+
num_gpus=1
|
| 41 |
+
):
|
| 42 |
+
# copy config to a new experiment directory and source from there.
|
| 43 |
+
# this makes sure the correct config is copied even if you change the config file
|
| 44 |
+
# after starting the experiment and before the first job is submitted
|
| 45 |
+
temp_config_dir = os.path.join(os.path.dirname(condor_dir), 'temp_configs', exp_name)
|
| 46 |
+
os.makedirs(temp_config_dir, exist_ok=True)
|
| 47 |
+
new_cfg_file = os.path.join(temp_config_dir, 'config.yaml')
|
| 48 |
+
shutil.copy(src=cfg_file, dst=new_cfg_file)
|
| 49 |
+
|
| 50 |
+
gpus = get_gpus(min_mem=gpu_min_mem, arch=gpu_arch)
|
| 51 |
+
|
| 52 |
+
gpus = ' || '.join([f'CUDADeviceName=={x}' for x in gpus])
|
| 53 |
+
|
| 54 |
+
condor_log_dir = os.path.join(condor_dir, 'condorlog', exp_name)
|
| 55 |
+
os.makedirs(condor_log_dir, exist_ok=True)
|
| 56 |
+
submission = f'executable = {condor_log_dir}/{exp_name}_run.sh\n' \
|
| 57 |
+
'arguments = $(Process) $(Cluster)\n' \
|
| 58 |
+
f'error = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).err\n' \
|
| 59 |
+
f'output = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).out\n' \
|
| 60 |
+
f'log = {condor_log_dir}/{exp_name}_$(Cluster).$(Process).log\n' \
|
| 61 |
+
f'request_memory = {memory}\n' \
|
| 62 |
+
f'request_cpus={int(num_workers)}\n' \
|
| 63 |
+
f'request_gpus={num_gpus}\n' \
|
| 64 |
+
f'requirements={gpus}\n' \
|
| 65 |
+
f'+MaxRunningPrice = 500\n' \
|
| 66 |
+
f'queue {num_exp}'
|
| 67 |
+
# f'request_cpus={int(num_workers/2)}\n' \
|
| 68 |
+
# f'+RunningPriceExceededAction = \"kill\"\n' \
|
| 69 |
+
print('<<< Condor Submission >>> ')
|
| 70 |
+
print(submission)
|
| 71 |
+
|
| 72 |
+
with open(f'{condor_log_dir}/{exp_name}_submit.sub', 'w') as f:
|
| 73 |
+
f.write(submission)
|
| 74 |
+
|
| 75 |
+
# output_dir = os.path.join(output_dir, exp_name)
|
| 76 |
+
logger.info(f'The logs for this experiments can be found under: {condor_log_dir}')
|
| 77 |
+
logger.info(f'The outputs for this experiments can be found under: {output_dir}')
|
| 78 |
+
## This is the trick. Notice there is no --cluster here
|
| 79 |
+
bash = 'export PYTHONBUFFERED=1\n export PATH=$PATH\n ' \
|
| 80 |
+
f'{sys.executable} {script} --cfg {new_cfg_file} --cfg_id $1'
|
| 81 |
+
|
| 82 |
+
if exp_opts is not None:
|
| 83 |
+
bash += ' --opts '
|
| 84 |
+
for opt in exp_opts:
|
| 85 |
+
bash += f'{opt} '
|
| 86 |
+
bash += 'SYSTEM.CLUSTER_NODE $2.$1'
|
| 87 |
+
else:
|
| 88 |
+
bash += ' --opts SYSTEM.CLUSTER_NODE $2.$1'
|
| 89 |
+
|
| 90 |
+
executable_path = f'{condor_log_dir}/{exp_name}_run.sh'
|
| 91 |
+
|
| 92 |
+
with open(executable_path, 'w') as f:
|
| 93 |
+
f.write(bash)
|
| 94 |
+
|
| 95 |
+
os.chmod(executable_path, stat.S_IRWXU)
|
| 96 |
+
|
| 97 |
+
cmd = ['condor_submit_bid', f'{bid_amount}', f'{condor_log_dir}/{exp_name}_submit.sub']
|
| 98 |
+
logger.info('Executing ' + ' '.join(cmd))
|
| 99 |
+
subprocess.call(cmd)
|
utils/colorwheel.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def make_color_wheel_image(img_width, img_height):
|
| 6 |
+
"""
|
| 7 |
+
Creates a color wheel based image of given width and height
|
| 8 |
+
Args:
|
| 9 |
+
img_width (int):
|
| 10 |
+
img_height (int):
|
| 11 |
+
|
| 12 |
+
Returns:
|
| 13 |
+
opencv image (numpy array): color wheel based image
|
| 14 |
+
"""
|
| 15 |
+
hue = np.fromfunction(lambda i, j: (np.arctan2(i-img_height/2, img_width/2-j) + np.pi)*(180/np.pi)/2,
|
| 16 |
+
(img_height, img_width), dtype=np.float)
|
| 17 |
+
saturation = np.ones((img_height, img_width)) * 255
|
| 18 |
+
value = np.ones((img_height, img_width)) * 255
|
| 19 |
+
hsl = np.dstack((hue, saturation, value))
|
| 20 |
+
color_map = cv2.cvtColor(np.array(hsl, dtype=np.uint8), cv2.COLOR_HSV2BGR)
|
| 21 |
+
return color_map
|
| 22 |
+
|
utils/config.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import operator
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
import time
|
| 6 |
+
from functools import reduce
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
|
| 9 |
+
import configargparse
|
| 10 |
+
import yaml
|
| 11 |
+
from flatten_dict import flatten, unflatten
|
| 12 |
+
from loguru import logger
|
| 13 |
+
from yacs.config import CfgNode as CN
|
| 14 |
+
|
| 15 |
+
from utils.cluster import execute_task_on_cluster
|
| 16 |
+
from utils.default_hparams import hparams
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def parse_args():
|
| 20 |
+
def add_common_cmdline_args(parser):
|
| 21 |
+
# for cluster runs
|
| 22 |
+
parser.add_argument('--cfg', required=True, type=str, help='cfg file path')
|
| 23 |
+
parser.add_argument('--opts', default=[], nargs='*', help='additional options to update config')
|
| 24 |
+
parser.add_argument('--cfg_id', type=int, default=0, help='cfg id to run when multiple experiments are spawned')
|
| 25 |
+
parser.add_argument('--cluster', default=False, action='store_true', help='creates submission files for cluster')
|
| 26 |
+
parser.add_argument('--bid', type=int, default=10, help='amount of bid for cluster')
|
| 27 |
+
parser.add_argument('--memory', type=int, default=64000, help='memory amount for cluster')
|
| 28 |
+
parser.add_argument('--gpu_min_mem', type=int, default=12000, help='minimum amount of GPU memory')
|
| 29 |
+
parser.add_argument('--gpu_arch', default=['tesla', 'quadro', 'rtx'],
|
| 30 |
+
nargs='*', help='additional options to update config')
|
| 31 |
+
parser.add_argument('--num_cpus', type=int, default=8, help='num cpus for cluster')
|
| 32 |
+
return parser
|
| 33 |
+
|
| 34 |
+
# For Blender main parser
|
| 35 |
+
arg_formatter = configargparse.ArgumentDefaultsHelpFormatter
|
| 36 |
+
cfg_parser = configargparse.YAMLConfigFileParser
|
| 37 |
+
description = 'PyTorch implementation of DECO'
|
| 38 |
+
|
| 39 |
+
parser = configargparse.ArgumentParser(formatter_class=arg_formatter,
|
| 40 |
+
config_file_parser_class=cfg_parser,
|
| 41 |
+
description=description,
|
| 42 |
+
prog='deco')
|
| 43 |
+
|
| 44 |
+
parser = add_common_cmdline_args(parser)
|
| 45 |
+
|
| 46 |
+
args = parser.parse_args()
|
| 47 |
+
print(args, end='\n\n')
|
| 48 |
+
|
| 49 |
+
return args
|
| 50 |
+
|
| 51 |
+
def get_hparams_defaults():
|
| 52 |
+
"""Get a yacs hparamsNode object with default values for my_project."""
|
| 53 |
+
# Return a clone so that the defaults will not be altered
|
| 54 |
+
# This is for the "local variable" use pattern
|
| 55 |
+
return hparams.clone()
|
| 56 |
+
|
| 57 |
+
def update_hparams(hparams_file):
|
| 58 |
+
hparams = get_hparams_defaults()
|
| 59 |
+
hparams.merge_from_file(hparams_file)
|
| 60 |
+
return hparams.clone()
|
| 61 |
+
|
| 62 |
+
def update_hparams_from_dict(cfg_dict):
|
| 63 |
+
hparams = get_hparams_defaults()
|
| 64 |
+
cfg = hparams.load_cfg(str(cfg_dict))
|
| 65 |
+
hparams.merge_from_other_cfg(cfg)
|
| 66 |
+
return hparams.clone()
|
| 67 |
+
|
| 68 |
+
def get_grid_search_configs(config, excluded_keys=[]):
|
| 69 |
+
"""
|
| 70 |
+
:param config: dictionary with the configurations
|
| 71 |
+
:return: The different configurations
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
def bool_to_string(x: Union[List[bool], bool]) -> Union[List[str], str]:
|
| 75 |
+
"""
|
| 76 |
+
boolean to string conversion
|
| 77 |
+
:param x: list or bool to be converted
|
| 78 |
+
:return: string converted thinghat
|
| 79 |
+
"""
|
| 80 |
+
if isinstance(x, bool):
|
| 81 |
+
return [str(x)]
|
| 82 |
+
for i, j in enumerate(x):
|
| 83 |
+
x[i] = str(j)
|
| 84 |
+
return x
|
| 85 |
+
|
| 86 |
+
# exclude from grid search
|
| 87 |
+
|
| 88 |
+
flattened_config_dict = flatten(config, reducer='path')
|
| 89 |
+
hyper_params = []
|
| 90 |
+
|
| 91 |
+
for k,v in flattened_config_dict.items():
|
| 92 |
+
if isinstance(v,list):
|
| 93 |
+
if k in excluded_keys:
|
| 94 |
+
flattened_config_dict[k] = ['+'.join(v)]
|
| 95 |
+
elif len(v) > 1:
|
| 96 |
+
hyper_params += [k]
|
| 97 |
+
|
| 98 |
+
if isinstance(v, list) and isinstance(v[0], bool) :
|
| 99 |
+
flattened_config_dict[k] = bool_to_string(v)
|
| 100 |
+
|
| 101 |
+
if not isinstance(v,list):
|
| 102 |
+
if isinstance(v, bool):
|
| 103 |
+
flattened_config_dict[k] = bool_to_string(v)
|
| 104 |
+
else:
|
| 105 |
+
flattened_config_dict[k] = [v]
|
| 106 |
+
|
| 107 |
+
keys, values = zip(*flattened_config_dict.items())
|
| 108 |
+
experiments = [dict(zip(keys, v)) for v in itertools.product(*values)]
|
| 109 |
+
|
| 110 |
+
for exp_id, exp in enumerate(experiments):
|
| 111 |
+
for param in excluded_keys:
|
| 112 |
+
exp[param] = exp[param].strip().split('+')
|
| 113 |
+
for param_name, param_value in exp.items():
|
| 114 |
+
# print(param_name,type(param_value))
|
| 115 |
+
if isinstance(param_value, list) and (param_value[0] in ['True', 'False']):
|
| 116 |
+
exp[param_name] = [True if x == 'True' else False for x in param_value]
|
| 117 |
+
if param_value in ['True', 'False']:
|
| 118 |
+
if param_value == 'True':
|
| 119 |
+
exp[param_name] = True
|
| 120 |
+
else:
|
| 121 |
+
exp[param_name] = False
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
experiments[exp_id] = unflatten(exp, splitter='path')
|
| 125 |
+
|
| 126 |
+
return experiments, hyper_params
|
| 127 |
+
|
| 128 |
+
def get_from_dict(dict, keys):
|
| 129 |
+
return reduce(operator.getitem, keys, dict)
|
| 130 |
+
|
| 131 |
+
def save_dict_to_yaml(obj, filename, mode='w'):
|
| 132 |
+
with open(filename, mode) as f:
|
| 133 |
+
yaml.dump(obj, f, default_flow_style=False)
|
| 134 |
+
|
| 135 |
+
def run_grid_search_experiments(
|
| 136 |
+
args,
|
| 137 |
+
script='train.py',
|
| 138 |
+
change_wt_name=True
|
| 139 |
+
):
|
| 140 |
+
cfg = yaml.safe_load(open(args.cfg))
|
| 141 |
+
# parse config file to split into a list of configs with tuning hyperparameters separated
|
| 142 |
+
# Also return the names of tuned hyperparameters hyperparameters
|
| 143 |
+
different_configs, hyperparams = get_grid_search_configs(
|
| 144 |
+
cfg,
|
| 145 |
+
excluded_keys=['TRAINING/DATASETS', 'TRAINING/DATASET_MIX_PDF', 'VALIDATION/DATASETS'],
|
| 146 |
+
)
|
| 147 |
+
logger.info(f'Grid search hparams: \n {hyperparams}')
|
| 148 |
+
|
| 149 |
+
# The config file may be missing some default values, so we need to add them
|
| 150 |
+
different_configs = [update_hparams_from_dict(c) for c in different_configs]
|
| 151 |
+
logger.info(f'======> Number of experiment configurations is {len(different_configs)}')
|
| 152 |
+
|
| 153 |
+
config_to_run = CN(different_configs[args.cfg_id])
|
| 154 |
+
|
| 155 |
+
if args.cluster:
|
| 156 |
+
execute_task_on_cluster(
|
| 157 |
+
script=script,
|
| 158 |
+
exp_name=config_to_run.EXP_NAME,
|
| 159 |
+
output_dir=config_to_run.OUTPUT_DIR,
|
| 160 |
+
condor_dir=config_to_run.CONDOR_DIR,
|
| 161 |
+
cfg_file=args.cfg,
|
| 162 |
+
num_exp=len(different_configs),
|
| 163 |
+
bid_amount=args.bid,
|
| 164 |
+
num_workers=config_to_run.DATASET.NUM_WORKERS,
|
| 165 |
+
memory=args.memory,
|
| 166 |
+
exp_opts=args.opts,
|
| 167 |
+
gpu_min_mem=args.gpu_min_mem,
|
| 168 |
+
gpu_arch=args.gpu_arch,
|
| 169 |
+
)
|
| 170 |
+
exit()
|
| 171 |
+
|
| 172 |
+
# ==== create logdir using hyperparam settings
|
| 173 |
+
logtime = time.strftime('%d-%m-%Y_%H-%M-%S')
|
| 174 |
+
logdir = f'{logtime}_{config_to_run.EXP_NAME}'
|
| 175 |
+
wt_file = config_to_run.EXP_NAME + '_'
|
| 176 |
+
for hp in hyperparams:
|
| 177 |
+
v = get_from_dict(different_configs[args.cfg_id], hp.split('/'))
|
| 178 |
+
logdir += f'_{hp.replace("/", ".").replace("_", "").lower()}-{v}'
|
| 179 |
+
wt_file += f'{hp.replace("/", ".").replace("_", "").lower()}-{v}_'
|
| 180 |
+
logdir = os.path.join(config_to_run.OUTPUT_DIR, logdir)
|
| 181 |
+
os.makedirs(logdir, exist_ok=True)
|
| 182 |
+
config_to_run.LOGDIR = logdir
|
| 183 |
+
|
| 184 |
+
wt_file += 'best.pth'
|
| 185 |
+
wt_path = os.path.join(os.path.dirname(config_to_run.TRAINING.BEST_MODEL_PATH), wt_file)
|
| 186 |
+
if change_wt_name: config_to_run.TRAINING.BEST_MODEL_PATH = wt_path
|
| 187 |
+
|
| 188 |
+
shutil.copy(src=args.cfg, dst=os.path.join(logdir, 'config.yaml'))
|
| 189 |
+
|
| 190 |
+
# save config
|
| 191 |
+
save_dict_to_yaml(
|
| 192 |
+
unflatten(flatten(config_to_run)),
|
| 193 |
+
os.path.join(config_to_run.LOGDIR, 'config_to_run.yaml')
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
return config_to_run
|
utils/default_hparams.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode as CN
|
| 2 |
+
|
| 3 |
+
# Set default hparams to construct new default config
|
| 4 |
+
# Make sure the defaults are same as in parser
|
| 5 |
+
hparams = CN()
|
| 6 |
+
|
| 7 |
+
# General settings
|
| 8 |
+
hparams.EXP_NAME = 'default'
|
| 9 |
+
hparams.PROJECT_NAME = 'default'
|
| 10 |
+
hparams.OUTPUT_DIR = 'deco_results/'
|
| 11 |
+
hparams.CONDOR_DIR = '/is/cluster/work/achatterjee/condor/rich/'
|
| 12 |
+
hparams.LOGDIR = ''
|
| 13 |
+
|
| 14 |
+
# Dataset hparams
|
| 15 |
+
hparams.DATASET = CN()
|
| 16 |
+
hparams.DATASET.BATCH_SIZE = 64
|
| 17 |
+
hparams.DATASET.NUM_WORKERS = 4
|
| 18 |
+
hparams.DATASET.NORMALIZE_IMAGES = True
|
| 19 |
+
|
| 20 |
+
# Optimizer hparams
|
| 21 |
+
hparams.OPTIMIZER = CN()
|
| 22 |
+
hparams.OPTIMIZER.TYPE = 'adam'
|
| 23 |
+
hparams.OPTIMIZER.LR = 5e-5
|
| 24 |
+
hparams.OPTIMIZER.NUM_UPDATE_LR = 10
|
| 25 |
+
|
| 26 |
+
# Training hparams
|
| 27 |
+
hparams.TRAINING = CN()
|
| 28 |
+
hparams.TRAINING.ENCODER = 'hrnet'
|
| 29 |
+
hparams.TRAINING.CONTEXT = True
|
| 30 |
+
hparams.TRAINING.NUM_EPOCHS = 50
|
| 31 |
+
hparams.TRAINING.SUMMARY_STEPS = 100
|
| 32 |
+
hparams.TRAINING.CHECKPOINT_EPOCHS = 5
|
| 33 |
+
hparams.TRAINING.NUM_EARLY_STOP = 10
|
| 34 |
+
hparams.TRAINING.DATASETS = ['rich']
|
| 35 |
+
hparams.TRAINING.DATASET_MIX_PDF = ['1.']
|
| 36 |
+
hparams.TRAINING.DATASET_ROOT_PATH = '/is/cluster/work/achatterjee/rich/npzs'
|
| 37 |
+
hparams.TRAINING.BEST_MODEL_PATH = '/is/cluster/work/achatterjee/weights/rich/exp/rich_exp.pth'
|
| 38 |
+
hparams.TRAINING.LOSS_WEIGHTS = 1.
|
| 39 |
+
hparams.TRAINING.PAL_LOSS_WEIGHTS = 1.
|
| 40 |
+
|
| 41 |
+
# Training hparams
|
| 42 |
+
hparams.VALIDATION = CN()
|
| 43 |
+
hparams.VALIDATION.SUMMARY_STEPS = 100
|
| 44 |
+
hparams.VALIDATION.DATASETS = ['rich']
|
| 45 |
+
hparams.VALIDATION.MAIN_DATASET = 'rich'
|
utils/diff_renderer.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# from https://gitlab.tuebingen.mpg.de/mkocabas/projects/-/blob/master/pare/pare/utils/diff_renderer.py
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from pytorch3d.renderer import (
|
| 8 |
+
PerspectiveCameras,
|
| 9 |
+
RasterizationSettings,
|
| 10 |
+
DirectionalLights,
|
| 11 |
+
BlendParams,
|
| 12 |
+
HardFlatShader,
|
| 13 |
+
MeshRasterizer,
|
| 14 |
+
TexturesVertex,
|
| 15 |
+
TexturesAtlas
|
| 16 |
+
)
|
| 17 |
+
from pytorch3d.structures import Meshes
|
| 18 |
+
|
| 19 |
+
from .image_utils import get_default_camera
|
| 20 |
+
from .smpl_uv import get_tenet_texture
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class MeshRendererWithDepth(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
A class for rendering a batch of heterogeneous meshes. The class should
|
| 26 |
+
be initialized with a rasterizer and shader class which each have a forward
|
| 27 |
+
function.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, rasterizer, shader):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.rasterizer = rasterizer
|
| 33 |
+
self.shader = shader
|
| 34 |
+
|
| 35 |
+
def forward(self, meshes_world, **kwargs) -> torch.Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Render a batch of images from a batch of meshes by rasterizing and then
|
| 38 |
+
shading.
|
| 39 |
+
|
| 40 |
+
NOTE: If the blur radius for rasterization is > 0.0, some pixels can
|
| 41 |
+
have one or more barycentric coordinates lying outside the range [0, 1].
|
| 42 |
+
For a pixel with out of bounds barycentric coordinates with respect to a
|
| 43 |
+
face f, clipping is required before interpolating the texture uv
|
| 44 |
+
coordinates and z buffer so that the colors and depths are limited to
|
| 45 |
+
the range for the corresponding face.
|
| 46 |
+
"""
|
| 47 |
+
fragments = self.rasterizer(meshes_world, **kwargs)
|
| 48 |
+
images = self.shader(fragments, meshes_world, **kwargs)
|
| 49 |
+
|
| 50 |
+
mask = (fragments.zbuf > -1).float()
|
| 51 |
+
|
| 52 |
+
zbuf = fragments.zbuf.view(images.shape[0], -1)
|
| 53 |
+
# print(images.shape, zbuf.shape)
|
| 54 |
+
depth = (zbuf - zbuf.min(-1, keepdims=True).values) / \
|
| 55 |
+
(zbuf.max(-1, keepdims=True).values - zbuf.min(-1, keepdims=True).values)
|
| 56 |
+
depth = depth.reshape(*images.shape[:3] + (1,))
|
| 57 |
+
|
| 58 |
+
images = torch.cat([images[:, :, :, :3], mask, depth], dim=-1)
|
| 59 |
+
return images
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class DifferentiableRenderer(nn.Module):
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
img_h,
|
| 66 |
+
img_w,
|
| 67 |
+
focal_length,
|
| 68 |
+
device='cuda',
|
| 69 |
+
background_color=(0.0, 0.0, 0.0),
|
| 70 |
+
texture_mode='smplpix',
|
| 71 |
+
vertex_colors=None,
|
| 72 |
+
face_textures=None,
|
| 73 |
+
smpl_faces=None,
|
| 74 |
+
is_train=False,
|
| 75 |
+
is_cam_batch=False,
|
| 76 |
+
):
|
| 77 |
+
super(DifferentiableRenderer, self).__init__()
|
| 78 |
+
self.x = 'a'
|
| 79 |
+
self.img_h = img_h
|
| 80 |
+
self.img_w = img_w
|
| 81 |
+
self.device = device
|
| 82 |
+
self.focal_length = focal_length
|
| 83 |
+
K, R = get_default_camera(focal_length, img_h, img_w, is_cam_batch=is_cam_batch)
|
| 84 |
+
K, R = K.to(device), R.to(device)
|
| 85 |
+
|
| 86 |
+
# T = torch.tensor([[0, 0, 2.5 * self.focal_length / max(self.img_h, self.img_w)]]).to(device)
|
| 87 |
+
if is_cam_batch:
|
| 88 |
+
T = torch.zeros((K.shape[0], 3)).to(device)
|
| 89 |
+
else:
|
| 90 |
+
T = torch.tensor([[0.0, 0.0, 0.0]]).to(device)
|
| 91 |
+
self.background_color = background_color
|
| 92 |
+
self.renderer = None
|
| 93 |
+
smpl_faces = smpl_faces
|
| 94 |
+
|
| 95 |
+
if texture_mode == 'smplpix':
|
| 96 |
+
face_colors = get_tenet_texture(mode=texture_mode).to(device).float()
|
| 97 |
+
vertex_colors = torch.from_numpy(
|
| 98 |
+
np.load(f'data/smpl/{texture_mode}_vertex_colors.npy')[:,:3]
|
| 99 |
+
).unsqueeze(0).to(device).float()
|
| 100 |
+
if texture_mode == 'partseg':
|
| 101 |
+
vertex_colors = vertex_colors[..., :3].unsqueeze(0).to(device)
|
| 102 |
+
face_colors = face_textures.to(device)
|
| 103 |
+
if texture_mode == 'deco':
|
| 104 |
+
vertex_colors = vertex_colors[..., :3].to(device)
|
| 105 |
+
face_colors = face_textures.to(device)
|
| 106 |
+
|
| 107 |
+
self.register_buffer('K', K)
|
| 108 |
+
self.register_buffer('R', R)
|
| 109 |
+
self.register_buffer('T', T)
|
| 110 |
+
self.register_buffer('face_colors', face_colors)
|
| 111 |
+
self.register_buffer('vertex_colors', vertex_colors)
|
| 112 |
+
self.register_buffer('smpl_faces', smpl_faces)
|
| 113 |
+
|
| 114 |
+
self.set_requires_grad(is_train)
|
| 115 |
+
|
| 116 |
+
def set_requires_grad(self, val=False):
|
| 117 |
+
self.K.requires_grad_(val)
|
| 118 |
+
self.R.requires_grad_(val)
|
| 119 |
+
self.T.requires_grad_(val)
|
| 120 |
+
self.face_colors.requires_grad_(val)
|
| 121 |
+
self.vertex_colors.requires_grad_(val)
|
| 122 |
+
# check if smpl_faces is a FloatTensor as requires_grad_ is not defined for LongTensor
|
| 123 |
+
if isinstance(self.smpl_faces, torch.FloatTensor):
|
| 124 |
+
self.smpl_faces.requires_grad_(val)
|
| 125 |
+
|
| 126 |
+
def forward(self, vertices, faces=None, R=None, T=None):
|
| 127 |
+
raise NotImplementedError
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class Pytorch3D(DifferentiableRenderer):
|
| 131 |
+
def __init__(
|
| 132 |
+
self,
|
| 133 |
+
img_h,
|
| 134 |
+
img_w,
|
| 135 |
+
focal_length,
|
| 136 |
+
device='cuda',
|
| 137 |
+
background_color=(0.0, 0.0, 0.0),
|
| 138 |
+
texture_mode='smplpix',
|
| 139 |
+
vertex_colors=None,
|
| 140 |
+
face_textures=None,
|
| 141 |
+
smpl_faces=None,
|
| 142 |
+
model_type='smpl',
|
| 143 |
+
is_train=False,
|
| 144 |
+
is_cam_batch=False,
|
| 145 |
+
):
|
| 146 |
+
super(Pytorch3D, self).__init__(
|
| 147 |
+
img_h,
|
| 148 |
+
img_w,
|
| 149 |
+
focal_length,
|
| 150 |
+
device=device,
|
| 151 |
+
background_color=background_color,
|
| 152 |
+
texture_mode=texture_mode,
|
| 153 |
+
vertex_colors=vertex_colors,
|
| 154 |
+
face_textures=face_textures,
|
| 155 |
+
smpl_faces=smpl_faces,
|
| 156 |
+
is_train=is_train,
|
| 157 |
+
is_cam_batch=is_cam_batch,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# this R converts the camera from pyrender NDC to
|
| 161 |
+
# OpenGL coordinate frame. It is basicall R(180, X) x R(180, Y)
|
| 162 |
+
# I manually defined it here for convenience
|
| 163 |
+
self.R = self.R @ torch.tensor(
|
| 164 |
+
[[[ -1.0, 0.0, 0.0],
|
| 165 |
+
[ 0.0, -1.0, 0.0],
|
| 166 |
+
[ 0.0, 0.0, 1.0]]],
|
| 167 |
+
dtype=self.R.dtype, device=self.R.device,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
if is_cam_batch:
|
| 171 |
+
focal_length = self.focal_length
|
| 172 |
+
else:
|
| 173 |
+
focal_length = self.focal_length[None, :]
|
| 174 |
+
|
| 175 |
+
principal_point = ((self.img_w // 2, self.img_h // 2),)
|
| 176 |
+
image_size = ((self.img_h, self.img_w),)
|
| 177 |
+
|
| 178 |
+
cameras = PerspectiveCameras(
|
| 179 |
+
device=self.device,
|
| 180 |
+
focal_length=focal_length,
|
| 181 |
+
principal_point=principal_point,
|
| 182 |
+
R=self.R,
|
| 183 |
+
T=self.T,
|
| 184 |
+
in_ndc=False,
|
| 185 |
+
image_size=image_size,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
for param in cameras.parameters():
|
| 189 |
+
param.requires_grad_(False)
|
| 190 |
+
|
| 191 |
+
raster_settings = RasterizationSettings(
|
| 192 |
+
image_size=(self.img_h, self.img_w),
|
| 193 |
+
blur_radius=0.0,
|
| 194 |
+
max_faces_per_bin=20000,
|
| 195 |
+
faces_per_pixel=1,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
lights = DirectionalLights(
|
| 199 |
+
device=self.device,
|
| 200 |
+
ambient_color=((1.0, 1.0, 1.0),),
|
| 201 |
+
diffuse_color=((0.0, 0.0, 0.0),),
|
| 202 |
+
specular_color=((0.0, 0.0, 0.0),),
|
| 203 |
+
direction=((0, 1, 0),),
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
blend_params = BlendParams(background_color=self.background_color)
|
| 207 |
+
|
| 208 |
+
shader = HardFlatShader(device=self.device,
|
| 209 |
+
cameras=cameras,
|
| 210 |
+
blend_params=blend_params,
|
| 211 |
+
lights=lights)
|
| 212 |
+
|
| 213 |
+
self.textures = TexturesVertex(verts_features=self.vertex_colors)
|
| 214 |
+
|
| 215 |
+
self.renderer = MeshRendererWithDepth(
|
| 216 |
+
rasterizer=MeshRasterizer(
|
| 217 |
+
cameras=cameras,
|
| 218 |
+
raster_settings=raster_settings
|
| 219 |
+
),
|
| 220 |
+
shader=shader,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def forward(self, vertices, faces=None, R=None, T=None, face_atlas=None):
|
| 224 |
+
batch_size = vertices.shape[0]
|
| 225 |
+
if faces is None:
|
| 226 |
+
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
| 227 |
+
|
| 228 |
+
if R is None:
|
| 229 |
+
R = self.R.expand(batch_size, -1, -1)
|
| 230 |
+
|
| 231 |
+
if T is None:
|
| 232 |
+
T = self.T.expand(batch_size, -1)
|
| 233 |
+
|
| 234 |
+
# convert camera translation to pytorch3d coordinate frame
|
| 235 |
+
T = torch.bmm(R, T.unsqueeze(-1)).squeeze(-1)
|
| 236 |
+
|
| 237 |
+
vertex_textures = TexturesVertex(
|
| 238 |
+
verts_features=self.vertex_colors.expand(batch_size, -1, -1)
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
# face_textures needed because vertex_texture cause interpolation at boundaries
|
| 242 |
+
if face_atlas:
|
| 243 |
+
face_textures = TexturesAtlas(atlas=face_atlas)
|
| 244 |
+
else:
|
| 245 |
+
face_textures = TexturesAtlas(atlas=self.face_colors)
|
| 246 |
+
|
| 247 |
+
# we may need to rotate the mesh
|
| 248 |
+
meshes = Meshes(verts=vertices, faces=faces, textures=face_textures)
|
| 249 |
+
images = self.renderer(meshes, R=R, T=T)
|
| 250 |
+
images = images.permute(0, 3, 1, 2)
|
| 251 |
+
return images
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class NeuralMeshRenderer(DifferentiableRenderer):
|
| 255 |
+
def __init__(self, *args, **kwargs):
|
| 256 |
+
import neural_renderer as nr
|
| 257 |
+
|
| 258 |
+
super(NeuralMeshRenderer, self).__init__(*args, **kwargs)
|
| 259 |
+
|
| 260 |
+
self.neural_renderer = nr.Renderer(
|
| 261 |
+
dist_coeffs=None,
|
| 262 |
+
orig_size=self.img_size,
|
| 263 |
+
image_size=self.img_size,
|
| 264 |
+
light_intensity_ambient=1,
|
| 265 |
+
light_intensity_directional=0,
|
| 266 |
+
anti_aliasing=False,
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
def forward(self, vertices, faces=None, R=None, T=None):
|
| 270 |
+
batch_size = vertices.shape[0]
|
| 271 |
+
if faces is None:
|
| 272 |
+
faces = self.smpl_faces.expand(batch_size, -1, -1)
|
| 273 |
+
|
| 274 |
+
if R is None:
|
| 275 |
+
R = self.R.expand(batch_size, -1, -1)
|
| 276 |
+
|
| 277 |
+
if T is None:
|
| 278 |
+
T = self.T.expand(batch_size, -1)
|
| 279 |
+
rgb, depth, mask = self.neural_renderer(
|
| 280 |
+
vertices,
|
| 281 |
+
faces,
|
| 282 |
+
textures=self.face_colors.expand(batch_size, -1, -1, -1, -1, -1),
|
| 283 |
+
K=self.K.expand(batch_size, -1, -1),
|
| 284 |
+
R=R,
|
| 285 |
+
t=T.unsqueeze(1),
|
| 286 |
+
)
|
| 287 |
+
return torch.cat([rgb, depth.unsqueeze(1), mask.unsqueeze(1)], dim=1)
|
utils/get_cfg.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from yacs.config import CfgNode
|
| 2 |
+
|
| 3 |
+
_VALID_TYPES = {tuple, list, str, int, float, bool}
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def convert_to_dict(cfg_node, key_list=[]):
|
| 7 |
+
""" Convert a config node to dictionary """
|
| 8 |
+
if not isinstance(cfg_node, CfgNode):
|
| 9 |
+
if type(cfg_node) not in _VALID_TYPES:
|
| 10 |
+
print("Key {} with value {} is not a valid type; valid types: {}".format(
|
| 11 |
+
".".join(key_list), type(cfg_node), _VALID_TYPES), )
|
| 12 |
+
return cfg_node
|
| 13 |
+
else:
|
| 14 |
+
cfg_dict = dict(cfg_node)
|
| 15 |
+
for k, v in cfg_dict.items():
|
| 16 |
+
cfg_dict[k] = convert_to_dict(v, key_list + [k])
|
| 17 |
+
return cfg_dict
|
utils/hrnet.py
ADDED
|
@@ -0,0 +1,625 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from loguru import logger
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from yacs.config import CfgNode as CN
|
| 8 |
+
|
| 9 |
+
models = [
|
| 10 |
+
'hrnet_w32',
|
| 11 |
+
'hrnet_w48',
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
BN_MOMENTUM = 0.1
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 18 |
+
"""3x3 convolution with padding"""
|
| 19 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 20 |
+
padding=1, bias=False)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class BasicBlock(nn.Module):
|
| 24 |
+
expansion = 1
|
| 25 |
+
|
| 26 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 27 |
+
super(BasicBlock, self).__init__()
|
| 28 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 29 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 30 |
+
self.relu = nn.ReLU(inplace=True)
|
| 31 |
+
self.conv2 = conv3x3(planes, planes)
|
| 32 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 33 |
+
self.downsample = downsample
|
| 34 |
+
self.stride = stride
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
residual = x
|
| 38 |
+
|
| 39 |
+
out = self.conv1(x)
|
| 40 |
+
out = self.bn1(out)
|
| 41 |
+
out = self.relu(out)
|
| 42 |
+
|
| 43 |
+
out = self.conv2(out)
|
| 44 |
+
out = self.bn2(out)
|
| 45 |
+
|
| 46 |
+
if self.downsample is not None:
|
| 47 |
+
residual = self.downsample(x)
|
| 48 |
+
|
| 49 |
+
out += residual
|
| 50 |
+
out = self.relu(out)
|
| 51 |
+
|
| 52 |
+
return out
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Bottleneck(nn.Module):
|
| 56 |
+
expansion = 4
|
| 57 |
+
|
| 58 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 59 |
+
super(Bottleneck, self).__init__()
|
| 60 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 61 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 62 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 63 |
+
padding=1, bias=False)
|
| 64 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
| 65 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1,
|
| 66 |
+
bias=False)
|
| 67 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
| 68 |
+
momentum=BN_MOMENTUM)
|
| 69 |
+
self.relu = nn.ReLU(inplace=True)
|
| 70 |
+
self.downsample = downsample
|
| 71 |
+
self.stride = stride
|
| 72 |
+
|
| 73 |
+
def forward(self, x):
|
| 74 |
+
residual = x
|
| 75 |
+
|
| 76 |
+
out = self.conv1(x)
|
| 77 |
+
out = self.bn1(out)
|
| 78 |
+
out = self.relu(out)
|
| 79 |
+
|
| 80 |
+
out = self.conv2(out)
|
| 81 |
+
out = self.bn2(out)
|
| 82 |
+
out = self.relu(out)
|
| 83 |
+
|
| 84 |
+
out = self.conv3(out)
|
| 85 |
+
out = self.bn3(out)
|
| 86 |
+
|
| 87 |
+
if self.downsample is not None:
|
| 88 |
+
residual = self.downsample(x)
|
| 89 |
+
|
| 90 |
+
out += residual
|
| 91 |
+
out = self.relu(out)
|
| 92 |
+
|
| 93 |
+
return out
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
class HighResolutionModule(nn.Module):
|
| 97 |
+
def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
|
| 98 |
+
num_channels, fuse_method, multi_scale_output=True):
|
| 99 |
+
super(HighResolutionModule, self).__init__()
|
| 100 |
+
self._check_branches(
|
| 101 |
+
num_branches, blocks, num_blocks, num_inchannels, num_channels)
|
| 102 |
+
|
| 103 |
+
self.num_inchannels = num_inchannels
|
| 104 |
+
self.fuse_method = fuse_method
|
| 105 |
+
self.num_branches = num_branches
|
| 106 |
+
|
| 107 |
+
self.multi_scale_output = multi_scale_output
|
| 108 |
+
|
| 109 |
+
self.branches = self._make_branches(
|
| 110 |
+
num_branches, blocks, num_blocks, num_channels)
|
| 111 |
+
self.fuse_layers = self._make_fuse_layers()
|
| 112 |
+
self.relu = nn.ReLU(True)
|
| 113 |
+
|
| 114 |
+
def _check_branches(self, num_branches, blocks, num_blocks,
|
| 115 |
+
num_inchannels, num_channels):
|
| 116 |
+
if num_branches != len(num_blocks):
|
| 117 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
| 118 |
+
num_branches, len(num_blocks))
|
| 119 |
+
logger.error(error_msg)
|
| 120 |
+
raise ValueError(error_msg)
|
| 121 |
+
|
| 122 |
+
if num_branches != len(num_channels):
|
| 123 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
| 124 |
+
num_branches, len(num_channels))
|
| 125 |
+
logger.error(error_msg)
|
| 126 |
+
raise ValueError(error_msg)
|
| 127 |
+
|
| 128 |
+
if num_branches != len(num_inchannels):
|
| 129 |
+
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
| 130 |
+
num_branches, len(num_inchannels))
|
| 131 |
+
logger.error(error_msg)
|
| 132 |
+
raise ValueError(error_msg)
|
| 133 |
+
|
| 134 |
+
def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
|
| 135 |
+
stride=1):
|
| 136 |
+
downsample = None
|
| 137 |
+
if stride != 1 or \
|
| 138 |
+
self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
|
| 139 |
+
downsample = nn.Sequential(
|
| 140 |
+
nn.Conv2d(
|
| 141 |
+
self.num_inchannels[branch_index],
|
| 142 |
+
num_channels[branch_index] * block.expansion,
|
| 143 |
+
kernel_size=1, stride=stride, bias=False
|
| 144 |
+
),
|
| 145 |
+
nn.BatchNorm2d(
|
| 146 |
+
num_channels[branch_index] * block.expansion,
|
| 147 |
+
momentum=BN_MOMENTUM
|
| 148 |
+
),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
layers = []
|
| 152 |
+
layers.append(
|
| 153 |
+
block(
|
| 154 |
+
self.num_inchannels[branch_index],
|
| 155 |
+
num_channels[branch_index],
|
| 156 |
+
stride,
|
| 157 |
+
downsample
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
self.num_inchannels[branch_index] = \
|
| 161 |
+
num_channels[branch_index] * block.expansion
|
| 162 |
+
for i in range(1, num_blocks[branch_index]):
|
| 163 |
+
layers.append(
|
| 164 |
+
block(
|
| 165 |
+
self.num_inchannels[branch_index],
|
| 166 |
+
num_channels[branch_index]
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
return nn.Sequential(*layers)
|
| 171 |
+
|
| 172 |
+
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
| 173 |
+
branches = []
|
| 174 |
+
|
| 175 |
+
for i in range(num_branches):
|
| 176 |
+
branches.append(
|
| 177 |
+
self._make_one_branch(i, block, num_blocks, num_channels)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return nn.ModuleList(branches)
|
| 181 |
+
|
| 182 |
+
def _make_fuse_layers(self):
|
| 183 |
+
if self.num_branches == 1:
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
num_branches = self.num_branches
|
| 187 |
+
num_inchannels = self.num_inchannels
|
| 188 |
+
fuse_layers = []
|
| 189 |
+
for i in range(num_branches if self.multi_scale_output else 1):
|
| 190 |
+
fuse_layer = []
|
| 191 |
+
for j in range(num_branches):
|
| 192 |
+
if j > i:
|
| 193 |
+
fuse_layer.append(
|
| 194 |
+
nn.Sequential(
|
| 195 |
+
nn.Conv2d(
|
| 196 |
+
num_inchannels[j],
|
| 197 |
+
num_inchannels[i],
|
| 198 |
+
1, 1, 0, bias=False
|
| 199 |
+
),
|
| 200 |
+
nn.BatchNorm2d(num_inchannels[i]),
|
| 201 |
+
nn.Upsample(scale_factor=2**(j-i), mode='nearest')
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
elif j == i:
|
| 205 |
+
fuse_layer.append(None)
|
| 206 |
+
else:
|
| 207 |
+
conv3x3s = []
|
| 208 |
+
for k in range(i-j):
|
| 209 |
+
if k == i - j - 1:
|
| 210 |
+
num_outchannels_conv3x3 = num_inchannels[i]
|
| 211 |
+
conv3x3s.append(
|
| 212 |
+
nn.Sequential(
|
| 213 |
+
nn.Conv2d(
|
| 214 |
+
num_inchannels[j],
|
| 215 |
+
num_outchannels_conv3x3,
|
| 216 |
+
3, 2, 1, bias=False
|
| 217 |
+
),
|
| 218 |
+
nn.BatchNorm2d(num_outchannels_conv3x3)
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
else:
|
| 222 |
+
num_outchannels_conv3x3 = num_inchannels[j]
|
| 223 |
+
conv3x3s.append(
|
| 224 |
+
nn.Sequential(
|
| 225 |
+
nn.Conv2d(
|
| 226 |
+
num_inchannels[j],
|
| 227 |
+
num_outchannels_conv3x3,
|
| 228 |
+
3, 2, 1, bias=False
|
| 229 |
+
),
|
| 230 |
+
nn.BatchNorm2d(num_outchannels_conv3x3),
|
| 231 |
+
nn.ReLU(True)
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
fuse_layer.append(nn.Sequential(*conv3x3s))
|
| 235 |
+
fuse_layers.append(nn.ModuleList(fuse_layer))
|
| 236 |
+
|
| 237 |
+
return nn.ModuleList(fuse_layers)
|
| 238 |
+
|
| 239 |
+
def get_num_inchannels(self):
|
| 240 |
+
return self.num_inchannels
|
| 241 |
+
|
| 242 |
+
def forward(self, x):
|
| 243 |
+
if self.num_branches == 1:
|
| 244 |
+
return [self.branches[0](x[0])]
|
| 245 |
+
|
| 246 |
+
for i in range(self.num_branches):
|
| 247 |
+
x[i] = self.branches[i](x[i])
|
| 248 |
+
|
| 249 |
+
x_fuse = []
|
| 250 |
+
|
| 251 |
+
for i in range(len(self.fuse_layers)):
|
| 252 |
+
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
| 253 |
+
for j in range(1, self.num_branches):
|
| 254 |
+
if i == j:
|
| 255 |
+
y = y + x[j]
|
| 256 |
+
else:
|
| 257 |
+
y = y + self.fuse_layers[i][j](x[j])
|
| 258 |
+
x_fuse.append(self.relu(y))
|
| 259 |
+
|
| 260 |
+
return x_fuse
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
blocks_dict = {
|
| 264 |
+
'BASIC': BasicBlock,
|
| 265 |
+
'BOTTLENECK': Bottleneck
|
| 266 |
+
}
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class PoseHighResolutionNet(nn.Module):
|
| 270 |
+
|
| 271 |
+
def __init__(self, cfg):
|
| 272 |
+
self.inplanes = 64
|
| 273 |
+
extra = cfg['MODEL']['EXTRA']
|
| 274 |
+
super(PoseHighResolutionNet, self).__init__()
|
| 275 |
+
|
| 276 |
+
self.cfg = extra
|
| 277 |
+
|
| 278 |
+
# stem net
|
| 279 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1,
|
| 280 |
+
bias=False)
|
| 281 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 282 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1,
|
| 283 |
+
bias=False)
|
| 284 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM)
|
| 285 |
+
self.relu = nn.ReLU(inplace=True)
|
| 286 |
+
self.layer1 = self._make_layer(Bottleneck, 64, 4)
|
| 287 |
+
|
| 288 |
+
self.stage2_cfg = extra['STAGE2']
|
| 289 |
+
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
| 290 |
+
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
| 291 |
+
num_channels = [
|
| 292 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 293 |
+
]
|
| 294 |
+
self.transition1 = self._make_transition_layer([256], num_channels)
|
| 295 |
+
self.stage2, pre_stage_channels = self._make_stage(
|
| 296 |
+
self.stage2_cfg, num_channels)
|
| 297 |
+
|
| 298 |
+
self.stage3_cfg = extra['STAGE3']
|
| 299 |
+
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
| 300 |
+
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
| 301 |
+
num_channels = [
|
| 302 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 303 |
+
]
|
| 304 |
+
self.transition2 = self._make_transition_layer(
|
| 305 |
+
pre_stage_channels, num_channels)
|
| 306 |
+
self.stage3, pre_stage_channels = self._make_stage(
|
| 307 |
+
self.stage3_cfg, num_channels)
|
| 308 |
+
|
| 309 |
+
self.stage4_cfg = extra['STAGE4']
|
| 310 |
+
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
| 311 |
+
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
| 312 |
+
num_channels = [
|
| 313 |
+
num_channels[i] * block.expansion for i in range(len(num_channels))
|
| 314 |
+
]
|
| 315 |
+
self.transition3 = self._make_transition_layer(
|
| 316 |
+
pre_stage_channels, num_channels)
|
| 317 |
+
self.stage4, pre_stage_channels = self._make_stage(
|
| 318 |
+
self.stage4_cfg, num_channels, multi_scale_output=True)
|
| 319 |
+
|
| 320 |
+
self.final_layer = nn.Conv2d(
|
| 321 |
+
in_channels=pre_stage_channels[0],
|
| 322 |
+
out_channels=cfg['MODEL']['NUM_JOINTS'],
|
| 323 |
+
kernel_size=extra['FINAL_CONV_KERNEL'],
|
| 324 |
+
stride=1,
|
| 325 |
+
padding=1 if extra['FINAL_CONV_KERNEL'] == 3 else 0
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
self.pretrained_layers = extra['PRETRAINED_LAYERS']
|
| 329 |
+
|
| 330 |
+
if extra.DOWNSAMPLE and extra.USE_CONV:
|
| 331 |
+
self.downsample_stage_1 = self._make_downsample_layer(3, num_channel=self.stage2_cfg['NUM_CHANNELS'][0])
|
| 332 |
+
self.downsample_stage_2 = self._make_downsample_layer(2, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
| 333 |
+
self.downsample_stage_3 = self._make_downsample_layer(1, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
| 334 |
+
elif not extra.DOWNSAMPLE and extra.USE_CONV:
|
| 335 |
+
self.upsample_stage_2 = self._make_upsample_layer(1, num_channel=self.stage2_cfg['NUM_CHANNELS'][-1])
|
| 336 |
+
self.upsample_stage_3 = self._make_upsample_layer(2, num_channel=self.stage3_cfg['NUM_CHANNELS'][-1])
|
| 337 |
+
self.upsample_stage_4 = self._make_upsample_layer(3, num_channel=self.stage4_cfg['NUM_CHANNELS'][-1])
|
| 338 |
+
|
| 339 |
+
def _make_transition_layer(
|
| 340 |
+
self, num_channels_pre_layer, num_channels_cur_layer):
|
| 341 |
+
num_branches_cur = len(num_channels_cur_layer)
|
| 342 |
+
num_branches_pre = len(num_channels_pre_layer)
|
| 343 |
+
|
| 344 |
+
transition_layers = []
|
| 345 |
+
for i in range(num_branches_cur):
|
| 346 |
+
if i < num_branches_pre:
|
| 347 |
+
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
| 348 |
+
transition_layers.append(
|
| 349 |
+
nn.Sequential(
|
| 350 |
+
nn.Conv2d(
|
| 351 |
+
num_channels_pre_layer[i],
|
| 352 |
+
num_channels_cur_layer[i],
|
| 353 |
+
3, 1, 1, bias=False
|
| 354 |
+
),
|
| 355 |
+
nn.BatchNorm2d(num_channels_cur_layer[i]),
|
| 356 |
+
nn.ReLU(inplace=True)
|
| 357 |
+
)
|
| 358 |
+
)
|
| 359 |
+
else:
|
| 360 |
+
transition_layers.append(None)
|
| 361 |
+
else:
|
| 362 |
+
conv3x3s = []
|
| 363 |
+
for j in range(i+1-num_branches_pre):
|
| 364 |
+
inchannels = num_channels_pre_layer[-1]
|
| 365 |
+
outchannels = num_channels_cur_layer[i] \
|
| 366 |
+
if j == i-num_branches_pre else inchannels
|
| 367 |
+
conv3x3s.append(
|
| 368 |
+
nn.Sequential(
|
| 369 |
+
nn.Conv2d(
|
| 370 |
+
inchannels, outchannels, 3, 2, 1, bias=False
|
| 371 |
+
),
|
| 372 |
+
nn.BatchNorm2d(outchannels),
|
| 373 |
+
nn.ReLU(inplace=True)
|
| 374 |
+
)
|
| 375 |
+
)
|
| 376 |
+
transition_layers.append(nn.Sequential(*conv3x3s))
|
| 377 |
+
|
| 378 |
+
return nn.ModuleList(transition_layers)
|
| 379 |
+
|
| 380 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 381 |
+
downsample = None
|
| 382 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 383 |
+
downsample = nn.Sequential(
|
| 384 |
+
nn.Conv2d(
|
| 385 |
+
self.inplanes, planes * block.expansion,
|
| 386 |
+
kernel_size=1, stride=stride, bias=False
|
| 387 |
+
),
|
| 388 |
+
nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
layers = []
|
| 392 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 393 |
+
self.inplanes = planes * block.expansion
|
| 394 |
+
for i in range(1, blocks):
|
| 395 |
+
layers.append(block(self.inplanes, planes))
|
| 396 |
+
|
| 397 |
+
return nn.Sequential(*layers)
|
| 398 |
+
|
| 399 |
+
def _make_stage(self, layer_config, num_inchannels,
|
| 400 |
+
multi_scale_output=True):
|
| 401 |
+
num_modules = layer_config['NUM_MODULES']
|
| 402 |
+
num_branches = layer_config['NUM_BRANCHES']
|
| 403 |
+
num_blocks = layer_config['NUM_BLOCKS']
|
| 404 |
+
num_channels = layer_config['NUM_CHANNELS']
|
| 405 |
+
block = blocks_dict[layer_config['BLOCK']]
|
| 406 |
+
fuse_method = layer_config['FUSE_METHOD']
|
| 407 |
+
|
| 408 |
+
modules = []
|
| 409 |
+
for i in range(num_modules):
|
| 410 |
+
# multi_scale_output is only used last module
|
| 411 |
+
if not multi_scale_output and i == num_modules - 1:
|
| 412 |
+
reset_multi_scale_output = False
|
| 413 |
+
else:
|
| 414 |
+
reset_multi_scale_output = True
|
| 415 |
+
|
| 416 |
+
modules.append(
|
| 417 |
+
HighResolutionModule(
|
| 418 |
+
num_branches,
|
| 419 |
+
block,
|
| 420 |
+
num_blocks,
|
| 421 |
+
num_inchannels,
|
| 422 |
+
num_channels,
|
| 423 |
+
fuse_method,
|
| 424 |
+
reset_multi_scale_output
|
| 425 |
+
)
|
| 426 |
+
)
|
| 427 |
+
num_inchannels = modules[-1].get_num_inchannels()
|
| 428 |
+
|
| 429 |
+
return nn.Sequential(*modules), num_inchannels
|
| 430 |
+
|
| 431 |
+
def _make_upsample_layer(self, num_layers, num_channel, kernel_size=3):
|
| 432 |
+
layers = []
|
| 433 |
+
for i in range(num_layers):
|
| 434 |
+
layers.append(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True))
|
| 435 |
+
layers.append(
|
| 436 |
+
nn.Conv2d(
|
| 437 |
+
in_channels=num_channel, out_channels=num_channel,
|
| 438 |
+
kernel_size=kernel_size, stride=1, padding=1, bias=False,
|
| 439 |
+
)
|
| 440 |
+
)
|
| 441 |
+
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
| 442 |
+
layers.append(nn.ReLU(inplace=True))
|
| 443 |
+
|
| 444 |
+
return nn.Sequential(*layers)
|
| 445 |
+
|
| 446 |
+
def _make_downsample_layer(self, num_layers, num_channel, kernel_size=3):
|
| 447 |
+
layers = []
|
| 448 |
+
for i in range(num_layers):
|
| 449 |
+
layers.append(
|
| 450 |
+
nn.Conv2d(
|
| 451 |
+
in_channels=num_channel, out_channels=num_channel,
|
| 452 |
+
kernel_size=kernel_size, stride=2, padding=1, bias=False,
|
| 453 |
+
)
|
| 454 |
+
)
|
| 455 |
+
layers.append(nn.BatchNorm2d(num_channel, momentum=BN_MOMENTUM))
|
| 456 |
+
layers.append(nn.ReLU(inplace=True))
|
| 457 |
+
|
| 458 |
+
return nn.Sequential(*layers)
|
| 459 |
+
|
| 460 |
+
def forward(self, x):
|
| 461 |
+
x = self.conv1(x)
|
| 462 |
+
x = self.bn1(x)
|
| 463 |
+
x = self.relu(x)
|
| 464 |
+
x = self.conv2(x)
|
| 465 |
+
x = self.bn2(x)
|
| 466 |
+
x = self.relu(x)
|
| 467 |
+
x = self.layer1(x)
|
| 468 |
+
|
| 469 |
+
x_list = []
|
| 470 |
+
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
| 471 |
+
if self.transition1[i] is not None:
|
| 472 |
+
x_list.append(self.transition1[i](x))
|
| 473 |
+
else:
|
| 474 |
+
x_list.append(x)
|
| 475 |
+
y_list = self.stage2(x_list)
|
| 476 |
+
|
| 477 |
+
x_list = []
|
| 478 |
+
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
| 479 |
+
if self.transition2[i] is not None:
|
| 480 |
+
x_list.append(self.transition2[i](y_list[-1]))
|
| 481 |
+
else:
|
| 482 |
+
x_list.append(y_list[i])
|
| 483 |
+
y_list = self.stage3(x_list)
|
| 484 |
+
|
| 485 |
+
x_list = []
|
| 486 |
+
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
| 487 |
+
if self.transition3[i] is not None:
|
| 488 |
+
x_list.append(self.transition3[i](y_list[-1]))
|
| 489 |
+
else:
|
| 490 |
+
x_list.append(y_list[i])
|
| 491 |
+
x = self.stage4(x_list)
|
| 492 |
+
|
| 493 |
+
if self.cfg.DOWNSAMPLE:
|
| 494 |
+
if self.cfg.USE_CONV:
|
| 495 |
+
# Downsampling with strided convolutions
|
| 496 |
+
x1 = self.downsample_stage_1(x[0])
|
| 497 |
+
x2 = self.downsample_stage_2(x[1])
|
| 498 |
+
x3 = self.downsample_stage_3(x[2])
|
| 499 |
+
x = torch.cat([x1, x2, x3, x[3]], 1)
|
| 500 |
+
else:
|
| 501 |
+
# Downsampling with interpolation
|
| 502 |
+
x0_h, x0_w = x[3].size(2), x[3].size(3)
|
| 503 |
+
x1 = F.interpolate(x[0], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 504 |
+
x2 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 505 |
+
x3 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 506 |
+
x = torch.cat([x1, x2, x3, x[3]], 1)
|
| 507 |
+
else:
|
| 508 |
+
if self.cfg.USE_CONV:
|
| 509 |
+
# Upsampling with interpolations + convolutions
|
| 510 |
+
x1 = self.upsample_stage_2(x[1])
|
| 511 |
+
x2 = self.upsample_stage_3(x[2])
|
| 512 |
+
x3 = self.upsample_stage_4(x[3])
|
| 513 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
| 514 |
+
else:
|
| 515 |
+
# Upsampling with interpolation
|
| 516 |
+
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
| 517 |
+
x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 518 |
+
x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 519 |
+
x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
| 520 |
+
x = torch.cat([x[0], x1, x2, x3], 1)
|
| 521 |
+
|
| 522 |
+
return x
|
| 523 |
+
|
| 524 |
+
def init_weights(self, pretrained=''):
|
| 525 |
+
logger.info('=> init weights from normal distribution')
|
| 526 |
+
for m in self.modules():
|
| 527 |
+
if isinstance(m, nn.Conv2d):
|
| 528 |
+
# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 529 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 530 |
+
for name, _ in m.named_parameters():
|
| 531 |
+
if name in ['bias']:
|
| 532 |
+
nn.init.constant_(m.bias, 0)
|
| 533 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 534 |
+
nn.init.constant_(m.weight, 1)
|
| 535 |
+
nn.init.constant_(m.bias, 0)
|
| 536 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 537 |
+
nn.init.normal_(m.weight, std=0.001)
|
| 538 |
+
for name, _ in m.named_parameters():
|
| 539 |
+
if name in ['bias']:
|
| 540 |
+
nn.init.constant_(m.bias, 0)
|
| 541 |
+
|
| 542 |
+
if os.path.isfile(pretrained):
|
| 543 |
+
pretrained_state_dict = torch.load(pretrained)
|
| 544 |
+
logger.info('=> loading pretrained model {}'.format(pretrained))
|
| 545 |
+
|
| 546 |
+
need_init_state_dict = {}
|
| 547 |
+
for name, m in pretrained_state_dict.items():
|
| 548 |
+
if name.split('.')[0] in self.pretrained_layers \
|
| 549 |
+
or self.pretrained_layers[0] is '*':
|
| 550 |
+
need_init_state_dict[name] = m
|
| 551 |
+
self.load_state_dict(need_init_state_dict, strict=False)
|
| 552 |
+
elif pretrained:
|
| 553 |
+
logger.warning('IMPORTANT WARNING!! Please download pre-trained models if you are in TRAINING mode!')
|
| 554 |
+
# raise ValueError('{} is not exist!'.format(pretrained))
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def get_pose_net(cfg, is_train):
|
| 558 |
+
model = PoseHighResolutionNet(cfg)
|
| 559 |
+
|
| 560 |
+
if is_train and cfg['MODEL']['INIT_WEIGHTS']:
|
| 561 |
+
model.init_weights(cfg['MODEL']['PRETRAINED'])
|
| 562 |
+
|
| 563 |
+
return model
|
| 564 |
+
|
| 565 |
+
|
| 566 |
+
def get_cfg_defaults(pretrained, width=32, downsample=False, use_conv=False):
|
| 567 |
+
# pose_multi_resoluton_net related params
|
| 568 |
+
HRNET = CN()
|
| 569 |
+
HRNET.PRETRAINED_LAYERS = [
|
| 570 |
+
'conv1', 'bn1', 'conv2', 'bn2', 'layer1', 'transition1',
|
| 571 |
+
'stage2', 'transition2', 'stage3', 'transition3', 'stage4',
|
| 572 |
+
]
|
| 573 |
+
HRNET.STEM_INPLANES = 64
|
| 574 |
+
HRNET.FINAL_CONV_KERNEL = 1
|
| 575 |
+
HRNET.STAGE2 = CN()
|
| 576 |
+
HRNET.STAGE2.NUM_MODULES = 1
|
| 577 |
+
HRNET.STAGE2.NUM_BRANCHES = 2
|
| 578 |
+
HRNET.STAGE2.NUM_BLOCKS = [4, 4]
|
| 579 |
+
HRNET.STAGE2.NUM_CHANNELS = [width, width*2]
|
| 580 |
+
HRNET.STAGE2.BLOCK = 'BASIC'
|
| 581 |
+
HRNET.STAGE2.FUSE_METHOD = 'SUM'
|
| 582 |
+
HRNET.STAGE3 = CN()
|
| 583 |
+
HRNET.STAGE3.NUM_MODULES = 4
|
| 584 |
+
HRNET.STAGE3.NUM_BRANCHES = 3
|
| 585 |
+
HRNET.STAGE3.NUM_BLOCKS = [4, 4, 4]
|
| 586 |
+
HRNET.STAGE3.NUM_CHANNELS = [width, width*2, width*4]
|
| 587 |
+
HRNET.STAGE3.BLOCK = 'BASIC'
|
| 588 |
+
HRNET.STAGE3.FUSE_METHOD = 'SUM'
|
| 589 |
+
HRNET.STAGE4 = CN()
|
| 590 |
+
HRNET.STAGE4.NUM_MODULES = 3
|
| 591 |
+
HRNET.STAGE4.NUM_BRANCHES = 4
|
| 592 |
+
HRNET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4]
|
| 593 |
+
HRNET.STAGE4.NUM_CHANNELS = [width, width*2, width*4, width*8]
|
| 594 |
+
HRNET.STAGE4.BLOCK = 'BASIC'
|
| 595 |
+
HRNET.STAGE4.FUSE_METHOD = 'SUM'
|
| 596 |
+
HRNET.DOWNSAMPLE = downsample
|
| 597 |
+
HRNET.USE_CONV = use_conv
|
| 598 |
+
|
| 599 |
+
cfg = CN()
|
| 600 |
+
cfg.MODEL = CN()
|
| 601 |
+
cfg.MODEL.INIT_WEIGHTS = True
|
| 602 |
+
cfg.MODEL.PRETRAINED = pretrained # 'data/pretrained_models/hrnet_w32-36af842e.pth'
|
| 603 |
+
cfg.MODEL.EXTRA = HRNET
|
| 604 |
+
cfg.MODEL.NUM_JOINTS = 24
|
| 605 |
+
return cfg
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def hrnet_w32(
|
| 609 |
+
pretrained=True,
|
| 610 |
+
pretrained_ckpt='data/weights/pose_hrnet_w32_256x192.pth',
|
| 611 |
+
downsample=False,
|
| 612 |
+
use_conv=False,
|
| 613 |
+
):
|
| 614 |
+
cfg = get_cfg_defaults(pretrained_ckpt, width=32, downsample=downsample, use_conv=use_conv)
|
| 615 |
+
return get_pose_net(cfg, is_train=True)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
def hrnet_w48(
|
| 619 |
+
pretrained=True,
|
| 620 |
+
pretrained_ckpt='data/weights/pose_hrnet_w48_256x192.pth',
|
| 621 |
+
downsample=False,
|
| 622 |
+
use_conv=False,
|
| 623 |
+
):
|
| 624 |
+
cfg = get_cfg_defaults(pretrained_ckpt, width=48, downsample=downsample, use_conv=use_conv)
|
| 625 |
+
return get_pose_net(cfg, is_train=True)
|
utils/image_utils.py
ADDED
|
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This file contains functions that are used to perform data augmentation.
|
| 3 |
+
"""
|
| 4 |
+
import cv2
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
from skimage.transform import rotate, resize
|
| 8 |
+
import numpy as np
|
| 9 |
+
import jpeg4py as jpeg
|
| 10 |
+
from trimesh.visual import color
|
| 11 |
+
|
| 12 |
+
# from ..core import constants
|
| 13 |
+
# from .vibe_image_utils import gen_trans_from_patch_cv
|
| 14 |
+
from .kp_utils import map_smpl_to_common, get_smpl_joint_names
|
| 15 |
+
|
| 16 |
+
def get_transform(center, scale, res, rot=0):
|
| 17 |
+
"""Generate transformation matrix."""
|
| 18 |
+
h = 200 * scale
|
| 19 |
+
t = np.zeros((3, 3))
|
| 20 |
+
t[0, 0] = float(res[1]) / h
|
| 21 |
+
t[1, 1] = float(res[0]) / h
|
| 22 |
+
t[0, 2] = res[1] * (-float(center[0]) / h + .5)
|
| 23 |
+
t[1, 2] = res[0] * (-float(center[1]) / h + .5)
|
| 24 |
+
t[2, 2] = 1
|
| 25 |
+
if not rot == 0:
|
| 26 |
+
rot = -rot # To match direction of rotation from cropping
|
| 27 |
+
rot_mat = np.zeros((3, 3))
|
| 28 |
+
rot_rad = rot * np.pi / 180
|
| 29 |
+
sn, cs = np.sin(rot_rad), np.cos(rot_rad)
|
| 30 |
+
rot_mat[0, :2] = [cs, -sn]
|
| 31 |
+
rot_mat[1, :2] = [sn, cs]
|
| 32 |
+
rot_mat[2, 2] = 1
|
| 33 |
+
# Need to rotate around center
|
| 34 |
+
t_mat = np.eye(3)
|
| 35 |
+
t_mat[0, 2] = -res[1] / 2
|
| 36 |
+
t_mat[1, 2] = -res[0] / 2
|
| 37 |
+
t_inv = t_mat.copy()
|
| 38 |
+
t_inv[:2, 2] *= -1
|
| 39 |
+
t = np.dot(t_inv, np.dot(rot_mat, np.dot(t_mat, t)))
|
| 40 |
+
return t
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def transform(pt, center, scale, res, invert=0, rot=0):
|
| 44 |
+
"""Transform pixel location to different reference."""
|
| 45 |
+
t = get_transform(center, scale, res, rot=rot)
|
| 46 |
+
if invert:
|
| 47 |
+
t = np.linalg.inv(t)
|
| 48 |
+
new_pt = np.array([pt[0] - 1, pt[1] - 1, 1.]).T
|
| 49 |
+
new_pt = np.dot(t, new_pt)
|
| 50 |
+
return new_pt[:2].astype(int) + 1
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def crop(img, center, scale, res, rot=0):
|
| 54 |
+
"""Crop image according to the supplied bounding box."""
|
| 55 |
+
# Upper left point
|
| 56 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
| 57 |
+
# Bottom right point
|
| 58 |
+
br = np.array(transform([res[0] + 1,
|
| 59 |
+
res[1] + 1], center, scale, res, invert=1)) - 1
|
| 60 |
+
|
| 61 |
+
# Padding so that when rotated proper amount of context is included
|
| 62 |
+
pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2)
|
| 63 |
+
if not rot == 0:
|
| 64 |
+
ul -= pad
|
| 65 |
+
br += pad
|
| 66 |
+
|
| 67 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
| 68 |
+
if len(img.shape) > 2:
|
| 69 |
+
new_shape += [img.shape[2]]
|
| 70 |
+
new_img = np.zeros(new_shape)
|
| 71 |
+
|
| 72 |
+
# Range to fill new array
|
| 73 |
+
new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0]
|
| 74 |
+
new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1]
|
| 75 |
+
# Range to sample from original image
|
| 76 |
+
old_x = max(0, ul[0]), min(len(img[0]), br[0])
|
| 77 |
+
old_y = max(0, ul[1]), min(len(img), br[1])
|
| 78 |
+
new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1],
|
| 79 |
+
old_x[0]:old_x[1]]
|
| 80 |
+
|
| 81 |
+
if not rot == 0:
|
| 82 |
+
# Remove padding
|
| 83 |
+
|
| 84 |
+
new_img = rotate(new_img, rot) # scipy.misc.imrotate(new_img, rot)
|
| 85 |
+
new_img = new_img[pad:-pad, pad:-pad]
|
| 86 |
+
|
| 87 |
+
# resize image
|
| 88 |
+
new_img = resize(new_img, res) # scipy.misc.imresize(new_img, res)
|
| 89 |
+
return new_img
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def crop_cv2(img, center, scale, res, rot=0):
|
| 93 |
+
c_x, c_y = center
|
| 94 |
+
c_x, c_y = int(round(c_x)), int(round(c_y))
|
| 95 |
+
patch_width, patch_height = int(round(res[0])), int(round(res[1]))
|
| 96 |
+
bb_width = bb_height = int(round(scale * 200.))
|
| 97 |
+
|
| 98 |
+
trans = gen_trans_from_patch_cv(
|
| 99 |
+
c_x, c_y, bb_width, bb_height,
|
| 100 |
+
patch_width, patch_height,
|
| 101 |
+
scale=1.0, rot=rot, inv=False,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
crop_img = cv2.warpAffine(
|
| 105 |
+
img, trans, (int(patch_width), int(patch_height)),
|
| 106 |
+
flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_CONSTANT
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
return crop_img
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start):
|
| 113 |
+
y1 = int((height - crop_height) * h_start)
|
| 114 |
+
y2 = y1 + crop_height
|
| 115 |
+
x1 = int((width - crop_width) * w_start)
|
| 116 |
+
x2 = x1 + crop_width
|
| 117 |
+
return x1, y1, x2, y2
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def random_crop(center, scale, crop_scale_factor, axis='all'):
|
| 121 |
+
'''
|
| 122 |
+
center: bbox center [x,y]
|
| 123 |
+
scale: bbox height / 200
|
| 124 |
+
crop_scale_factor: amount of cropping to be applied
|
| 125 |
+
axis: axis which cropping will be applied
|
| 126 |
+
"x": center the y axis and get random crops in x
|
| 127 |
+
"y": center the x axis and get random crops in y
|
| 128 |
+
"all": randomly crop from all locations
|
| 129 |
+
'''
|
| 130 |
+
orig_size = int(scale * 200.)
|
| 131 |
+
ul = (center - (orig_size / 2.)).astype(int)
|
| 132 |
+
|
| 133 |
+
crop_size = int(orig_size * crop_scale_factor)
|
| 134 |
+
|
| 135 |
+
if axis == 'all':
|
| 136 |
+
h_start = np.random.rand()
|
| 137 |
+
w_start = np.random.rand()
|
| 138 |
+
elif axis == 'x':
|
| 139 |
+
h_start = np.random.rand()
|
| 140 |
+
w_start = 0.5
|
| 141 |
+
elif axis == 'y':
|
| 142 |
+
h_start = 0.5
|
| 143 |
+
w_start = np.random.rand()
|
| 144 |
+
else:
|
| 145 |
+
raise ValueError(f'axis {axis} is undefined!')
|
| 146 |
+
|
| 147 |
+
x1, y1, x2, y2 = get_random_crop_coords(
|
| 148 |
+
height=orig_size,
|
| 149 |
+
width=orig_size,
|
| 150 |
+
crop_height=crop_size,
|
| 151 |
+
crop_width=crop_size,
|
| 152 |
+
h_start=h_start,
|
| 153 |
+
w_start=w_start,
|
| 154 |
+
)
|
| 155 |
+
scale = (y2 - y1) / 200.
|
| 156 |
+
center = ul + np.array([(y1 + y2) / 2, (x1 + x2) / 2])
|
| 157 |
+
return center, scale
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True):
|
| 161 |
+
"""'Undo' the image cropping/resizing.
|
| 162 |
+
This function is used when evaluating mask/part segmentation.
|
| 163 |
+
"""
|
| 164 |
+
res = img.shape[:2]
|
| 165 |
+
# Upper left point
|
| 166 |
+
ul = np.array(transform([1, 1], center, scale, res, invert=1)) - 1
|
| 167 |
+
# Bottom right point
|
| 168 |
+
br = np.array(transform([res[0] + 1, res[1] + 1], center, scale, res, invert=1)) - 1
|
| 169 |
+
# size of cropped image
|
| 170 |
+
crop_shape = [br[1] - ul[1], br[0] - ul[0]]
|
| 171 |
+
|
| 172 |
+
new_shape = [br[1] - ul[1], br[0] - ul[0]]
|
| 173 |
+
if len(img.shape) > 2:
|
| 174 |
+
new_shape += [img.shape[2]]
|
| 175 |
+
new_img = np.zeros(orig_shape, dtype=np.uint8)
|
| 176 |
+
# Range to fill new array
|
| 177 |
+
new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0]
|
| 178 |
+
new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1]
|
| 179 |
+
# Range to sample from original image
|
| 180 |
+
old_x = max(0, ul[0]), min(orig_shape[1], br[0])
|
| 181 |
+
old_y = max(0, ul[1]), min(orig_shape[0], br[1])
|
| 182 |
+
img = resize(img, crop_shape) #, interp='nearest') # scipy.misc.imresize(img, crop_shape, interp='nearest')
|
| 183 |
+
new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]]
|
| 184 |
+
return new_img
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def rot_aa(aa, rot):
|
| 188 |
+
"""Rotate axis angle parameters."""
|
| 189 |
+
# pose parameters
|
| 190 |
+
R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0],
|
| 191 |
+
[np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0],
|
| 192 |
+
[0, 0, 1]])
|
| 193 |
+
# find the rotation of the body in camera frame
|
| 194 |
+
per_rdg, _ = cv2.Rodrigues(aa)
|
| 195 |
+
# apply the global rotation to the global orientation
|
| 196 |
+
resrot, _ = cv2.Rodrigues(np.dot(R, per_rdg))
|
| 197 |
+
aa = (resrot.T)[0]
|
| 198 |
+
return aa
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def flip_img(img):
|
| 202 |
+
"""Flip rgb images or masks.
|
| 203 |
+
channels come last, e.g. (256,256,3).
|
| 204 |
+
"""
|
| 205 |
+
img = np.fliplr(img)
|
| 206 |
+
return img
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def flip_kp(kp):
|
| 210 |
+
"""Flip keypoints."""
|
| 211 |
+
if len(kp) == 24:
|
| 212 |
+
flipped_parts = constants.J24_FLIP_PERM
|
| 213 |
+
elif len(kp) == 49:
|
| 214 |
+
flipped_parts = constants.J49_FLIP_PERM
|
| 215 |
+
kp = kp[flipped_parts]
|
| 216 |
+
kp[:, 0] = - kp[:, 0]
|
| 217 |
+
return kp
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def flip_pose(pose):
|
| 221 |
+
"""Flip pose.
|
| 222 |
+
The flipping is based on SMPL parameters.
|
| 223 |
+
"""
|
| 224 |
+
flipped_parts = constants.SMPL_POSE_FLIP_PERM
|
| 225 |
+
pose = pose[flipped_parts]
|
| 226 |
+
# we also negate the second and the third dimension of the axis-angle
|
| 227 |
+
pose[1::3] = -pose[1::3]
|
| 228 |
+
pose[2::3] = -pose[2::3]
|
| 229 |
+
return pose
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def denormalize_images(images):
|
| 233 |
+
images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
|
| 234 |
+
images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
|
| 235 |
+
return images
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def read_img(img_fn):
|
| 239 |
+
# return pil_img.fromarray(
|
| 240 |
+
# cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB))
|
| 241 |
+
# with open(img_fn, 'rb') as f:
|
| 242 |
+
# img = pil_img.open(f).convert('RGB')
|
| 243 |
+
# return img
|
| 244 |
+
if img_fn.endswith('jpeg') or img_fn.endswith('jpg'):
|
| 245 |
+
try:
|
| 246 |
+
with open(img_fn, 'rb') as f:
|
| 247 |
+
img = np.array(jpeg.JPEG(f).decode())
|
| 248 |
+
except jpeg.JPEGRuntimeError:
|
| 249 |
+
# logger.warning('{} produced a JPEGRuntimeError', img_fn)
|
| 250 |
+
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
| 251 |
+
else:
|
| 252 |
+
# elif img_fn.endswith('png') or img_fn.endswith('JPG') or img_fn.endswith(''):
|
| 253 |
+
img = cv2.cvtColor(cv2.imread(img_fn), cv2.COLOR_BGR2RGB)
|
| 254 |
+
return img.astype(np.float32)
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def generate_heatmaps_2d(joints, joints_vis, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
| 258 |
+
'''
|
| 259 |
+
:param joints: [num_joints, 3]
|
| 260 |
+
:param joints_vis: [num_joints, 3]
|
| 261 |
+
:return: target, target_weight(1: visible, 0: invisible)
|
| 262 |
+
'''
|
| 263 |
+
target_weight = np.ones((num_joints, 1), dtype=np.float32)
|
| 264 |
+
target_weight[:, 0] = joints_vis[:, 0]
|
| 265 |
+
|
| 266 |
+
target = np.zeros((num_joints, heatmap_size, heatmap_size), dtype=np.float32)
|
| 267 |
+
|
| 268 |
+
tmp_size = sigma * 3
|
| 269 |
+
|
| 270 |
+
# denormalize joint into heatmap coordinates
|
| 271 |
+
joints = (joints + 1.) * (image_size / 2.)
|
| 272 |
+
|
| 273 |
+
for joint_id in range(num_joints):
|
| 274 |
+
feat_stride = image_size / heatmap_size
|
| 275 |
+
mu_x = int(joints[joint_id][0] / feat_stride + 0.5)
|
| 276 |
+
mu_y = int(joints[joint_id][1] / feat_stride + 0.5)
|
| 277 |
+
# Check that any part of the gaussian is in-bounds
|
| 278 |
+
ul = [int(mu_x - tmp_size), int(mu_y - tmp_size)]
|
| 279 |
+
br = [int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)]
|
| 280 |
+
if ul[0] >= heatmap_size or ul[1] >= heatmap_size \
|
| 281 |
+
or br[0] < 0 or br[1] < 0:
|
| 282 |
+
# If not, just return the image as is
|
| 283 |
+
target_weight[joint_id] = 0
|
| 284 |
+
continue
|
| 285 |
+
|
| 286 |
+
# # Generate gaussian
|
| 287 |
+
size = 2 * tmp_size + 1
|
| 288 |
+
x = np.arange(0, size, 1, np.float32)
|
| 289 |
+
y = x[:, np.newaxis]
|
| 290 |
+
x0 = y0 = size // 2
|
| 291 |
+
# The gaussian is not normalized, we want the center value to equal 1
|
| 292 |
+
g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
| 293 |
+
|
| 294 |
+
# Usable gaussian range
|
| 295 |
+
g_x = max(0, -ul[0]), min(br[0], heatmap_size) - ul[0]
|
| 296 |
+
g_y = max(0, -ul[1]), min(br[1], heatmap_size) - ul[1]
|
| 297 |
+
# Image range
|
| 298 |
+
img_x = max(0, ul[0]), min(br[0], heatmap_size)
|
| 299 |
+
img_y = max(0, ul[1]), min(br[1], heatmap_size)
|
| 300 |
+
|
| 301 |
+
v = target_weight[joint_id]
|
| 302 |
+
if v > 0.5:
|
| 303 |
+
target[joint_id][img_y[0]:img_y[1], img_x[0]:img_x[1]] = \
|
| 304 |
+
g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
|
| 305 |
+
|
| 306 |
+
return target, target_weight
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def generate_part_labels(vertices, faces, cam_t, neural_renderer, body_part_texture, K, R, part_bins):
|
| 310 |
+
batch_size = vertices.shape[0]
|
| 311 |
+
|
| 312 |
+
body_parts, depth, mask = neural_renderer(
|
| 313 |
+
vertices,
|
| 314 |
+
faces.expand(batch_size, -1, -1),
|
| 315 |
+
textures=body_part_texture.expand(batch_size, -1, -1, -1, -1, -1),
|
| 316 |
+
K=K.expand(batch_size, -1, -1),
|
| 317 |
+
R=R.expand(batch_size, -1, -1),
|
| 318 |
+
t=cam_t.unsqueeze(1),
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
render_rgb = body_parts.clone()
|
| 322 |
+
|
| 323 |
+
body_parts = body_parts.permute(0, 2, 3, 1)
|
| 324 |
+
body_parts *= 255. # multiply it with 255 to make labels distant
|
| 325 |
+
body_parts, _ = body_parts.max(-1) # reduce to single channel
|
| 326 |
+
|
| 327 |
+
body_parts = torch.bucketize(body_parts.detach(), part_bins, right=True) # np.digitize(body_parts, bins, right=True)
|
| 328 |
+
|
| 329 |
+
# add 1 to make background label 0
|
| 330 |
+
body_parts = body_parts.long() + 1
|
| 331 |
+
body_parts = body_parts * mask.detach()
|
| 332 |
+
|
| 333 |
+
return body_parts.long(), render_rgb
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def generate_heatmaps_2d_batch(joints, num_joints=24, heatmap_size=56, image_size=224, sigma=1.75):
|
| 337 |
+
batch_size = joints.shape[0]
|
| 338 |
+
|
| 339 |
+
joints = joints.detach().cpu().numpy()
|
| 340 |
+
joints_vis = np.ones_like(joints)
|
| 341 |
+
|
| 342 |
+
heatmaps = []
|
| 343 |
+
heatmaps_vis = []
|
| 344 |
+
for i in range(batch_size):
|
| 345 |
+
hm, hm_vis = generate_heatmaps_2d(joints[i], joints_vis[i], num_joints, heatmap_size, image_size, sigma)
|
| 346 |
+
heatmaps.append(hm)
|
| 347 |
+
heatmaps_vis.append(hm_vis)
|
| 348 |
+
|
| 349 |
+
return torch.from_numpy(np.stack(heatmaps)).float().to('cuda'), \
|
| 350 |
+
torch.from_numpy(np.stack(heatmaps_vis)).float().to('cuda')
|
| 351 |
+
|
| 352 |
+
|
| 353 |
+
def get_body_part_texture(faces, model_type='smpl', non_parametric=False):
|
| 354 |
+
if model_type == 'smpl':
|
| 355 |
+
n_vertices = 6890
|
| 356 |
+
segmentation_path = 'data/smpl_vert_segmentation.json'
|
| 357 |
+
if model_type == 'smplx':
|
| 358 |
+
n_vertices = 10475
|
| 359 |
+
segmentation_path = 'data/smplx_vert_segmentation.json'
|
| 360 |
+
|
| 361 |
+
with open(segmentation_path, 'rb') as f:
|
| 362 |
+
part_segmentation = json.load(f)
|
| 363 |
+
|
| 364 |
+
# map all vertex ids to the joint ids
|
| 365 |
+
joint_names = get_smpl_joint_names()
|
| 366 |
+
smplx_extra_joint_names = ['leftEye', 'eyeballs', 'rightEye']
|
| 367 |
+
body_vert_idx = np.zeros((n_vertices), dtype=np.int32) - 1 # -1 for missing label
|
| 368 |
+
for i, (k, v) in enumerate(part_segmentation.items()):
|
| 369 |
+
if k in smplx_extra_joint_names and model_type == 'smplx':
|
| 370 |
+
k = 'head' # map all extra smplx face joints to head
|
| 371 |
+
body_joint_idx = joint_names.index(k)
|
| 372 |
+
body_vert_idx[v] = body_joint_idx
|
| 373 |
+
|
| 374 |
+
# pare implementation
|
| 375 |
+
# import joblib
|
| 376 |
+
# part_segmentation = joblib.load('data/smpl_partSegmentation_mapping.pkl')
|
| 377 |
+
# body_vert_idx = part_segmentation['smpl_index']
|
| 378 |
+
|
| 379 |
+
n_parts = 24.
|
| 380 |
+
|
| 381 |
+
if non_parametric:
|
| 382 |
+
# reduce the number of body_parts to 14
|
| 383 |
+
# by mapping some joints to others
|
| 384 |
+
n_parts = 14.
|
| 385 |
+
joint_mapping = map_smpl_to_common()
|
| 386 |
+
|
| 387 |
+
for jm in joint_mapping:
|
| 388 |
+
for j in jm[0]:
|
| 389 |
+
body_vert_idx[body_vert_idx==j] = jm[1]
|
| 390 |
+
|
| 391 |
+
vertex_colors = np.ones((n_vertices, 4))
|
| 392 |
+
vertex_colors[:, :3] = body_vert_idx[..., None]
|
| 393 |
+
|
| 394 |
+
vertex_colors = color.to_rgba(vertex_colors)
|
| 395 |
+
vertex_colors = vertex_colors[:, :3]/255.
|
| 396 |
+
|
| 397 |
+
face_colors = vertex_colors[faces].min(axis=1)
|
| 398 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 3), dtype=np.float32)
|
| 399 |
+
# texture[0, :, 0, 0, :] = face_colors[:, :3] / n_parts
|
| 400 |
+
texture[0, :, 0, 0, :] = face_colors[:, :3]
|
| 401 |
+
|
| 402 |
+
vertex_colors = torch.from_numpy(vertex_colors).float()
|
| 403 |
+
texture = torch.from_numpy(texture).float()
|
| 404 |
+
return vertex_colors, texture
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def get_default_camera(focal_length, img_h, img_w, is_cam_batch=False):
|
| 408 |
+
if not is_cam_batch:
|
| 409 |
+
K = torch.eye(3)
|
| 410 |
+
K[0, 0] = focal_length
|
| 411 |
+
K[1, 1] = focal_length
|
| 412 |
+
K[2, 2] = 1
|
| 413 |
+
K[0, 2] = img_w / 2.
|
| 414 |
+
K[1, 2] = img_h / 2.
|
| 415 |
+
K = K[None, :, :]
|
| 416 |
+
R = torch.eye(3)[None, :, :]
|
| 417 |
+
else:
|
| 418 |
+
bs = focal_length.shape[0]
|
| 419 |
+
K = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
| 420 |
+
K[:, 0, 0] = focal_length[:, 0]
|
| 421 |
+
K[:, 1, 1] = focal_length[:, 1]
|
| 422 |
+
K[:, 2, 2] = 1
|
| 423 |
+
K[:, 0, 2] = img_w / 2.
|
| 424 |
+
K[:, 1, 2] = img_h / 2.
|
| 425 |
+
R = torch.eye(3)[None, :, :].repeat(bs, 1, 1)
|
| 426 |
+
return K, R
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def read_exif_data(img_fname):
|
| 430 |
+
import PIL.Image
|
| 431 |
+
import PIL.ExifTags
|
| 432 |
+
|
| 433 |
+
img = PIL.Image.open(img_fname)
|
| 434 |
+
exif_data = img._getexif()
|
| 435 |
+
|
| 436 |
+
if exif_data == None:
|
| 437 |
+
return None
|
| 438 |
+
|
| 439 |
+
exif = {
|
| 440 |
+
PIL.ExifTags.TAGS[k]: v
|
| 441 |
+
for k, v in exif_data.items()
|
| 442 |
+
if k in PIL.ExifTags.TAGS
|
| 443 |
+
}
|
| 444 |
+
return exif
|
utils/kp_utils.py
ADDED
|
@@ -0,0 +1,1114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def keypoint_hflip(kp, img_width):
|
| 5 |
+
# Flip a keypoint horizontally around the y-axis
|
| 6 |
+
# kp N,2
|
| 7 |
+
if len(kp.shape) == 2:
|
| 8 |
+
kp[:,0] = (img_width - 1.) - kp[:,0]
|
| 9 |
+
elif len(kp.shape) == 3:
|
| 10 |
+
kp[:, :, 0] = (img_width - 1.) - kp[:, :, 0]
|
| 11 |
+
return kp
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def convert_kps(joints2d, src, dst):
|
| 15 |
+
src_names = eval(f'get_{src}_joint_names')()
|
| 16 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
| 17 |
+
|
| 18 |
+
out_joints2d = np.zeros((joints2d.shape[0], len(dst_names), joints2d.shape[-1]))
|
| 19 |
+
|
| 20 |
+
for idx, jn in enumerate(dst_names):
|
| 21 |
+
if jn in src_names:
|
| 22 |
+
out_joints2d[:, idx] = joints2d[:, src_names.index(jn)]
|
| 23 |
+
|
| 24 |
+
return out_joints2d
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_perm_idxs(src, dst):
|
| 28 |
+
src_names = eval(f'get_{src}_joint_names')()
|
| 29 |
+
dst_names = eval(f'get_{dst}_joint_names')()
|
| 30 |
+
idxs = [src_names.index(h) for h in dst_names if h in src_names]
|
| 31 |
+
return idxs
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_mpii3d_test_joint_names():
|
| 35 |
+
return [
|
| 36 |
+
'headtop', # 'head_top',
|
| 37 |
+
'neck',
|
| 38 |
+
'rshoulder',# 'right_shoulder',
|
| 39 |
+
'relbow',# 'right_elbow',
|
| 40 |
+
'rwrist',# 'right_wrist',
|
| 41 |
+
'lshoulder',# 'left_shoulder',
|
| 42 |
+
'lelbow', # 'left_elbow',
|
| 43 |
+
'lwrist', # 'left_wrist',
|
| 44 |
+
'rhip', # 'right_hip',
|
| 45 |
+
'rknee', # 'right_knee',
|
| 46 |
+
'rankle',# 'right_ankle',
|
| 47 |
+
'lhip',# 'left_hip',
|
| 48 |
+
'lknee',# 'left_knee',
|
| 49 |
+
'lankle',# 'left_ankle'
|
| 50 |
+
'hip',# 'pelvis',
|
| 51 |
+
'Spine (H36M)',# 'spine',
|
| 52 |
+
'Head (H36M)',# 'head'
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_mpii3d_joint_names():
|
| 57 |
+
return [
|
| 58 |
+
'spine3', # 0,
|
| 59 |
+
'spine4', # 1,
|
| 60 |
+
'spine2', # 2,
|
| 61 |
+
'Spine (H36M)', #'spine', # 3,
|
| 62 |
+
'hip', # 'pelvis', # 4,
|
| 63 |
+
'neck', # 5,
|
| 64 |
+
'Head (H36M)', # 'head', # 6,
|
| 65 |
+
"headtop", # 'head_top', # 7,
|
| 66 |
+
'left_clavicle', # 8,
|
| 67 |
+
"lshoulder", # 'left_shoulder', # 9,
|
| 68 |
+
"lelbow", # 'left_elbow',# 10,
|
| 69 |
+
"lwrist", # 'left_wrist',# 11,
|
| 70 |
+
'left_hand',# 12,
|
| 71 |
+
'right_clavicle',# 13,
|
| 72 |
+
'rshoulder',# 'right_shoulder',# 14,
|
| 73 |
+
'relbow',# 'right_elbow',# 15,
|
| 74 |
+
'rwrist',# 'right_wrist',# 16,
|
| 75 |
+
'right_hand',# 17,
|
| 76 |
+
'lhip', # left_hip',# 18,
|
| 77 |
+
'lknee', # 'left_knee',# 19,
|
| 78 |
+
'lankle', #left ankle # 20
|
| 79 |
+
'left_foot', # 21
|
| 80 |
+
'left_toe', # 22
|
| 81 |
+
"rhip", # 'right_hip',# 23
|
| 82 |
+
"rknee", # 'right_knee',# 24
|
| 83 |
+
"rankle", #'right_ankle', # 25
|
| 84 |
+
'right_foot',# 26
|
| 85 |
+
'right_toe' # 27
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
# def get_insta_joint_names():
|
| 90 |
+
# return [
|
| 91 |
+
# 'rheel' , # 0
|
| 92 |
+
# 'rknee' , # 1
|
| 93 |
+
# 'rhip' , # 2
|
| 94 |
+
# 'lhip' , # 3
|
| 95 |
+
# 'lknee' , # 4
|
| 96 |
+
# 'lheel' , # 5
|
| 97 |
+
# 'rwrist' , # 6
|
| 98 |
+
# 'relbow' , # 7
|
| 99 |
+
# 'rshoulder' , # 8
|
| 100 |
+
# 'lshoulder' , # 9
|
| 101 |
+
# 'lelbow' , # 10
|
| 102 |
+
# 'lwrist' , # 11
|
| 103 |
+
# 'neck' , # 12
|
| 104 |
+
# 'headtop' , # 13
|
| 105 |
+
# 'nose' , # 14
|
| 106 |
+
# 'leye' , # 15
|
| 107 |
+
# 'reye' , # 16
|
| 108 |
+
# 'lear' , # 17
|
| 109 |
+
# 'rear' , # 18
|
| 110 |
+
# 'lbigtoe' , # 19
|
| 111 |
+
# 'rbigtoe' , # 20
|
| 112 |
+
# 'lsmalltoe' , # 21
|
| 113 |
+
# 'rsmalltoe' , # 22
|
| 114 |
+
# 'lankle' , # 23
|
| 115 |
+
# 'rankle' , # 24
|
| 116 |
+
# ]
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def get_insta_joint_names():
|
| 120 |
+
return [
|
| 121 |
+
'OP RHeel',
|
| 122 |
+
'OP RKnee',
|
| 123 |
+
'OP RHip',
|
| 124 |
+
'OP LHip',
|
| 125 |
+
'OP LKnee',
|
| 126 |
+
'OP LHeel',
|
| 127 |
+
'OP RWrist',
|
| 128 |
+
'OP RElbow',
|
| 129 |
+
'OP RShoulder',
|
| 130 |
+
'OP LShoulder',
|
| 131 |
+
'OP LElbow',
|
| 132 |
+
'OP LWrist',
|
| 133 |
+
'OP Neck',
|
| 134 |
+
'headtop',
|
| 135 |
+
'OP Nose',
|
| 136 |
+
'OP LEye',
|
| 137 |
+
'OP REye',
|
| 138 |
+
'OP LEar',
|
| 139 |
+
'OP REar',
|
| 140 |
+
'OP LBigToe',
|
| 141 |
+
'OP RBigToe',
|
| 142 |
+
'OP LSmallToe',
|
| 143 |
+
'OP RSmallToe',
|
| 144 |
+
'OP LAnkle',
|
| 145 |
+
'OP RAnkle',
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_mmpose_joint_names():
|
| 150 |
+
# this naming is for the first 23 joints of MMPose
|
| 151 |
+
# does not include hands and face
|
| 152 |
+
return [
|
| 153 |
+
'OP Nose', # 1
|
| 154 |
+
'OP LEye', # 2
|
| 155 |
+
'OP REye', # 3
|
| 156 |
+
'OP LEar', # 4
|
| 157 |
+
'OP REar', # 5
|
| 158 |
+
'OP LShoulder', # 6
|
| 159 |
+
'OP RShoulder', # 7
|
| 160 |
+
'OP LElbow', # 8
|
| 161 |
+
'OP RElbow', # 9
|
| 162 |
+
'OP LWrist', # 10
|
| 163 |
+
'OP RWrist', # 11
|
| 164 |
+
'OP LHip', # 12
|
| 165 |
+
'OP RHip', # 13
|
| 166 |
+
'OP LKnee', # 14
|
| 167 |
+
'OP RKnee', # 15
|
| 168 |
+
'OP LAnkle', # 16
|
| 169 |
+
'OP RAnkle', # 17
|
| 170 |
+
'OP LBigToe', # 18
|
| 171 |
+
'OP LSmallToe', # 19
|
| 172 |
+
'OP LHeel', # 20
|
| 173 |
+
'OP RBigToe', # 21
|
| 174 |
+
'OP RSmallToe', # 22
|
| 175 |
+
'OP RHeel', # 23
|
| 176 |
+
]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_insta_skeleton():
|
| 180 |
+
return np.array(
|
| 181 |
+
[
|
| 182 |
+
[0 , 1],
|
| 183 |
+
[1 , 2],
|
| 184 |
+
[2 , 3],
|
| 185 |
+
[3 , 4],
|
| 186 |
+
[4 , 5],
|
| 187 |
+
[6 , 7],
|
| 188 |
+
[7 , 8],
|
| 189 |
+
[8 , 9],
|
| 190 |
+
[9 ,10],
|
| 191 |
+
[2 , 8],
|
| 192 |
+
[3 , 9],
|
| 193 |
+
[10,11],
|
| 194 |
+
[8 ,12],
|
| 195 |
+
[9 ,12],
|
| 196 |
+
[12,13],
|
| 197 |
+
[12,14],
|
| 198 |
+
[14,15],
|
| 199 |
+
[14,16],
|
| 200 |
+
[15,17],
|
| 201 |
+
[16,18],
|
| 202 |
+
[0 ,20],
|
| 203 |
+
[20,22],
|
| 204 |
+
[5 ,19],
|
| 205 |
+
[19,21],
|
| 206 |
+
[5 ,23],
|
| 207 |
+
[0 ,24],
|
| 208 |
+
])
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def get_staf_skeleton():
|
| 212 |
+
return np.array(
|
| 213 |
+
[
|
| 214 |
+
[0, 1],
|
| 215 |
+
[1, 2],
|
| 216 |
+
[2, 3],
|
| 217 |
+
[3, 4],
|
| 218 |
+
[1, 5],
|
| 219 |
+
[5, 6],
|
| 220 |
+
[6, 7],
|
| 221 |
+
[1, 8],
|
| 222 |
+
[8, 9],
|
| 223 |
+
[9, 10],
|
| 224 |
+
[10, 11],
|
| 225 |
+
[8, 12],
|
| 226 |
+
[12, 13],
|
| 227 |
+
[13, 14],
|
| 228 |
+
[0, 15],
|
| 229 |
+
[0, 16],
|
| 230 |
+
[15, 17],
|
| 231 |
+
[16, 18],
|
| 232 |
+
[2, 9],
|
| 233 |
+
[5, 12],
|
| 234 |
+
[1, 19],
|
| 235 |
+
[20, 19],
|
| 236 |
+
]
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def get_staf_joint_names():
|
| 241 |
+
return [
|
| 242 |
+
'OP Nose', # 0,
|
| 243 |
+
'OP Neck', # 1,
|
| 244 |
+
'OP RShoulder', # 2,
|
| 245 |
+
'OP RElbow', # 3,
|
| 246 |
+
'OP RWrist', # 4,
|
| 247 |
+
'OP LShoulder', # 5,
|
| 248 |
+
'OP LElbow', # 6,
|
| 249 |
+
'OP LWrist', # 7,
|
| 250 |
+
'OP MidHip', # 8,
|
| 251 |
+
'OP RHip', # 9,
|
| 252 |
+
'OP RKnee', # 10,
|
| 253 |
+
'OP RAnkle', # 11,
|
| 254 |
+
'OP LHip', # 12,
|
| 255 |
+
'OP LKnee', # 13,
|
| 256 |
+
'OP LAnkle', # 14,
|
| 257 |
+
'OP REye', # 15,
|
| 258 |
+
'OP LEye', # 16,
|
| 259 |
+
'OP REar', # 17,
|
| 260 |
+
'OP LEar', # 18,
|
| 261 |
+
'Neck (LSP)', # 19,
|
| 262 |
+
'Top of Head (LSP)', # 20,
|
| 263 |
+
]
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def get_spin_op_joint_names():
|
| 267 |
+
return [
|
| 268 |
+
'OP Nose', # 0
|
| 269 |
+
'OP Neck', # 1
|
| 270 |
+
'OP RShoulder', # 2
|
| 271 |
+
'OP RElbow', # 3
|
| 272 |
+
'OP RWrist', # 4
|
| 273 |
+
'OP LShoulder', # 5
|
| 274 |
+
'OP LElbow', # 6
|
| 275 |
+
'OP LWrist', # 7
|
| 276 |
+
'OP MidHip', # 8
|
| 277 |
+
'OP RHip', # 9
|
| 278 |
+
'OP RKnee', # 10
|
| 279 |
+
'OP RAnkle', # 11
|
| 280 |
+
'OP LHip', # 12
|
| 281 |
+
'OP LKnee', # 13
|
| 282 |
+
'OP LAnkle', # 14
|
| 283 |
+
'OP REye', # 15
|
| 284 |
+
'OP LEye', # 16
|
| 285 |
+
'OP REar', # 17
|
| 286 |
+
'OP LEar', # 18
|
| 287 |
+
'OP LBigToe', # 19
|
| 288 |
+
'OP LSmallToe', # 20
|
| 289 |
+
'OP LHeel', # 21
|
| 290 |
+
'OP RBigToe', # 22
|
| 291 |
+
'OP RSmallToe', # 23
|
| 292 |
+
'OP RHeel', # 24
|
| 293 |
+
]
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def get_openpose_joint_names():
|
| 297 |
+
return [
|
| 298 |
+
'OP Nose', # 0
|
| 299 |
+
'OP Neck', # 1
|
| 300 |
+
'OP RShoulder', # 2
|
| 301 |
+
'OP RElbow', # 3
|
| 302 |
+
'OP RWrist', # 4
|
| 303 |
+
'OP LShoulder', # 5
|
| 304 |
+
'OP LElbow', # 6
|
| 305 |
+
'OP LWrist', # 7
|
| 306 |
+
'OP MidHip', # 8
|
| 307 |
+
'OP RHip', # 9
|
| 308 |
+
'OP RKnee', # 10
|
| 309 |
+
'OP RAnkle', # 11
|
| 310 |
+
'OP LHip', # 12
|
| 311 |
+
'OP LKnee', # 13
|
| 312 |
+
'OP LAnkle', # 14
|
| 313 |
+
'OP REye', # 15
|
| 314 |
+
'OP LEye', # 16
|
| 315 |
+
'OP REar', # 17
|
| 316 |
+
'OP LEar', # 18
|
| 317 |
+
'OP LBigToe', # 19
|
| 318 |
+
'OP LSmallToe', # 20
|
| 319 |
+
'OP LHeel', # 21
|
| 320 |
+
'OP RBigToe', # 22
|
| 321 |
+
'OP RSmallToe', # 23
|
| 322 |
+
'OP RHeel', # 24
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def get_spin_joint_names():
|
| 327 |
+
return [
|
| 328 |
+
'OP Nose', # 0
|
| 329 |
+
'OP Neck', # 1
|
| 330 |
+
'OP RShoulder', # 2
|
| 331 |
+
'OP RElbow', # 3
|
| 332 |
+
'OP RWrist', # 4
|
| 333 |
+
'OP LShoulder', # 5
|
| 334 |
+
'OP LElbow', # 6
|
| 335 |
+
'OP LWrist', # 7
|
| 336 |
+
'OP MidHip', # 8
|
| 337 |
+
'OP RHip', # 9
|
| 338 |
+
'OP RKnee', # 10
|
| 339 |
+
'OP RAnkle', # 11
|
| 340 |
+
'OP LHip', # 12
|
| 341 |
+
'OP LKnee', # 13
|
| 342 |
+
'OP LAnkle', # 14
|
| 343 |
+
'OP REye', # 15
|
| 344 |
+
'OP LEye', # 16
|
| 345 |
+
'OP REar', # 17
|
| 346 |
+
'OP LEar', # 18
|
| 347 |
+
'OP LBigToe', # 19
|
| 348 |
+
'OP LSmallToe', # 20
|
| 349 |
+
'OP LHeel', # 21
|
| 350 |
+
'OP RBigToe', # 22
|
| 351 |
+
'OP RSmallToe', # 23
|
| 352 |
+
'OP RHeel', # 24
|
| 353 |
+
'rankle', # 25
|
| 354 |
+
'rknee', # 26
|
| 355 |
+
'rhip', # 27
|
| 356 |
+
'lhip', # 28
|
| 357 |
+
'lknee', # 29
|
| 358 |
+
'lankle', # 30
|
| 359 |
+
'rwrist', # 31
|
| 360 |
+
'relbow', # 32
|
| 361 |
+
'rshoulder', # 33
|
| 362 |
+
'lshoulder', # 34
|
| 363 |
+
'lelbow', # 35
|
| 364 |
+
'lwrist', # 36
|
| 365 |
+
'neck', # 37
|
| 366 |
+
'headtop', # 38
|
| 367 |
+
'hip', # 39 'Pelvis (MPII)', # 39
|
| 368 |
+
'thorax', # 40 'Thorax (MPII)', # 40
|
| 369 |
+
'Spine (H36M)', # 41
|
| 370 |
+
'Jaw (H36M)', # 42
|
| 371 |
+
'Head (H36M)', # 43
|
| 372 |
+
'nose', # 44
|
| 373 |
+
'leye', # 45 'Left Eye', # 45
|
| 374 |
+
'reye', # 46 'Right Eye', # 46
|
| 375 |
+
'lear', # 47 'Left Ear', # 47
|
| 376 |
+
'rear', # 48 'Right Ear', # 48
|
| 377 |
+
]
|
| 378 |
+
|
| 379 |
+
def get_muco3dhp_joint_names():
|
| 380 |
+
return [
|
| 381 |
+
'headtop',
|
| 382 |
+
'thorax',
|
| 383 |
+
'rshoulder',
|
| 384 |
+
'relbow',
|
| 385 |
+
'rwrist',
|
| 386 |
+
'lshoulder',
|
| 387 |
+
'lelbow',
|
| 388 |
+
'lwrist',
|
| 389 |
+
'rhip',
|
| 390 |
+
'rknee',
|
| 391 |
+
'rankle',
|
| 392 |
+
'lhip',
|
| 393 |
+
'lknee',
|
| 394 |
+
'lankle',
|
| 395 |
+
'hip',
|
| 396 |
+
'Spine (H36M)',
|
| 397 |
+
'Head (H36M)',
|
| 398 |
+
'R_Hand',
|
| 399 |
+
'L_Hand',
|
| 400 |
+
'R_Toe',
|
| 401 |
+
'L_Toe'
|
| 402 |
+
]
|
| 403 |
+
|
| 404 |
+
def get_h36m_joint_names():
|
| 405 |
+
return [
|
| 406 |
+
'hip', # 0
|
| 407 |
+
'lhip', # 1
|
| 408 |
+
'lknee', # 2
|
| 409 |
+
'lankle', # 3
|
| 410 |
+
'rhip', # 4
|
| 411 |
+
'rknee', # 5
|
| 412 |
+
'rankle', # 6
|
| 413 |
+
'Spine (H36M)', # 7
|
| 414 |
+
'neck', # 8
|
| 415 |
+
'Head (H36M)', # 9
|
| 416 |
+
'headtop', # 10
|
| 417 |
+
'lshoulder', # 11
|
| 418 |
+
'lelbow', # 12
|
| 419 |
+
'lwrist', # 13
|
| 420 |
+
'rshoulder', # 14
|
| 421 |
+
'relbow', # 15
|
| 422 |
+
'rwrist', # 16
|
| 423 |
+
]
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def get_spin_skeleton():
|
| 427 |
+
return np.array(
|
| 428 |
+
[
|
| 429 |
+
[0 , 1],
|
| 430 |
+
[1 , 2],
|
| 431 |
+
[2 , 3],
|
| 432 |
+
[3 , 4],
|
| 433 |
+
[1 , 5],
|
| 434 |
+
[5 , 6],
|
| 435 |
+
[6 , 7],
|
| 436 |
+
[1 , 8],
|
| 437 |
+
[8 , 9],
|
| 438 |
+
[9 ,10],
|
| 439 |
+
[10,11],
|
| 440 |
+
[8 ,12],
|
| 441 |
+
[12,13],
|
| 442 |
+
[13,14],
|
| 443 |
+
[0 ,15],
|
| 444 |
+
[0 ,16],
|
| 445 |
+
[15,17],
|
| 446 |
+
[16,18],
|
| 447 |
+
[21,19],
|
| 448 |
+
[19,20],
|
| 449 |
+
[14,21],
|
| 450 |
+
[11,24],
|
| 451 |
+
[24,22],
|
| 452 |
+
[22,23],
|
| 453 |
+
[0 ,38],
|
| 454 |
+
]
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def get_openpose_skeleton():
|
| 459 |
+
return np.array(
|
| 460 |
+
[
|
| 461 |
+
[0 , 1],
|
| 462 |
+
[1 , 2],
|
| 463 |
+
[2 , 3],
|
| 464 |
+
[3 , 4],
|
| 465 |
+
[1 , 5],
|
| 466 |
+
[5 , 6],
|
| 467 |
+
[6 , 7],
|
| 468 |
+
[1 , 8],
|
| 469 |
+
[8 , 9],
|
| 470 |
+
[9 ,10],
|
| 471 |
+
[10,11],
|
| 472 |
+
[8 ,12],
|
| 473 |
+
[12,13],
|
| 474 |
+
[13,14],
|
| 475 |
+
[0 ,15],
|
| 476 |
+
[0 ,16],
|
| 477 |
+
[15,17],
|
| 478 |
+
[16,18],
|
| 479 |
+
[21,19],
|
| 480 |
+
[19,20],
|
| 481 |
+
[14,21],
|
| 482 |
+
[11,24],
|
| 483 |
+
[24,22],
|
| 484 |
+
[22,23],
|
| 485 |
+
]
|
| 486 |
+
)
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
def get_posetrack_joint_names():
|
| 490 |
+
return [
|
| 491 |
+
"nose",
|
| 492 |
+
"neck",
|
| 493 |
+
"headtop",
|
| 494 |
+
"lear",
|
| 495 |
+
"rear",
|
| 496 |
+
"lshoulder",
|
| 497 |
+
"rshoulder",
|
| 498 |
+
"lelbow",
|
| 499 |
+
"relbow",
|
| 500 |
+
"lwrist",
|
| 501 |
+
"rwrist",
|
| 502 |
+
"lhip",
|
| 503 |
+
"rhip",
|
| 504 |
+
"lknee",
|
| 505 |
+
"rknee",
|
| 506 |
+
"lankle",
|
| 507 |
+
"rankle"
|
| 508 |
+
]
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def get_posetrack_original_kp_names():
|
| 512 |
+
return [
|
| 513 |
+
'nose',
|
| 514 |
+
'head_bottom',
|
| 515 |
+
'head_top',
|
| 516 |
+
'left_ear',
|
| 517 |
+
'right_ear',
|
| 518 |
+
'left_shoulder',
|
| 519 |
+
'right_shoulder',
|
| 520 |
+
'left_elbow',
|
| 521 |
+
'right_elbow',
|
| 522 |
+
'left_wrist',
|
| 523 |
+
'right_wrist',
|
| 524 |
+
'left_hip',
|
| 525 |
+
'right_hip',
|
| 526 |
+
'left_knee',
|
| 527 |
+
'right_knee',
|
| 528 |
+
'left_ankle',
|
| 529 |
+
'right_ankle'
|
| 530 |
+
]
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def get_pennaction_joint_names():
|
| 534 |
+
return [
|
| 535 |
+
"headtop", # 0
|
| 536 |
+
"lshoulder", # 1
|
| 537 |
+
"rshoulder", # 2
|
| 538 |
+
"lelbow", # 3
|
| 539 |
+
"relbow", # 4
|
| 540 |
+
"lwrist", # 5
|
| 541 |
+
"rwrist", # 6
|
| 542 |
+
"lhip" , # 7
|
| 543 |
+
"rhip" , # 8
|
| 544 |
+
"lknee", # 9
|
| 545 |
+
"rknee" , # 10
|
| 546 |
+
"lankle", # 11
|
| 547 |
+
"rankle" # 12
|
| 548 |
+
]
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def get_common_joint_names():
|
| 552 |
+
return [
|
| 553 |
+
"rankle", # 0 "lankle", # 0
|
| 554 |
+
"rknee", # 1 "lknee", # 1
|
| 555 |
+
"rhip", # 2 "lhip", # 2
|
| 556 |
+
"lhip", # 3 "rhip", # 3
|
| 557 |
+
"lknee", # 4 "rknee", # 4
|
| 558 |
+
"lankle", # 5 "rankle", # 5
|
| 559 |
+
"rwrist", # 6 "lwrist", # 6
|
| 560 |
+
"relbow", # 7 "lelbow", # 7
|
| 561 |
+
"rshoulder", # 8 "lshoulder", # 8
|
| 562 |
+
"lshoulder", # 9 "rshoulder", # 9
|
| 563 |
+
"lelbow", # 10 "relbow", # 10
|
| 564 |
+
"lwrist", # 11 "rwrist", # 11
|
| 565 |
+
"neck", # 12 "neck", # 12
|
| 566 |
+
"headtop", # 13 "headtop", # 13
|
| 567 |
+
]
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def get_common_paper_joint_names():
|
| 571 |
+
return [
|
| 572 |
+
"Right Ankle", # 0 "lankle", # 0
|
| 573 |
+
"Right Knee", # 1 "lknee", # 1
|
| 574 |
+
"Right Hip", # 2 "lhip", # 2
|
| 575 |
+
"Left Hip", # 3 "rhip", # 3
|
| 576 |
+
"Left Knee", # 4 "rknee", # 4
|
| 577 |
+
"Left Ankle", # 5 "rankle", # 5
|
| 578 |
+
"Right Wrist", # 6 "lwrist", # 6
|
| 579 |
+
"Right Elbow", # 7 "lelbow", # 7
|
| 580 |
+
"Right Shoulder", # 8 "lshoulder", # 8
|
| 581 |
+
"Left Shoulder", # 9 "rshoulder", # 9
|
| 582 |
+
"Left Elbow", # 10 "relbow", # 10
|
| 583 |
+
"Left Wrist", # 11 "rwrist", # 11
|
| 584 |
+
"Neck", # 12 "neck", # 12
|
| 585 |
+
"Head", # 13 "headtop", # 13
|
| 586 |
+
]
|
| 587 |
+
|
| 588 |
+
|
| 589 |
+
def get_common_skeleton():
|
| 590 |
+
return np.array(
|
| 591 |
+
[
|
| 592 |
+
[ 0, 1 ],
|
| 593 |
+
[ 1, 2 ],
|
| 594 |
+
[ 3, 4 ],
|
| 595 |
+
[ 4, 5 ],
|
| 596 |
+
[ 6, 7 ],
|
| 597 |
+
[ 7, 8 ],
|
| 598 |
+
[ 8, 2 ],
|
| 599 |
+
[ 8, 9 ],
|
| 600 |
+
[ 9, 3 ],
|
| 601 |
+
[ 2, 3 ],
|
| 602 |
+
[ 8, 12],
|
| 603 |
+
[ 9, 10],
|
| 604 |
+
[12, 9 ],
|
| 605 |
+
[10, 11],
|
| 606 |
+
[12, 13],
|
| 607 |
+
]
|
| 608 |
+
)
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def get_coco_joint_names():
|
| 612 |
+
return [
|
| 613 |
+
"nose", # 0
|
| 614 |
+
"leye", # 1
|
| 615 |
+
"reye", # 2
|
| 616 |
+
"lear", # 3
|
| 617 |
+
"rear", # 4
|
| 618 |
+
"lshoulder", # 5
|
| 619 |
+
"rshoulder", # 6
|
| 620 |
+
"lelbow", # 7
|
| 621 |
+
"relbow", # 8
|
| 622 |
+
"lwrist", # 9
|
| 623 |
+
"rwrist", # 10
|
| 624 |
+
"lhip", # 11
|
| 625 |
+
"rhip", # 12
|
| 626 |
+
"lknee", # 13
|
| 627 |
+
"rknee", # 14
|
| 628 |
+
"lankle", # 15
|
| 629 |
+
"rankle", # 16
|
| 630 |
+
]
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
def get_ochuman_joint_names():
|
| 634 |
+
return [
|
| 635 |
+
'rshoulder',
|
| 636 |
+
'relbow',
|
| 637 |
+
'rwrist',
|
| 638 |
+
'lshoulder',
|
| 639 |
+
'lelbow',
|
| 640 |
+
'lwrist',
|
| 641 |
+
'rhip',
|
| 642 |
+
'rknee',
|
| 643 |
+
'rankle',
|
| 644 |
+
'lhip',
|
| 645 |
+
'lknee',
|
| 646 |
+
'lankle',
|
| 647 |
+
'headtop',
|
| 648 |
+
'neck',
|
| 649 |
+
'rear',
|
| 650 |
+
'lear',
|
| 651 |
+
'nose',
|
| 652 |
+
'reye',
|
| 653 |
+
'leye'
|
| 654 |
+
]
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def get_crowdpose_joint_names():
|
| 658 |
+
return [
|
| 659 |
+
'lshoulder',
|
| 660 |
+
'rshoulder',
|
| 661 |
+
'lelbow',
|
| 662 |
+
'relbow',
|
| 663 |
+
'lwrist',
|
| 664 |
+
'rwrist',
|
| 665 |
+
'lhip',
|
| 666 |
+
'rhip',
|
| 667 |
+
'lknee',
|
| 668 |
+
'rknee',
|
| 669 |
+
'lankle',
|
| 670 |
+
'rankle',
|
| 671 |
+
'headtop',
|
| 672 |
+
'neck'
|
| 673 |
+
]
|
| 674 |
+
|
| 675 |
+
def get_coco_skeleton():
|
| 676 |
+
# 0 - nose,
|
| 677 |
+
# 1 - leye,
|
| 678 |
+
# 2 - reye,
|
| 679 |
+
# 3 - lear,
|
| 680 |
+
# 4 - rear,
|
| 681 |
+
# 5 - lshoulder,
|
| 682 |
+
# 6 - rshoulder,
|
| 683 |
+
# 7 - lelbow,
|
| 684 |
+
# 8 - relbow,
|
| 685 |
+
# 9 - lwrist,
|
| 686 |
+
# 10 - rwrist,
|
| 687 |
+
# 11 - lhip,
|
| 688 |
+
# 12 - rhip,
|
| 689 |
+
# 13 - lknee,
|
| 690 |
+
# 14 - rknee,
|
| 691 |
+
# 15 - lankle,
|
| 692 |
+
# 16 - rankle,
|
| 693 |
+
return np.array(
|
| 694 |
+
[
|
| 695 |
+
[15, 13],
|
| 696 |
+
[13, 11],
|
| 697 |
+
[16, 14],
|
| 698 |
+
[14, 12],
|
| 699 |
+
[11, 12],
|
| 700 |
+
[ 5, 11],
|
| 701 |
+
[ 6, 12],
|
| 702 |
+
[ 5, 6 ],
|
| 703 |
+
[ 5, 7 ],
|
| 704 |
+
[ 6, 8 ],
|
| 705 |
+
[ 7, 9 ],
|
| 706 |
+
[ 8, 10],
|
| 707 |
+
[ 1, 2 ],
|
| 708 |
+
[ 0, 1 ],
|
| 709 |
+
[ 0, 2 ],
|
| 710 |
+
[ 1, 3 ],
|
| 711 |
+
[ 2, 4 ],
|
| 712 |
+
[ 3, 5 ],
|
| 713 |
+
[ 4, 6 ]
|
| 714 |
+
]
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def get_mpii_joint_names():
|
| 719 |
+
return [
|
| 720 |
+
"rankle", # 0
|
| 721 |
+
"rknee", # 1
|
| 722 |
+
"rhip", # 2
|
| 723 |
+
"lhip", # 3
|
| 724 |
+
"lknee", # 4
|
| 725 |
+
"lankle", # 5
|
| 726 |
+
"hip", # 6
|
| 727 |
+
"thorax", # 7
|
| 728 |
+
"neck", # 8
|
| 729 |
+
"headtop", # 9
|
| 730 |
+
"rwrist", # 10
|
| 731 |
+
"relbow", # 11
|
| 732 |
+
"rshoulder", # 12
|
| 733 |
+
"lshoulder", # 13
|
| 734 |
+
"lelbow", # 14
|
| 735 |
+
"lwrist", # 15
|
| 736 |
+
]
|
| 737 |
+
|
| 738 |
+
|
| 739 |
+
def get_mpii_skeleton():
|
| 740 |
+
# 0 - rankle,
|
| 741 |
+
# 1 - rknee,
|
| 742 |
+
# 2 - rhip,
|
| 743 |
+
# 3 - lhip,
|
| 744 |
+
# 4 - lknee,
|
| 745 |
+
# 5 - lankle,
|
| 746 |
+
# 6 - hip,
|
| 747 |
+
# 7 - thorax,
|
| 748 |
+
# 8 - neck,
|
| 749 |
+
# 9 - headtop,
|
| 750 |
+
# 10 - rwrist,
|
| 751 |
+
# 11 - relbow,
|
| 752 |
+
# 12 - rshoulder,
|
| 753 |
+
# 13 - lshoulder,
|
| 754 |
+
# 14 - lelbow,
|
| 755 |
+
# 15 - lwrist,
|
| 756 |
+
return np.array(
|
| 757 |
+
[
|
| 758 |
+
[ 0, 1 ],
|
| 759 |
+
[ 1, 2 ],
|
| 760 |
+
[ 2, 6 ],
|
| 761 |
+
[ 6, 3 ],
|
| 762 |
+
[ 3, 4 ],
|
| 763 |
+
[ 4, 5 ],
|
| 764 |
+
[ 6, 7 ],
|
| 765 |
+
[ 7, 8 ],
|
| 766 |
+
[ 8, 9 ],
|
| 767 |
+
[ 7, 12],
|
| 768 |
+
[12, 11],
|
| 769 |
+
[11, 10],
|
| 770 |
+
[ 7, 13],
|
| 771 |
+
[13, 14],
|
| 772 |
+
[14, 15]
|
| 773 |
+
]
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
def get_aich_joint_names():
|
| 778 |
+
return [
|
| 779 |
+
"rshoulder", # 0
|
| 780 |
+
"relbow", # 1
|
| 781 |
+
"rwrist", # 2
|
| 782 |
+
"lshoulder", # 3
|
| 783 |
+
"lelbow", # 4
|
| 784 |
+
"lwrist", # 5
|
| 785 |
+
"rhip", # 6
|
| 786 |
+
"rknee", # 7
|
| 787 |
+
"rankle", # 8
|
| 788 |
+
"lhip", # 9
|
| 789 |
+
"lknee", # 10
|
| 790 |
+
"lankle", # 11
|
| 791 |
+
"headtop", # 12
|
| 792 |
+
"neck", # 13
|
| 793 |
+
]
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def get_aich_skeleton():
|
| 797 |
+
# 0 - rshoulder,
|
| 798 |
+
# 1 - relbow,
|
| 799 |
+
# 2 - rwrist,
|
| 800 |
+
# 3 - lshoulder,
|
| 801 |
+
# 4 - lelbow,
|
| 802 |
+
# 5 - lwrist,
|
| 803 |
+
# 6 - rhip,
|
| 804 |
+
# 7 - rknee,
|
| 805 |
+
# 8 - rankle,
|
| 806 |
+
# 9 - lhip,
|
| 807 |
+
# 10 - lknee,
|
| 808 |
+
# 11 - lankle,
|
| 809 |
+
# 12 - headtop,
|
| 810 |
+
# 13 - neck,
|
| 811 |
+
return np.array(
|
| 812 |
+
[
|
| 813 |
+
[ 0, 1 ],
|
| 814 |
+
[ 1, 2 ],
|
| 815 |
+
[ 3, 4 ],
|
| 816 |
+
[ 4, 5 ],
|
| 817 |
+
[ 6, 7 ],
|
| 818 |
+
[ 7, 8 ],
|
| 819 |
+
[ 9, 10],
|
| 820 |
+
[10, 11],
|
| 821 |
+
[12, 13],
|
| 822 |
+
[13, 0 ],
|
| 823 |
+
[13, 3 ],
|
| 824 |
+
[ 0, 6 ],
|
| 825 |
+
[ 3, 9 ]
|
| 826 |
+
]
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
def get_3dpw_joint_names():
|
| 831 |
+
return [
|
| 832 |
+
"nose", # 0
|
| 833 |
+
"thorax", # 1
|
| 834 |
+
"rshoulder", # 2
|
| 835 |
+
"relbow", # 3
|
| 836 |
+
"rwrist", # 4
|
| 837 |
+
"lshoulder", # 5
|
| 838 |
+
"lelbow", # 6
|
| 839 |
+
"lwrist", # 7
|
| 840 |
+
"rhip", # 8
|
| 841 |
+
"rknee", # 9
|
| 842 |
+
"rankle", # 10
|
| 843 |
+
"lhip", # 11
|
| 844 |
+
"lknee", # 12
|
| 845 |
+
"lankle", # 13
|
| 846 |
+
]
|
| 847 |
+
|
| 848 |
+
|
| 849 |
+
def get_3dpw_skeleton():
|
| 850 |
+
return np.array(
|
| 851 |
+
[
|
| 852 |
+
[ 0, 1 ],
|
| 853 |
+
[ 1, 2 ],
|
| 854 |
+
[ 2, 3 ],
|
| 855 |
+
[ 3, 4 ],
|
| 856 |
+
[ 1, 5 ],
|
| 857 |
+
[ 5, 6 ],
|
| 858 |
+
[ 6, 7 ],
|
| 859 |
+
[ 2, 8 ],
|
| 860 |
+
[ 5, 11],
|
| 861 |
+
[ 8, 11],
|
| 862 |
+
[ 8, 9 ],
|
| 863 |
+
[ 9, 10],
|
| 864 |
+
[11, 12],
|
| 865 |
+
[12, 13]
|
| 866 |
+
]
|
| 867 |
+
)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def get_smplcoco_joint_names():
|
| 871 |
+
return [
|
| 872 |
+
"rankle", # 0
|
| 873 |
+
"rknee", # 1
|
| 874 |
+
"rhip", # 2
|
| 875 |
+
"lhip", # 3
|
| 876 |
+
"lknee", # 4
|
| 877 |
+
"lankle", # 5
|
| 878 |
+
"rwrist", # 6
|
| 879 |
+
"relbow", # 7
|
| 880 |
+
"rshoulder", # 8
|
| 881 |
+
"lshoulder", # 9
|
| 882 |
+
"lelbow", # 10
|
| 883 |
+
"lwrist", # 11
|
| 884 |
+
"neck", # 12
|
| 885 |
+
"headtop", # 13
|
| 886 |
+
"nose", # 14
|
| 887 |
+
"leye", # 15
|
| 888 |
+
"reye", # 16
|
| 889 |
+
"lear", # 17
|
| 890 |
+
"rear", # 18
|
| 891 |
+
]
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def get_smplcoco_skeleton():
|
| 895 |
+
return np.array(
|
| 896 |
+
[
|
| 897 |
+
[ 0, 1 ],
|
| 898 |
+
[ 1, 2 ],
|
| 899 |
+
[ 3, 4 ],
|
| 900 |
+
[ 4, 5 ],
|
| 901 |
+
[ 6, 7 ],
|
| 902 |
+
[ 7, 8 ],
|
| 903 |
+
[ 8, 12],
|
| 904 |
+
[12, 9 ],
|
| 905 |
+
[ 9, 10],
|
| 906 |
+
[10, 11],
|
| 907 |
+
[12, 13],
|
| 908 |
+
[14, 15],
|
| 909 |
+
[15, 17],
|
| 910 |
+
[16, 18],
|
| 911 |
+
[14, 16],
|
| 912 |
+
[ 8, 2 ],
|
| 913 |
+
[ 9, 3 ],
|
| 914 |
+
[ 2, 3 ],
|
| 915 |
+
]
|
| 916 |
+
)
|
| 917 |
+
|
| 918 |
+
|
| 919 |
+
def get_smpl_joint_names():
|
| 920 |
+
return [
|
| 921 |
+
'hips', # 0
|
| 922 |
+
'leftUpLeg', # 1
|
| 923 |
+
'rightUpLeg', # 2
|
| 924 |
+
'spine', # 3
|
| 925 |
+
'leftLeg', # 4
|
| 926 |
+
'rightLeg', # 5
|
| 927 |
+
'spine1', # 6
|
| 928 |
+
'leftFoot', # 7
|
| 929 |
+
'rightFoot', # 8
|
| 930 |
+
'spine2', # 9
|
| 931 |
+
'leftToeBase', # 10
|
| 932 |
+
'rightToeBase', # 11
|
| 933 |
+
'neck', # 12
|
| 934 |
+
'leftShoulder', # 13
|
| 935 |
+
'rightShoulder', # 14
|
| 936 |
+
'head', # 15
|
| 937 |
+
'leftArm', # 16
|
| 938 |
+
'rightArm', # 17
|
| 939 |
+
'leftForeArm', # 18
|
| 940 |
+
'rightForeArm', # 19
|
| 941 |
+
'leftHand', # 20
|
| 942 |
+
'rightHand', # 21
|
| 943 |
+
'leftHandIndex1', # 22
|
| 944 |
+
'rightHandIndex1', # 23
|
| 945 |
+
]
|
| 946 |
+
|
| 947 |
+
|
| 948 |
+
def get_smpl_paper_joint_names():
|
| 949 |
+
return [
|
| 950 |
+
'Hips', # 0
|
| 951 |
+
'Left Hip', # 1
|
| 952 |
+
'Right Hip', # 2
|
| 953 |
+
'Spine', # 3
|
| 954 |
+
'Left Knee', # 4
|
| 955 |
+
'Right Knee', # 5
|
| 956 |
+
'Spine_1', # 6
|
| 957 |
+
'Left Ankle', # 7
|
| 958 |
+
'Right Ankle', # 8
|
| 959 |
+
'Spine_2', # 9
|
| 960 |
+
'Left Toe', # 10
|
| 961 |
+
'Right Toe', # 11
|
| 962 |
+
'Neck', # 12
|
| 963 |
+
'Left Shoulder', # 13
|
| 964 |
+
'Right Shoulder', # 14
|
| 965 |
+
'Head', # 15
|
| 966 |
+
'Left Arm', # 16
|
| 967 |
+
'Right Arm', # 17
|
| 968 |
+
'Left Elbow', # 18
|
| 969 |
+
'Right Elbow', # 19
|
| 970 |
+
'Left Hand', # 20
|
| 971 |
+
'Right Hand', # 21
|
| 972 |
+
'Left Thumb', # 22
|
| 973 |
+
'Right Thumb', # 23
|
| 974 |
+
]
|
| 975 |
+
|
| 976 |
+
|
| 977 |
+
def get_smpl_neighbor_triplets():
|
| 978 |
+
return [
|
| 979 |
+
[ 0, 1, 2 ], # 0
|
| 980 |
+
[ 1, 4, 0 ], # 1
|
| 981 |
+
[ 2, 0, 5 ], # 2
|
| 982 |
+
[ 3, 0, 6 ], # 3
|
| 983 |
+
[ 4, 7, 1 ], # 4
|
| 984 |
+
[ 5, 2, 8 ], # 5
|
| 985 |
+
[ 6, 3, 9 ], # 6
|
| 986 |
+
[ 7, 10, 4 ], # 7
|
| 987 |
+
[ 8, 5, 11], # 8
|
| 988 |
+
[ 9, 13, 14], # 9
|
| 989 |
+
[10, 7, 4 ], # 10
|
| 990 |
+
[11, 8, 5 ], # 11
|
| 991 |
+
[12, 9, 15], # 12
|
| 992 |
+
[13, 16, 9 ], # 13
|
| 993 |
+
[14, 9, 17], # 14
|
| 994 |
+
[15, 9, 12], # 15
|
| 995 |
+
[16, 18, 13], # 16
|
| 996 |
+
[17, 14, 19], # 17
|
| 997 |
+
[18, 20, 16], # 18
|
| 998 |
+
[19, 17, 21], # 19
|
| 999 |
+
[20, 22, 18], # 20
|
| 1000 |
+
[21, 19, 23], # 21
|
| 1001 |
+
[22, 20, 18], # 22
|
| 1002 |
+
[23, 19, 21], # 23
|
| 1003 |
+
]
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def get_smpl_skeleton():
|
| 1007 |
+
return np.array(
|
| 1008 |
+
[
|
| 1009 |
+
[ 0, 1 ],
|
| 1010 |
+
[ 0, 2 ],
|
| 1011 |
+
[ 0, 3 ],
|
| 1012 |
+
[ 1, 4 ],
|
| 1013 |
+
[ 2, 5 ],
|
| 1014 |
+
[ 3, 6 ],
|
| 1015 |
+
[ 4, 7 ],
|
| 1016 |
+
[ 5, 8 ],
|
| 1017 |
+
[ 6, 9 ],
|
| 1018 |
+
[ 7, 10],
|
| 1019 |
+
[ 8, 11],
|
| 1020 |
+
[ 9, 12],
|
| 1021 |
+
[ 9, 13],
|
| 1022 |
+
[ 9, 14],
|
| 1023 |
+
[12, 15],
|
| 1024 |
+
[13, 16],
|
| 1025 |
+
[14, 17],
|
| 1026 |
+
[16, 18],
|
| 1027 |
+
[17, 19],
|
| 1028 |
+
[18, 20],
|
| 1029 |
+
[19, 21],
|
| 1030 |
+
[20, 22],
|
| 1031 |
+
[21, 23],
|
| 1032 |
+
]
|
| 1033 |
+
)
|
| 1034 |
+
|
| 1035 |
+
|
| 1036 |
+
def map_spin_joints_to_smpl():
|
| 1037 |
+
# this function primarily will be used to copy 2D keypoint
|
| 1038 |
+
# confidences to pose parameters
|
| 1039 |
+
return [
|
| 1040 |
+
[(39, 27, 28), 0], # hip,lhip,rhip->hips
|
| 1041 |
+
[(28,), 1], # lhip->leftUpLeg
|
| 1042 |
+
[(27,), 2], # rhip->rightUpLeg
|
| 1043 |
+
[(41, 27, 28, 39), 3], # Spine->spine
|
| 1044 |
+
[(29,), 4], # lknee->leftLeg
|
| 1045 |
+
[(26,), 5], # rknee->rightLeg
|
| 1046 |
+
[(41, 40, 33, 34,), 6], # spine, thorax ->spine1
|
| 1047 |
+
[(30,), 7], # lankle->leftFoot
|
| 1048 |
+
[(25,), 8], # rankle->rightFoot
|
| 1049 |
+
[(40, 33, 34), 9], # thorax,shoulders->spine2
|
| 1050 |
+
[(30,), 10], # lankle -> leftToe
|
| 1051 |
+
[(25,), 11], # rankle -> rightToe
|
| 1052 |
+
[(37, 42, 33, 34), 12], # neck, shoulders -> neck
|
| 1053 |
+
[(34,), 13], # lshoulder->leftShoulder
|
| 1054 |
+
[(33,), 14], # rshoulder->rightShoulder
|
| 1055 |
+
[(33, 34, 38, 43, 44, 45, 46, 47, 48,), 15], # nose, eyes, ears, headtop, shoulders->head
|
| 1056 |
+
[(34,), 16], # lshoulder->leftArm
|
| 1057 |
+
[(33,), 17], # rshoulder->rightArm
|
| 1058 |
+
[(35,), 18], # lelbow->leftForeArm
|
| 1059 |
+
[(32,), 19], # relbow->rightForeArm
|
| 1060 |
+
[(36,), 20], # lwrist->leftHand
|
| 1061 |
+
[(31,), 21], # rwrist->rightHand
|
| 1062 |
+
[(36,), 22], # lhand -> leftHandIndex
|
| 1063 |
+
[(31,), 23], # rhand -> rightHandIndex
|
| 1064 |
+
]
|
| 1065 |
+
|
| 1066 |
+
|
| 1067 |
+
def map_smpl_to_common():
|
| 1068 |
+
return [
|
| 1069 |
+
[(11, 8), 0], # rightToe, rightFoot -> rankle
|
| 1070 |
+
[(5,), 1], # rightleg -> rknee,
|
| 1071 |
+
[(2,), 2], # rhip
|
| 1072 |
+
[(1,), 3], # lhip
|
| 1073 |
+
[(4,), 4], # leftLeg -> lknee
|
| 1074 |
+
[(10, 7), 5], # lefttoe, leftfoot -> lankle
|
| 1075 |
+
[(21, 23), 6], # rwrist
|
| 1076 |
+
[(18,), 7], # relbow
|
| 1077 |
+
[(17, 14), 8], # rshoulder
|
| 1078 |
+
[(16, 13), 9], # lshoulder
|
| 1079 |
+
[(19,), 10], # lelbow
|
| 1080 |
+
[(20, 22), 11], # lwrist
|
| 1081 |
+
[(0, 3, 6, 9, 12), 12], # neck
|
| 1082 |
+
[(15,), 13], # headtop
|
| 1083 |
+
]
|
| 1084 |
+
|
| 1085 |
+
|
| 1086 |
+
def relation_among_spin_joints():
|
| 1087 |
+
# this function primarily will be used to copy 2D keypoint
|
| 1088 |
+
# confidences to 3D joints
|
| 1089 |
+
return [
|
| 1090 |
+
[(), 25],
|
| 1091 |
+
[(), 26],
|
| 1092 |
+
[(39,), 27],
|
| 1093 |
+
[(39,), 28],
|
| 1094 |
+
[(), 29],
|
| 1095 |
+
[(), 30],
|
| 1096 |
+
[(), 31],
|
| 1097 |
+
[(), 32],
|
| 1098 |
+
[(), 33],
|
| 1099 |
+
[(), 34],
|
| 1100 |
+
[(), 35],
|
| 1101 |
+
[(), 36],
|
| 1102 |
+
[(40,42,44,43,38,33,34,), 37],
|
| 1103 |
+
[(43,44,45,46,47,48,33,34,), 38],
|
| 1104 |
+
[(27,28,), 39],
|
| 1105 |
+
[(27,28,37,41,42,), 40],
|
| 1106 |
+
[(27,28,39,40,), 41],
|
| 1107 |
+
[(37,38,44,45,46,47,48,), 42],
|
| 1108 |
+
[(44,45,46,47,48,38,42,37,33,34,), 43],
|
| 1109 |
+
[(44,45,46,47,48,38,42,37,33,34), 44],
|
| 1110 |
+
[(44,45,46,47,48,38,42,37,33,34), 45],
|
| 1111 |
+
[(44,45,46,47,48,38,42,37,33,34), 46],
|
| 1112 |
+
[(44,45,46,47,48,38,42,37,33,34), 47],
|
| 1113 |
+
[(44,45,46,47,48,38,42,37,33,34), 48],
|
| 1114 |
+
]
|
utils/loss.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from common import constants
|
| 4 |
+
from models.smpl import SMPL
|
| 5 |
+
from smplx import SMPLX
|
| 6 |
+
import pickle as pkl
|
| 7 |
+
import numpy as np
|
| 8 |
+
from utils.mesh_utils import save_results_mesh
|
| 9 |
+
from utils.diff_renderer import Pytorch3D
|
| 10 |
+
import os
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class sem_loss_function(nn.Module):
|
| 15 |
+
def __init__(self):
|
| 16 |
+
super(sem_loss_function, self).__init__()
|
| 17 |
+
self.ce = nn.BCELoss()
|
| 18 |
+
|
| 19 |
+
def forward(self, y_true, y_pred):
|
| 20 |
+
loss = self.ce(y_pred, y_true)
|
| 21 |
+
return loss
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class class_loss_function(nn.Module):
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super(class_loss_function, self).__init__()
|
| 27 |
+
self.ce_loss = nn.BCELoss()
|
| 28 |
+
# self.ce_loss = nn.MultiLabelSoftMarginLoss()
|
| 29 |
+
# self.ce_loss = nn.MultiLabelMarginLoss()
|
| 30 |
+
|
| 31 |
+
def forward(self, y_true, y_pred, valid_mask):
|
| 32 |
+
# y_true = torch.squeeze(y_true, 1).long()
|
| 33 |
+
# y_true = torch.squeeze(y_true, 1)
|
| 34 |
+
# y_pred = torch.squeeze(y_pred, 1)
|
| 35 |
+
bs = y_true.shape[0]
|
| 36 |
+
if bs != 1:
|
| 37 |
+
y_pred = y_pred[valid_mask == 1]
|
| 38 |
+
y_true = y_true[valid_mask == 1]
|
| 39 |
+
if len(y_pred) > 0:
|
| 40 |
+
return self.ce_loss(y_pred, y_true)
|
| 41 |
+
else:
|
| 42 |
+
return torch.tensor(0.0).to(y_pred.device)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class pixel_anchoring_function(nn.Module):
|
| 46 |
+
def __init__(self, model_type, device='cuda'):
|
| 47 |
+
super(pixel_anchoring_function, self).__init__()
|
| 48 |
+
|
| 49 |
+
self.device = device
|
| 50 |
+
|
| 51 |
+
self.model_type = model_type
|
| 52 |
+
|
| 53 |
+
if self.model_type == 'smplx':
|
| 54 |
+
# load mapping from smpl vertices to smplx vertices
|
| 55 |
+
mapping_pkl = os.path.join(constants.CONTACT_MAPPING_PATH, "smpl_to_smplx.pkl")
|
| 56 |
+
with open(mapping_pkl, 'rb') as f:
|
| 57 |
+
smpl_to_smplx_mapping = pkl.load(f)
|
| 58 |
+
smpl_to_smplx_mapping = smpl_to_smplx_mapping["matrix"]
|
| 59 |
+
self.smpl_to_smplx_mapping = torch.from_numpy(smpl_to_smplx_mapping).float().to(self.device)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Setup the SMPL model
|
| 63 |
+
if self.model_type == 'smpl':
|
| 64 |
+
self.n_vertices = 6890
|
| 65 |
+
self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
|
| 66 |
+
if self.model_type == 'smplx':
|
| 67 |
+
self.n_vertices = 10475
|
| 68 |
+
self.body_model = SMPLX(constants.SMPLX_MODEL_DIR,
|
| 69 |
+
num_betas=10,
|
| 70 |
+
use_pca=False).to(self.device)
|
| 71 |
+
self.body_faces = torch.LongTensor(self.body_model.faces.astype(np.int32)).to(self.device)
|
| 72 |
+
|
| 73 |
+
self.ce_loss = nn.BCELoss()
|
| 74 |
+
|
| 75 |
+
def get_posed_mesh(self, body_params, debug=False):
|
| 76 |
+
betas = body_params['betas']
|
| 77 |
+
pose = body_params['pose']
|
| 78 |
+
transl = body_params['transl']
|
| 79 |
+
|
| 80 |
+
# extra smplx params
|
| 81 |
+
extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
| 82 |
+
'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
| 83 |
+
'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
|
| 84 |
+
'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
|
| 85 |
+
'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
|
| 86 |
+
'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
|
| 87 |
+
|
| 88 |
+
smpl_output = self.body_model(betas=betas,
|
| 89 |
+
body_pose=pose[:, 3:],
|
| 90 |
+
global_orient=pose[:, :3],
|
| 91 |
+
pose2rot=True,
|
| 92 |
+
transl=transl,
|
| 93 |
+
**extra_args)
|
| 94 |
+
smpl_verts = smpl_output.vertices
|
| 95 |
+
smpl_joints = smpl_output.joints
|
| 96 |
+
|
| 97 |
+
if debug:
|
| 98 |
+
for mesh_i in range(smpl_verts.shape[0]):
|
| 99 |
+
out_dir = 'temp_meshes'
|
| 100 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 101 |
+
out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
|
| 102 |
+
save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
|
| 103 |
+
return smpl_verts, smpl_joints
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def render_batch(self, smpl_verts, cam_k, img_scale_factor, vertex_colors=None, face_textures=None, debug=False):
|
| 107 |
+
|
| 108 |
+
bs = smpl_verts.shape[0]
|
| 109 |
+
|
| 110 |
+
# Incorporate resizing factor into the camera
|
| 111 |
+
img_w = 256 # TODO: Remove hardcoding
|
| 112 |
+
img_h = 256 # TODO: Remove hardcoding
|
| 113 |
+
focal_length_x = cam_k[:, 0, 0] * img_scale_factor[:, 0]
|
| 114 |
+
focal_length_y = cam_k[:, 1, 1] * img_scale_factor[:, 1]
|
| 115 |
+
# convert to float for pytorch3d
|
| 116 |
+
focal_length_x, focal_length_y = focal_length_x.float(), focal_length_y.float()
|
| 117 |
+
|
| 118 |
+
# concatenate focal length
|
| 119 |
+
focal_length = torch.stack([focal_length_x, focal_length_y], dim=1)
|
| 120 |
+
|
| 121 |
+
# Setup renderer
|
| 122 |
+
renderer = Pytorch3D(img_h=img_h,
|
| 123 |
+
img_w=img_w,
|
| 124 |
+
focal_length=focal_length,
|
| 125 |
+
smpl_faces=self.body_faces,
|
| 126 |
+
texture_mode='deco',
|
| 127 |
+
vertex_colors=vertex_colors,
|
| 128 |
+
face_textures=face_textures,
|
| 129 |
+
is_train=True,
|
| 130 |
+
is_cam_batch=True)
|
| 131 |
+
front_view = renderer(smpl_verts)
|
| 132 |
+
if debug:
|
| 133 |
+
# visualize the front view as images in a temp_image folder
|
| 134 |
+
for i in range(bs):
|
| 135 |
+
front_view_rgb = front_view[i, :3, :, :].permute(1, 2, 0).detach().cpu()
|
| 136 |
+
front_view_mask = front_view[i, 3, :, :].detach().cpu()
|
| 137 |
+
out_dir = 'temp_images'
|
| 138 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 139 |
+
out_file_rgb = os.path.join(out_dir, f'{i:04d}_rgb.png')
|
| 140 |
+
out_file_mask = os.path.join(out_dir, f'{i:04d}_mask.png')
|
| 141 |
+
cv2.imwrite(out_file_rgb, front_view_rgb.numpy()*255)
|
| 142 |
+
cv2.imwrite(out_file_mask, front_view_mask.numpy()*255)
|
| 143 |
+
|
| 144 |
+
return front_view
|
| 145 |
+
|
| 146 |
+
def paint_contact(self, pred_contact):
|
| 147 |
+
"""
|
| 148 |
+
Paints the contact vertices on the SMPL mesh
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
pred_contact: prbabilities of contact vertices
|
| 152 |
+
|
| 153 |
+
Returns:
|
| 154 |
+
pred_rgb: RGB colors for the contact vertices
|
| 155 |
+
"""
|
| 156 |
+
bs = pred_contact.shape[0]
|
| 157 |
+
|
| 158 |
+
# initialize black and while colors
|
| 159 |
+
colors = torch.tensor([[0, 0, 0], [1, 1, 1]]).float().to(self.device)
|
| 160 |
+
colors = torch.unsqueeze(colors, 0).expand(bs, -1, -1)
|
| 161 |
+
|
| 162 |
+
# add another dimension to the contact probabilities for inverse probabilities
|
| 163 |
+
pred_contact = torch.unsqueeze(pred_contact, 2)
|
| 164 |
+
pred_contact = torch.cat((1 - pred_contact, pred_contact), 2)
|
| 165 |
+
|
| 166 |
+
# get pred_rgb colors
|
| 167 |
+
pred_vert_rgb = torch.bmm(pred_contact, colors)
|
| 168 |
+
pred_face_rgb = pred_vert_rgb[:, self.body_faces, :][:, :, 0, :] # take the first vertex color
|
| 169 |
+
pred_face_texture = torch.zeros((bs, self.body_faces.shape[0], 1, 1, 3), dtype=torch.float32).to(self.device)
|
| 170 |
+
pred_face_texture[:, :, 0, 0, :] = pred_face_rgb
|
| 171 |
+
return pred_vert_rgb, pred_face_texture
|
| 172 |
+
|
| 173 |
+
def forward(self, pred_contact, body_params, cam_k, img_scale_factor, gt_contact_polygon, valid_mask):
|
| 174 |
+
"""
|
| 175 |
+
Takes predicted contact labels (probabilities), transfers them to the posed mesh and
|
| 176 |
+
renders to the image. Loss is computed between the rendered contact and the ground truth
|
| 177 |
+
polygons from HOT.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
pred_contact: predicted contact labels (probabilities)
|
| 181 |
+
body_params: SMPL parameters in camera coords
|
| 182 |
+
cam_k: camera intrinsics
|
| 183 |
+
gt_contact_polygon: ground truth polygons from HOT
|
| 184 |
+
"""
|
| 185 |
+
# convert pred_contact to smplx
|
| 186 |
+
bs = pred_contact.shape[0]
|
| 187 |
+
if self.model_type == 'smplx':
|
| 188 |
+
smpl_to_smplx_mapping = self.smpl_to_smplx_mapping[None].expand(bs, -1, -1)
|
| 189 |
+
pred_contact = torch.bmm(smpl_to_smplx_mapping, pred_contact[..., None])
|
| 190 |
+
pred_contact = pred_contact.squeeze()
|
| 191 |
+
|
| 192 |
+
# get the posed mesh
|
| 193 |
+
smpl_verts, smpl_joints = self.get_posed_mesh(body_params)
|
| 194 |
+
|
| 195 |
+
# paint the contact vertices on the mesh
|
| 196 |
+
vertex_colors, face_textures = self.paint_contact(pred_contact)
|
| 197 |
+
|
| 198 |
+
# render the mesh
|
| 199 |
+
front_view = self.render_batch(smpl_verts, cam_k, img_scale_factor, vertex_colors, face_textures)
|
| 200 |
+
front_view_rgb = front_view[:, :3, :, :].permute(0, 2, 3, 1)
|
| 201 |
+
front_view_mask = front_view[:, 3, :, :]
|
| 202 |
+
|
| 203 |
+
# compute segmentation loss between rendered contact mask and ground truth contact mask
|
| 204 |
+
front_view_rgb = front_view_rgb[valid_mask == 1]
|
| 205 |
+
gt_contact_polygon = gt_contact_polygon[valid_mask == 1]
|
| 206 |
+
loss = self.ce_loss(front_view_rgb, gt_contact_polygon)
|
| 207 |
+
return loss, front_view_rgb, front_view_mask
|
utils/mesh_utils.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import trimesh
|
| 2 |
+
|
| 3 |
+
def save_results_mesh(vertices, faces, filename):
|
| 4 |
+
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
| 5 |
+
mesh.export(filename)
|
| 6 |
+
print(f'save results to {filename}')
|
utils/metrics.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import monai.metrics as metrics
|
| 4 |
+
from common.constants import DIST_MATRIX_PATH
|
| 5 |
+
|
| 6 |
+
DIST_MATRIX = np.load(DIST_MATRIX_PATH)
|
| 7 |
+
|
| 8 |
+
def metric(mask, pred, back=True):
|
| 9 |
+
iou = metrics.compute_meaniou(pred, mask, back, False)
|
| 10 |
+
iou = iou.mean()
|
| 11 |
+
|
| 12 |
+
return iou
|
| 13 |
+
|
| 14 |
+
def precision_recall_f1score(gt, pred):
|
| 15 |
+
"""
|
| 16 |
+
Compute precision, recall, and f1
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
# gt = gt.numpy()
|
| 20 |
+
# pred = pred.numpy()
|
| 21 |
+
|
| 22 |
+
precision = torch.zeros(gt.shape[0])
|
| 23 |
+
recall = torch.zeros(gt.shape[0])
|
| 24 |
+
f1 = torch.zeros(gt.shape[0])
|
| 25 |
+
|
| 26 |
+
for b in range(gt.shape[0]):
|
| 27 |
+
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
| 28 |
+
precision_denominator = (pred[b, :] >= 0.5).sum()
|
| 29 |
+
recall_denominator = (gt[b, :]).sum()
|
| 30 |
+
|
| 31 |
+
precision_ = tp_num / precision_denominator
|
| 32 |
+
recall_ = tp_num / recall_denominator
|
| 33 |
+
if precision_denominator == 0: # if no pred
|
| 34 |
+
precision_ = 1.
|
| 35 |
+
recall_ = 0.
|
| 36 |
+
f1_ = 0.
|
| 37 |
+
elif recall_denominator == 0: # if no GT
|
| 38 |
+
precision_ = 0.
|
| 39 |
+
recall_ = 1.
|
| 40 |
+
f1_ = 0.
|
| 41 |
+
elif (precision_ + recall_) <= 1e-10: # to avoid precision issues
|
| 42 |
+
precision_= 0.
|
| 43 |
+
recall_= 0.
|
| 44 |
+
f1_ = 0.
|
| 45 |
+
else:
|
| 46 |
+
f1_ = 2 * precision_ * recall_ / (precision_ + recall_)
|
| 47 |
+
|
| 48 |
+
precision[b] = precision_
|
| 49 |
+
recall[b] = recall_
|
| 50 |
+
f1[b] = f1_
|
| 51 |
+
|
| 52 |
+
# return precision, recall, f1
|
| 53 |
+
return precision, recall, f1
|
| 54 |
+
|
| 55 |
+
def acc_precision_recall_f1score(gt, pred):
|
| 56 |
+
"""
|
| 57 |
+
Compute acc, precision, recall, and f1
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
# gt = gt.numpy()
|
| 61 |
+
# pred = pred.numpy()
|
| 62 |
+
|
| 63 |
+
acc = torch.zeros(gt.shape[0])
|
| 64 |
+
precision = torch.zeros(gt.shape[0])
|
| 65 |
+
recall = torch.zeros(gt.shape[0])
|
| 66 |
+
f1 = torch.zeros(gt.shape[0])
|
| 67 |
+
|
| 68 |
+
for b in range(gt.shape[0]):
|
| 69 |
+
tp_num = gt[b, pred[b, :] >= 0.5].sum()
|
| 70 |
+
precision_denominator = (pred[b, :] >= 0.5).sum()
|
| 71 |
+
recall_denominator = (gt[b, :]).sum()
|
| 72 |
+
tn_num = gt.shape[-1] - precision_denominator - recall_denominator + tp_num
|
| 73 |
+
|
| 74 |
+
acc_ = (tp_num + tn_num) / gt.shape[-1]
|
| 75 |
+
precision_ = tp_num / (precision_denominator + 1e-10)
|
| 76 |
+
recall_ = tp_num / (recall_denominator + 1e-10)
|
| 77 |
+
f1_ = 2 * precision_ * recall_ / (precision_ + recall_ + 1e-10)
|
| 78 |
+
|
| 79 |
+
acc[b] = acc_
|
| 80 |
+
precision[b] = precision_
|
| 81 |
+
recall[b] = recall_
|
| 82 |
+
|
| 83 |
+
# return precision, recall, f1
|
| 84 |
+
return acc, precision, recall, f1
|
| 85 |
+
|
| 86 |
+
def det_error_metric(pred, gt):
|
| 87 |
+
|
| 88 |
+
gt = gt.detach().cpu()
|
| 89 |
+
pred = pred.detach().cpu()
|
| 90 |
+
|
| 91 |
+
dist_matrix = torch.tensor(DIST_MATRIX)
|
| 92 |
+
|
| 93 |
+
false_positive_dist = torch.zeros(gt.shape[0])
|
| 94 |
+
false_negative_dist = torch.zeros(gt.shape[0])
|
| 95 |
+
|
| 96 |
+
for b in range(gt.shape[0]):
|
| 97 |
+
gt_columns = dist_matrix[:, gt[b, :]==1] if any(gt[b, :]==1) else dist_matrix
|
| 98 |
+
error_matrix = gt_columns[pred[b, :] >= 0.5, :] if any(pred[b, :] >= 0.5) else gt_columns
|
| 99 |
+
|
| 100 |
+
false_positive_dist_ = error_matrix.min(dim=1)[0].mean()
|
| 101 |
+
false_negative_dist_ = error_matrix.min(dim=0)[0].mean()
|
| 102 |
+
|
| 103 |
+
false_positive_dist[b] = false_positive_dist_
|
| 104 |
+
false_negative_dist[b] = false_negative_dist_
|
| 105 |
+
|
| 106 |
+
return false_positive_dist, false_negative_dist
|
utils/smpl_uv.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import trimesh
|
| 3 |
+
import numpy as np
|
| 4 |
+
import skimage.io as io
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from smplx import SMPL
|
| 7 |
+
from matplotlib import cm as mpl_cm, colors as mpl_colors
|
| 8 |
+
from trimesh.visual.color import face_to_vertex_color, vertex_to_face_color, to_rgba
|
| 9 |
+
|
| 10 |
+
from common import constants
|
| 11 |
+
from .colorwheel import make_color_wheel_image
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def get_smpl_uv():
|
| 15 |
+
uv_obj = 'data/body_models/smpl_uv_20200910/smpl_uv.obj'
|
| 16 |
+
|
| 17 |
+
uv_map = []
|
| 18 |
+
with open(uv_obj) as f:
|
| 19 |
+
for line in f.readlines():
|
| 20 |
+
if line.startswith('vt'):
|
| 21 |
+
coords = [float(x) for x in line.split(' ')[1:]]
|
| 22 |
+
uv_map.append(coords)
|
| 23 |
+
|
| 24 |
+
uv_map = np.array(uv_map)
|
| 25 |
+
|
| 26 |
+
return uv_map
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def show_uv_texture():
|
| 30 |
+
# image = io.imread('data/body_models/smpl_uv_20200910/smpl_uv_20200910.png')
|
| 31 |
+
image = make_color_wheel_image(1024, 1024)
|
| 32 |
+
image = Image.fromarray(image)
|
| 33 |
+
|
| 34 |
+
uv = np.load('data/body_models/smpl_uv_20200910/uv_table.npy') # get_smpl_uv()
|
| 35 |
+
material = trimesh.visual.texture.SimpleMaterial(image=image)
|
| 36 |
+
tex_visuals = trimesh.visual.TextureVisuals(uv=uv, image=image, material=material)
|
| 37 |
+
|
| 38 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
| 39 |
+
|
| 40 |
+
faces = smpl.faces
|
| 41 |
+
verts = smpl().vertices[0].detach().numpy()
|
| 42 |
+
|
| 43 |
+
# assert(len(uv) == len(verts))
|
| 44 |
+
print(uv.shape)
|
| 45 |
+
vc = tex_visuals.to_color().vertex_colors
|
| 46 |
+
fc = trimesh.visual.color.vertex_to_face_color(vc, faces)
|
| 47 |
+
face_colors = fc.copy()
|
| 48 |
+
fc = fc.astype(float)
|
| 49 |
+
vc = vc.astype(float)
|
| 50 |
+
fc[:,:3] = fc[:,:3] / 255.
|
| 51 |
+
vc[:,:3] = vc[:,:3] / 255.
|
| 52 |
+
print(fc[:,:3].max(), fc[:,:3].min(), fc[:,:3].mean())
|
| 53 |
+
print(vc[:, :3].max(), vc[:, :3].min(), vc[:, :3].mean())
|
| 54 |
+
np.save('data/body_models/smpl/color_wheel_face_colors.npy', fc)
|
| 55 |
+
np.save('data/body_models/smpl/color_wheel_vertex_colors.npy', vc)
|
| 56 |
+
print(fc.shape)
|
| 57 |
+
mesh = trimesh.Trimesh(verts, faces, validate=True, process=False, face_colors=face_colors)
|
| 58 |
+
# mesh = trimesh.load('data/body_models/smpl_uv_20200910/smpl_uv.obj', process=False)
|
| 59 |
+
# mesh.visual = tex_visuals
|
| 60 |
+
|
| 61 |
+
# import ipdb; ipdb.set_trace()
|
| 62 |
+
# print(vc.shape)
|
| 63 |
+
mesh.show()
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def show_colored_mesh():
|
| 67 |
+
cm = mpl_cm.get_cmap('jet')
|
| 68 |
+
norm_gt = mpl_colors.Normalize()
|
| 69 |
+
|
| 70 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
| 71 |
+
|
| 72 |
+
faces = smpl.faces
|
| 73 |
+
verts = smpl().vertices[0].detach().numpy()
|
| 74 |
+
|
| 75 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
| 76 |
+
|
| 77 |
+
mode = 1
|
| 78 |
+
if mode == 0:
|
| 79 |
+
# mano_segm_labels = m.triangles_center
|
| 80 |
+
face_labels = m.triangles_center
|
| 81 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
| 82 |
+
|
| 83 |
+
elif mode == 1:
|
| 84 |
+
# print(face_labels.shape)
|
| 85 |
+
face_labels = m.triangles_center
|
| 86 |
+
face_labels = np.argsort(np.linalg.norm(face_labels, axis=-1))
|
| 87 |
+
face_colors = np.ones((13776, 4))
|
| 88 |
+
face_colors[:, 3] = 1.0
|
| 89 |
+
face_colors[:, :3] = cm(norm_gt(face_labels))[:, :3]
|
| 90 |
+
elif mode == 2:
|
| 91 |
+
# breakpoint()
|
| 92 |
+
fc = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')[0, :, 0, 0, 0, :]
|
| 93 |
+
face_colors = np.ones((13776, 4))
|
| 94 |
+
face_colors[:, :3] = fc
|
| 95 |
+
mesh = trimesh.Trimesh(verts, faces, process=False, face_colors=face_colors)
|
| 96 |
+
mesh.show()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def get_tenet_texture(mode='smplpix'):
|
| 100 |
+
# mode = 'smplpix', 'decomr'
|
| 101 |
+
|
| 102 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
| 103 |
+
|
| 104 |
+
faces = smpl.faces
|
| 105 |
+
verts = smpl().vertices[0].detach().numpy()
|
| 106 |
+
|
| 107 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
| 108 |
+
if mode == 'smplpix':
|
| 109 |
+
# mano_segm_labels = m.triangles_center
|
| 110 |
+
face_labels = m.triangles_center
|
| 111 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
| 112 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
| 113 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
| 114 |
+
texture = torch.from_numpy(texture).float()
|
| 115 |
+
elif mode == 'decomr':
|
| 116 |
+
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
| 117 |
+
texture = torch.from_numpy(texture).float()
|
| 118 |
+
elif mode == 'colorwheel':
|
| 119 |
+
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
| 120 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
| 121 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
| 122 |
+
texture = torch.from_numpy(texture).float()
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f'{mode} is not defined!')
|
| 125 |
+
|
| 126 |
+
return texture
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def save_tenet_textures(mode='smplpix'):
|
| 130 |
+
# mode = 'smplpix', 'decomr'
|
| 131 |
+
|
| 132 |
+
smpl = SMPL(constants.SMPL_MODEL_DIR)
|
| 133 |
+
|
| 134 |
+
faces = smpl.faces
|
| 135 |
+
verts = smpl().vertices[0].detach().numpy()
|
| 136 |
+
|
| 137 |
+
m = trimesh.Trimesh(verts, faces, process=False)
|
| 138 |
+
|
| 139 |
+
if mode == 'smplpix':
|
| 140 |
+
# mano_segm_labels = m.triangles_center
|
| 141 |
+
face_labels = m.triangles_center
|
| 142 |
+
face_colors = (face_labels - face_labels.min()) / np.ptp(face_labels)
|
| 143 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
| 144 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
| 145 |
+
texture = torch.from_numpy(texture).float()
|
| 146 |
+
|
| 147 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
| 148 |
+
|
| 149 |
+
elif mode == 'decomr':
|
| 150 |
+
texture = np.load('data/body_models/smpl_uv_20200910/data/vertex_texture.npy')
|
| 151 |
+
texture = torch.from_numpy(texture).float()
|
| 152 |
+
face_colors = texture[0, :, 0, 0, 0, :]
|
| 153 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
| 154 |
+
|
| 155 |
+
elif mode == 'colorwheel':
|
| 156 |
+
face_colors = np.load('data/body_models/smpl/color_wheel_face_colors.npy')
|
| 157 |
+
texture = np.zeros((1, faces.shape[0], 1, 1, 1, 3), dtype=np.float32)
|
| 158 |
+
texture[0, :, 0, 0, 0, :] = face_colors[:, :3]
|
| 159 |
+
texture = torch.from_numpy(texture).float()
|
| 160 |
+
face_colors[:, :3] *= 255
|
| 161 |
+
vert_colors = face_to_vertex_color(m, face_colors).astype(float) / 255.0
|
| 162 |
+
else:
|
| 163 |
+
raise ValueError(f'{mode} is not defined!')
|
| 164 |
+
|
| 165 |
+
print(vert_colors.shape, vert_colors.max())
|
| 166 |
+
np.save(f'data/body_models/smpl/{mode}_vertex_colors.npy', vert_colors)
|
| 167 |
+
return texture
|