|
|
import torch |
|
|
import torch.nn as nn |
|
|
from pointnet2_ops import pointnet2_utils |
|
|
import contextlib |
|
|
import logging |
|
|
import torch.nn.functional as F |
|
|
|
|
|
def fps(data, number): |
|
|
''' |
|
|
data B N 3 |
|
|
number int |
|
|
''' |
|
|
fps_idx = pointnet2_utils.furthest_point_sample(data, number) |
|
|
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() |
|
|
return fps_data |
|
|
|
|
|
|
|
|
def index_points(points, idx): |
|
|
""" |
|
|
Input: |
|
|
points: input points data, [B, N, C] |
|
|
idx: sample index data, [B, S] |
|
|
Return: |
|
|
new_points:, indexed points data, [B, S, C] |
|
|
""" |
|
|
device = points.device |
|
|
B = points.shape[0] |
|
|
view_shape = list(idx.shape) |
|
|
view_shape[1:] = [1] * (len(view_shape) - 1) |
|
|
repeat_shape = list(idx.shape) |
|
|
repeat_shape[0] = 1 |
|
|
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) |
|
|
new_points = points[batch_indices, idx, :] |
|
|
return new_points |
|
|
|
|
|
|
|
|
|
|
|
def knn_point(nsample, xyz, new_xyz): |
|
|
""" |
|
|
Input: |
|
|
nsample: max sample number in local region |
|
|
xyz: all points, [B, N, C] |
|
|
new_xyz: query points, [B, S, C] |
|
|
Return: |
|
|
group_idx: grouped points index, [B, S, nsample] |
|
|
""" |
|
|
sqrdists = square_distance(new_xyz, xyz) |
|
|
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) |
|
|
return group_idx |
|
|
|
|
|
def square_distance(src, dst): |
|
|
""" |
|
|
Calculate Euclid distance between each two points. |
|
|
src^T * dst = xn * xm + yn * ym + zn * zm; |
|
|
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; |
|
|
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; |
|
|
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 |
|
|
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst |
|
|
Input: |
|
|
src: source points, [B, N, C] |
|
|
dst: target points, [B, M, C] |
|
|
Output: |
|
|
dist: per-point square distance, [B, N, M] |
|
|
""" |
|
|
B, N, _ = src.shape |
|
|
_, M, _ = dst.shape |
|
|
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) |
|
|
dist += torch.sum(src ** 2, -1).view(B, N, 1) |
|
|
dist += torch.sum(dst ** 2, -1).view(B, 1, M) |
|
|
return dist |
|
|
|
|
|
|
|
|
class PatchDropout(nn.Module): |
|
|
""" |
|
|
https://arxiv.org/abs/2212.00794 |
|
|
""" |
|
|
|
|
|
def __init__(self, prob, exclude_first_token=True): |
|
|
super().__init__() |
|
|
assert 0 <= prob < 1. |
|
|
self.prob = prob |
|
|
self.exclude_first_token = exclude_first_token |
|
|
logging.info("patch dropout prob is {}".format(prob)) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
|
|
|
|
|
|
if self.exclude_first_token: |
|
|
cls_tokens, x = x[:, :1], x[:, 1:] |
|
|
else: |
|
|
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) |
|
|
|
|
|
batch = x.size()[0] |
|
|
num_tokens = x.size()[1] |
|
|
|
|
|
batch_indices = torch.arange(batch) |
|
|
batch_indices = batch_indices[..., None] |
|
|
|
|
|
keep_prob = 1 - self.prob |
|
|
num_patches_keep = max(1, int(num_tokens * keep_prob)) |
|
|
|
|
|
rand = torch.randn(batch, num_tokens) |
|
|
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices |
|
|
|
|
|
x = x[batch_indices, patch_indices_keep] |
|
|
|
|
|
if self.exclude_first_token: |
|
|
x = torch.cat((cls_tokens, x), dim=1) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
class Group(nn.Module): |
|
|
def __init__(self, num_group, group_size): |
|
|
super().__init__() |
|
|
self.num_group = num_group |
|
|
self.group_size = group_size |
|
|
|
|
|
def forward(self, xyz, color): |
|
|
''' |
|
|
input: B N 3 |
|
|
--------------------------- |
|
|
output: B G M 3 |
|
|
center : B G 3 |
|
|
''' |
|
|
batch_size, num_points, _ = xyz.shape |
|
|
|
|
|
center = fps(xyz, self.num_group) |
|
|
|
|
|
|
|
|
idx = knn_point(self.group_size, xyz, center) |
|
|
assert idx.size(1) == self.num_group |
|
|
assert idx.size(2) == self.group_size |
|
|
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points |
|
|
idx = idx + idx_base |
|
|
idx = idx.view(-1) |
|
|
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] |
|
|
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() |
|
|
|
|
|
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :] |
|
|
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous() |
|
|
|
|
|
|
|
|
neighborhood = neighborhood - center.unsqueeze(2) |
|
|
|
|
|
features = torch.cat((neighborhood, neighborhood_color), dim=-1) |
|
|
return neighborhood, center, features |
|
|
|
|
|
class Encoder(nn.Module): |
|
|
def __init__(self, encoder_channel): |
|
|
super().__init__() |
|
|
self.encoder_channel = encoder_channel |
|
|
self.first_conv = nn.Sequential( |
|
|
nn.Conv1d(6, 128, 1), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv1d(128, 256, 1) |
|
|
) |
|
|
self.second_conv = nn.Sequential( |
|
|
nn.Conv1d(512, 512, 1), |
|
|
nn.BatchNorm1d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv1d(512, self.encoder_channel, 1) |
|
|
) |
|
|
def forward(self, point_groups): |
|
|
''' |
|
|
point_groups : B G N 3 |
|
|
----------------- |
|
|
feature_global : B G C |
|
|
''' |
|
|
bs, g, n , _ = point_groups.shape |
|
|
point_groups = point_groups.reshape(bs * g, n, 6) |
|
|
|
|
|
feature = self.first_conv(point_groups.transpose(2,1)) |
|
|
feature_global = torch.max(feature,dim=2,keepdim=True)[0] |
|
|
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1) |
|
|
feature = self.second_conv(feature) |
|
|
feature_global = torch.max(feature, dim=2, keepdim=False)[0] |
|
|
return feature_global.reshape(bs, g, self.encoder_channel) |
|
|
|
|
|
|
|
|
|
|
|
class skeleton_Group(nn.Module): |
|
|
def __init__(self, num_group=32, group_size=8): |
|
|
super().__init__() |
|
|
self.num_group = num_group |
|
|
self.group_size = group_size |
|
|
|
|
|
def forward(self, xyz, token_feat, num_group=32, group_size=8): |
|
|
''' |
|
|
xyz: 所有token的xyz |
|
|
|
|
|
input: B N 3 |
|
|
--------------------------- |
|
|
output: B G M 3 |
|
|
center : B G 3 |
|
|
''' |
|
|
self.num_group = num_group |
|
|
self.group_size = group_size |
|
|
|
|
|
batch_size, num_points, _ = xyz.shape |
|
|
_, _, C_ = token_feat.shape |
|
|
|
|
|
center = fps(xyz, self.num_group) |
|
|
|
|
|
|
|
|
idx = knn_point(self.group_size, xyz, center) |
|
|
|
|
|
assert idx.size(1) == self.num_group |
|
|
assert idx.size(2) == self.group_size |
|
|
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points |
|
|
idx = idx + idx_base |
|
|
idx = idx.view(-1) |
|
|
|
|
|
token_feat = token_feat.contiguous() |
|
|
neighborhood_token_feat = token_feat.view(batch_size * num_points, -1)[idx, :] |
|
|
|
|
|
neighborhood_token_feat = neighborhood_token_feat.view(batch_size, self.num_group, self.group_size, C_).contiguous() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
center_token, _ = torch.max(neighborhood_token_feat, dim=2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
center_token = center_token.unsqueeze(2) |
|
|
|
|
|
center_token = center_token.permute(0, 1, 3, 2) |
|
|
|
|
|
neighborhood_token_feat_ = torch.nn.functional.normalize(neighborhood_token_feat, p=2, dim=-1) |
|
|
center_token_ = torch.nn.functional.normalize(center_token, p=2, dim=-2) |
|
|
|
|
|
|
|
|
router_weights = torch.einsum('btnj,btjm->btnm', [neighborhood_token_feat_, center_token_]) |
|
|
|
|
|
router_weights = router_weights.permute(0, 1, 3, 2) |
|
|
|
|
|
router_weights = F.softmax(router_weights, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
skeleton_token = torch.einsum('bmok, bmkc->bmoc', [router_weights, neighborhood_token_feat]) |
|
|
skeleton_token = skeleton_token.squeeze(dim=-2) |
|
|
|
|
|
|
|
|
return skeleton_token |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PointcloudEncoder(nn.Module): |
|
|
def __init__(self, point_transformer, args): |
|
|
super().__init__() |
|
|
from easydict import EasyDict |
|
|
self.trans_dim = args.pc_feat_dim |
|
|
self.embed_dim = args.embed_dim |
|
|
self.group_size = args.group_size |
|
|
self.num_group = args.num_group |
|
|
|
|
|
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size) |
|
|
self.skeleton_Group = skeleton_Group() |
|
|
|
|
|
self.encoder_dim = args.pc_encoder_dim |
|
|
self.encoder = Encoder(encoder_channel = self.encoder_dim) |
|
|
|
|
|
|
|
|
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim) |
|
|
|
|
|
|
|
|
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim) |
|
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) |
|
|
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) |
|
|
|
|
|
self.pos_embed = nn.Sequential( |
|
|
nn.Linear(3, 128), |
|
|
nn.GELU(), |
|
|
nn.Linear(128, self.trans_dim) |
|
|
) |
|
|
|
|
|
self.patch_dropout = PatchDropout(args.patch_dropout) if args.patch_dropout > 0. else nn.Identity() |
|
|
self.visual = point_transformer |
|
|
|
|
|
self.float_flag =False |
|
|
|
|
|
|
|
|
def forward(self, pc, get_pc_tokens_way=None): |
|
|
|
|
|
pc = pc.to(dtype=torch.float) |
|
|
|
|
|
pts = pc[:,:,:3].contiguous() |
|
|
colors = pc[:,:,3:].contiguous() |
|
|
|
|
|
|
|
|
_, center, features = self.group_divider(pts, colors) |
|
|
|
|
|
|
|
|
|
|
|
group_input_tokens = self.encoder(features) |
|
|
|
|
|
group_input_tokens = self.encoder2trans(group_input_tokens) |
|
|
|
|
|
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) |
|
|
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) |
|
|
|
|
|
pos = self.pos_embed(center) |
|
|
|
|
|
x = torch.cat((cls_tokens, group_input_tokens), dim=1) |
|
|
pos = torch.cat((cls_pos, pos), dim=1) |
|
|
|
|
|
x = x + pos |
|
|
|
|
|
|
|
|
|
|
|
x = self.patch_dropout(x) |
|
|
|
|
|
x = self.visual.pos_drop(x) |
|
|
|
|
|
|
|
|
block_len = len(self.visual.blocks) |
|
|
for i, blk in enumerate(self.visual.blocks): |
|
|
x = blk(x) |
|
|
if block_len-2 == i: |
|
|
last_sec_x =x |
|
|
|
|
|
|
|
|
if get_pc_tokens_way=="CLS": |
|
|
x = self.visual.norm(x[:, 0, :]) |
|
|
x = self.visual.fc_norm(x) |
|
|
x = self.trans2embed(x) |
|
|
x = x.unsqueeze(1) |
|
|
|
|
|
elif get_pc_tokens_way=="OM_Pooling": |
|
|
|
|
|
|
|
|
pc_skeleton = self.skeleton_Group(center, last_sec_x[:, 1:]) |
|
|
|
|
|
x = torch.cat([x[:, 0].unsqueeze(1), x[:, 1:].max(1)[0].unsqueeze(1), x[:, 1:].mean(1).unsqueeze(1), torch.sum(x[:, 1:], dim=1).unsqueeze(1), pc_skeleton], dim=-2) |
|
|
x = self.visual.norm(x) |
|
|
x = self.visual.fc_norm(x) |
|
|
x = self.trans2embed(x) |
|
|
|
|
|
return x |