|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
from models.gcn_basic_modules import MLP, GCU
|
|
|
from torch_scatter import scatter_max
|
|
|
from torch.nn import Sequential, Linear
|
|
|
from torch_geometric.nn import PointConv, fps, radius, global_max_pool, knn_interpolate
|
|
|
__all__ = ['ROOTNET']
|
|
|
|
|
|
class SAModule(torch.nn.Module):
|
|
|
def __init__(self, ratio, r, nn):
|
|
|
super(SAModule, self).__init__()
|
|
|
self.ratio = ratio
|
|
|
self.r = r
|
|
|
self.conv = PointConv(nn)
|
|
|
|
|
|
def forward(self, x, pos, batch):
|
|
|
idx = fps(pos, batch, ratio=self.ratio)
|
|
|
row, col = radius(pos, pos[idx], self.r, batch, batch[idx],
|
|
|
max_num_neighbors=64)
|
|
|
edge_index = torch.stack([col, row], dim=0)
|
|
|
x = self.conv(x, (pos, pos[idx]), edge_index)
|
|
|
pos, batch = pos[idx], batch[idx]
|
|
|
return x, pos, batch
|
|
|
|
|
|
|
|
|
class GlobalSAModule(torch.nn.Module):
|
|
|
def __init__(self, nn):
|
|
|
super(GlobalSAModule, self).__init__()
|
|
|
self.nn = nn
|
|
|
|
|
|
def forward(self, x, pos, batch):
|
|
|
x = self.nn(torch.cat([x, pos], dim=1))
|
|
|
x = global_max_pool(x, batch)
|
|
|
pos = pos.new_zeros((x.size(0), 3))
|
|
|
batch = torch.arange(x.size(0), device=batch.device)
|
|
|
return x, pos, batch
|
|
|
|
|
|
|
|
|
class FPModule(torch.nn.Module):
|
|
|
def __init__(self, k, nn):
|
|
|
super(FPModule, self).__init__()
|
|
|
self.k = k
|
|
|
self.nn = nn
|
|
|
|
|
|
def forward(self, x, pos, batch, x_skip, pos_skip, batch_skip):
|
|
|
x = knn_interpolate(x, pos, pos_skip, batch, batch_skip, k=self.k)
|
|
|
if x_skip is not None:
|
|
|
x = torch.cat([x, x_skip], dim=1)
|
|
|
x = self.nn(x)
|
|
|
return x, pos_skip, batch_skip
|
|
|
|
|
|
|
|
|
class ShapeEncoder(torch.nn.Module):
|
|
|
def __init__(self, aggr='max'):
|
|
|
super(ShapeEncoder, self).__init__()
|
|
|
self.gcu_1 = GCU(in_channels=3, out_channels=64, aggr=aggr)
|
|
|
self.gcu_2 = GCU(in_channels=64, out_channels=128, aggr=aggr)
|
|
|
self.gcu_3 = GCU(in_channels=128, out_channels=256, aggr=aggr)
|
|
|
self.mlp_glb = MLP([(64 + 128 + 256), 128])
|
|
|
|
|
|
|
|
|
def forward(self, data):
|
|
|
x_1 = self.gcu_1(data.pos, data.tpl_edge_index, data.geo_edge_index)
|
|
|
x_2 = self.gcu_2(x_1, data.tpl_edge_index, data.geo_edge_index)
|
|
|
x_3 = self.gcu_3(x_2, data.tpl_edge_index, data.geo_edge_index)
|
|
|
x_4 = self.mlp_glb(torch.cat([x_1, x_2, x_3], dim=1))
|
|
|
x_global, _ = scatter_max(x_4, data.batch, dim=0)
|
|
|
return x_global
|
|
|
|
|
|
|
|
|
class JointEncoder(torch.nn.Module):
|
|
|
def __init__(self):
|
|
|
super(JointEncoder, self).__init__()
|
|
|
'''self.mlp_1 = MLP([4, 64])
|
|
|
self.mlp_2 = MLP([64, 128, 1024])
|
|
|
self.mlp_3 = MLP([1088, 512, 256, 128, 64])'''
|
|
|
self.sa1_joint = SAModule(0.999, 0.4, MLP([4, 64, 64, 128]))
|
|
|
self.sa2_joint = SAModule(0.33, 0.6, MLP([128 + 3, 128, 128, 256]))
|
|
|
self.sa3_joint = GlobalSAModule(MLP([256 + 3, 256, 256, 512]))
|
|
|
self.fp3_joint = FPModule(1, MLP([512 + 256, 256, 256]))
|
|
|
self.fp2_joint = FPModule(3, MLP([256 + 128, 128, 128]))
|
|
|
self.fp1_joint = FPModule(3, MLP([128 + 1, 128, 128]))
|
|
|
|
|
|
def forward(self, x, pos, batch):
|
|
|
'''x1= self.mlp_1(torch.cat((pos, x), dim=1))
|
|
|
x2 = self.mlp_2(x1)
|
|
|
x_glb, _ = scatter_max(x2, batch, dim=0)
|
|
|
x_glb = torch.repeat_interleave(x_glb, torch.bincount(batch), dim=0)
|
|
|
x3 = self.mlp_3(torch.cat((x_glb, x1), dim=1))
|
|
|
return x3'''
|
|
|
sa0_joint = (x, pos, batch)
|
|
|
sa1_joint = self.sa1_joint(*sa0_joint)
|
|
|
sa2_joint = self.sa2_joint(*sa1_joint)
|
|
|
sa3_joint = self.sa3_joint(*sa2_joint)
|
|
|
fp3_joint = self.fp3_joint(*sa3_joint, *sa2_joint)
|
|
|
fp2_joint = self.fp2_joint(*fp3_joint, *sa1_joint)
|
|
|
x_joint, _, _ = self.fp1_joint(*fp2_joint, *sa0_joint)
|
|
|
return x_joint
|
|
|
|
|
|
|
|
|
class ROOTNET(torch.nn.Module):
|
|
|
def __init__(self):
|
|
|
super(ROOTNET, self).__init__()
|
|
|
self.shape_encoder = ShapeEncoder()
|
|
|
self.joint_encoder = JointEncoder()
|
|
|
self.back_layers = Sequential(MLP([128 + 128, 200, 64]), Linear(64, 1))
|
|
|
|
|
|
def forward(self, data, shuffle=True):
|
|
|
joints_label = []
|
|
|
joints_shuffle = []
|
|
|
for i in range(len(torch.unique(data.joints_batch))):
|
|
|
joint_i = data.joints[data.joints_batch==i]
|
|
|
label_i = joint_i.new(torch.Size((joint_i.shape[0], 1))).zero_()
|
|
|
label_i[0, 0] = 1
|
|
|
|
|
|
if shuffle:
|
|
|
idx = torch.randperm(label_i.nelement())
|
|
|
label_i = label_i[idx]
|
|
|
joint_i = joint_i[idx]
|
|
|
joints_shuffle.append(joint_i)
|
|
|
joints_label.append(label_i)
|
|
|
joints_shuffle = torch.cat(joints_shuffle, dim=0)
|
|
|
joints_label = torch.cat(joints_label)
|
|
|
|
|
|
x_glb_shape = self.shape_encoder(data)
|
|
|
shape_feature = torch.repeat_interleave(x_glb_shape, torch.bincount(data.joints_batch), dim=0)
|
|
|
joint_feature = self.joint_encoder(torch.abs(joints_shuffle[:,0:1]), joints_shuffle, data.joints_batch)
|
|
|
x_joint = torch.cat([shape_feature, joint_feature], dim=1)
|
|
|
x_joint = self.back_layers(x_joint)
|
|
|
return x_joint, joints_label
|
|
|
|