| from .transformer_utils import BaseTemperalPointModel
|
| from copy import deepcopy
|
| import torch
|
| import einops
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch import nn
|
| from einops import rearrange
|
| import pointops
|
| from pointcept.models.utils import offset2batch, batch2offset
|
| class PointBatchNorm(nn.Module):
|
| """
|
| Batch Normalization for Point Clouds data in shape of [B*N, C], [B*N, L, C]
|
| """
|
|
|
| def __init__(self, embed_channels):
|
| super().__init__()
|
| self.norm = nn.BatchNorm1d(embed_channels)
|
|
|
| def forward(self, input: torch.Tensor) -> torch.Tensor:
|
| if input.dim() == 3:
|
| return (
|
| self.norm(input.transpose(1, 2).contiguous())
|
| .transpose(1, 2)
|
| .contiguous()
|
| )
|
| elif input.dim() == 2:
|
| return self.norm(input)
|
| else:
|
| raise NotImplementedError
|
|
|
| class GroupedVectorAttention(nn.Module):
|
| def __init__(
|
| self,
|
| embed_channels,
|
| groups,
|
| attn_drop_rate=0.0,
|
| qkv_bias=True,
|
| pe_multiplier=False,
|
| pe_bias=True,
|
| ):
|
| super(GroupedVectorAttention, self).__init__()
|
| self.embed_channels = embed_channels
|
| self.groups = groups
|
| assert embed_channels % groups == 0
|
| self.attn_drop_rate = attn_drop_rate
|
| self.qkv_bias = qkv_bias
|
| self.pe_multiplier = pe_multiplier
|
| self.pe_bias = pe_bias
|
|
|
| self.linear_q = nn.Sequential(
|
| nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
|
| PointBatchNorm(embed_channels),
|
| nn.ReLU(inplace=True),
|
| )
|
| self.linear_k = nn.Sequential(
|
| nn.Linear(embed_channels, embed_channels, bias=qkv_bias),
|
| PointBatchNorm(embed_channels),
|
| nn.ReLU(inplace=True),
|
| )
|
|
|
| self.linear_v = nn.Linear(embed_channels, embed_channels, bias=qkv_bias)
|
|
|
| if self.pe_multiplier:
|
| self.linear_p_multiplier = nn.Sequential(
|
| nn.Linear(3, embed_channels),
|
| PointBatchNorm(embed_channels),
|
| nn.ReLU(inplace=True),
|
| nn.Linear(embed_channels, embed_channels),
|
| )
|
| if self.pe_bias:
|
| self.linear_p_bias = nn.Sequential(
|
| nn.Linear(3, embed_channels),
|
| PointBatchNorm(embed_channels),
|
| nn.ReLU(inplace=True),
|
| nn.Linear(embed_channels, embed_channels),
|
| )
|
| self.weight_encoding = nn.Sequential(
|
| nn.Linear(embed_channels, groups),
|
| PointBatchNorm(groups),
|
| nn.ReLU(inplace=True),
|
| nn.Linear(groups, groups),
|
| )
|
| self.softmax = nn.Softmax(dim=1)
|
| self.attn_drop = nn.Dropout(attn_drop_rate)
|
|
|
| def forward(self, feat, coord, reference_index):
|
| query, key, value = (
|
| self.linear_q(feat),
|
| self.linear_k(feat),
|
| self.linear_v(feat),
|
| )
|
| key = pointops.grouping(reference_index, key, coord, with_xyz=True)
|
| value = pointops.grouping(reference_index, value, coord, with_xyz=False)
|
| pos, key = key[:, :, 0:3], key[:, :, 3:]
|
| relation_qk = key - query.unsqueeze(1)
|
| if self.pe_multiplier:
|
| pem = self.linear_p_multiplier(pos)
|
| relation_qk = relation_qk * pem
|
| if self.pe_bias:
|
| peb = self.linear_p_bias(pos)
|
| relation_qk = relation_qk + peb
|
| value = value + peb
|
|
|
| weight = self.weight_encoding(relation_qk)
|
| weight = self.attn_drop(self.softmax(weight))
|
|
|
| mask = torch.sign(reference_index + 1)
|
| weight = torch.einsum("n s g, n s -> n s g", weight, mask)
|
| value = einops.rearrange(value, "n ns (g i) -> n ns g i", g=self.groups)
|
| feat = torch.einsum("n s g i, n s g -> n g i", value, weight)
|
| feat = einops.rearrange(feat, "n g i -> n (g i)")
|
| return feat
|
|
|
| class BlockSequence(nn.Module):
|
| def __init__(
|
| self,
|
| depth,
|
| embed_channels,
|
| groups,
|
| neighbours=16,
|
| qkv_bias=True,
|
| pe_multiplier=False,
|
| pe_bias=True,
|
| attn_drop_rate=0.0,
|
| drop_path_rate=0.0,
|
| enable_checkpoint=False,
|
| ):
|
| super(BlockSequence, self).__init__()
|
|
|
| if isinstance(drop_path_rate, list):
|
| drop_path_rates = drop_path_rate
|
| assert len(drop_path_rates) == depth
|
| elif isinstance(drop_path_rate, float):
|
| drop_path_rates = [deepcopy(drop_path_rate) for _ in range(depth)]
|
| else:
|
| drop_path_rates = [0.0 for _ in range(depth)]
|
|
|
| self.neighbours = neighbours
|
| self.blocks = nn.ModuleList()
|
| for i in range(depth):
|
| block = Block(
|
| embed_channels=embed_channels,
|
| groups=groups,
|
| qkv_bias=qkv_bias,
|
| pe_multiplier=pe_multiplier,
|
| pe_bias=pe_bias,
|
| attn_drop_rate=attn_drop_rate,
|
| drop_path_rate=drop_path_rates[i],
|
| enable_checkpoint=enable_checkpoint,
|
| )
|
| self.blocks.append(block)
|
|
|
| def forward(self, points):
|
| coord, feat, offset = points
|
|
|
|
|
| reference_index, _ = pointops.knn_query(self.neighbours, coord, offset)
|
| for block in self.blocks:
|
| points = block(points, reference_index)
|
| return points
|
|
|
| class GVAPatchEmbed(nn.Module):
|
| def __init__(
|
| self,
|
| depth,
|
| in_channels,
|
| embed_channels,
|
| groups,
|
| neighbours=16,
|
| qkv_bias=True,
|
| pe_multiplier=False,
|
| pe_bias=True,
|
| attn_drop_rate=0.0,
|
| drop_path_rate=0.0,
|
| enable_checkpoint=False,
|
| ):
|
| super(GVAPatchEmbed, self).__init__()
|
| self.in_channels = in_channels
|
| self.embed_channels = embed_channels
|
| self.proj = nn.Sequential(
|
| nn.Linear(in_channels, embed_channels, bias=False),
|
| PointBatchNorm(embed_channels),
|
| nn.ReLU(inplace=True),
|
| )
|
| self.blocks = BlockSequence(
|
| depth=depth,
|
| embed_channels=embed_channels,
|
| groups=groups,
|
| neighbours=neighbours,
|
| qkv_bias=qkv_bias,
|
| pe_multiplier=pe_multiplier,
|
| pe_bias=pe_bias,
|
| attn_drop_rate=attn_drop_rate,
|
| drop_path_rate=drop_path_rate,
|
| enable_checkpoint=enable_checkpoint,
|
| )
|
|
|
| def forward(self, points):
|
| coord, feat, offset = points
|
| feat = self.proj(feat)
|
| return self.blocks([coord, feat, offset])
|
|
|
|
|
| class Block(nn.Module):
|
| def __init__(
|
| self,
|
| embed_channels,
|
| groups,
|
| qkv_bias=True,
|
| pe_multiplier=False,
|
| pe_bias=True,
|
| attn_drop_rate=0.0,
|
| drop_path_rate=0.0,
|
| enable_checkpoint=False,
|
| ):
|
| super(Block, self).__init__()
|
| self.attn = GroupedVectorAttention(
|
| embed_channels=embed_channels,
|
| groups=groups,
|
| qkv_bias=qkv_bias,
|
| attn_drop_rate=attn_drop_rate,
|
| pe_multiplier=pe_multiplier,
|
| pe_bias=pe_bias,
|
| )
|
| self.fc1 = nn.Linear(embed_channels, embed_channels, bias=False)
|
| self.fc3 = nn.Linear(embed_channels, embed_channels, bias=False)
|
| self.norm1 = PointBatchNorm(embed_channels)
|
| self.norm2 = PointBatchNorm(embed_channels)
|
| self.norm3 = PointBatchNorm(embed_channels)
|
| self.act = nn.ReLU(inplace=True)
|
| self.enable_checkpoint = enable_checkpoint
|
| self.drop_path = (
|
| DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
| )
|
|
|
| def forward(self, points, reference_index):
|
| coord, feat, offset = points
|
| identity = feat
|
| feat = self.act(self.norm1(self.fc1(feat)))
|
| feat = (
|
| self.attn(feat, coord, reference_index)
|
| if not self.enable_checkpoint
|
| else checkpoint(self.attn, feat, coord, reference_index)
|
| )
|
| feat = self.act(self.norm2(feat))
|
| feat = self.norm3(self.fc3(feat))
|
| feat = identity + self.drop_path(feat)
|
| feat = self.act(feat)
|
| return [coord, feat, offset] |