Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils.checkpoint | |
| import pointops | |
| from pointops import grouping, grouping2 | |
| from einops import rearrange | |
| import time | |
| from ..unimatch.dinov2.layers.block import Block as MultiViewBlock | |
| from ..unimatch.utils import mv_feature_add_position | |
| from ..unimatch.mv_transformer import MultiViewFeatureTransformer | |
| USE_PYTORCH_ATTN = False | |
| USE_FLASH_ATTN3 = False | |
| # try: | |
| # from flash_attn_interface import flash_attn_func | |
| # FA3_AVAILABLE = True | |
| # warnings.warn('flash attention 3 is available (point attn)') | |
| # except ImportError: | |
| # FA3_AVAILABLE = False | |
| # warnings.warn('flash attention 3 is not available (point attn)') | |
| class KNNAttention(nn.Module): | |
| # TODO: multi-head | |
| def __init__(self, channels, knn_samples=16, no_rpe=True, | |
| qk_norm=False, | |
| num_heads=1, | |
| proj_channels=None, | |
| use_fused=False, | |
| ): | |
| super().__init__() | |
| self.proj_channels = proj_channels | |
| self.knn_samples = knn_samples | |
| self.no_rpe = no_rpe | |
| self.num_heads = num_heads | |
| assert self.num_heads == 1 | |
| self.use_fused = use_fused | |
| if use_fused: | |
| try: | |
| import sys | |
| from optgs.paths import PROJECT_DIR | |
| sys.path.append(str(PROJECT_DIR / "submodules")) | |
| from fused_knn_attn import fused_knn_attention, FUSED_KNN_ATTN_CUDA_AVAILABLE | |
| self._fused_knn_attention = fused_knn_attention | |
| if not FUSED_KNN_ATTN_CUDA_AVAILABLE: | |
| import warnings | |
| warnings.warn( | |
| "Fused KNN attention CUDA extension not available, " | |
| "using PyTorch fallback (still avoids [N,K,C] intermediates)" | |
| ) | |
| except ImportError: | |
| import warnings | |
| warnings.warn( | |
| "fused_knn_attn package not found, falling back to unfused attention" | |
| ) | |
| self.use_fused = False | |
| self.qk_norm = qk_norm | |
| if qk_norm: | |
| self.q_norm = nn.RMSNorm(channels) | |
| self.k_norm = nn.RMSNorm(channels) | |
| if self.proj_channels is not None: | |
| self.qkv = nn.Linear(channels, self.proj_channels * 3, bias=False) | |
| self.proj = nn.Linear(self.proj_channels, channels) | |
| else: | |
| self.qkv = nn.Linear(channels, channels * 3, bias=False) | |
| self.proj = nn.Linear(channels, channels) | |
| if not self.no_rpe: | |
| self.rpe = nn.Sequential( | |
| nn.Linear(3, 32), | |
| nn.GELU(), | |
| nn.Linear(32, 1) | |
| ) | |
| def forward(self, pxo, knn_idx=None): | |
| # [N, 3], [N, C], [B] | |
| p, x, o = pxo | |
| c = x.size(1) | |
| if self.proj_channels is not None: | |
| c = self.proj_channels | |
| assert c % self.num_heads == 0 | |
| head_dim = c // self.num_heads | |
| scale_factor = head_dim ** -0.5 | |
| qkv = self.qkv(x) # [N, 3*C] | |
| x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) # each [N, C] | |
| # ---- Fused path: gather + attention in one kernel ---- | |
| if self.use_fused and self.no_rpe: | |
| # Ensure we have KNN indices | |
| if knn_idx is None: | |
| knn_idx, _ = pointops.knn_query( | |
| self.knn_samples, p, o, p, o | |
| ) | |
| # qk_norm: RMSNorm normalizes each C-dim vector independently, | |
| # so applying before gather is equivalent to applying after gather. | |
| if self.qk_norm: | |
| x_q = self.q_norm(x_q) | |
| x_k = self.k_norm(x_k) | |
| out = self._fused_knn_attention( | |
| x_q.contiguous(), x_k.contiguous(), x_v.contiguous(), | |
| knn_idx.contiguous(), scale_factor | |
| ) | |
| out = self.proj(out) | |
| return out | |
| # ---- Original unfused path ---- | |
| # # [N, K, C], [N, K] | |
| # x_k, idx = pointops.knn_query_and_group( | |
| # x_k.contiguous(), p, o, new_xyz=p, new_offset=o, | |
| # idx=knn_idx, | |
| # nsample=self.knn_samples, with_xyz=False | |
| # ) # [N, K, C] | |
| # | |
| # # [N, K, C] | |
| # x_v, _ = pointops.knn_query_and_group( | |
| # x_v.contiguous(), | |
| # p, | |
| # o, | |
| # new_xyz=p, | |
| # new_offset=o, | |
| # idx=idx, | |
| # nsample=self.knn_samples, | |
| # with_xyz=False, | |
| # ) | |
| # ---- Initial improved version ---- | |
| x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] | |
| x_kv_query, _ = pointops.knn_query_and_group( | |
| x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, | |
| idx=knn_idx, nsample=self.knn_samples, with_xyz=False | |
| ) # [N, K, 2C/3] | |
| x_k, x_v = torch.chunk(x_kv_query, chunks=2, dim=-1) | |
| # [N, K, 3], [N, K, C] | |
| # NOTE: without xyz in knn | |
| # p_r, x_k = x_k[:, :, :3], x_k[:, :, 3:] | |
| # [N, 1, K] | |
| assert self.no_rpe | |
| if not self.no_rpe: | |
| rpe = self.rpe(p_r).permute(0, 2, 1) | |
| else: | |
| rpe = 0 | |
| if self.qk_norm: | |
| x_q = self.q_norm(x_q) | |
| x_k = self.k_norm(x_k) | |
| n, k, c = x_k.shape | |
| # attention | |
| if USE_PYTORCH_ATTN: | |
| out = F.scaled_dot_product_attention( | |
| x_q.view(n, 1, c), | |
| x_k.view(n, k, c), | |
| x_v.view(n, k, c), | |
| ).reshape(n, c) # [N, C] | |
| elif (USE_FLASH_ATTN3 and FA3_AVAILABLE and self.no_rpe): | |
| # no relative pos enc | |
| out = flash_attn_func( | |
| x_q.view(n, 1, self.num_heads, head_dim).to(torch.bfloat16), | |
| x_k.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), | |
| x_v.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), | |
| )[0].reshape(n, c).float() # [N, C] | |
| else: | |
| # [N, 1, K] | |
| scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor + rpe | |
| # [N, C] | |
| out = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) | |
| out = self.proj(out) | |
| return out | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| channels, | |
| act="gelu", | |
| ): | |
| super().__init__() | |
| expansion = 4 | |
| self.fc1 = nn.Linear(channels, channels * expansion) | |
| if act is None or act in ['none', 'identity']: | |
| self.act = nn.Identity() | |
| elif act == 'gelu': | |
| self.act = nn.GELU() | |
| elif act == 'tanh': | |
| self.act = nn.Tanh() | |
| else: | |
| raise ValueError(f"unsupported activation {act}") | |
| self.fc2 = nn.Linear(channels * expansion, channels) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.fc2(x) | |
| return x | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, channels, knn_samples=16, post_norm=False, | |
| no_rpe=False, | |
| no_attn=False, | |
| no_norm=False, | |
| act="gelu", | |
| qk_norm=False, | |
| norm_pt_block=False, | |
| num_heads=1, | |
| attn_proj_channels=None, | |
| use_fused_attn=False, | |
| ): | |
| super().__init__() | |
| self.post_norm = post_norm | |
| self.no_attn = no_attn | |
| self.norm_pt_block = norm_pt_block | |
| if no_norm: | |
| self.norm1 = nn.Identity() | |
| self.norm2 = nn.Identity() | |
| else: | |
| self.norm1 = nn.LayerNorm(channels) | |
| self.norm2 = nn.LayerNorm(channels) | |
| if self.no_attn: | |
| self.linear = nn.Linear(channels, channels) | |
| else: | |
| self.attn = KNNAttention(channels, knn_samples=knn_samples, no_rpe=no_rpe, | |
| qk_norm=qk_norm, | |
| num_heads=num_heads, | |
| proj_channels=attn_proj_channels, | |
| use_fused=use_fused_attn, | |
| ) | |
| self.mlp = MLP(channels, act=act) | |
| if self.norm_pt_block: | |
| self.norm3 = nn.LayerNorm(channels) | |
| def forward(self, pxo, knn_idx=None): | |
| p, x, o = pxo | |
| if self.post_norm: | |
| if self.no_attn: | |
| x = x + self.norm1(self.linear(x)) | |
| else: | |
| x = x + self.norm1(self.attn((p, x, o), knn_idx=knn_idx)) | |
| x = x + self.norm2(self.mlp(x)) | |
| else: | |
| if self.no_attn: | |
| x = x + self.linear(self.norm1(x)) | |
| else: | |
| x = x + self.attn((p, self.norm1(x), o), knn_idx=knn_idx) | |
| x = x + self.mlp(self.norm2(x)) | |
| if self.norm_pt_block: | |
| x = self.norm3(x) | |
| return x | |
| class FPSSubsample(nn.Module): | |
| def __init__(self, in_planes, out_planes, stride=4, nsample=16, | |
| agg_func='attn', | |
| subsample_method='fps', | |
| return_idx=False, | |
| fps_num_samples=None, | |
| attn_channels=64, | |
| ): | |
| super().__init__() | |
| assert stride > 0 | |
| self.agg_func = agg_func | |
| self.subsample_method = subsample_method | |
| self.knn_samples = nsample | |
| self.return_idx = return_idx | |
| self.stride, self.nsample = stride, nsample | |
| if fps_num_samples is not None: | |
| self.nsample = fps_num_samples | |
| # if stride != 1: | |
| # # xyz + feature | |
| # # self.linear = nn.Linear(3 + in_planes, out_planes, bias=not post_norm) | |
| # # only feature | |
| # # TODO: attention aggregation | |
| # if agg_func == 'maxpool': | |
| # self.agg = nn.MaxPool1d(nsample) | |
| # elif agg_func == 'avgpool': | |
| # self.agg = nn.AvgPool1d(nsample) | |
| # else: | |
| # raise ValueError(f"unsupported agg_func {agg_func}") | |
| # fewer channels to save memory | |
| assert agg_func in ['attn', 'avgpool'] | |
| if self.agg_func == 'attn': | |
| self.q = nn.Linear(in_planes, attn_channels, bias=False) | |
| self.k = nn.Linear(in_planes, attn_channels, bias=False) | |
| self.v = nn.Linear(in_planes, attn_channels, bias=False) | |
| self.proj = nn.Linear(attn_channels, out_planes, bias=True) | |
| self.residual = nn.Linear(in_planes, out_planes, bias=True) | |
| else: | |
| self.proj = nn.Linear(in_planes, out_planes, bias=True) | |
| def forward(self, pxo): | |
| p, x, o = pxo # (n, 3), (n, c), (b) | |
| if self.stride != 1: | |
| if self.subsample_method == 'density': | |
| assert False # not well tested | |
| n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride | |
| for i in range(1, o.shape[0]): | |
| count += (o[i].item() - o[i - 1].item()) // self.stride | |
| n_o.append(count) | |
| n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) | |
| # [N, K, C+3] | |
| x_k, _ = pointops.knn_query_and_group( | |
| x.contiguous(), p, o, new_xyz=p, new_offset=o, nsample=self.knn_samples, with_xyz=True | |
| ) | |
| p_r = x_k[:, :, 0:3] | |
| density = torch.mean(torch.norm(p_r, dim=-1), dim=-1) # [N] | |
| # TODO: normalize the distance | |
| weights = (density - density.min()) / (density.max() - density.min() + 1e-6) | |
| # weights = density | |
| # weights = 1.0 / (density + 1e-6) # Inverse density weighting | |
| # to batch | |
| lists = [weights[:o[0]]] | |
| for i in range(o.shape[0] - 1): | |
| lists.append(weights[o[i]:o[i+1]]) | |
| weights = torch.stack(lists, dim=0) # [B, N] | |
| weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights | |
| # Sample points based on weights | |
| batch = n_o.shape[0] | |
| num_samples = o[0].item() // self.stride | |
| sampled_indices = torch.stack([ | |
| torch.multinomial(weights[b], num_samples, replacement=False) | |
| for b in range(batch) | |
| ], dim=0) # (B, num_samples) | |
| idx = rearrange(sampled_indices, "b n -> (b n)") | |
| point_list = [p[:o[0]]] | |
| for i in range(o.shape[0] - 1): | |
| point_list.append(p[o[i]:o[i+1], :]) | |
| points = torch.stack(point_list, dim=0) # [B, N, 3] | |
| # Gather sampled points | |
| sampled_points = torch.gather(points, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, 3)) | |
| # print(sampled_points.shape) # [B, M, 3] | |
| sampled_points = rearrange(sampled_points, "b m c -> (b m) c") | |
| # average pooling | |
| # TODO: try others | |
| x = x_k.mean(dim=1) # [N, C] | |
| x_list = [x[:o[0]]] | |
| for i in range(o.shape[0] - 1): | |
| x_list.append(x[o[i]:o[i+1], :]) | |
| x = torch.stack(x_list, dim=0) # [B, N, C] | |
| # Gather sampled points | |
| x = torch.gather(x, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, x.size(-1))) | |
| x = rearrange(x, "b n c -> (b n) c") | |
| # TODO: do we need to add residual to x here? | |
| # use the index to subsample the initial features | |
| x = self.proj(x) | |
| p, o = sampled_points, n_o | |
| elif self.subsample_method in ['fps', 'grid']: | |
| n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride | |
| for i in range(1, o.shape[0]): | |
| count += (o[i].item() - o[i - 1].item()) // self.stride | |
| n_o.append(count) | |
| n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) | |
| if self.subsample_method == 'fps': | |
| idx = pointops.farthest_point_sampling(p, o, n_o) # (m) | |
| else: | |
| # uniform sampling: sanity check | |
| # first reshape to V, H, W, then do grid sampling | |
| # Generate grid indices | |
| # TODO: grid sample in the image space | |
| idx = torch.arange(0, p.size(0), self.stride).to(x.device) | |
| n_p = p[idx.long(), :] # (m, 3) | |
| x_subsample = x[idx.long(), :] # [M, C] | |
| if self.agg_func == 'attn': | |
| x_q = self.q(x_subsample) # [M, C] | |
| # [M, K, C] | |
| x_k = self.k(x) # [N, C] | |
| else: | |
| x_k = x | |
| x_k, knn_idx = pointops.knn_query_and_group( | |
| x_k, | |
| p, | |
| offset=o, | |
| new_xyz=n_p, | |
| new_offset=n_o, | |
| nsample=self.nsample, | |
| with_xyz=False, # remove xyz | |
| ) | |
| if self.agg_func == 'attn': | |
| x_v = self.v(x) | |
| x_v, _ = pointops.knn_query_and_group( | |
| x_v, | |
| p, | |
| offset=o, | |
| new_xyz=n_p, | |
| new_offset=n_o, | |
| idx=knn_idx, | |
| nsample=self.nsample, | |
| with_xyz=False, # remove xyz | |
| ) | |
| # attention | |
| # x_q: [M, C], x_k: [M, K, C], x_v: [M, K, C] | |
| scale_factor = x_q.shape[-1] ** -0.5 | |
| # [M, 1, K] | |
| # no relative posenc | |
| scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor | |
| # [M, C] | |
| x = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) | |
| # if self.agg_func in ['avgpool', 'maxpool']: | |
| # x = self.agg(x.transpose(1, 2).contiguous()).squeeze(-1) # (m, c) | |
| # else: | |
| # raise NotImplementedError | |
| # add residual to x here? | |
| # use the index to subsample the initial features | |
| x = self.residual(x_subsample) + self.proj(x) | |
| else: | |
| x = x_k.mean(dim=1) | |
| x = self.proj(x) | |
| p, o = n_p, n_o | |
| else: | |
| raise ValueError(f"unsupported subsampling method {self.subsample_method}") | |
| else: | |
| # add residual to x here? | |
| x = x + self.proj(x) | |
| idx = torch.arange(0, p.size(0)).to(x.device) | |
| if self.return_idx: | |
| return [p, x, o], idx | |
| return [p, x, o] | |
| class SubsampleBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride=4, knn_samples=16, post_norm=False, | |
| agg_func='attn', | |
| subsample_method='fps', | |
| return_idx=False, | |
| fps_num_samples=None, | |
| attn_proj_channels=None, | |
| ): | |
| super().__init__() | |
| assert not post_norm | |
| self.return_idx = return_idx | |
| self.post_norm = post_norm | |
| self.norm1 = nn.LayerNorm(in_channels) | |
| self.fps = FPSSubsample(in_channels, out_channels, stride=stride, nsample=knn_samples, | |
| agg_func=agg_func, | |
| subsample_method=subsample_method, | |
| return_idx=return_idx, | |
| fps_num_samples=fps_num_samples, | |
| attn_channels=attn_proj_channels, | |
| ) | |
| self.norm2 = nn.LayerNorm(out_channels) | |
| self.mlp = MLP(out_channels) | |
| def forward(self, pxo): | |
| # pre norm | |
| p, x, o = pxo | |
| x = self.norm1(x) | |
| if self.return_idx: | |
| pxo, idx = self.fps([p, x, o]) | |
| else: | |
| pxo = self.fps([p, x, o]) | |
| p, x, o = pxo | |
| x = x + self.mlp(self.norm2(x)) | |
| if self.return_idx: | |
| return [p, x, o], idx | |
| return [p, x, o] | |
| class SkipConnect(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.proj1 = nn.Linear(out_channels, out_channels) | |
| self.proj2 = nn.Linear(in_channels, out_channels) | |
| self.proj3 = nn.Linear(out_channels, out_channels) | |
| def forward(self, pxo1, pxo2): | |
| p1, x1, o1 = pxo1 | |
| p2, x2, o2 = pxo2 | |
| # TODO: support half precision | |
| with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.float32): | |
| x = self.proj1(x1) + pointops.interpolation2( | |
| p2, p1, self.proj2(x2), o2, o1 | |
| ) | |
| x = self.proj3(x) | |
| return x | |
| class PlainPointTransformer(nn.Module): | |
| def __init__(self, channels, knn_samples=16, num_blocks=4, post_norm=False, | |
| no_rpe=False, | |
| no_attn=False, | |
| no_norm=False, | |
| act="gelu", | |
| qk_norm=False, | |
| norm_pt_block=False, | |
| num_heads=1, | |
| attn_proj_channels=None, | |
| cache_knn_idx=None, | |
| knn_idx_update_every=1, | |
| with_mv_attn=False, | |
| with_mv_attn_lowres=False, | |
| mv_attn_first=False, | |
| no_mv_attn=False, | |
| conv_with_norm=False, | |
| mv_shuffle_attn=False, | |
| with_pos_enc=False, | |
| shuffle_attn_no_norm=False, | |
| mv_unimatch_attn=False, | |
| use_checkpointing=False, | |
| init_use_checkpointing=False, | |
| use_fused_attn=False, | |
| ): | |
| super().__init__() | |
| self.cache_knn_idx = cache_knn_idx | |
| self.knn_idx_update_every = knn_idx_update_every | |
| self.knn_samples = knn_samples | |
| self.use_checkpointing = use_checkpointing | |
| self.init_use_checkpointing = init_use_checkpointing | |
| self.with_mv_attn = with_mv_attn | |
| self.with_mv_attn_lowres = with_mv_attn_lowres | |
| if with_pos_enc: | |
| assert mv_shuffle_attn | |
| self.blocks = nn.ModuleList() | |
| for _ in range(num_blocks): | |
| self.blocks.append(TransformerBlock(channels, knn_samples=knn_samples, | |
| post_norm=post_norm, | |
| no_rpe=no_rpe, | |
| no_attn=no_attn, | |
| no_norm=no_norm, | |
| act=act, | |
| qk_norm=qk_norm, | |
| norm_pt_block=norm_pt_block, | |
| num_heads=num_heads, | |
| attn_proj_channels=attn_proj_channels, | |
| use_fused_attn=use_fused_attn, | |
| )) | |
| # multi-view attention | |
| if self.with_mv_attn: | |
| self.mv_blocks = nn.ModuleList() | |
| for _ in range(num_blocks): | |
| # if mv_shuffle_attn: | |
| if self.with_mv_attn_lowres: | |
| self.mv_blocks.append( | |
| MultViewLowresAttn( | |
| channels, | |
| ) | |
| ) | |
| else: | |
| self.mv_blocks.append( | |
| MultiViewBlock( | |
| channels, | |
| num_heads=4, | |
| ) | |
| ) | |
| # elif mv_unimatch_attn: | |
| # self.mv_blocks.append( | |
| # MultViewUniMatchAttn( | |
| # channels, | |
| # ) | |
| # ) | |
| # else: | |
| # self.mv_blocks.append( | |
| # MultViewUnetAttn(channels, | |
| # no_mv_attn=no_mv_attn, | |
| # conv_with_norm=conv_with_norm, | |
| # ) | |
| # ) | |
| def forward(self, pxo, iter=0, b=None, v=None, h=None, w=None): | |
| p, x, o = pxo | |
| # compute knn idx here only once and pass it to the model | |
| # the positions are not changed inside the blocks | |
| if self.cache_knn_idx is None or (iter % self.knn_idx_update_every) == 0: | |
| knn_idx, _ = pointops.knn_query(self.knn_samples, p, o, p, o) | |
| self.cache_knn_idx = knn_idx | |
| # print(knn_idx.float().mean().item()) | |
| else: | |
| knn_idx = self.cache_knn_idx | |
| if self.with_mv_attn: | |
| assert b is not None and v is not None and h is not None and w is not None | |
| if self.use_checkpointing: | |
| raise NotImplementedError | |
| for i in range(len(self.blocks)): | |
| # knn attention | |
| x = self.blocks[i]([p, x, o], knn_idx=knn_idx) | |
| # global multi-view attention | |
| x = rearrange(x, "(b v h w) c -> b (v h w) c", b=b, v=v, h=h, w=w) | |
| if self.with_mv_attn_lowres: | |
| x = self.mv_blocks[i](x, v=v, h=h, w=w) | |
| # # TODO: hard-coded for now | |
| # if x.size(1) == 8 * 512 // 4 * 960 // 4: | |
| # x = self.mv_blocks[i](x, v=8, h=512 // 4, w=960 // 4) | |
| # elif x.size(1) == 8 * 256 // 4 * 448 // 4: | |
| # x = self.mv_blocks[i](x, v=8, h=256 // 4, w=448 // 4) | |
| # else: | |
| # raise ValueError(f"unsupported input size {x.size(1)} for multi-view attention") | |
| # # print(x.shape) | |
| else: | |
| x = self.mv_blocks[i](x) | |
| # x = x.squeeze(0) | |
| x = rearrange(x, "b (v h w) c -> (b v h w) c", | |
| b=b, v=v, h=h, w=w) | |
| else: | |
| for blk in self.blocks: | |
| if self.init_use_checkpointing: | |
| # checkpointing the inital reconstruction model | |
| # NOTE: cannot cache knn_idx here, otherwise index out error | |
| def custom_forward(p, x, o): | |
| return blk((p, x, o), knn_idx=None) # knn_idx is closed over | |
| x = torch.utils.checkpoint.checkpoint(custom_forward, p, x, o) | |
| else: | |
| x = blk((p, x, o), knn_idx=knn_idx) | |
| return x | |
| class MultViewUnetAttn(nn.Module): | |
| def __init__(self, channels, no_mv_attn=False, conv_with_norm=False): | |
| super().__init__() | |
| self.conv_with_norm = conv_with_norm | |
| self.down1 = nn.Conv2d(channels, channels, 3, 2, 1) | |
| self.down2 = nn.Conv2d(channels, channels, 3, 2, 1) | |
| self.up2 = nn.Conv2d(channels, channels, 3, 1, 1) | |
| self.up1 = nn.Conv2d(channels, channels, 3, 1, 1) | |
| self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) | |
| if self.conv_with_norm: | |
| self.norm1 = nn.LayerNorm(channels) | |
| self.norm2 = nn.LayerNorm(channels) | |
| self.norm3 = nn.LayerNorm(channels) | |
| self.norm4 = nn.LayerNorm(channels) | |
| def forward(self, x): | |
| v = 8 | |
| h = 256 // 4 | |
| w = 448 // 4 | |
| b = 1 | |
| assert x.size(0) == b * v * h * w | |
| residual = x | |
| x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) | |
| x1 = self.down1(x) # 1/2 | |
| if self.conv_with_norm: | |
| x1 = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| x2 = self.down2(x1) # 1/4 | |
| if self.conv_with_norm: | |
| x2 = self.norm2(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| x2 = rearrange(x2, "(b v) c h w -> b (v h w) c", b=b, v=v) | |
| x2 = self.attn(x2) # 1/4 | |
| x2 = rearrange(x2, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h//4, w=w//4) | |
| x2 = self.up2(x1 + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1/2 | |
| if self.conv_with_norm: | |
| x2 = self.norm3(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| x = self.up1(x + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1 | |
| if self.conv_with_norm: | |
| x = self.norm4(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) | |
| x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) | |
| x = residual + x | |
| return x | |
| class MultViewShuffleAttn(nn.Module): | |
| def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): | |
| super().__init__() | |
| self.down_factor = 4 | |
| self.with_pos_enc = with_pos_enc | |
| self.proj1 = nn.Linear(channels * self.down_factor ** 2, channels) | |
| if shuffle_attn_no_norm: | |
| self.norm1 = nn.Identity() | |
| else: | |
| self.norm1 = nn.LayerNorm(channels) | |
| self.proj2 = nn.Linear(channels, channels * self.down_factor ** 2) | |
| if shuffle_attn_no_norm: | |
| self.norm2 = nn.Identity() | |
| else: | |
| self.norm2 = nn.LayerNorm(channels * self.down_factor ** 2) | |
| self.conv = nn.Conv2d(channels, channels, 3, 1, 1) | |
| if no_mv_attn: | |
| self.attn = nn.Identity() | |
| else: | |
| self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) | |
| def forward(self, x): | |
| v = 8 | |
| h = 256 // 4 | |
| w = 448 // 4 | |
| b = 1 | |
| assert x.size(0) == b * v * h * w | |
| residual = x | |
| x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) | |
| # TODO: add positional encoding to x | |
| if self.with_pos_enc: | |
| x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) | |
| # print(x.shape) | |
| x = F.pixel_unshuffle(x, self.down_factor) | |
| x = rearrange(x, "(b v) c h w -> b (v h w) c", b=b) | |
| x = self.proj1(x) | |
| x = self.norm1(x) | |
| x = self.attn(x) | |
| x = self.proj2(x) | |
| x = self.norm2(x) | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h // self.down_factor, w=w // self.down_factor) | |
| x = F.pixel_shuffle(x, self.down_factor) | |
| x = self.conv(x) | |
| x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) | |
| x = x + residual | |
| return x | |
| class MultViewLowresAttn(nn.Module): | |
| def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, | |
| down_factor=4, | |
| attn_proj_channels=None, | |
| ): | |
| super().__init__() | |
| self.down_factor = down_factor | |
| self.with_pos_enc = with_pos_enc | |
| self.attn_proj_channels = attn_proj_channels | |
| if attn_proj_channels: | |
| ori_channels = channels | |
| self.proj0 = nn.Linear(channels, attn_proj_channels) | |
| channels = attn_proj_channels | |
| if self.down_factor == 8: | |
| down_factor = 4 | |
| else: | |
| down_factor = self.down_factor | |
| self.proj1 = nn.Linear(channels * down_factor ** 2, channels) | |
| if shuffle_attn_no_norm: | |
| self.norm1 = nn.Identity() | |
| else: | |
| self.norm1 = nn.LayerNorm(channels) | |
| self.proj2 = nn.Linear(channels, channels * down_factor ** 2) | |
| if shuffle_attn_no_norm: | |
| self.norm2 = nn.Identity() | |
| else: | |
| self.norm2 = nn.LayerNorm(channels * down_factor ** 2) | |
| self.conv = nn.Conv2d(channels, channels, 3, 1, 1) | |
| if attn_proj_channels: | |
| self.proj3 = nn.Linear(channels, ori_channels) | |
| if no_mv_attn: | |
| self.attn = nn.Identity() | |
| else: | |
| num_heads = 1 if self.attn_proj_channels else 4 | |
| self.attn = MultiViewBlock(channels, num_heads, no_attn=no_mv_attn) | |
| def forward(self, x, v=None, h=None, w=None, y=None): | |
| if y is not None: | |
| return self.forward_cross_attn(x, y, v, h, w) | |
| residual = x | |
| if self.attn_proj_channels: | |
| x = self.proj0(x) | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) | |
| # TODO: add positional encoding to x | |
| if self.with_pos_enc: | |
| x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) | |
| # print(x.shape) | |
| if self.down_factor == 8: | |
| # bilinear to half first to save channels | |
| x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| down_factor = 4 | |
| else: | |
| down_factor = self.down_factor | |
| x = F.pixel_unshuffle(x, down_factor) | |
| x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) | |
| x = self.proj1(x) | |
| x = self.norm1(x) | |
| x = self.attn(x) | |
| x = self.proj2(x) | |
| x = self.norm2(x) | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) | |
| x = F.pixel_shuffle(x, down_factor) | |
| x = self.conv(x) | |
| if self.down_factor == 8: | |
| # bilinear to full | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
| x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) | |
| if self.attn_proj_channels: | |
| x = self.proj3(x) | |
| x = x + residual | |
| return x | |
| def forward_cross_attn(self, x, y, v=None, h=None, w=None): | |
| residual = x | |
| if self.attn_proj_channels: | |
| x = self.proj0(x) | |
| assert y is not None | |
| y = rearrange(y, "b (v h w) c -> (b v) c h w", h=h, w=w) # different v with x | |
| num_cross_view = y.shape[0] // x.shape[0] | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) | |
| # TODO: add positional encoding to x | |
| if self.with_pos_enc: | |
| x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) | |
| # print(x.shape) | |
| if self.down_factor == 8: | |
| # bilinear to half first to save channels | |
| x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| y = F.interpolate(y, scale_factor=0.5, mode='bilinear', align_corners=True) | |
| down_factor = 4 | |
| else: | |
| down_factor = self.down_factor | |
| x = F.pixel_unshuffle(x, down_factor) | |
| y = F.pixel_unshuffle(y, down_factor) | |
| x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) | |
| y = rearrange(y, "(b v) c h w -> b (v h w) c", v=num_cross_view) | |
| x = self.proj1(x) | |
| x = self.norm1(x) | |
| y = self.proj1(y) | |
| y = self.norm1(y) | |
| # x_tmp = self.attn(x) | |
| # print((x - y).abs().max().item()) | |
| x = self.attn(x, y) | |
| # there will be slight diff for self and cross attn caused by flash3 | |
| # print((x_tmp - x).abs().max().item()) | |
| x = self.proj2(x) | |
| x = self.norm2(x) | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) | |
| x = F.pixel_shuffle(x, down_factor) | |
| x = self.conv(x) | |
| if self.down_factor == 8: | |
| # bilinear to full | |
| x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) | |
| x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) | |
| if self.attn_proj_channels: | |
| x = self.proj3(x) | |
| x = x + residual | |
| return x | |
| class GaussianErrorCrossAttn(nn.Module): | |
| def __init__(self, gaussian_channels, | |
| error_channels, | |
| model_channels=256, | |
| no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, | |
| down_factor=4, | |
| attn_proj_channels=None, | |
| num_heads=4, | |
| with_mlp=False, | |
| ): | |
| super().__init__() | |
| self.num_heads = num_heads | |
| self.model_channels = model_channels | |
| self.down_factor = down_factor | |
| self.with_mlp = with_mlp | |
| # self.q_norm = nn.LayerNorm(gaussian_channels) | |
| self.q_proj = nn.Linear(gaussian_channels, model_channels) | |
| kv_channels = error_channels * (down_factor ** 2) | |
| # self.kv_norm = nn.LayerNorm(kv_channels) | |
| self.kv_proj = nn.Linear(kv_channels, 2 * model_channels) | |
| # self.out_proj = nn.Linear(model_channels, gaussian_channels) | |
| # concat | |
| self.out_proj = nn.Linear(model_channels + gaussian_channels, gaussian_channels) | |
| if with_mlp: | |
| self.mlp_norm = nn.LayerNorm(gaussian_channels) | |
| self.mlp = MLP(gaussian_channels) | |
| def forward(self, gaussian, error, v=None, h=None, w=None, mask=None): | |
| # [B, VHW, C] | |
| residual = gaussian | |
| b = gaussian.size(0) | |
| # x = self.q_norm(gaussian) | |
| x = gaussian | |
| q = self.q_proj(x) # [B, VHW, C] | |
| # spatial reshape to save computation | |
| error = rearrange(error, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) | |
| error = F.pixel_unshuffle(error, self.down_factor) | |
| error = rearrange(error, "(b v) c h w -> b (v h w) c", v=v) | |
| # error = self.kv_norm(error) | |
| kv = self.kv_proj(error) | |
| k, v = kv.chunk(2, dim=-1) # [B, VHW, C] | |
| # attention | |
| c = self.model_channels | |
| head_dim = c // self.num_heads | |
| # [B, N, C] → [B, num_heads, N, head_dim] | |
| def reshape(x): | |
| return x.view(b, -1, self.num_heads, head_dim).transpose(1, 2) # [B, H, N, D] | |
| q = reshape(q) | |
| k = reshape(k) | |
| v = reshape(v) | |
| # Fast fused attention | |
| out = F.scaled_dot_product_attention(q, k, v) | |
| # [B, H, N, D] → [B, N, C] | |
| out = out.transpose(1, 2).contiguous().view(b, -1, c) | |
| # return self.out_proj(out) | |
| # out = residual + self.out_proj(out) | |
| # concat | |
| out = self.out_proj(torch.cat([out, gaussian], dim=-1)) | |
| # if self.with_mlp: | |
| # out = out + self.mlp(self.mlp_norm(out)) | |
| return out | |
| class MultViewUniMatchAttn(nn.Module): | |
| def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): | |
| super().__init__() | |
| self.attn = MultiViewFeatureTransformer(num_layers=1, | |
| d_model=channels, | |
| ) | |
| def forward(self, x, v=None, h=None, w=None): | |
| residual = x | |
| x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) | |
| attn_splits = 4 | |
| # add pos enc | |
| x = mv_feature_add_position(x, attn_splits, feature_channels=x.size(1)) | |
| x = rearrange(x, "(b v) c h w -> b v c h w", v=v) | |
| x_list = list(torch.unbind(x, dim=1)) | |
| x_list = self.attn(x_list, attn_splits) | |
| x = torch.stack(x_list, dim=1) | |
| x = rearrange(x, "b v c h w -> b (v h w) c") | |
| return x | |
| class MultiScalePointTransformer(nn.Module): | |
| def __init__(self, channels, knn_samples=16, post_norm=False, | |
| no_rpe=True, | |
| no_attn=False, | |
| qk_norm=False, | |
| norm_pt_block=False, | |
| num_heads=1, | |
| num_scales=3, | |
| stride=4, | |
| downsample_agg_func='attn', | |
| subsample_method='fps', | |
| fps_num_samples=None, | |
| attn_proj_channels=None, | |
| ): | |
| super().__init__() | |
| self.blocks = nn.ModuleList() | |
| # knn 4 at 1 | |
| self.blocks.append(TransformerBlock(channels, knn_samples=4, | |
| post_norm=post_norm, | |
| no_rpe=no_rpe, | |
| no_attn=no_attn, | |
| qk_norm=qk_norm, | |
| norm_pt_block=norm_pt_block, | |
| num_heads=num_heads, | |
| attn_proj_channels=attn_proj_channels, | |
| )) | |
| for i in range(num_scales - 2, -1, -1): | |
| # knn 8 at 1/4 | |
| # knn 16 at 1/16 | |
| self.blocks.append(TransformerBlock(channels * (2 ** i), knn_samples= knn_samples // (2 ** i), | |
| post_norm=post_norm, | |
| no_rpe=no_rpe, | |
| no_attn=no_attn, | |
| qk_norm=qk_norm, | |
| norm_pt_block=norm_pt_block, | |
| num_heads=num_heads, | |
| attn_proj_channels=attn_proj_channels, | |
| )) | |
| self.down_blocks = nn.ModuleList() | |
| for i in range(num_scales - 1): | |
| self.down_blocks.append( | |
| SubsampleBlock( | |
| channels * (2 ** i), channels * (2 ** (i + 1)), | |
| stride=stride, | |
| knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), | |
| subsample_method=subsample_method, | |
| agg_func=downsample_agg_func, | |
| fps_num_samples=fps_num_samples, | |
| attn_proj_channels=attn_proj_channels, | |
| ) | |
| ) | |
| self.down_agg = nn.ModuleList() | |
| for i in range(num_scales - 1): | |
| self.down_agg.append( | |
| TransformerBlock(channels * (2 ** (i + 1)), knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), | |
| post_norm=post_norm, | |
| no_rpe=no_rpe, | |
| no_attn=no_attn, | |
| qk_norm=qk_norm, | |
| norm_pt_block=norm_pt_block, | |
| num_heads=num_heads, | |
| attn_proj_channels=attn_proj_channels, | |
| ) | |
| ) | |
| self.skip_blocks = nn.ModuleList() | |
| for i in range(num_scales - 1, 0, -1): | |
| self.skip_blocks.append( | |
| SkipConnect( | |
| channels * (2 ** i), | |
| channels * (2 ** (i - 1)) | |
| ) | |
| ) | |
| def forward(self, pxo): | |
| x1 = self.blocks[0](pxo) # 1 | |
| p1, o1 = pxo[0], pxo[2] | |
| p2, x2, o2 = self.down_blocks[0]([p1, x1, o1]) # 1/4 | |
| x2 = self.down_agg[0]([p2, x2, o2]) # 1/4 | |
| p3, x3, o3 = self.down_blocks[1]([p2, x2, o2]) # 1/16 | |
| x3 = self.down_agg[1]([p3, x3, o3]) # 1/16 | |
| x4 = self.skip_blocks[0]([p2, x2, o2], [p3, x3, o3]) # 1/4 | |
| p4, o4 = p2, o2 | |
| x4 = self.blocks[1]([p4, x4, o4]) | |
| x5 = self.skip_blocks[1]([p1, x1, o1], [p4, x4, o4]) # 1 | |
| p5, o5 = p1, o1 | |
| x5 = self.blocks[2]([p5, x5, o5]) | |
| return x5 | |
| class PointLinearWrapper(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.linear = nn.Linear(in_channels, out_channels) | |
| def forward(self, pxo, b=None, v=None, h=None, w=None): | |
| p, x, o = pxo | |
| x = self.linear(x) | |
| return [p, x, o] | |
| class SwiGLUFFN(nn.Module): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| hidden_features: int | None = None, | |
| out_features: int | None = None, | |
| bias: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| out_features = out_features or in_features | |
| hidden_features = hidden_features or in_features | |
| self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) | |
| self.w3 = nn.Linear(hidden_features, out_features, bias=bias) | |
| def forward(self, x): | |
| x12 = self.w12(x) | |
| x1, x2 = x12.chunk(2, dim=-1) | |
| hidden = F.silu(x1) * x2 | |
| return self.w3(hidden) | |
| def test_fps(): | |
| model = FPSSubsample(256, 256, | |
| fps_num_samples=16, | |
| subsample_method='fps', | |
| ).cuda() | |
| print(model) | |
| # FPS is significantly slower than grid with many points | |
| c = 256 | |
| b, n = 2, 40480 | |
| x = torch.randn(b, n, c).cuda() | |
| offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) | |
| p = torch.randn(b, n, 3).cuda() | |
| pxo = [p.view(-1, 3), x.view(-1, c), offset] | |
| y = model(pxo) | |
| print(y[1].shape) | |
| count = 100 | |
| for _ in range(5): | |
| model(pxo) | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| for i in range(count): | |
| model(pxo) | |
| torch.cuda.synchronize() | |
| print(time.time() - start) | |
| def test_knn_query_and_group(): | |
| c = 256 | |
| # b, n = 2, 80480 | |
| b, n = 8, 57344 | |
| knn_samples = 16 | |
| x = torch.randn(b, n, c).cuda() | |
| offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) | |
| o = offset | |
| p = torch.randn(b, n, 3).cuda() | |
| p = p.view(-1, 3) | |
| knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) | |
| print(knn_idx.shape) | |
| c_qkv = 192 | |
| qkv = torch.randn(b*n, c_qkv).cuda() | |
| T = 1000 | |
| # chunk first, then query twice | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| for _ in range(T): | |
| x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) | |
| x_k_query, idx = pointops.knn_query_and_group( | |
| x_k.contiguous(), p, o, new_xyz=p, new_offset=o, | |
| idx=knn_idx, | |
| nsample=knn_samples, with_xyz=False | |
| ) # [N, K, C/3] | |
| x_v_query, _ = pointops.knn_query_and_group( | |
| x_v.contiguous(), | |
| p, | |
| o, | |
| new_xyz=p, | |
| new_offset=o, | |
| idx=idx, | |
| nsample=knn_samples, | |
| with_xyz=False, | |
| ) | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") | |
| # query first, then chunk | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| for _ in range(T): | |
| x_qkv_query = pointops.knn_query_and_group( | |
| qkv.contiguous(), p, o, new_xyz=p, new_offset=o, | |
| idx=knn_idx, | |
| nsample=knn_samples, with_xyz=False | |
| )[0] # [N, K, C*3] | |
| x_q, x_k, x_v = torch.chunk(x_qkv_query, chunks=3, dim=-1) | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") | |
| # chunk first, then query once | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| for _ in range(T): | |
| x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) | |
| x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] | |
| x_kv_query = pointops.knn_query_and_group( | |
| x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, | |
| idx=knn_idx, nsample=knn_samples, with_xyz=False | |
| )[0] # [N, K, 2C/3] | |
| x_k_query, x_v_query = torch.chunk(x_kv_query, 2, dim=-1) | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") | |
| def test_knn(): | |
| c = 256 | |
| b, n = 2, 80480 | |
| model = KNNAttention(channels=c, | |
| # proj_feature=64, | |
| ).cuda() | |
| print(model) | |
| x = torch.randn(b, n, c).cuda() | |
| offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) | |
| p = torch.randn(b, n, 3).cuda() | |
| pxo = [p.view(-1, 3), x.view(-1, c), offset] | |
| y = model(pxo) | |
| print(y.shape) | |
| count = 100 | |
| for _ in range(5): | |
| model(pxo) | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| for i in range(count): | |
| model(pxo) | |
| torch.cuda.synchronize() | |
| print(time.time() - start) | |
| def test_faiss_knn(): | |
| # cannot install faiss unfortunately | |
| # TODO: maybe implement a sliding window knn search later | |
| c = 256 | |
| b, n = 2, 80480 | |
| knn_samples = 16 | |
| x = torch.randn(b, n, c).cuda() | |
| offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) | |
| o = offset | |
| p = torch.randn(b, n, 3).cuda() | |
| p = p.view(-1, 3) | |
| # pxo = [p.view(-1, 3), x.view(-1, c), offset] | |
| # print(p.shape, o.shape) | |
| # print(o) | |
| knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) | |
| print(knn_idx.shape) | |
| count = 100 | |
| for _ in range(5): | |
| pointops.knn_query(knn_samples, p, o, p, o) | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| for i in range(count): | |
| pointops.knn_query(knn_samples, p, o, p, o) | |
| torch.cuda.synchronize() | |
| print(time.time() - start) | |
| def count_parameters(model): | |
| return sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| def test_mlp(): | |
| b, n, c = 2, 40240, 256 | |
| model = MLP(c).cuda() | |
| x = torch.randn(b, n, c).cuda() | |
| # model = SwiGLUFFN(c, c * 3).cuda() | |
| print('parameters:', count_parameters(model)) | |
| x = x.to(torch.bfloat16) | |
| model.to(dtype=torch.bfloat16) | |
| with torch.autocast('cuda', enabled=True, dtype=torch.bfloat16): | |
| y = model(x) | |
| print(y.shape) | |
| count = 100 | |
| for _ in range(5): | |
| model(x) | |
| torch.cuda.synchronize() | |
| start = time.time() | |
| for i in range(count): | |
| model(x) | |
| torch.cuda.synchronize() | |
| print(time.time() - start) | |
| def test_mv_block(): | |
| c = 256 | |
| num_heads = 4 | |
| model = MultiViewBlock(c, num_heads).cuda() | |
| x = torch.rand(2, 256, c).cuda() | |
| print(model) | |
| y = model(x) | |
| print(y.shape) | |
| def test_cross_attn(): | |
| c = 256 | |
| v, h, w = 8, 64, 128 | |
| num_heads = 4 | |
| model = GaussianErrorCrossAttn(512, c, c).cuda() | |
| x = torch.rand(2, v * h * w, 512).cuda() | |
| y = torch.rand(2, v * h * w, c).cuda() | |
| print(model) | |
| y = model(x, y, v=v, h=h, w=w) | |
| print(x.shape, y.shape) | |
| def test_grouping(): | |
| c = 256 | |
| # b, n = 2, 80480 | |
| b, n = 1, 57344 | |
| knn_samples = 16 | |
| x = torch.randn(b, n, c).cuda() | |
| offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) | |
| o = offset | |
| p = torch.randn(b, n, 3).cuda() | |
| p = p.view(-1, 3) | |
| knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) | |
| print(knn_idx.shape) | |
| c_qkv = 192 | |
| qkv = torch.randn(b*n, c_qkv).cuda() | |
| x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) | |
| x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] | |
| m, nsample, c = knn_idx.shape[0], knn_idx.shape[1], x_kv.shape[1] | |
| feat = torch.cat([x_kv, torch.zeros([1, c]).to(x_kv.device)], dim=0) | |
| T = 1000 | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| for _ in range(T): | |
| grouping(idx=knn_idx, feat=x_kv, xyz=p, new_xyz=p, with_xyz=False) | |
| # grouping_idx = feat[knn_idx.view(-1).long(), :].view( | |
| # m, nsample, c | |
| # ) # (m, num_sample, c) | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| # print(f"Grouping via indexing: {(end_time - start_time) / T * 1000:.2f} ms") | |
| print(f"grouping pytorch: {(end_time - start_time) / T * 1000:.2f} ms") | |
| torch.cuda.synchronize() | |
| start_time = time.time() | |
| for _ in range(T): | |
| grouping2(x_kv, knn_idx) | |
| # grouping_embed = torch.nn.functional.embedding(knn_idx, feat) # [m,num_sample,c] | |
| torch.cuda.synchronize() | |
| end_time = time.time() | |
| # print(f"Grouping via embedding: {(end_time - start_time) / T * 1000:.2f} ms") | |
| print(f"grouping cuda: {(end_time - start_time) / T * 1000:.2f} ms") | |
| if __name__ == '__main__': | |
| # test_fps() | |
| # test_knn() | |
| # test_mlp() | |
| # test_mv_block() | |
| # test_cross_attn() | |
| # test_faiss_knn() | |
| # test_knn_query_and_group() | |
| test_grouping() | |