|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import cv2
|
|
|
import argparse
|
|
|
import numpy as np
|
|
|
import open3d as o3d
|
|
|
from utils import binvox_rw
|
|
|
from utils.tree_utils import TreeNode
|
|
|
from utils.rig_parser import Skel
|
|
|
from utils.vis_utils import show_obj_skel, draw_shifted_pts
|
|
|
from utils.io_utils import readPly
|
|
|
from utils.cluster_utils import meanshift_cluster, nms_meanshift
|
|
|
from utils.mst_utils import primMST_symmetry, loadSkel_recur, increase_cost_for_outside_bone, flip, inside_check, sample_on_bone
|
|
|
from gen_dataset import get_geo_edges, get_tpl_edges
|
|
|
from geometric_proc.common_ops import calc_surface_geodesic
|
|
|
|
|
|
import torch
|
|
|
from torch_geometric.data import Data
|
|
|
from torch_geometric.utils import add_self_loops
|
|
|
|
|
|
from models.ROOT_GCN import ROOTNET
|
|
|
from models.PairCls_GCN import PairCls
|
|
|
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
def predict_joints(model_id, args):
|
|
|
"""
|
|
|
predict joints for a specified model
|
|
|
:param model_id: processed model ID number
|
|
|
:param args:
|
|
|
:return: predicted joints, and voxelized mesh
|
|
|
"""
|
|
|
vox_folder = os.path.join(args.dataset_folder, 'vox/')
|
|
|
mesh_folder = os.path.join(args.dataset_folder, 'obj_remesh/')
|
|
|
raw_pred = os.path.join(args.res_folder, '{:d}.ply'.format(model_id))
|
|
|
vox_file = os.path.join(vox_folder, '{:d}.binvox'.format(model_id))
|
|
|
mesh_file = os.path.join(mesh_folder, '{:d}.obj'.format(model_id))
|
|
|
pred_attn = np.load(os.path.join(args.res_folder, '{:d}_attn.npy'.format(model_id)))
|
|
|
|
|
|
with open(vox_file, 'rb') as fvox:
|
|
|
vox = binvox_rw.read_as_3d_array(fvox)
|
|
|
pred_joints = readPly(raw_pred)
|
|
|
pred_joints, index_inside = inside_check(pred_joints, vox)
|
|
|
pred_attn = pred_attn[index_inside, :]
|
|
|
|
|
|
|
|
|
bandwidth = np.load(os.path.join(args.res_folder, '{:d}_bandwidth.npy'.format(model_id)))
|
|
|
bandwidth = bandwidth[0]
|
|
|
pred_joints = pred_joints[pred_attn.squeeze() > 1e-3]
|
|
|
pred_attn = pred_attn[pred_attn.squeeze() > 1e-3]
|
|
|
|
|
|
|
|
|
pred_joints_reflect = pred_joints * np.array([[-1, 1, 1]])
|
|
|
pred_joints = np.concatenate((pred_joints, pred_joints_reflect), axis=0)
|
|
|
pred_attn = np.tile(pred_attn, (2, 1))
|
|
|
|
|
|
|
|
|
|
|
|
pred_joints = meanshift_cluster(pred_joints, bandwidth, pred_attn, max_iter=20)
|
|
|
Y_dist = np.sum(((pred_joints[np.newaxis, ...] - pred_joints[:, np.newaxis, :]) ** 2), axis=2)
|
|
|
density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
|
|
|
|
|
|
density = np.sum(density, axis=0)
|
|
|
density_sum = np.sum(density)
|
|
|
pred_joints_ = pred_joints[density / density_sum > args.threshold_best]
|
|
|
density_ = density[density / density_sum > args.threshold_best]
|
|
|
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
|
|
|
pred_joints_, _ = flip(pred_joints_)
|
|
|
|
|
|
reduce_threshold = args.threshold_best
|
|
|
while len(pred_joints_) < 2 and reduce_threshold > 1e-7:
|
|
|
|
|
|
reduce_threshold = reduce_threshold / 1.3
|
|
|
pred_joints_ = pred_joints[density / density_sum >= reduce_threshold]
|
|
|
density_ = density[density / density_sum > reduce_threshold]
|
|
|
pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
|
|
|
pred_joints_, _ = flip(pred_joints_)
|
|
|
if reduce_threshold <= 1e-7:
|
|
|
pred_joints_ = nms_meanshift(pred_joints_, density, bandwidth)
|
|
|
pred_joints_, _ = flip(pred_joints_)
|
|
|
|
|
|
pred_joints = pred_joints_
|
|
|
|
|
|
|
|
|
|
|
|
return pred_joints, vox
|
|
|
|
|
|
|
|
|
def getInitId(data, model):
|
|
|
"""
|
|
|
predict root joint ID via rootnet
|
|
|
:param data:
|
|
|
:param model:
|
|
|
:return:
|
|
|
"""
|
|
|
with torch.no_grad():
|
|
|
root_prob, _ = model(data, shuffle=False)
|
|
|
root_prob = torch.sigmoid(root_prob).data.cpu().numpy()
|
|
|
root_id = np.argmax(root_prob)
|
|
|
return root_id
|
|
|
|
|
|
|
|
|
def create_single_data(mesh, vox, surface_geodesic, pred_joints):
|
|
|
"""
|
|
|
create data used as input to networks, wrapped by Data structure in pytorch-gemetric library
|
|
|
:param mesh: input mesh loaded by open3d
|
|
|
:param vox: voxelized mesh
|
|
|
:param surface_geodesic: geodesic distance matrix of all vertices
|
|
|
:param pred_joints: predicted joints
|
|
|
:return: wrapped data structure
|
|
|
"""
|
|
|
mesh_v = np.asarray(mesh.vertices)
|
|
|
mesh_vn = np.asarray(mesh.vertex_normals)
|
|
|
mesh_f = np.asarray(mesh.triangles)
|
|
|
|
|
|
|
|
|
v = np.concatenate((mesh_v, mesh_vn), axis=1)
|
|
|
v = torch.from_numpy(v).float()
|
|
|
|
|
|
|
|
|
print(" gathering topological edges.")
|
|
|
tpl_e = get_tpl_edges(mesh_v, mesh_f).T
|
|
|
tpl_e = torch.from_numpy(tpl_e).long()
|
|
|
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
|
|
|
|
|
|
|
|
|
print(" gathering geodesic edges.")
|
|
|
geo_e = get_geo_edges(surface_geodesic, mesh_v).T
|
|
|
geo_e = torch.from_numpy(geo_e).long()
|
|
|
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
|
|
|
|
|
|
batch = np.zeros(len(v))
|
|
|
batch = torch.from_numpy(batch).long()
|
|
|
|
|
|
pair_all = []
|
|
|
for joint1_id in range(len(pred_joints)):
|
|
|
for joint2_id in range(joint1_id + 1, len(pred_joints)):
|
|
|
dist = np.linalg.norm(pred_joints[joint1_id] - pred_joints[joint2_id])
|
|
|
bone_samples = sample_on_bone(pred_joints[joint1_id], pred_joints[joint2_id])
|
|
|
bone_samples_inside, _ = inside_check(bone_samples, vox)
|
|
|
outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
|
|
|
pair = np.array([joint1_id, joint2_id, dist, outside_proportion, 1])
|
|
|
pair_all.append(pair)
|
|
|
pair_all = np.array(pair_all)
|
|
|
pair_all = torch.from_numpy(pair_all).float()
|
|
|
num_pair = len(pair_all)
|
|
|
num_joint = len(pred_joints)
|
|
|
if len(pred_joints) < len(mesh_v):
|
|
|
pred_joints = np.tile(pred_joints, (round(1.0 * len(mesh_v) / len(pred_joints) + 0.5), 1))
|
|
|
pred_joints = pred_joints[:len(mesh_v), :]
|
|
|
elif len(pred_joints) > len(mesh_v):
|
|
|
pred_joints = pred_joints[:len(mesh_v), :]
|
|
|
pred_joints = torch.from_numpy(pred_joints).float()
|
|
|
|
|
|
data = Data(x=torch.from_numpy(mesh_vn), pos=torch.from_numpy(mesh_v).float(), batch=batch, y=pred_joints,
|
|
|
pairs=pair_all, num_pair=[num_pair], tpl_edge_index=tpl_e, geo_edge_index=geo_e, num_joint=[num_joint]).to(device)
|
|
|
return data
|
|
|
|
|
|
|
|
|
def run_mst_generate(args):
|
|
|
"""
|
|
|
generate skeleton in batch
|
|
|
:param args: input folder path and data folder path
|
|
|
"""
|
|
|
test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'), dtype=np.int)
|
|
|
root_select_model = ROOTNET()
|
|
|
root_select_model.to(device)
|
|
|
root_select_model.eval()
|
|
|
root_checkpoint = torch.load(args.rootnet)
|
|
|
root_select_model.load_state_dict(root_checkpoint['state_dict'])
|
|
|
connectivity_model = PairCls()
|
|
|
connectivity_model.to(device)
|
|
|
connectivity_model.eval()
|
|
|
conn_checkpoint = torch.load(args.bonenet)
|
|
|
connectivity_model.load_state_dict(conn_checkpoint['state_dict'])
|
|
|
|
|
|
for model_id in test_list:
|
|
|
print(model_id)
|
|
|
pred_joints, vox = predict_joints(model_id, args)
|
|
|
mesh_filename = os.path.join(args.dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id))
|
|
|
mesh = o3d.io.read_triangle_mesh(mesh_filename)
|
|
|
surface_geodesic = calc_surface_geodesic(mesh)
|
|
|
data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
|
|
|
root_id = getInitId(data, root_select_model)
|
|
|
with torch.no_grad():
|
|
|
cost_matrix, _ = connectivity_model.forward(data)
|
|
|
connect_prob = torch.sigmoid(cost_matrix)
|
|
|
pair_idx = data.pairs.long().data.cpu().numpy()
|
|
|
cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
|
|
|
cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
|
|
|
cost_matrix = cost_matrix + cost_matrix.transpose()
|
|
|
cost_matrix = -np.log(cost_matrix+1e-10)
|
|
|
|
|
|
cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)
|
|
|
|
|
|
skel = Skel()
|
|
|
parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints)
|
|
|
for i in range(len(parent)):
|
|
|
if parent[i] == -1:
|
|
|
skel.root = TreeNode('root', tuple(pred_joints[i]))
|
|
|
break
|
|
|
loadSkel_recur(skel.root, i, None, pred_joints, parent)
|
|
|
img = show_obj_skel(mesh_filename, skel.root)
|
|
|
cv2.imwrite(os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)), img[:,:,::-1])
|
|
|
skel.save(os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
parser = argparse.ArgumentParser(description='')
|
|
|
parser.add_argument('--dataset_folder', default='/media/zhanxu/4T1/ModelResource_RigNetv1_preproccessed/', type=str)
|
|
|
parser.add_argument('--res_folder', default='results/gcn_meanshift/best_25/', type=str)
|
|
|
parser.add_argument('--rootnet', default='checkpoints/rootnet/model_best.pth.tar', type=str)
|
|
|
parser.add_argument('--bonenet', default='checkpoints/bonenet/model_best.pth.tar', type=str)
|
|
|
parser.add_argument('--threshold_best', default=1e-5, type=float)
|
|
|
args = parser.parse_args()
|
|
|
print(args)
|
|
|
run_mst_generate(args)
|
|
|
|