Rignet / RigNet /mst_generate.py
ckc99u's picture
Upload 116 files
3b8bcb1 verified
#-------------------------------------------------------------------------------
# Name: mst_generate.py
# Purpose: Generate skeleton as a tree based on predicted joints.
# RigNet Copyright 2020 University of Massachusetts
# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
#-------------------------------------------------------------------------------
import os
import cv2
import argparse
import numpy as np
import open3d as o3d
from utils import binvox_rw
from utils.tree_utils import TreeNode
from utils.rig_parser import Skel
from utils.vis_utils import show_obj_skel, draw_shifted_pts
from utils.io_utils import readPly
from utils.cluster_utils import meanshift_cluster, nms_meanshift
from utils.mst_utils import primMST_symmetry, loadSkel_recur, increase_cost_for_outside_bone, flip, inside_check, sample_on_bone
from gen_dataset import get_geo_edges, get_tpl_edges
from geometric_proc.common_ops import calc_surface_geodesic
import torch
from torch_geometric.data import Data
from torch_geometric.utils import add_self_loops
from models.ROOT_GCN import ROOTNET
from models.PairCls_GCN import PairCls
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def predict_joints(model_id, args):
"""
predict joints for a specified model
:param model_id: processed model ID number
:param args:
:return: predicted joints, and voxelized mesh
"""
vox_folder = os.path.join(args.dataset_folder, 'vox/')
mesh_folder = os.path.join(args.dataset_folder, 'obj_remesh/')
raw_pred = os.path.join(args.res_folder, '{:d}.ply'.format(model_id))
vox_file = os.path.join(vox_folder, '{:d}.binvox'.format(model_id))
mesh_file = os.path.join(mesh_folder, '{:d}.obj'.format(model_id))
pred_attn = np.load(os.path.join(args.res_folder, '{:d}_attn.npy'.format(model_id)))
with open(vox_file, 'rb') as fvox:
vox = binvox_rw.read_as_3d_array(fvox)
pred_joints = readPly(raw_pred)
pred_joints, index_inside = inside_check(pred_joints, vox)
pred_attn = pred_attn[index_inside, :]
# img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)
bandwidth = np.load(os.path.join(args.res_folder, '{:d}_bandwidth.npy'.format(model_id)))
bandwidth = bandwidth[0]
pred_joints = pred_joints[pred_attn.squeeze() > 1e-3]
pred_attn = pred_attn[pred_attn.squeeze() > 1e-3]
# reflect raw points
pred_joints_reflect = pred_joints * np.array([[-1, 1, 1]])
pred_joints = np.concatenate((pred_joints, pred_joints_reflect), axis=0)
pred_attn = np.tile(pred_attn, (2, 1))
# img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)
# cv2.imwrite(os.path.join(res_folder, '{:s}_raw.jpg'.format(model_id)), img[:, :, ::-1])
pred_joints = meanshift_cluster(pred_joints, bandwidth, pred_attn, max_iter=20)
Y_dist = np.sum(((pred_joints[np.newaxis, ...] - pred_joints[:, np.newaxis, :]) ** 2), axis=2)
density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
# density = density * pred_attn
density = np.sum(density, axis=0)
density_sum = np.sum(density)
pred_joints_ = pred_joints[density / density_sum > args.threshold_best]
density_ = density[density / density_sum > args.threshold_best]
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
pred_joints_, _ = flip(pred_joints_)
reduce_threshold = args.threshold_best
while len(pred_joints_) < 2 and reduce_threshold > 1e-7:
# print('reducing')
reduce_threshold = reduce_threshold / 1.3
pred_joints_ = pred_joints[density / density_sum >= reduce_threshold]
density_ = density[density / density_sum > reduce_threshold]
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
pred_joints_, _ = flip(pred_joints_)
if reduce_threshold <= 1e-7:
pred_joints_ = nms_meanshift(pred_joints_, density, bandwidth)
pred_joints_, _ = flip(pred_joints_)
pred_joints = pred_joints_
# img = draw_shifted_pts(mesh_file, pred_joints)
# cv2.imwrite(os.path.join(res_folder, '{:d}_joint.jpg'.format(model_id)), img)
# np.save(os.path.join(res_folder, '{:d}_joint.npy'.format(model_id)), pred_joints)
return pred_joints, vox
def getInitId(data, model):
"""
predict root joint ID via rootnet
:param data:
:param model:
:return:
"""
with torch.no_grad():
root_prob, _ = model(data, shuffle=False)
root_prob = torch.sigmoid(root_prob).data.cpu().numpy()
root_id = np.argmax(root_prob)
return root_id
def create_single_data(mesh, vox, surface_geodesic, pred_joints):
"""
create data used as input to networks, wrapped by Data structure in pytorch-gemetric library
:param mesh: input mesh loaded by open3d
:param vox: voxelized mesh
:param surface_geodesic: geodesic distance matrix of all vertices
:param pred_joints: predicted joints
:return: wrapped data structure
"""
mesh_v = np.asarray(mesh.vertices)
mesh_vn = np.asarray(mesh.vertex_normals)
mesh_f = np.asarray(mesh.triangles)
# vertices
v = np.concatenate((mesh_v, mesh_vn), axis=1)
v = torch.from_numpy(v).float()
# topology edges
print(" gathering topological edges.")
tpl_e = get_tpl_edges(mesh_v, mesh_f).T
tpl_e = torch.from_numpy(tpl_e).long()
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
# geodesic edges
print(" gathering geodesic edges.")
geo_e = get_geo_edges(surface_geodesic, mesh_v).T
geo_e = torch.from_numpy(geo_e).long()
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
batch = np.zeros(len(v))
batch = torch.from_numpy(batch).long()
pair_all = []
for joint1_id in range(len(pred_joints)):
for joint2_id in range(joint1_id + 1, len(pred_joints)):
dist = np.linalg.norm(pred_joints[joint1_id] - pred_joints[joint2_id])
bone_samples = sample_on_bone(pred_joints[joint1_id], pred_joints[joint2_id])
bone_samples_inside, _ = inside_check(bone_samples, vox)
outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
pair = np.array([joint1_id, joint2_id, dist, outside_proportion, 1])
pair_all.append(pair)
pair_all = np.array(pair_all)
pair_all = torch.from_numpy(pair_all).float()
num_pair = len(pair_all)
num_joint = len(pred_joints)
if len(pred_joints) < len(mesh_v):
pred_joints = np.tile(pred_joints, (round(1.0 * len(mesh_v) / len(pred_joints) + 0.5), 1))
pred_joints = pred_joints[:len(mesh_v), :]
elif len(pred_joints) > len(mesh_v):
pred_joints = pred_joints[:len(mesh_v), :]
pred_joints = torch.from_numpy(pred_joints).float()
data = Data(x=torch.from_numpy(mesh_vn), pos=torch.from_numpy(mesh_v).float(), batch=batch, y=pred_joints,
pairs=pair_all, num_pair=[num_pair], tpl_edge_index=tpl_e, geo_edge_index=geo_e, num_joint=[num_joint]).to(device)
return data
def run_mst_generate(args):
"""
generate skeleton in batch
:param args: input folder path and data folder path
"""
test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'), dtype=np.int)
root_select_model = ROOTNET()
root_select_model.to(device)
root_select_model.eval()
root_checkpoint = torch.load(args.rootnet)
root_select_model.load_state_dict(root_checkpoint['state_dict'])
connectivity_model = PairCls()
connectivity_model.to(device)
connectivity_model.eval()
conn_checkpoint = torch.load(args.bonenet)
connectivity_model.load_state_dict(conn_checkpoint['state_dict'])
for model_id in test_list:
print(model_id)
pred_joints, vox = predict_joints(model_id, args)
mesh_filename = os.path.join(args.dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id))
mesh = o3d.io.read_triangle_mesh(mesh_filename)
surface_geodesic = calc_surface_geodesic(mesh)
data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
root_id = getInitId(data, root_select_model)
with torch.no_grad():
cost_matrix, _ = connectivity_model.forward(data)
connect_prob = torch.sigmoid(cost_matrix)
pair_idx = data.pairs.long().data.cpu().numpy()
cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
cost_matrix = cost_matrix + cost_matrix.transpose()
cost_matrix = -np.log(cost_matrix+1e-10)
#cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
skel = Skel()
parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
for i in range(len(parent)):
if parent[i] == -1:
skel.root = TreeNode('root', tuple(pred_joints[i]))
break
loadSkel_recur(skel.root, i, None, pred_joints, parent)
img = show_obj_skel(mesh_filename, skel.root)
cv2.imwrite(os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)), img[:,:,::-1])
skel.save(os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='')
parser.add_argument('--dataset_folder', default='/media/zhanxu/4T1/ModelResource_RigNetv1_preproccessed/', type=str)
parser.add_argument('--res_folder', default='results/gcn_meanshift/best_25/', type=str)
parser.add_argument('--rootnet', default='checkpoints/rootnet/model_best.pth.tar', type=str)
parser.add_argument('--bonenet', default='checkpoints/bonenet/model_best.pth.tar', type=str)
parser.add_argument('--threshold_best', default=1e-5, type=float)
args = parser.parse_args()
print(args)
run_mst_generate(args)