File size: 10,376 Bytes
3b8bcb1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 | #-------------------------------------------------------------------------------
# Name: mst_generate.py
# Purpose: Generate skeleton as a tree based on predicted joints.
# RigNet Copyright 2020 University of Massachusetts
# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
#-------------------------------------------------------------------------------
import os
import cv2
import argparse
import numpy as np
import open3d as o3d
from utils import binvox_rw
from utils.tree_utils import TreeNode
from utils.rig_parser import Skel
from utils.vis_utils import show_obj_skel, draw_shifted_pts
from utils.io_utils import readPly
from utils.cluster_utils import meanshift_cluster, nms_meanshift
from utils.mst_utils import primMST_symmetry, loadSkel_recur, increase_cost_for_outside_bone, flip, inside_check, sample_on_bone
from gen_dataset import get_geo_edges, get_tpl_edges
from geometric_proc.common_ops import calc_surface_geodesic
import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from models.ROOT_GCN import ROOTNET
from models.PairCls_GCN import PairCls
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def predict_joints(model_id, args):
"""
predict joints for a specified model
:param model_id: processed model ID number
:param args:
:return: predicted joints, and voxelized mesh
"""
vox_folder = os.path.join(args.dataset_folder, 'vox/')
mesh_folder = os.path.join(args.dataset_folder, 'obj_remesh/')
raw_pred = os.path.join(args.res_folder, '{:d}.ply'.format(model_id))
vox_file = os.path.join(vox_folder, '{:d}.binvox'.format(model_id))
mesh_file = os.path.join(mesh_folder, '{:d}.obj'.format(model_id))
pred_attn = np.load(os.path.join(args.res_folder, '{:d}_attn.npy'.format(model_id)))
with open(vox_file, 'rb') as fvox:
vox = binvox_rw.read_as_3d_array(fvox)
pred_joints = readPly(raw_pred)
pred_joints, index_inside = inside_check(pred_joints, vox)
pred_attn = pred_attn[index_inside, :]
# img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)
bandwidth = np.load(os.path.join(args.res_folder, '{:d}_bandwidth.npy'.format(model_id)))
bandwidth = bandwidth[0]
pred_joints = pred_joints[pred_attn.squeeze() > 1e-3]
pred_attn = pred_attn[pred_attn.squeeze() > 1e-3]
# reflect raw points
pred_joints_reflect = pred_joints * np.array([[-1, 1, 1]])
pred_joints = np.concatenate((pred_joints, pred_joints_reflect), axis=0)
pred_attn = np.tile(pred_attn, (2, 1))
# img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)
# cv2.imwrite(os.path.join(res_folder, '{:s}_raw.jpg'.format(model_id)), img[:, :, ::-1])
pred_joints = meanshift_cluster(pred_joints, bandwidth, pred_attn, max_iter=20)
Y_dist = np.sum(((pred_joints[np.newaxis, ...] - pred_joints[:, np.newaxis, :]) ** 2), axis=2)
density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
# density = density * pred_attn
density = np.sum(density, axis=0)
density_sum = np.sum(density)
pred_joints_ = pred_joints[density / density_sum > args.threshold_best]
density_ = density[density / density_sum > args.threshold_best]
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
pred_joints_, _ = flip(pred_joints_)
reduce_threshold = args.threshold_best
while len(pred_joints_) < 2 and reduce_threshold > 1e-7:
# print('reducing')
reduce_threshold = reduce_threshold / 1.3
pred_joints_ = pred_joints[density / density_sum >= reduce_threshold]
density_ = density[density / density_sum > reduce_threshold]
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
pred_joints_, _ = flip(pred_joints_)
if reduce_threshold <= 1e-7:
pred_joints_ = nms_meanshift(pred_joints_, density, bandwidth)
pred_joints_, _ = flip(pred_joints_)
pred_joints = pred_joints_
# img = draw_shifted_pts(mesh_file, pred_joints)
# cv2.imwrite(os.path.join(res_folder, '{:d}_joint.jpg'.format(model_id)), img)
# np.save(os.path.join(res_folder, '{:d}_joint.npy'.format(model_id)), pred_joints)
return pred_joints, vox
def getInitId(data, model):
"""
predict root joint ID via rootnet
:param data:
:param model:
:return:
"""
with torch.no_grad():
root_prob, _ = model(data, shuffle=False)
root_prob = torch.sigmoid(root_prob).data.cpu().numpy()
root_id = np.argmax(root_prob)
return root_id
def create_single_data(mesh, vox, surface_geodesic, pred_joints):
"""
create data used as input to networks, wrapped by Data structure in pytorch-gemetric library
:param mesh: input mesh loaded by open3d
:param vox: voxelized mesh
:param surface_geodesic: geodesic distance matrix of all vertices
:param pred_joints: predicted joints
:return: wrapped data structure
"""
mesh_v = np.asarray(mesh.vertices)
mesh_vn = np.asarray(mesh.vertex_normals)
mesh_f = np.asarray(mesh.triangles)
# vertices
v = np.concatenate((mesh_v, mesh_vn), axis=1)
v = torch.from_numpy(v).float()
# topology edges
print(" gathering topological edges.")
tpl_e = get_tpl_edges(mesh_v, mesh_f).T
tpl_e = torch.from_numpy(tpl_e).long()
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
# geodesic edges
print(" gathering geodesic edges.")
geo_e = get_geo_edges(surface_geodesic, mesh_v).T
geo_e = torch.from_numpy(geo_e).long()
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
batch = np.zeros(len(v))
batch = torch.from_numpy(batch).long()
pair_all = []
for joint1_id in range(len(pred_joints)):
for joint2_id in range(joint1_id + 1, len(pred_joints)):
dist = np.linalg.norm(pred_joints[joint1_id] - pred_joints[joint2_id])
bone_samples = sample_on_bone(pred_joints[joint1_id], pred_joints[joint2_id])
bone_samples_inside, _ = inside_check(bone_samples, vox)
outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
pair = np.array([joint1_id, joint2_id, dist, outside_proportion, 1])
pair_all.append(pair)
pair_all = np.array(pair_all)
pair_all = torch.from_numpy(pair_all).float()
num_pair = len(pair_all)
num_joint = len(pred_joints)
if len(pred_joints) < len(mesh_v):
pred_joints = np.tile(pred_joints, (round(1.0 * len(mesh_v) / len(pred_joints) + 0.5), 1))
pred_joints = pred_joints[:len(mesh_v), :]
elif len(pred_joints) > len(mesh_v):
pred_joints = pred_joints[:len(mesh_v), :]
pred_joints = torch.from_numpy(pred_joints).float()
data = Data(x=torch.from_numpy(mesh_vn), pos=torch.from_numpy(mesh_v).float(), batch=batch, y=pred_joints,
pairs=pair_all, num_pair=[num_pair], tpl_edge_index=tpl_e, geo_edge_index=geo_e, num_joint=[num_joint]).to(device)
return data
def run_mst_generate(args):
"""
generate skeleton in batch
:param args: input folder path and data folder path
"""
test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'), dtype=np.int)
root_select_model = ROOTNET()
root_select_model.to(device)
root_select_model.eval()
root_checkpoint = torch.load(args.rootnet)
root_select_model.load_state_dict(root_checkpoint['state_dict'])
connectivity_model = PairCls()
connectivity_model.to(device)
connectivity_model.eval()
conn_checkpoint = torch.load(args.bonenet)
connectivity_model.load_state_dict(conn_checkpoint['state_dict'])
for model_id in test_list:
print(model_id)
pred_joints, vox = predict_joints(model_id, args)
mesh_filename = os.path.join(args.dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id))
mesh = o3d.io.read_triangle_mesh(mesh_filename)
surface_geodesic = calc_surface_geodesic(mesh)
data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
root_id = getInitId(data, root_select_model)
with torch.no_grad():
cost_matrix, _ = connectivity_model.forward(data)
connect_prob = torch.sigmoid(cost_matrix)
pair_idx = data.pairs.long().data.cpu().numpy()
cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
cost_matrix = cost_matrix + cost_matrix.transpose()
cost_matrix = -np.log(cost_matrix+1e-10)
#cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
skel = Skel()
parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
for i in range(len(parent)):
if parent[i] == -1:
skel.root = TreeNode('root', tuple(pred_joints[i]))
break
loadSkel_recur(skel.root, i, None, pred_joints, parent)
img = show_obj_skel(mesh_filename, skel.root)
cv2.imwrite(os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)), img[:,:,::-1])
skel.save(os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_folder', default='/media/zhanxu/4T1/ModelResource_RigNetv1_preproccessed/', type=str)
parser.add_argument('--res_folder', default='results/gcn_meanshift/best_25/', type=str)
parser.add_argument('--rootnet', default='checkpoints/rootnet/model_best.pth.tar', type=str)
parser.add_argument('--bonenet', default='checkpoints/bonenet/model_best.pth.tar', type=str)
parser.add_argument('--threshold_best', default=1e-5, type=float)
args = parser.parse_args()
print(args)
run_mst_generate(args)
|