|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
import torch
|
|
|
import numpy as np
|
|
|
import glob
|
|
|
import itertools as it
|
|
|
from utils import binvox_rw
|
|
|
from torch_geometric.data import Data, InMemoryDataset
|
|
|
from torch_geometric.utils import add_self_loops
|
|
|
|
|
|
|
|
|
class SkeletonData(Data):
|
|
|
def __init__(self, x=None, pos=None, name=None, mask=None, joints=None,
|
|
|
tpl_edge_index=None, geo_edge_index=None, pairs=None, pair_attr=None):
|
|
|
super(SkeletonData, self).__init__()
|
|
|
self.x = x
|
|
|
self.pos = pos
|
|
|
self.name = name
|
|
|
self.mask = mask
|
|
|
self.joints = joints
|
|
|
self.tpl_edge_index = tpl_edge_index
|
|
|
self.geo_edge_index = geo_edge_index
|
|
|
self.pairs = pairs
|
|
|
self.pair_attr = pair_attr
|
|
|
|
|
|
def __inc__(self, key, value):
|
|
|
if key == 'pairs':
|
|
|
return self.joints.size(0)
|
|
|
else:
|
|
|
return super(SkeletonData, self).__inc__(key, value)
|
|
|
|
|
|
|
|
|
class GraphDataset(InMemoryDataset):
|
|
|
def __init__(self, root):
|
|
|
super(GraphDataset, self).__init__(root)
|
|
|
self.data, self.slices = torch.load(self.processed_paths[0])
|
|
|
|
|
|
@property
|
|
|
def raw_file_names(self):
|
|
|
raw_v_filelist = glob.glob(os.path.join(self.root, '*_v.txt'))
|
|
|
return raw_v_filelist
|
|
|
|
|
|
@property
|
|
|
def processed_file_names(self):
|
|
|
return '{:s}_skeleton_data.pt'.format(self.root.split('/')[-1])
|
|
|
|
|
|
def __len__(self):
|
|
|
return len(self.raw_paths)
|
|
|
|
|
|
def download(self):
|
|
|
pass
|
|
|
|
|
|
def sample_on_bone(self, p_pos, ch_pos):
|
|
|
ray = ch_pos - p_pos
|
|
|
bone_length = np.sqrt(np.sum((p_pos - ch_pos) ** 2))
|
|
|
num_step = np.round(bone_length / 0.01)
|
|
|
i_step = np.arange(1, num_step + 1)
|
|
|
unit_step = ray / (num_step + 1e-30)
|
|
|
unit_step = np.repeat(unit_step[np.newaxis, :], num_step, axis=0)
|
|
|
res = p_pos + unit_step * i_step[:, np.newaxis]
|
|
|
return res
|
|
|
|
|
|
def inside_check(self, pts, vox):
|
|
|
vc = (pts - vox.translate) / vox.scale * vox.dims[0]
|
|
|
vc = np.round(vc).astype(int)
|
|
|
ind1 = np.logical_and(np.all(vc >= 0, axis=1), np.all(vc < vox.dims[0], axis=1))
|
|
|
vc = np.clip(vc, 0, vox.dims[0]-1)
|
|
|
ind2 = vox.data[vc[:, 0], vc[:, 1], vc[:, 2]]
|
|
|
ind = np.logical_and(ind1, ind2)
|
|
|
pts = pts[ind]
|
|
|
return pts, np.argwhere(ind).squeeze()
|
|
|
|
|
|
def process(self):
|
|
|
data_list = []
|
|
|
i = 0.0
|
|
|
for v_filename in self.raw_paths:
|
|
|
print('preprecessing data complete: {:.4f}%'.format(100 * i / len(self.raw_paths)))
|
|
|
i += 1.0
|
|
|
v = np.loadtxt(v_filename)
|
|
|
m = np.loadtxt(v_filename.replace('_v.txt', '_attn.txt'))
|
|
|
tpl_e = np.loadtxt(v_filename.replace('_v.txt', '_tpl_e.txt')).T
|
|
|
geo_e = np.loadtxt(v_filename.replace('_v.txt', '_geo_e.txt')).T
|
|
|
joints = np.loadtxt(v_filename.replace('_v.txt', '_j.txt'))
|
|
|
adj = np.loadtxt(v_filename.replace('_v.txt', '_adj.txt'), dtype=np.uint8)
|
|
|
|
|
|
vox_file = v_filename.replace('_v.txt', '.binvox')
|
|
|
with open(vox_file, 'rb') as fvox:
|
|
|
vox = binvox_rw.read_as_3d_array(fvox)
|
|
|
pairs = list(it.combinations(range(adj.shape[0]), 2))
|
|
|
pair_attr = []
|
|
|
for pr in pairs:
|
|
|
dist = np.linalg.norm(joints[pr[0]] - joints[pr[1]])
|
|
|
bone_samples = self.sample_on_bone(joints[pr[0]], joints[pr[1]])
|
|
|
bone_samples_inside, _ = self.inside_check(bone_samples, vox)
|
|
|
outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10)
|
|
|
attr = np.array([dist, outside_proportion, adj[pr[0], pr[1]]])
|
|
|
pair_attr.append(attr)
|
|
|
pairs = np.array(pairs)
|
|
|
pair_attr = np.array(pair_attr)
|
|
|
name = int(v_filename.split('/')[-1].split('_')[0])
|
|
|
|
|
|
v = torch.from_numpy(v).float()
|
|
|
m = torch.from_numpy(m).long()
|
|
|
tpl_e = torch.from_numpy(tpl_e).long()
|
|
|
geo_e = torch.from_numpy(geo_e).long()
|
|
|
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0))
|
|
|
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0))
|
|
|
joints = torch.from_numpy(joints).float()
|
|
|
pairs = torch.from_numpy(pairs).float()
|
|
|
pair_attr = torch.from_numpy(pair_attr).float()
|
|
|
data_list.append(SkeletonData(x=v[:, 3:6], pos=v[:, 0:3], name=name, mask=m, joints=joints,
|
|
|
tpl_edge_index=tpl_e, geo_edge_index=geo_e, pairs=pairs, pair_attr=pair_attr))
|
|
|
data, slices = self.collate(data_list)
|
|
|
torch.save((data, slices), self.processed_paths[0])
|
|
|
|