import torch import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint import pointops from pointops import grouping, grouping2 from einops import rearrange import time from ..unimatch.dinov2.layers.block import Block as MultiViewBlock from ..unimatch.utils import mv_feature_add_position from ..unimatch.mv_transformer import MultiViewFeatureTransformer USE_PYTORCH_ATTN = False USE_FLASH_ATTN3 = False # try: # from flash_attn_interface import flash_attn_func # FA3_AVAILABLE = True # warnings.warn('flash attention 3 is available (point attn)') # except ImportError: # FA3_AVAILABLE = False # warnings.warn('flash attention 3 is not available (point attn)') class KNNAttention(nn.Module): # TODO: multi-head def __init__(self, channels, knn_samples=16, no_rpe=True, qk_norm=False, num_heads=1, proj_channels=None, use_fused=False, ): super().__init__() self.proj_channels = proj_channels self.knn_samples = knn_samples self.no_rpe = no_rpe self.num_heads = num_heads assert self.num_heads == 1 self.use_fused = use_fused if use_fused: try: import sys from optgs.paths import PROJECT_DIR sys.path.append(str(PROJECT_DIR / "submodules")) from fused_knn_attn import fused_knn_attention, FUSED_KNN_ATTN_CUDA_AVAILABLE self._fused_knn_attention = fused_knn_attention if not FUSED_KNN_ATTN_CUDA_AVAILABLE: import warnings warnings.warn( "Fused KNN attention CUDA extension not available, " "using PyTorch fallback (still avoids [N,K,C] intermediates)" ) except ImportError: import warnings warnings.warn( "fused_knn_attn package not found, falling back to unfused attention" ) self.use_fused = False self.qk_norm = qk_norm if qk_norm: self.q_norm = nn.RMSNorm(channels) self.k_norm = nn.RMSNorm(channels) if self.proj_channels is not None: self.qkv = nn.Linear(channels, self.proj_channels * 3, bias=False) self.proj = nn.Linear(self.proj_channels, channels) else: self.qkv = nn.Linear(channels, channels * 3, bias=False) self.proj = nn.Linear(channels, channels) if not self.no_rpe: self.rpe = nn.Sequential( nn.Linear(3, 32), nn.GELU(), nn.Linear(32, 1) ) def forward(self, pxo, knn_idx=None): # [N, 3], [N, C], [B] p, x, o = pxo c = x.size(1) if self.proj_channels is not None: c = self.proj_channels assert c % self.num_heads == 0 head_dim = c // self.num_heads scale_factor = head_dim ** -0.5 qkv = self.qkv(x) # [N, 3*C] x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) # each [N, C] # ---- Fused path: gather + attention in one kernel ---- if self.use_fused and self.no_rpe: # Ensure we have KNN indices if knn_idx is None: knn_idx, _ = pointops.knn_query( self.knn_samples, p, o, p, o ) # qk_norm: RMSNorm normalizes each C-dim vector independently, # so applying before gather is equivalent to applying after gather. if self.qk_norm: x_q = self.q_norm(x_q) x_k = self.k_norm(x_k) out = self._fused_knn_attention( x_q.contiguous(), x_k.contiguous(), x_v.contiguous(), knn_idx.contiguous(), scale_factor ) out = self.proj(out) return out # ---- Original unfused path ---- # # [N, K, C], [N, K] # x_k, idx = pointops.knn_query_and_group( # x_k.contiguous(), p, o, new_xyz=p, new_offset=o, # idx=knn_idx, # nsample=self.knn_samples, with_xyz=False # ) # [N, K, C] # # # [N, K, C] # x_v, _ = pointops.knn_query_and_group( # x_v.contiguous(), # p, # o, # new_xyz=p, # new_offset=o, # idx=idx, # nsample=self.knn_samples, # with_xyz=False, # ) # ---- Initial improved version ---- x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] x_kv_query, _ = pointops.knn_query_and_group( x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, idx=knn_idx, nsample=self.knn_samples, with_xyz=False ) # [N, K, 2C/3] x_k, x_v = torch.chunk(x_kv_query, chunks=2, dim=-1) # [N, K, 3], [N, K, C] # NOTE: without xyz in knn # p_r, x_k = x_k[:, :, :3], x_k[:, :, 3:] # [N, 1, K] assert self.no_rpe if not self.no_rpe: rpe = self.rpe(p_r).permute(0, 2, 1) else: rpe = 0 if self.qk_norm: x_q = self.q_norm(x_q) x_k = self.k_norm(x_k) n, k, c = x_k.shape # attention if USE_PYTORCH_ATTN: out = F.scaled_dot_product_attention( x_q.view(n, 1, c), x_k.view(n, k, c), x_v.view(n, k, c), ).reshape(n, c) # [N, C] elif (USE_FLASH_ATTN3 and FA3_AVAILABLE and self.no_rpe): # no relative pos enc out = flash_attn_func( x_q.view(n, 1, self.num_heads, head_dim).to(torch.bfloat16), x_k.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), x_v.view(n, k, self.num_heads, head_dim).to(torch.bfloat16), )[0].reshape(n, c).float() # [N, C] else: # [N, 1, K] scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor + rpe # [N, C] out = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) out = self.proj(out) return out class MLP(nn.Module): def __init__( self, channels, act="gelu", ): super().__init__() expansion = 4 self.fc1 = nn.Linear(channels, channels * expansion) if act is None or act in ['none', 'identity']: self.act = nn.Identity() elif act == 'gelu': self.act = nn.GELU() elif act == 'tanh': self.act = nn.Tanh() else: raise ValueError(f"unsupported activation {act}") self.fc2 = nn.Linear(channels * expansion, channels) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.fc2(x) return x class TransformerBlock(nn.Module): def __init__(self, channels, knn_samples=16, post_norm=False, no_rpe=False, no_attn=False, no_norm=False, act="gelu", qk_norm=False, norm_pt_block=False, num_heads=1, attn_proj_channels=None, use_fused_attn=False, ): super().__init__() self.post_norm = post_norm self.no_attn = no_attn self.norm_pt_block = norm_pt_block if no_norm: self.norm1 = nn.Identity() self.norm2 = nn.Identity() else: self.norm1 = nn.LayerNorm(channels) self.norm2 = nn.LayerNorm(channels) if self.no_attn: self.linear = nn.Linear(channels, channels) else: self.attn = KNNAttention(channels, knn_samples=knn_samples, no_rpe=no_rpe, qk_norm=qk_norm, num_heads=num_heads, proj_channels=attn_proj_channels, use_fused=use_fused_attn, ) self.mlp = MLP(channels, act=act) if self.norm_pt_block: self.norm3 = nn.LayerNorm(channels) def forward(self, pxo, knn_idx=None): p, x, o = pxo if self.post_norm: if self.no_attn: x = x + self.norm1(self.linear(x)) else: x = x + self.norm1(self.attn((p, x, o), knn_idx=knn_idx)) x = x + self.norm2(self.mlp(x)) else: if self.no_attn: x = x + self.linear(self.norm1(x)) else: x = x + self.attn((p, self.norm1(x), o), knn_idx=knn_idx) x = x + self.mlp(self.norm2(x)) if self.norm_pt_block: x = self.norm3(x) return x class FPSSubsample(nn.Module): def __init__(self, in_planes, out_planes, stride=4, nsample=16, agg_func='attn', subsample_method='fps', return_idx=False, fps_num_samples=None, attn_channels=64, ): super().__init__() assert stride > 0 self.agg_func = agg_func self.subsample_method = subsample_method self.knn_samples = nsample self.return_idx = return_idx self.stride, self.nsample = stride, nsample if fps_num_samples is not None: self.nsample = fps_num_samples # if stride != 1: # # xyz + feature # # self.linear = nn.Linear(3 + in_planes, out_planes, bias=not post_norm) # # only feature # # TODO: attention aggregation # if agg_func == 'maxpool': # self.agg = nn.MaxPool1d(nsample) # elif agg_func == 'avgpool': # self.agg = nn.AvgPool1d(nsample) # else: # raise ValueError(f"unsupported agg_func {agg_func}") # fewer channels to save memory assert agg_func in ['attn', 'avgpool'] if self.agg_func == 'attn': self.q = nn.Linear(in_planes, attn_channels, bias=False) self.k = nn.Linear(in_planes, attn_channels, bias=False) self.v = nn.Linear(in_planes, attn_channels, bias=False) self.proj = nn.Linear(attn_channels, out_planes, bias=True) self.residual = nn.Linear(in_planes, out_planes, bias=True) else: self.proj = nn.Linear(in_planes, out_planes, bias=True) def forward(self, pxo): p, x, o = pxo # (n, 3), (n, c), (b) if self.stride != 1: if self.subsample_method == 'density': assert False # not well tested n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride for i in range(1, o.shape[0]): count += (o[i].item() - o[i - 1].item()) // self.stride n_o.append(count) n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) # [N, K, C+3] x_k, _ = pointops.knn_query_and_group( x.contiguous(), p, o, new_xyz=p, new_offset=o, nsample=self.knn_samples, with_xyz=True ) p_r = x_k[:, :, 0:3] density = torch.mean(torch.norm(p_r, dim=-1), dim=-1) # [N] # TODO: normalize the distance weights = (density - density.min()) / (density.max() - density.min() + 1e-6) # weights = density # weights = 1.0 / (density + 1e-6) # Inverse density weighting # to batch lists = [weights[:o[0]]] for i in range(o.shape[0] - 1): lists.append(weights[o[i]:o[i+1]]) weights = torch.stack(lists, dim=0) # [B, N] weights = weights / weights.sum(dim=1, keepdim=True) # Normalize weights # Sample points based on weights batch = n_o.shape[0] num_samples = o[0].item() // self.stride sampled_indices = torch.stack([ torch.multinomial(weights[b], num_samples, replacement=False) for b in range(batch) ], dim=0) # (B, num_samples) idx = rearrange(sampled_indices, "b n -> (b n)") point_list = [p[:o[0]]] for i in range(o.shape[0] - 1): point_list.append(p[o[i]:o[i+1], :]) points = torch.stack(point_list, dim=0) # [B, N, 3] # Gather sampled points sampled_points = torch.gather(points, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, 3)) # print(sampled_points.shape) # [B, M, 3] sampled_points = rearrange(sampled_points, "b m c -> (b m) c") # average pooling # TODO: try others x = x_k.mean(dim=1) # [N, C] x_list = [x[:o[0]]] for i in range(o.shape[0] - 1): x_list.append(x[o[i]:o[i+1], :]) x = torch.stack(x_list, dim=0) # [B, N, C] # Gather sampled points x = torch.gather(x, 1, sampled_indices.unsqueeze(-1).expand(-1, -1, x.size(-1))) x = rearrange(x, "b n c -> (b n) c") # TODO: do we need to add residual to x here? # use the index to subsample the initial features x = self.proj(x) p, o = sampled_points, n_o elif self.subsample_method in ['fps', 'grid']: n_o, count = [o[0].item() // self.stride], o[0].item() // self.stride for i in range(1, o.shape[0]): count += (o[i].item() - o[i - 1].item()) // self.stride n_o.append(count) n_o = torch.tensor(n_o, dtype=torch.int32, device=x.device) if self.subsample_method == 'fps': idx = pointops.farthest_point_sampling(p, o, n_o) # (m) else: # uniform sampling: sanity check # first reshape to V, H, W, then do grid sampling # Generate grid indices # TODO: grid sample in the image space idx = torch.arange(0, p.size(0), self.stride).to(x.device) n_p = p[idx.long(), :] # (m, 3) x_subsample = x[idx.long(), :] # [M, C] if self.agg_func == 'attn': x_q = self.q(x_subsample) # [M, C] # [M, K, C] x_k = self.k(x) # [N, C] else: x_k = x x_k, knn_idx = pointops.knn_query_and_group( x_k, p, offset=o, new_xyz=n_p, new_offset=n_o, nsample=self.nsample, with_xyz=False, # remove xyz ) if self.agg_func == 'attn': x_v = self.v(x) x_v, _ = pointops.knn_query_and_group( x_v, p, offset=o, new_xyz=n_p, new_offset=n_o, idx=knn_idx, nsample=self.nsample, with_xyz=False, # remove xyz ) # attention # x_q: [M, C], x_k: [M, K, C], x_v: [M, K, C] scale_factor = x_q.shape[-1] ** -0.5 # [M, 1, K] # no relative posenc scores = torch.matmul(x_q.unsqueeze(1), x_k.permute(0, 2, 1)) * scale_factor # [M, C] x = torch.matmul(torch.softmax(scores, dim=2), x_v).squeeze(1) # if self.agg_func in ['avgpool', 'maxpool']: # x = self.agg(x.transpose(1, 2).contiguous()).squeeze(-1) # (m, c) # else: # raise NotImplementedError # add residual to x here? # use the index to subsample the initial features x = self.residual(x_subsample) + self.proj(x) else: x = x_k.mean(dim=1) x = self.proj(x) p, o = n_p, n_o else: raise ValueError(f"unsupported subsampling method {self.subsample_method}") else: # add residual to x here? x = x + self.proj(x) idx = torch.arange(0, p.size(0)).to(x.device) if self.return_idx: return [p, x, o], idx return [p, x, o] class SubsampleBlock(nn.Module): def __init__(self, in_channels, out_channels, stride=4, knn_samples=16, post_norm=False, agg_func='attn', subsample_method='fps', return_idx=False, fps_num_samples=None, attn_proj_channels=None, ): super().__init__() assert not post_norm self.return_idx = return_idx self.post_norm = post_norm self.norm1 = nn.LayerNorm(in_channels) self.fps = FPSSubsample(in_channels, out_channels, stride=stride, nsample=knn_samples, agg_func=agg_func, subsample_method=subsample_method, return_idx=return_idx, fps_num_samples=fps_num_samples, attn_channels=attn_proj_channels, ) self.norm2 = nn.LayerNorm(out_channels) self.mlp = MLP(out_channels) def forward(self, pxo): # pre norm p, x, o = pxo x = self.norm1(x) if self.return_idx: pxo, idx = self.fps([p, x, o]) else: pxo = self.fps([p, x, o]) p, x, o = pxo x = x + self.mlp(self.norm2(x)) if self.return_idx: return [p, x, o], idx return [p, x, o] class SkipConnect(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.proj1 = nn.Linear(out_channels, out_channels) self.proj2 = nn.Linear(in_channels, out_channels) self.proj3 = nn.Linear(out_channels, out_channels) def forward(self, pxo1, pxo2): p1, x1, o1 = pxo1 p2, x2, o2 = pxo2 # TODO: support half precision with torch.amp.autocast(device_type='cuda', enabled=True, dtype=torch.float32): x = self.proj1(x1) + pointops.interpolation2( p2, p1, self.proj2(x2), o2, o1 ) x = self.proj3(x) return x class PlainPointTransformer(nn.Module): def __init__(self, channels, knn_samples=16, num_blocks=4, post_norm=False, no_rpe=False, no_attn=False, no_norm=False, act="gelu", qk_norm=False, norm_pt_block=False, num_heads=1, attn_proj_channels=None, cache_knn_idx=None, knn_idx_update_every=1, with_mv_attn=False, with_mv_attn_lowres=False, mv_attn_first=False, no_mv_attn=False, conv_with_norm=False, mv_shuffle_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, mv_unimatch_attn=False, use_checkpointing=False, init_use_checkpointing=False, use_fused_attn=False, ): super().__init__() self.cache_knn_idx = cache_knn_idx self.knn_idx_update_every = knn_idx_update_every self.knn_samples = knn_samples self.use_checkpointing = use_checkpointing self.init_use_checkpointing = init_use_checkpointing self.with_mv_attn = with_mv_attn self.with_mv_attn_lowres = with_mv_attn_lowres if with_pos_enc: assert mv_shuffle_attn self.blocks = nn.ModuleList() for _ in range(num_blocks): self.blocks.append(TransformerBlock(channels, knn_samples=knn_samples, post_norm=post_norm, no_rpe=no_rpe, no_attn=no_attn, no_norm=no_norm, act=act, qk_norm=qk_norm, norm_pt_block=norm_pt_block, num_heads=num_heads, attn_proj_channels=attn_proj_channels, use_fused_attn=use_fused_attn, )) # multi-view attention if self.with_mv_attn: self.mv_blocks = nn.ModuleList() for _ in range(num_blocks): # if mv_shuffle_attn: if self.with_mv_attn_lowres: self.mv_blocks.append( MultViewLowresAttn( channels, ) ) else: self.mv_blocks.append( MultiViewBlock( channels, num_heads=4, ) ) # elif mv_unimatch_attn: # self.mv_blocks.append( # MultViewUniMatchAttn( # channels, # ) # ) # else: # self.mv_blocks.append( # MultViewUnetAttn(channels, # no_mv_attn=no_mv_attn, # conv_with_norm=conv_with_norm, # ) # ) def forward(self, pxo, iter=0, b=None, v=None, h=None, w=None): p, x, o = pxo # compute knn idx here only once and pass it to the model # the positions are not changed inside the blocks if self.cache_knn_idx is None or (iter % self.knn_idx_update_every) == 0: knn_idx, _ = pointops.knn_query(self.knn_samples, p, o, p, o) self.cache_knn_idx = knn_idx # print(knn_idx.float().mean().item()) else: knn_idx = self.cache_knn_idx if self.with_mv_attn: assert b is not None and v is not None and h is not None and w is not None if self.use_checkpointing: raise NotImplementedError for i in range(len(self.blocks)): # knn attention x = self.blocks[i]([p, x, o], knn_idx=knn_idx) # global multi-view attention x = rearrange(x, "(b v h w) c -> b (v h w) c", b=b, v=v, h=h, w=w) if self.with_mv_attn_lowres: x = self.mv_blocks[i](x, v=v, h=h, w=w) # # TODO: hard-coded for now # if x.size(1) == 8 * 512 // 4 * 960 // 4: # x = self.mv_blocks[i](x, v=8, h=512 // 4, w=960 // 4) # elif x.size(1) == 8 * 256 // 4 * 448 // 4: # x = self.mv_blocks[i](x, v=8, h=256 // 4, w=448 // 4) # else: # raise ValueError(f"unsupported input size {x.size(1)} for multi-view attention") # # print(x.shape) else: x = self.mv_blocks[i](x) # x = x.squeeze(0) x = rearrange(x, "b (v h w) c -> (b v h w) c", b=b, v=v, h=h, w=w) else: for blk in self.blocks: if self.init_use_checkpointing: # checkpointing the inital reconstruction model # NOTE: cannot cache knn_idx here, otherwise index out error def custom_forward(p, x, o): return blk((p, x, o), knn_idx=None) # knn_idx is closed over x = torch.utils.checkpoint.checkpoint(custom_forward, p, x, o) else: x = blk((p, x, o), knn_idx=knn_idx) return x class MultViewUnetAttn(nn.Module): def __init__(self, channels, no_mv_attn=False, conv_with_norm=False): super().__init__() self.conv_with_norm = conv_with_norm self.down1 = nn.Conv2d(channels, channels, 3, 2, 1) self.down2 = nn.Conv2d(channels, channels, 3, 2, 1) self.up2 = nn.Conv2d(channels, channels, 3, 1, 1) self.up1 = nn.Conv2d(channels, channels, 3, 1, 1) self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) if self.conv_with_norm: self.norm1 = nn.LayerNorm(channels) self.norm2 = nn.LayerNorm(channels) self.norm3 = nn.LayerNorm(channels) self.norm4 = nn.LayerNorm(channels) def forward(self, x): v = 8 h = 256 // 4 w = 448 // 4 b = 1 assert x.size(0) == b * v * h * w residual = x x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) x1 = self.down1(x) # 1/2 if self.conv_with_norm: x1 = self.norm1(x1.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x2 = self.down2(x1) # 1/4 if self.conv_with_norm: x2 = self.norm2(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x2 = rearrange(x2, "(b v) c h w -> b (v h w) c", b=b, v=v) x2 = self.attn(x2) # 1/4 x2 = rearrange(x2, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h//4, w=w//4) x2 = self.up2(x1 + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1/2 if self.conv_with_norm: x2 = self.norm3(x2.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.up1(x + F.interpolate(x2, scale_factor=2, mode='bilinear', align_corners=True)) # 1 if self.conv_with_norm: x = self.norm4(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) x = residual + x return x class MultViewShuffleAttn(nn.Module): def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): super().__init__() self.down_factor = 4 self.with_pos_enc = with_pos_enc self.proj1 = nn.Linear(channels * self.down_factor ** 2, channels) if shuffle_attn_no_norm: self.norm1 = nn.Identity() else: self.norm1 = nn.LayerNorm(channels) self.proj2 = nn.Linear(channels, channels * self.down_factor ** 2) if shuffle_attn_no_norm: self.norm2 = nn.Identity() else: self.norm2 = nn.LayerNorm(channels * self.down_factor ** 2) self.conv = nn.Conv2d(channels, channels, 3, 1, 1) if no_mv_attn: self.attn = nn.Identity() else: self.attn = MultiViewBlock(channels, 4, no_attn=no_mv_attn) def forward(self, x): v = 8 h = 256 // 4 w = 448 // 4 b = 1 assert x.size(0) == b * v * h * w residual = x x = rearrange(x, "(b v h w) c -> (b v) c h w", b=b, v=v, h=h, w=w) # TODO: add positional encoding to x if self.with_pos_enc: x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) # print(x.shape) x = F.pixel_unshuffle(x, self.down_factor) x = rearrange(x, "(b v) c h w -> b (v h w) c", b=b) x = self.proj1(x) x = self.norm1(x) x = self.attn(x) x = self.proj2(x) x = self.norm2(x) x = rearrange(x, "b (v h w) c -> (b v) c h w", b=b, v=v, h=h // self.down_factor, w=w // self.down_factor) x = F.pixel_shuffle(x, self.down_factor) x = self.conv(x) x = rearrange(x, "(b v) c h w -> (b v h w) c", b=b, v=v) x = x + residual return x class MultViewLowresAttn(nn.Module): def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, down_factor=4, attn_proj_channels=None, ): super().__init__() self.down_factor = down_factor self.with_pos_enc = with_pos_enc self.attn_proj_channels = attn_proj_channels if attn_proj_channels: ori_channels = channels self.proj0 = nn.Linear(channels, attn_proj_channels) channels = attn_proj_channels if self.down_factor == 8: down_factor = 4 else: down_factor = self.down_factor self.proj1 = nn.Linear(channels * down_factor ** 2, channels) if shuffle_attn_no_norm: self.norm1 = nn.Identity() else: self.norm1 = nn.LayerNorm(channels) self.proj2 = nn.Linear(channels, channels * down_factor ** 2) if shuffle_attn_no_norm: self.norm2 = nn.Identity() else: self.norm2 = nn.LayerNorm(channels * down_factor ** 2) self.conv = nn.Conv2d(channels, channels, 3, 1, 1) if attn_proj_channels: self.proj3 = nn.Linear(channels, ori_channels) if no_mv_attn: self.attn = nn.Identity() else: num_heads = 1 if self.attn_proj_channels else 4 self.attn = MultiViewBlock(channels, num_heads, no_attn=no_mv_attn) def forward(self, x, v=None, h=None, w=None, y=None): if y is not None: return self.forward_cross_attn(x, y, v, h, w) residual = x if self.attn_proj_channels: x = self.proj0(x) x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) # TODO: add positional encoding to x if self.with_pos_enc: x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) # print(x.shape) if self.down_factor == 8: # bilinear to half first to save channels x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) down_factor = 4 else: down_factor = self.down_factor x = F.pixel_unshuffle(x, down_factor) x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) x = self.proj1(x) x = self.norm1(x) x = self.attn(x) x = self.proj2(x) x = self.norm2(x) x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) x = F.pixel_shuffle(x, down_factor) x = self.conv(x) if self.down_factor == 8: # bilinear to full x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) if self.attn_proj_channels: x = self.proj3(x) x = x + residual return x def forward_cross_attn(self, x, y, v=None, h=None, w=None): residual = x if self.attn_proj_channels: x = self.proj0(x) assert y is not None y = rearrange(y, "b (v h w) c -> (b v) c h w", h=h, w=w) # different v with x num_cross_view = y.shape[0] // x.shape[0] x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) # TODO: add positional encoding to x if self.with_pos_enc: x = mv_feature_add_position(x, attn_splits=1, feature_channels=x.size(1)) # print(x.shape) if self.down_factor == 8: # bilinear to half first to save channels x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=True) y = F.interpolate(y, scale_factor=0.5, mode='bilinear', align_corners=True) down_factor = 4 else: down_factor = self.down_factor x = F.pixel_unshuffle(x, down_factor) y = F.pixel_unshuffle(y, down_factor) x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) y = rearrange(y, "(b v) c h w -> b (v h w) c", v=num_cross_view) x = self.proj1(x) x = self.norm1(x) y = self.proj1(y) y = self.norm1(y) # x_tmp = self.attn(x) # print((x - y).abs().max().item()) x = self.attn(x, y) # there will be slight diff for self and cross attn caused by flash3 # print((x_tmp - x).abs().max().item()) x = self.proj2(x) x = self.norm2(x) x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h // self.down_factor, w=w // self.down_factor) x = F.pixel_shuffle(x, down_factor) x = self.conv(x) if self.down_factor == 8: # bilinear to full x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=True) x = rearrange(x, "(b v) c h w -> b (v h w) c", v=v) if self.attn_proj_channels: x = self.proj3(x) x = x + residual return x class GaussianErrorCrossAttn(nn.Module): def __init__(self, gaussian_channels, error_channels, model_channels=256, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False, down_factor=4, attn_proj_channels=None, num_heads=4, with_mlp=False, ): super().__init__() self.num_heads = num_heads self.model_channels = model_channels self.down_factor = down_factor self.with_mlp = with_mlp # self.q_norm = nn.LayerNorm(gaussian_channels) self.q_proj = nn.Linear(gaussian_channels, model_channels) kv_channels = error_channels * (down_factor ** 2) # self.kv_norm = nn.LayerNorm(kv_channels) self.kv_proj = nn.Linear(kv_channels, 2 * model_channels) # self.out_proj = nn.Linear(model_channels, gaussian_channels) # concat self.out_proj = nn.Linear(model_channels + gaussian_channels, gaussian_channels) if with_mlp: self.mlp_norm = nn.LayerNorm(gaussian_channels) self.mlp = MLP(gaussian_channels) def forward(self, gaussian, error, v=None, h=None, w=None, mask=None): # [B, VHW, C] residual = gaussian b = gaussian.size(0) # x = self.q_norm(gaussian) x = gaussian q = self.q_proj(x) # [B, VHW, C] # spatial reshape to save computation error = rearrange(error, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) error = F.pixel_unshuffle(error, self.down_factor) error = rearrange(error, "(b v) c h w -> b (v h w) c", v=v) # error = self.kv_norm(error) kv = self.kv_proj(error) k, v = kv.chunk(2, dim=-1) # [B, VHW, C] # attention c = self.model_channels head_dim = c // self.num_heads # [B, N, C] → [B, num_heads, N, head_dim] def reshape(x): return x.view(b, -1, self.num_heads, head_dim).transpose(1, 2) # [B, H, N, D] q = reshape(q) k = reshape(k) v = reshape(v) # Fast fused attention out = F.scaled_dot_product_attention(q, k, v) # [B, H, N, D] → [B, N, C] out = out.transpose(1, 2).contiguous().view(b, -1, c) # return self.out_proj(out) # out = residual + self.out_proj(out) # concat out = self.out_proj(torch.cat([out, gaussian], dim=-1)) # if self.with_mlp: # out = out + self.mlp(self.mlp_norm(out)) return out class MultViewUniMatchAttn(nn.Module): def __init__(self, channels, no_mv_attn=False, with_pos_enc=False, shuffle_attn_no_norm=False): super().__init__() self.attn = MultiViewFeatureTransformer(num_layers=1, d_model=channels, ) def forward(self, x, v=None, h=None, w=None): residual = x x = rearrange(x, "b (v h w) c -> (b v) c h w", v=v, h=h, w=w) attn_splits = 4 # add pos enc x = mv_feature_add_position(x, attn_splits, feature_channels=x.size(1)) x = rearrange(x, "(b v) c h w -> b v c h w", v=v) x_list = list(torch.unbind(x, dim=1)) x_list = self.attn(x_list, attn_splits) x = torch.stack(x_list, dim=1) x = rearrange(x, "b v c h w -> b (v h w) c") return x class MultiScalePointTransformer(nn.Module): def __init__(self, channels, knn_samples=16, post_norm=False, no_rpe=True, no_attn=False, qk_norm=False, norm_pt_block=False, num_heads=1, num_scales=3, stride=4, downsample_agg_func='attn', subsample_method='fps', fps_num_samples=None, attn_proj_channels=None, ): super().__init__() self.blocks = nn.ModuleList() # knn 4 at 1 self.blocks.append(TransformerBlock(channels, knn_samples=4, post_norm=post_norm, no_rpe=no_rpe, no_attn=no_attn, qk_norm=qk_norm, norm_pt_block=norm_pt_block, num_heads=num_heads, attn_proj_channels=attn_proj_channels, )) for i in range(num_scales - 2, -1, -1): # knn 8 at 1/4 # knn 16 at 1/16 self.blocks.append(TransformerBlock(channels * (2 ** i), knn_samples= knn_samples // (2 ** i), post_norm=post_norm, no_rpe=no_rpe, no_attn=no_attn, qk_norm=qk_norm, norm_pt_block=norm_pt_block, num_heads=num_heads, attn_proj_channels=attn_proj_channels, )) self.down_blocks = nn.ModuleList() for i in range(num_scales - 1): self.down_blocks.append( SubsampleBlock( channels * (2 ** i), channels * (2 ** (i + 1)), stride=stride, knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), subsample_method=subsample_method, agg_func=downsample_agg_func, fps_num_samples=fps_num_samples, attn_proj_channels=attn_proj_channels, ) ) self.down_agg = nn.ModuleList() for i in range(num_scales - 1): self.down_agg.append( TransformerBlock(channels * (2 ** (i + 1)), knn_samples=knn_samples // (2 ** (num_scales - 1 - i)), post_norm=post_norm, no_rpe=no_rpe, no_attn=no_attn, qk_norm=qk_norm, norm_pt_block=norm_pt_block, num_heads=num_heads, attn_proj_channels=attn_proj_channels, ) ) self.skip_blocks = nn.ModuleList() for i in range(num_scales - 1, 0, -1): self.skip_blocks.append( SkipConnect( channels * (2 ** i), channels * (2 ** (i - 1)) ) ) def forward(self, pxo): x1 = self.blocks[0](pxo) # 1 p1, o1 = pxo[0], pxo[2] p2, x2, o2 = self.down_blocks[0]([p1, x1, o1]) # 1/4 x2 = self.down_agg[0]([p2, x2, o2]) # 1/4 p3, x3, o3 = self.down_blocks[1]([p2, x2, o2]) # 1/16 x3 = self.down_agg[1]([p3, x3, o3]) # 1/16 x4 = self.skip_blocks[0]([p2, x2, o2], [p3, x3, o3]) # 1/4 p4, o4 = p2, o2 x4 = self.blocks[1]([p4, x4, o4]) x5 = self.skip_blocks[1]([p1, x1, o1], [p4, x4, o4]) # 1 p5, o5 = p1, o1 x5 = self.blocks[2]([p5, x5, o5]) return x5 class PointLinearWrapper(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.linear = nn.Linear(in_channels, out_channels) def forward(self, pxo, b=None, v=None, h=None, w=None): p, x, o = pxo x = self.linear(x) return [p, x, o] class SwiGLUFFN(nn.Module): def __init__( self, in_features: int, hidden_features: int | None = None, out_features: int | None = None, bias: bool = True, ) -> None: super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) self.w3 = nn.Linear(hidden_features, out_features, bias=bias) def forward(self, x): x12 = self.w12(x) x1, x2 = x12.chunk(2, dim=-1) hidden = F.silu(x1) * x2 return self.w3(hidden) def test_fps(): model = FPSSubsample(256, 256, fps_num_samples=16, subsample_method='fps', ).cuda() print(model) # FPS is significantly slower than grid with many points c = 256 b, n = 2, 40480 x = torch.randn(b, n, c).cuda() offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) p = torch.randn(b, n, 3).cuda() pxo = [p.view(-1, 3), x.view(-1, c), offset] y = model(pxo) print(y[1].shape) count = 100 for _ in range(5): model(pxo) torch.cuda.synchronize() start = time.time() for i in range(count): model(pxo) torch.cuda.synchronize() print(time.time() - start) def test_knn_query_and_group(): c = 256 # b, n = 2, 80480 b, n = 8, 57344 knn_samples = 16 x = torch.randn(b, n, c).cuda() offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) o = offset p = torch.randn(b, n, 3).cuda() p = p.view(-1, 3) knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) print(knn_idx.shape) c_qkv = 192 qkv = torch.randn(b*n, c_qkv).cuda() T = 1000 # chunk first, then query twice torch.cuda.synchronize() start_time = time.time() for _ in range(T): x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) x_k_query, idx = pointops.knn_query_and_group( x_k.contiguous(), p, o, new_xyz=p, new_offset=o, idx=knn_idx, nsample=knn_samples, with_xyz=False ) # [N, K, C/3] x_v_query, _ = pointops.knn_query_and_group( x_v.contiguous(), p, o, new_xyz=p, new_offset=o, idx=idx, nsample=knn_samples, with_xyz=False, ) torch.cuda.synchronize() end_time = time.time() print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") # query first, then chunk torch.cuda.synchronize() start_time = time.time() for _ in range(T): x_qkv_query = pointops.knn_query_and_group( qkv.contiguous(), p, o, new_xyz=p, new_offset=o, idx=knn_idx, nsample=knn_samples, with_xyz=False )[0] # [N, K, C*3] x_q, x_k, x_v = torch.chunk(x_qkv_query, chunks=3, dim=-1) torch.cuda.synchronize() end_time = time.time() print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") # chunk first, then query once torch.cuda.synchronize() start_time = time.time() for _ in range(T): x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] x_kv_query = pointops.knn_query_and_group( x_kv.contiguous(), p, o, new_xyz=p, new_offset=o, idx=knn_idx, nsample=knn_samples, with_xyz=False )[0] # [N, K, 2C/3] x_k_query, x_v_query = torch.chunk(x_kv_query, 2, dim=-1) torch.cuda.synchronize() end_time = time.time() print(f"KNN query and group time: {(end_time - start_time) / T * 1000:.2f} ms") def test_knn(): c = 256 b, n = 2, 80480 model = KNNAttention(channels=c, # proj_feature=64, ).cuda() print(model) x = torch.randn(b, n, c).cuda() offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) p = torch.randn(b, n, 3).cuda() pxo = [p.view(-1, 3), x.view(-1, c), offset] y = model(pxo) print(y.shape) count = 100 for _ in range(5): model(pxo) torch.cuda.synchronize() start = time.time() for i in range(count): model(pxo) torch.cuda.synchronize() print(time.time() - start) def test_faiss_knn(): # cannot install faiss unfortunately # TODO: maybe implement a sliding window knn search later c = 256 b, n = 2, 80480 knn_samples = 16 x = torch.randn(b, n, c).cuda() offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) o = offset p = torch.randn(b, n, 3).cuda() p = p.view(-1, 3) # pxo = [p.view(-1, 3), x.view(-1, c), offset] # print(p.shape, o.shape) # print(o) knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) print(knn_idx.shape) count = 100 for _ in range(5): pointops.knn_query(knn_samples, p, o, p, o) torch.cuda.synchronize() start = time.time() for i in range(count): pointops.knn_query(knn_samples, p, o, p, o) torch.cuda.synchronize() print(time.time() - start) def count_parameters(model): return sum(p.numel() for p in model.parameters() if p.requires_grad) def test_mlp(): b, n, c = 2, 40240, 256 model = MLP(c).cuda() x = torch.randn(b, n, c).cuda() # model = SwiGLUFFN(c, c * 3).cuda() print('parameters:', count_parameters(model)) x = x.to(torch.bfloat16) model.to(dtype=torch.bfloat16) with torch.autocast('cuda', enabled=True, dtype=torch.bfloat16): y = model(x) print(y.shape) count = 100 for _ in range(5): model(x) torch.cuda.synchronize() start = time.time() for i in range(count): model(x) torch.cuda.synchronize() print(time.time() - start) def test_mv_block(): c = 256 num_heads = 4 model = MultiViewBlock(c, num_heads).cuda() x = torch.rand(2, 256, c).cuda() print(model) y = model(x) print(y.shape) def test_cross_attn(): c = 256 v, h, w = 8, 64, 128 num_heads = 4 model = GaussianErrorCrossAttn(512, c, c).cuda() x = torch.rand(2, v * h * w, 512).cuda() y = torch.rand(2, v * h * w, c).cuda() print(model) y = model(x, y, v=v, h=h, w=w) print(x.shape, y.shape) def test_grouping(): c = 256 # b, n = 2, 80480 b, n = 1, 57344 knn_samples = 16 x = torch.randn(b, n, c).cuda() offset = torch.tensor([n * (i + 1) for i in range(b)]).to(x.device) o = offset p = torch.randn(b, n, 3).cuda() p = p.view(-1, 3) knn_idx, _ = pointops.knn_query(knn_samples, p, o, p, o) print(knn_idx.shape) c_qkv = 192 qkv = torch.randn(b*n, c_qkv).cuda() x_q, x_k, x_v = torch.chunk(qkv, chunks=3, dim=-1) x_kv = torch.cat([x_k, x_v], dim=-1) # [N, 2C/3] m, nsample, c = knn_idx.shape[0], knn_idx.shape[1], x_kv.shape[1] feat = torch.cat([x_kv, torch.zeros([1, c]).to(x_kv.device)], dim=0) T = 1000 torch.cuda.synchronize() start_time = time.time() for _ in range(T): grouping(idx=knn_idx, feat=x_kv, xyz=p, new_xyz=p, with_xyz=False) # grouping_idx = feat[knn_idx.view(-1).long(), :].view( # m, nsample, c # ) # (m, num_sample, c) torch.cuda.synchronize() end_time = time.time() # print(f"Grouping via indexing: {(end_time - start_time) / T * 1000:.2f} ms") print(f"grouping pytorch: {(end_time - start_time) / T * 1000:.2f} ms") torch.cuda.synchronize() start_time = time.time() for _ in range(T): grouping2(x_kv, knn_idx) # grouping_embed = torch.nn.functional.embedding(knn_idx, feat) # [m,num_sample,c] torch.cuda.synchronize() end_time = time.time() # print(f"Grouping via embedding: {(end_time - start_time) / T * 1000:.2f} ms") print(f"grouping cuda: {(end_time - start_time) / T * 1000:.2f} ms") if __name__ == '__main__': # test_fps() # test_knn() # test_mlp() # test_mv_block() # test_cross_attn() # test_faiss_knn() # test_knn_query_and_group() test_grouping()