YuanTang96's picture
1
b30c1d8
import torch
import torch.nn as nn
from pointnet2_ops import pointnet2_utils
import contextlib
import logging
import torch.nn.functional as F
def fps(data, number):
'''
data B N 3
number int
'''
fps_idx = pointnet2_utils.furthest_point_sample(data, number)
fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous()
return fps_data
def index_points(points, idx):
"""
Input:
points: input points data, [B, N, C]
idx: sample index data, [B, S]
Return:
new_points:, indexed points data, [B, S, C]
"""
device = points.device
B = points.shape[0]
view_shape = list(idx.shape)
view_shape[1:] = [1] * (len(view_shape) - 1)
repeat_shape = list(idx.shape)
repeat_shape[0] = 1
batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape)
new_points = points[batch_indices, idx, :]
return new_points
# https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py
def knn_point(nsample, xyz, new_xyz):
"""
Input:
nsample: max sample number in local region
xyz: all points, [B, N, C]
new_xyz: query points, [B, S, C]
Return:
group_idx: grouped points index, [B, S, nsample]
"""
sqrdists = square_distance(new_xyz, xyz)
_, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
return group_idx
def square_distance(src, dst):
"""
Calculate Euclid distance between each two points.
src^T * dst = xn * xm + yn * ym + zn * zm;
sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
= sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
Input:
src: source points, [B, N, C]
dst: target points, [B, M, C]
Output:
dist: per-point square distance, [B, N, M]
"""
B, N, _ = src.shape
_, M, _ = dst.shape
dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
dist += torch.sum(src ** 2, -1).view(B, N, 1)
dist += torch.sum(dst ** 2, -1).view(B, 1, M)
return dist
class PatchDropout(nn.Module):
"""
https://arxiv.org/abs/2212.00794
"""
def __init__(self, prob, exclude_first_token=True):
super().__init__()
assert 0 <= prob < 1.
self.prob = prob
self.exclude_first_token = exclude_first_token # exclude CLS token
logging.info("patch dropout prob is {}".format(prob))
def forward(self, x):
# if not self.training or self.prob == 0.:
# return x
if self.exclude_first_token:
cls_tokens, x = x[:, :1], x[:, 1:]
else:
cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
batch = x.size()[0]
num_tokens = x.size()[1]
batch_indices = torch.arange(batch)
batch_indices = batch_indices[..., None]
keep_prob = 1 - self.prob
num_patches_keep = max(1, int(num_tokens * keep_prob))
rand = torch.randn(batch, num_tokens)
patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
x = x[batch_indices, patch_indices_keep]
if self.exclude_first_token:
x = torch.cat((cls_tokens, x), dim=1)
return x
class Group(nn.Module):
def __init__(self, num_group, group_size):
super().__init__()
self.num_group = num_group
self.group_size = group_size
def forward(self, xyz, color):
'''
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
batch_size, num_points, _ = xyz.shape
# fps the centers out
center = fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
# _, idx = self.knn(xyz, center) # B G M
idx = knn_point(self.group_size, xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
neighborhood = xyz.view(batch_size * num_points, -1)[idx, :]
neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous()
neighborhood_color = color.view(batch_size * num_points, -1)[idx, :]
neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous()
# normalize
neighborhood = neighborhood - center.unsqueeze(2)
features = torch.cat((neighborhood, neighborhood_color), dim=-1)
return neighborhood, center, features
class Encoder(nn.Module):
def __init__(self, encoder_channel):
super().__init__()
self.encoder_channel = encoder_channel
self.first_conv = nn.Sequential(
nn.Conv1d(6, 128, 1),
nn.BatchNorm1d(128),
nn.ReLU(inplace=True),
nn.Conv1d(128, 256, 1)
)
self.second_conv = nn.Sequential(
nn.Conv1d(512, 512, 1),
nn.BatchNorm1d(512),
nn.ReLU(inplace=True),
nn.Conv1d(512, self.encoder_channel, 1)
)
def forward(self, point_groups):
'''
point_groups : B G N 3
-----------------
feature_global : B G C
'''
bs, g, n , _ = point_groups.shape
point_groups = point_groups.reshape(bs * g, n, 6)
# encoder
feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n
feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1
feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n
feature = self.second_conv(feature) # BG 1024 n
feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024
return feature_global.reshape(bs, g, self.encoder_channel)
class skeleton_Group(nn.Module):
def __init__(self, num_group=32, group_size=8):
super().__init__()
self.num_group = num_group
self.group_size = group_size
def forward(self, xyz, token_feat, num_group=32, group_size=8):
'''
xyz: 所有token的xyz
input: B N 3
---------------------------
output: B G M 3
center : B G 3
'''
self.num_group = num_group
self.group_size = group_size
batch_size, num_points, _ = xyz.shape
_, _, C_ = token_feat.shape
# fps the centers out
center = fps(xyz, self.num_group) # B G 3
# knn to get the neighborhood
# _, idx = self.knn(xyz, center) # B G M
idx = knn_point(self.group_size, xyz, center) # B G M
assert idx.size(1) == self.num_group
assert idx.size(2) == self.group_size
idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points
idx = idx + idx_base
idx = idx.view(-1)
token_feat = token_feat.contiguous()
neighborhood_token_feat = token_feat.view(batch_size * num_points, -1)[idx, :]
# T_p: B, 32, 8,384 -> B, M, K, C
neighborhood_token_feat = neighborhood_token_feat.view(batch_size, self.num_group, self.group_size, C_).contiguous()
# T_m: # B,32, 384 -> B, M, C
center_token, _ = torch.max(neighborhood_token_feat, dim=2)
# #### fix OM——pooling no resnet, no dim norm
# T_m: B,32,1, 384 -> B, M, 1, C
center_token = center_token.unsqueeze(2)
# T_m (T): B,32,384, 1 -> B, M, C, 1
center_token = center_token.permute(0, 1, 3, 2)
neighborhood_token_feat_ = torch.nn.functional.normalize(neighborhood_token_feat, p=2, dim=-1)
center_token_ = torch.nn.functional.normalize(center_token, p=2, dim=-2)
# A: B, 32, 8, 1 -> B, M, K, 1
router_weights = torch.einsum('btnj,btjm->btnm', [neighborhood_token_feat_, center_token_])
# A(T): B, 32, 1, 8 -> B, M, 1, K
router_weights = router_weights.permute(0, 1, 3, 2)
# A(T): B, 32, 1, 8 -> B, M, 1, K
router_weights = F.softmax(router_weights, dim=-1)
# T^p_pc: B, 32, 1, 384 -> B, M, 1, C
skeleton_token = torch.einsum('bmok, bmkc->bmoc', [router_weights, neighborhood_token_feat])
skeleton_token = skeleton_token.squeeze(dim=-2)
####
return skeleton_token
class PointcloudEncoder(nn.Module):
def __init__(self, point_transformer, args):
super().__init__()
from easydict import EasyDict
self.trans_dim = args.pc_feat_dim # 768
self.embed_dim = args.embed_dim # 512
self.group_size = args.group_size # 32
self.num_group = args.num_group # 512
# grouper
self.group_divider = Group(num_group = self.num_group, group_size = self.group_size)
self.skeleton_Group = skeleton_Group()
# define the encoder
self.encoder_dim = args.pc_encoder_dim # 256
self.encoder = Encoder(encoder_channel = self.encoder_dim)
# bridge encoder and transformer
self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim)
# bridge transformer and clip embedding
self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim)
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim))
self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim))
self.pos_embed = nn.Sequential(
nn.Linear(3, 128),
nn.GELU(),
nn.Linear(128, self.trans_dim)
)
# setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn
self.patch_dropout = PatchDropout(args.patch_dropout) if args.patch_dropout > 0. else nn.Identity()
self.visual = point_transformer
self.float_flag =False #
def forward(self, pc, get_pc_tokens_way=None):
pc = pc.to(dtype=torch.float)
pts = pc[:,:,:3].contiguous()
colors = pc[:,:,3:].contiguous()
# divide the point cloud in the same form. This is important
_, center, features = self.group_divider(pts, colors)
# encoder the input cloud patches
group_input_tokens = self.encoder(features) # B G N
group_input_tokens = self.encoder2trans(group_input_tokens)
# prepare cls
cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1)
cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1)
# add pos embedding
pos = self.pos_embed(center)
# final input
x = torch.cat((cls_tokens, group_input_tokens), dim=1)
pos = torch.cat((cls_pos, pos), dim=1)
# transformer
x = x + pos
# x = x.half()
# a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
x = self.patch_dropout(x)
x = self.visual.pos_drop(x)
# ModuleList not support forward
block_len = len(self.visual.blocks)
for i, blk in enumerate(self.visual.blocks):
x = blk(x)
if block_len-2 == i:
last_sec_x =x
if get_pc_tokens_way=="CLS": # CLS
x = self.visual.norm(x[:, 0, :])
x = self.visual.fc_norm(x)
x = self.trans2embed(x)
x = x.unsqueeze(1)
elif get_pc_tokens_way=="OM_Pooling":
pc_skeleton = self.skeleton_Group(center, last_sec_x[:, 1:])
x = torch.cat([x[:, 0].unsqueeze(1), x[:, 1:].max(1)[0].unsqueeze(1), x[:, 1:].mean(1).unsqueeze(1), torch.sum(x[:, 1:], dim=1).unsqueeze(1), pc_skeleton], dim=-2)
x = self.visual.norm(x)
x = self.visual.fc_norm(x)
x = self.trans2embed(x)
return x