import torch import torch.nn as nn from pointnet2_ops import pointnet2_utils import contextlib import logging import torch.nn.functional as F def fps(data, number): ''' data B N 3 number int ''' fps_idx = pointnet2_utils.furthest_point_sample(data, number) fps_data = pointnet2_utils.gather_operation(data.transpose(1, 2).contiguous(), fps_idx).transpose(1,2).contiguous() return fps_data def index_points(points, idx): """ Input: points: input points data, [B, N, C] idx: sample index data, [B, S] Return: new_points:, indexed points data, [B, S, C] """ device = points.device B = points.shape[0] view_shape = list(idx.shape) view_shape[1:] = [1] * (len(view_shape) - 1) repeat_shape = list(idx.shape) repeat_shape[0] = 1 batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) new_points = points[batch_indices, idx, :] return new_points # https://github.com/Strawberry-Eat-Mango/PCT_Pytorch/blob/main/util.py def knn_point(nsample, xyz, new_xyz): """ Input: nsample: max sample number in local region xyz: all points, [B, N, C] new_xyz: query points, [B, S, C] Return: group_idx: grouped points index, [B, S, nsample] """ sqrdists = square_distance(new_xyz, xyz) _, group_idx = torch.topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) return group_idx def square_distance(src, dst): """ Calculate Euclid distance between each two points. src^T * dst = xn * xm + yn * ym + zn * zm; sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst Input: src: source points, [B, N, C] dst: target points, [B, M, C] Output: dist: per-point square distance, [B, N, M] """ B, N, _ = src.shape _, M, _ = dst.shape dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) dist += torch.sum(src ** 2, -1).view(B, N, 1) dist += torch.sum(dst ** 2, -1).view(B, 1, M) return dist class PatchDropout(nn.Module): """ https://arxiv.org/abs/2212.00794 """ def __init__(self, prob, exclude_first_token=True): super().__init__() assert 0 <= prob < 1. self.prob = prob self.exclude_first_token = exclude_first_token # exclude CLS token logging.info("patch dropout prob is {}".format(prob)) def forward(self, x): # if not self.training or self.prob == 0.: # return x if self.exclude_first_token: cls_tokens, x = x[:, :1], x[:, 1:] else: cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1]) batch = x.size()[0] num_tokens = x.size()[1] batch_indices = torch.arange(batch) batch_indices = batch_indices[..., None] keep_prob = 1 - self.prob num_patches_keep = max(1, int(num_tokens * keep_prob)) rand = torch.randn(batch, num_tokens) patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices x = x[batch_indices, patch_indices_keep] if self.exclude_first_token: x = torch.cat((cls_tokens, x), dim=1) return x class Group(nn.Module): def __init__(self, num_group, group_size): super().__init__() self.num_group = num_group self.group_size = group_size def forward(self, xyz, color): ''' input: B N 3 --------------------------- output: B G M 3 center : B G 3 ''' batch_size, num_points, _ = xyz.shape # fps the centers out center = fps(xyz, self.num_group) # B G 3 # knn to get the neighborhood # _, idx = self.knn(xyz, center) # B G M idx = knn_point(self.group_size, xyz, center) # B G M assert idx.size(1) == self.num_group assert idx.size(2) == self.group_size idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) neighborhood = xyz.view(batch_size * num_points, -1)[idx, :] neighborhood = neighborhood.view(batch_size, self.num_group, self.group_size, 3).contiguous() neighborhood_color = color.view(batch_size * num_points, -1)[idx, :] neighborhood_color = neighborhood_color.view(batch_size, self.num_group, self.group_size, 3).contiguous() # normalize neighborhood = neighborhood - center.unsqueeze(2) features = torch.cat((neighborhood, neighborhood_color), dim=-1) return neighborhood, center, features class Encoder(nn.Module): def __init__(self, encoder_channel): super().__init__() self.encoder_channel = encoder_channel self.first_conv = nn.Sequential( nn.Conv1d(6, 128, 1), nn.BatchNorm1d(128), nn.ReLU(inplace=True), nn.Conv1d(128, 256, 1) ) self.second_conv = nn.Sequential( nn.Conv1d(512, 512, 1), nn.BatchNorm1d(512), nn.ReLU(inplace=True), nn.Conv1d(512, self.encoder_channel, 1) ) def forward(self, point_groups): ''' point_groups : B G N 3 ----------------- feature_global : B G C ''' bs, g, n , _ = point_groups.shape point_groups = point_groups.reshape(bs * g, n, 6) # encoder feature = self.first_conv(point_groups.transpose(2,1)) # BG 256 n feature_global = torch.max(feature,dim=2,keepdim=True)[0] # BG 256 1 feature = torch.cat([feature_global.expand(-1,-1,n), feature], dim=1)# BG 512 n feature = self.second_conv(feature) # BG 1024 n feature_global = torch.max(feature, dim=2, keepdim=False)[0] # BG 1024 return feature_global.reshape(bs, g, self.encoder_channel) class skeleton_Group(nn.Module): def __init__(self, num_group=32, group_size=8): super().__init__() self.num_group = num_group self.group_size = group_size def forward(self, xyz, token_feat, num_group=32, group_size=8): ''' xyz: 所有token的xyz input: B N 3 --------------------------- output: B G M 3 center : B G 3 ''' self.num_group = num_group self.group_size = group_size batch_size, num_points, _ = xyz.shape _, _, C_ = token_feat.shape # fps the centers out center = fps(xyz, self.num_group) # B G 3 # knn to get the neighborhood # _, idx = self.knn(xyz, center) # B G M idx = knn_point(self.group_size, xyz, center) # B G M assert idx.size(1) == self.num_group assert idx.size(2) == self.group_size idx_base = torch.arange(0, batch_size, device=xyz.device).view(-1, 1, 1) * num_points idx = idx + idx_base idx = idx.view(-1) token_feat = token_feat.contiguous() neighborhood_token_feat = token_feat.view(batch_size * num_points, -1)[idx, :] # T_p: B, 32, 8,384 -> B, M, K, C neighborhood_token_feat = neighborhood_token_feat.view(batch_size, self.num_group, self.group_size, C_).contiguous() # T_m: # B,32, 384 -> B, M, C center_token, _ = torch.max(neighborhood_token_feat, dim=2) # #### fix OM——pooling no resnet, no dim norm # T_m: B,32,1, 384 -> B, M, 1, C center_token = center_token.unsqueeze(2) # T_m (T): B,32,384, 1 -> B, M, C, 1 center_token = center_token.permute(0, 1, 3, 2) neighborhood_token_feat_ = torch.nn.functional.normalize(neighborhood_token_feat, p=2, dim=-1) center_token_ = torch.nn.functional.normalize(center_token, p=2, dim=-2) # A: B, 32, 8, 1 -> B, M, K, 1 router_weights = torch.einsum('btnj,btjm->btnm', [neighborhood_token_feat_, center_token_]) # A(T): B, 32, 1, 8 -> B, M, 1, K router_weights = router_weights.permute(0, 1, 3, 2) # A(T): B, 32, 1, 8 -> B, M, 1, K router_weights = F.softmax(router_weights, dim=-1) # T^p_pc: B, 32, 1, 384 -> B, M, 1, C skeleton_token = torch.einsum('bmok, bmkc->bmoc', [router_weights, neighborhood_token_feat]) skeleton_token = skeleton_token.squeeze(dim=-2) #### return skeleton_token class PointcloudEncoder(nn.Module): def __init__(self, point_transformer, args): super().__init__() from easydict import EasyDict self.trans_dim = args.pc_feat_dim # 768 self.embed_dim = args.embed_dim # 512 self.group_size = args.group_size # 32 self.num_group = args.num_group # 512 # grouper self.group_divider = Group(num_group = self.num_group, group_size = self.group_size) self.skeleton_Group = skeleton_Group() # define the encoder self.encoder_dim = args.pc_encoder_dim # 256 self.encoder = Encoder(encoder_channel = self.encoder_dim) # bridge encoder and transformer self.encoder2trans = nn.Linear(self.encoder_dim, self.trans_dim) # bridge transformer and clip embedding self.trans2embed = nn.Linear(self.trans_dim, self.embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, self.trans_dim)) self.cls_pos = nn.Parameter(torch.randn(1, 1, self.trans_dim)) self.pos_embed = nn.Sequential( nn.Linear(3, 128), nn.GELU(), nn.Linear(128, self.trans_dim) ) # setting a patch_dropout of 0. would mean it is disabled and this function would be the identity fn self.patch_dropout = PatchDropout(args.patch_dropout) if args.patch_dropout > 0. else nn.Identity() self.visual = point_transformer self.float_flag =False # def forward(self, pc, get_pc_tokens_way=None): pc = pc.to(dtype=torch.float) pts = pc[:,:,:3].contiguous() colors = pc[:,:,3:].contiguous() # divide the point cloud in the same form. This is important _, center, features = self.group_divider(pts, colors) # encoder the input cloud patches group_input_tokens = self.encoder(features) # B G N group_input_tokens = self.encoder2trans(group_input_tokens) # prepare cls cls_tokens = self.cls_token.expand(group_input_tokens.size(0), -1, -1) cls_pos = self.cls_pos.expand(group_input_tokens.size(0), -1, -1) # add pos embedding pos = self.pos_embed(center) # final input x = torch.cat((cls_tokens, group_input_tokens), dim=1) pos = torch.cat((cls_pos, pos), dim=1) # transformer x = x + pos # x = x.half() # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in x = self.patch_dropout(x) x = self.visual.pos_drop(x) # ModuleList not support forward block_len = len(self.visual.blocks) for i, blk in enumerate(self.visual.blocks): x = blk(x) if block_len-2 == i: last_sec_x =x if get_pc_tokens_way=="CLS": # CLS x = self.visual.norm(x[:, 0, :]) x = self.visual.fc_norm(x) x = self.trans2embed(x) x = x.unsqueeze(1) elif get_pc_tokens_way=="OM_Pooling": pc_skeleton = self.skeleton_Group(center, last_sec_x[:, 1:]) x = torch.cat([x[:, 0].unsqueeze(1), x[:, 1:].max(1)[0].unsqueeze(1), x[:, 1:].mean(1).unsqueeze(1), torch.sum(x[:, 1:], dim=1).unsqueeze(1), pc_skeleton], dim=-2) x = self.visual.norm(x) x = self.visual.fc_norm(x) x = self.trans2embed(x) return x