# MIT License # Copyright (c) Microsoft Corporation. # Copyright (c) 2025 VAST-AI-Research and contributors. # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE from typing import * import torch import torch.nn as nn from . import SparseTensor import torch.nn.functional as F import spconv.pytorch as spconv from typing import Optional from query_point import PE_NeRF from ...modules.sparse.transformer.blocks import SparseTransformerCrossBlock __all__ = [ 'SparseDownsample', 'SparseUpsample', 'SparseSubdivide', 'SparseSubdivide_attn' ] class SparseDownsample(nn.Module): """ Downsample a sparse tensor by a factor of `factor`. Implemented as average pooling. """ def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): super(SparseDownsample, self).__init__() self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' coord = list(input.coords.unbind(dim=-1)) for i, f in enumerate(factor): coord[i+1] = coord[i+1] // f MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] code = sum([c * o for c, o in zip(coord, OFFSET)]) code, idx = code.unique(return_inverse=True) new_feats = torch.scatter_reduce( torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), dim=0, index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), src=input.feats, # reduce='mean' reduce='amax', ) new_coords = torch.stack( [code // OFFSET[0]] + [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], dim=-1 ) out = SparseTensor(new_feats, new_coords, input.shape,) out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) out._spatial_cache = input._spatial_cache out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) out.register_spatial_cache(f'upsample_{factor}_idx', idx) return out # class SparseDownsample(nn.Module): # """ # Downsample a sparse tensor by a factor of `factor`. # Implemented as average pooling. # """ # def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): # super(SparseDownsample, self).__init__() # self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor # def forward(self, input: SparseTensor) -> SparseTensor: # DIM = input.coords.shape[-1] - 1 # factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM # assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' # coord = list(input.coords.unbind(dim=-1)) # for i, f in enumerate(factor): # coord[i+1] = coord[i+1] // f # MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] # OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] # code = sum([c * o for c, o in zip(coord, OFFSET)]) # code, idx = code.unique(return_inverse=True) # new_feats = torch.scatter_reduce( # torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), # dim=0, # index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), # src=input.feats, # reduce='mean' # ) # new_coords = torch.stack( # [code // OFFSET[0]] + # [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], # dim=-1 # ) # out = SparseTensor(new_feats, new_coords, input.shape,) # out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) # out._spatial_cache = input._spatial_cache # if out.get_spatial_cache(f'upsample_{factor}_coords') is not None: # out.register_spatial_cache(f'upsample_{factor}_coords', [*out.get_spatial_cache(f'upsample_{factor}_coords'), input.coords]) # out.register_spatial_cache(f'upsample_{factor}_layout', [*out.get_spatial_cache(f'upsample_{factor}_layout'), input.layout]) # out.register_spatial_cache(f'upsample_{factor}_idx', [*out.get_spatial_cache(f'upsample_{factor}_idx'), idx]) # else: # out.register_spatial_cache(f'upsample_{factor}_coords', [input.coords]) # out.register_spatial_cache(f'upsample_{factor}_layout', [input.layout]) # out.register_spatial_cache(f'upsample_{factor}_idx', [idx]) # return out class SparseUpsample(nn.Module): """ Upsample a sparse tensor by a factor of `factor`. Implemented as nearest neighbor interpolation. """ def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): super(SparseUpsample, self).__init__() self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') idx = input.get_spatial_cache(f'upsample_{factor}_idx') # print(len(new_coords)) new_coords = new_coords.pop(-1) new_layout = new_layout.pop(-1) idx = idx.pop(-1) if any([x is None for x in [new_coords, new_layout, idx]]): raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') new_feats = input.feats[idx] out = SparseTensor(new_feats, new_coords, input.shape, new_layout) out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) out._spatial_cache = input._spatial_cache return out class SparseSubdivide(nn.Module): """ Upsample a sparse tensor by a factor of `factor`. Implemented as nearest neighbor interpolation. """ def __init__(self): super(SparseSubdivide, self).__init__() def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 # upsample scale=2^DIM n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) n_coords = torch.nonzero(n_cube) n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) factor = n_coords.shape[0] assert factor == 2 ** DIM # print(n_coords.shape) new_coords = input.coords.clone() new_coords[:, 1:] *= 2 new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) out._scale = input._scale * 2 out._spatial_cache = input._spatial_cache return out #################### new ######################## # 730 add ca, class SparseSubdivide_attn(nn.Module): """ Attention-based upsampling: Compute child voxel features with multi-head attention Enhanced with residual connections, layer normalization, and position encoding Improvements to overcome training plateau: 1. Position encoding using relative offsets instead of indices 2. Feature enhancement before attention 3. Output normalization and projection 4. Careful residual connections """ def __init__(self, in_channels: int, num_heads: int = 4,): super().__init__() assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" self.in_channels = in_channels self.head_dim = in_channels // num_heads self.num_heads = num_heads self.scale = self.head_dim ** -0.5 # Enhanced position encoding (continuous offsets instead of discrete indices) # self.pos_embed = nn.Sequential( # nn.Linear(3, 64), # Process actual spatial offsets # nn.LayerNorm(64), # nn.GELU(), # nn.Linear(64, in_channels) # Map to feature dimension # ) self.pos_embed = nn.Sequential( PE_NeRF(out_channels=in_channels, multires=10), # Process actual spatial offsets nn.LayerNorm(in_channels * 3), nn.GELU(), nn.Linear(in_channels * 3, in_channels) # Map to feature dimension ) # Feature enhancement before attention self.feat_enhance = nn.Sequential( nn.Linear(in_channels, in_channels * 2), nn.LayerNorm(in_channels * 2), nn.GELU(), nn.Linear(in_channels * 2, in_channels), nn.LayerNorm(in_channels) ) # Attention projections self.q_proj = nn.Linear(in_channels, in_channels) # Query from position self.k_proj = nn.Linear(in_channels, in_channels) # Key from content self.v_proj = nn.Linear(in_channels, in_channels) # Value from content # Output processing with residual self.output_norm = nn.LayerNorm(in_channels) self.output_proj = nn.Sequential( nn.Linear(in_channels, in_channels * 2), nn.GELU(), nn.Linear(in_channels * 2, in_channels) ) # Initialize for stable training self._initialize_weights() def _initialize_weights(self): nn.init.xavier_uniform_(self.q_proj.weight) nn.init.xavier_uniform_(self.k_proj.weight) nn.init.xavier_uniform_(self.v_proj.weight) nn.init.zeros_(self.q_proj.bias) nn.init.zeros_(self.k_proj.bias) nn.init.zeros_(self.v_proj.bias) nn.init.zeros_(self.output_proj[-1].weight) nn.init.zeros_(self.output_proj[-1].bias) nn.init.uniform_(self.output_proj[-1].weight, -1e-5, 1e-5) nn.init.constant_(self.output_proj[-1].bias, 0) def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 # Spatial dimensions (3 for 3D) device = input.device batch_coords = input.coords feats = input.feats # Generate child positions (identical to original) n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) n_coords = torch.nonzero(n_cube) # [8, 3] for 3D n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # Calculate actual spatial offsets (normalized to [-0.5, 0.5]) spatial_offsets = (n_coords[:, 1:].float() - 0.5) # Centered at origin pos_emb = self.pos_embed(spatial_offsets) # [8, C] # Enhance original features before attention enhanced_feats = self.feat_enhance(feats) # Compute new coordinates (same as original) new_coords = batch_coords.clone() new_coords[:, 1:] *= 2 expanded_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) # Prepare attention inputs N = feats.shape[0] # Number of parent voxels num_children = n_coords.shape[0] # Always 8 for 3D # Project features to K, V K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # Project position embeddings to Q Q = self.q_proj(pos_emb).view(1, num_children, self.num_heads, self.head_dim) ######################################### # # Expand tensors for attention # K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D] # V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D] # Q = Q.expand(N, -1, -1, -1) # [N, 8, H, D] # attn_out = F.scaled_dot_product_attention( # Q.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), # K.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), # V.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), # dropout_p=0.0, # ) # # Reshape attention output # attn_out = attn_out.view(N, num_children, self.in_channels) K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D] V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D] Q = Q.expand(N, num_children, -1, -1) # [N, 8, H, D] # === 手动 scaled dot-product attention === Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] K_ = K.permute(0, 2, 1, 3) # [N, H, 8, D] V_ = V.permute(0, 2, 1, 3) # [N, H, 8, D] scale = self.head_dim ** -0.5 attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * scale # [N, H, 8, 8] # 稳定 softmax:减去最大值 attn_logits = attn_logits - attn_logits.amax(dim=-1, keepdim=True) attn_weights = torch.softmax(attn_logits, dim=-1) attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] # 拼回原始形状 [N, 8, C] attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, num_children, self.in_channels) # Position injection and output processing modulated = attn_out + pos_emb.unsqueeze(0) # Inject position info transformed = self.output_proj(self.output_norm(modulated)) # Residual connection: Combine with expanded parent features base_features = enhanced_feats.unsqueeze(1).expand(-1, num_children, -1) child_feats = base_features + transformed # Preserve original information # Create new sparse tensor out = SparseTensor( child_feats.reshape(N * num_children, -1), expanded_coords.flatten(0, 1), input.shape ) out._scale = input._scale * 2 out._spatial_cache = input._spatial_cache return out # ######################## relative linear ############################# # class SparseSubdivide_attn(nn.Module): # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): # super().__init__() # assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" # self.in_channels = in_channels # self.head_dim = in_channels // num_heads # self.num_heads = num_heads # self.scale = self.head_dim ** -0.5 # self.pos_embed = nn.Sequential( # nn.Linear(3, in_channels), # nn.LayerNorm(in_channels), # nn.GELU(), # nn.Linear(in_channels, in_channels) # ) # self.feat_enhance = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.LayerNorm(in_channels * 2), # nn.GELU(), # nn.Linear(in_channels * 2, in_channels), # nn.LayerNorm(in_channels) # ) # self.q_proj = nn.Linear(in_channels, in_channels) # self.k_proj = nn.Linear(in_channels, in_channels) # self.v_proj = nn.Linear(in_channels, in_channels) # self.output_norm = nn.LayerNorm(in_channels) # self.output_proj = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.GELU(), # nn.Dropout(dropout), # nn.Linear(in_channels * 2, in_channels) # ) # # self._initialize_weights() # def _initialize_weights(self): # nn.init.xavier_uniform_(self.q_proj.weight) # nn.init.xavier_uniform_(self.k_proj.weight) # nn.init.xavier_uniform_(self.v_proj.weight) # nn.init.zeros_(self.q_proj.bias) # nn.init.zeros_(self.k_proj.bias) # nn.init.zeros_(self.v_proj.bias) # nn.init.zeros_(self.output_proj[-1].weight) # nn.init.constant_(self.output_proj[-1].bias, 0) # def forward(self, input: SparseTensor) -> SparseTensor: # DIM = input.coords.shape[-1] - 1 # device = input.device # coords = input.coords # feats = input.feats # N = feats.shape[0] # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) # n_coords = torch.nonzero(n_cube) # [8, 3] # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] # spatial_offsets = n_coords[:, 1:].float() # [8, 3], 用于位置编码 # pos_emb = self.pos_embed(spatial_offsets.to(device=device, dtype=feats.dtype)) # [8, C] # Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D] # enhanced_feats = self.feat_enhance(feats) # [N, C] # K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] # V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] # Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] # K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D] # V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D] # attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1] # attn_weights = torch.softmax(attn_logits, dim=2) # over children # attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) # attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] # attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C] # modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C] # transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C] # base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1) # child_feats = base_features + transformed # [N, 8, C] # new_coords = coords.clone() # new_coords[:, 1:] *= 2 # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] # return SparseTensor( # child_feats.reshape(N * 8, self.in_channels), # expanded_coords.reshape(N * 8, 4), # input.shape # ) # ######################## relative embedding ############################# # class SparseSubdivide_attn(nn.Module): # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): # super().__init__() # assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" # self.in_channels = in_channels # self.head_dim = in_channels // num_heads # self.num_heads = num_heads # self.scale = self.head_dim ** -0.5 # self.pos_index_embed = nn.Embedding(8, in_channels) # self.feat_enhance = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.LayerNorm(in_channels * 2), # nn.GELU(), # nn.Linear(in_channels * 2, in_channels), # nn.LayerNorm(in_channels) # ) # self.q_proj = nn.Linear(in_channels, in_channels) # self.k_proj = nn.Linear(in_channels, in_channels) # self.v_proj = nn.Linear(in_channels, in_channels) # self.output_norm = nn.LayerNorm(in_channels) # self.output_proj = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.GELU(), # nn.Dropout(dropout), # nn.Linear(in_channels * 2, in_channels) # ) # def forward(self, input: SparseTensor) -> SparseTensor: # DIM = input.coords.shape[-1] - 1 # device = input.device # coords = input.coords # feats = input.feats # N = feats.shape[0] # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) # n_coords = torch.nonzero(n_cube) # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] # pos_indices = torch.arange(8, device=device) # pos_emb = self.pos_index_embed(pos_indices) # [8, C] # Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D] # enhanced_feats = self.feat_enhance(feats) # [N, C] # K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] # V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] # # === Cross Attention === # Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] # K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D] # V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D] # attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1] # attn_weights = torch.softmax(attn_logits, dim=2) # attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) # attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] # attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C] # modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C] # transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C] # base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1) # child_feats = base_features + transformed # [N, 8, C] # new_coords = coords.clone() # new_coords[:, 1:] *= 2 # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] # return SparseTensor( # child_feats.reshape(N * 8, self.in_channels), # expanded_coords.reshape(N * 8, 4), # input.shape # ) # ############################## Position-Specific Filters ##################################### # class SparseSubdivide_attn(nn.Module): # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): # super().__init__() # self.in_channels = in_channels # self.num_heads = num_heads # self.head_dim = in_channels // num_heads # self.scale = self.head_dim ** -0.5 # # Position-aware modulation components # self.pos_index_embed = nn.Embedding(8, in_channels) # self.pos_filters = nn.Embedding(8, in_channels * in_channels) # # Feature enhancement # self.feat_enhance = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.LayerNorm(in_channels * 2), # nn.GELU(), # nn.Linear(in_channels * 2, in_channels), # nn.LayerNorm(in_channels) # ) # # Output transformation # self.output_norm = nn.LayerNorm(in_channels) # self.output_proj = nn.Sequential( # nn.Linear(in_channels, in_channels * 2), # nn.GELU(), # nn.Dropout(dropout), # nn.Linear(in_channels * 2, in_channels) # ) # def forward(self, input: SparseTensor) -> SparseTensor: # DIM = input.coords.shape[-1] - 1 # device = input.device # coords = input.coords # feats = input.feats # N = feats.shape[0] # # Generate subdivision coordinates (8 children per voxel) # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) # n_coords = torch.nonzero(n_cube) # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] # pos_indices = torch.arange(8, device=device) # # Position-aware feature modulation # pos_emb = self.pos_index_embed(pos_indices) # [8, C] # pos_filters = self.pos_filters(pos_indices) # [8, C*C] # pos_filters = pos_filters.view(8, self.in_channels, self.in_channels) # [8, C, C] # # Enhance parent features # enhanced_feats = self.feat_enhance(feats) # [N, C] # # Apply position-specific transformation # modulated_feats = torch.einsum('pci,nc->npc', pos_filters, enhanced_feats) # [N, 8, C] # modulated_feats = modulated_feats + pos_emb.unsqueeze(0) # Add positional encoding # # Final transformation # transformed = self.output_proj(self.output_norm(modulated_feats)) # [N, 8, C] # child_feats = enhanced_feats.unsqueeze(1) + transformed # Residual connection # # Compute new coordinates # new_coords = coords.clone() # new_coords[:, 1:] *= 2 # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] # return SparseTensor( # child_feats.reshape(N * 8, self.in_channels), # expanded_coords.reshape(N * 8, 4), # input.shape # ) # # ################################# nn.Parameter ####################################### # # class SparseSubdivideCrossAttn(nn.Module): # # def __init__(self, in_channels: int, num_heads: int = 4, mlp_ratio: int = 4): # # super().__init__() # # self.in_channels = in_channels # # self.num_heads = num_heads # # self.head_dim = in_channels // num_heads # # self.scale = self.head_dim ** -0.5 # # self.mlp_ratio = mlp_ratio # # self.pos_embed = nn.Parameter(torch.randn(8, in_channels)) # # self.q_proj = nn.Linear(in_channels, in_channels) # # self.kv_proj = nn.Linear(in_channels, in_channels * 2) # # self.proj = nn.Linear(in_channels, in_channels) # # self.norm1 = nn.LayerNorm(in_channels) # # self.norm2 = nn.LayerNorm(in_channels) # # self.mlp = nn.Sequential( # # nn.Linear(in_channels, in_channels * mlp_ratio), # # nn.GELU(), # # nn.Linear(in_channels * mlp_ratio, in_channels) # # ) # # def forward(self, input: SparseTensor) -> SparseTensor: # # DIM = input.coords.shape[-1] - 1 # # device = input.device # # coords = input.coords # # feats = input.feats # # N = feats.shape[0] # # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) # # n_coords = torch.nonzero(n_cube) # # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # # q = self.q_proj(self.pos_embed) # # q = q.reshape(8, self.num_heads, self.head_dim).permute(1, 0, 2) # [num_heads, 8, head_dim] # # kv = self.kv_proj(feats) # # kv = kv.reshape(N, 2, self.num_heads, self.head_dim).permute(2, 0, 1, 3) # [num_heads, N, 2, head_dim] # # k, v = kv[:, :, 0, :], kv[:, :, 1, :] # [num_heads, N, head_dim] for both # # attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [num_heads, 8, N] # # attn = torch.softmax(attn, dim=-1) # # out = torch.matmul(attn, v) # [num_heads, 8, head_dim] # # out = out.permute(1, 0, 2).reshape(8, N, self.in_channels) # [8, N, in_channels] # # out = out.permute(1, 0, 2) # [N, 8, in_channels] # # x = self.proj(out) + feats.unsqueeze(1) # # x = self.norm1(x) # # x = x + self.mlp(self.norm2(x)) # # new_coords = coords.clone() # # new_coords[:, 1:] *= 2 # # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # # return SparseTensor( # # x.reshape(N * 8, self.in_channels), # # expanded_coords.reshape(N * 8, 4), # # input.shape # # ) # # ################################ Modulation #################################### # # class SparseSubdivideModulation(nn.Module): # # def __init__(self, in_channels: int): # # super().__init__() # # self.in_channels = in_channels # # self.position_emb = nn.Embedding(8, in_channels) # # self.modulation_vectors = nn.Embedding(8, in_channels) # # self.feature_transformer = nn.Sequential( # # nn.Linear(in_channels, in_channels * 2), # # nn.LayerNorm(in_channels * 2), # # nn.GELU(), # # nn.Linear(in_channels * 2, in_channels) # # ) # # self.output_mlp = nn.Sequential( # # nn.Linear(in_channels, in_channels * 2), # # nn.GELU(), # # nn.Linear(in_channels * 2, in_channels) # # ) # # def forward(self, input: SparseTensor) -> SparseTensor: # # DIM = input.coords.shape[-1] - 1 # # device = input.device # # coords = input.coords # # feats = input.feats # # N = feats.shape[0] # # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) # # n_coords = torch.nonzero(n_cube) # # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # # pos_ids = torch.arange(8, device=device) # # trans_feats = self.feature_transformer(feats) # [N, C] # # mod_vectors = self.modulation_vectors(pos_ids) # [8, C] # # pos_emb = self.position_emb(pos_ids) # [8, C] # # modulated_feats = torch.einsum('nc,pc->npc', trans_feats, mod_vectors) + pos_emb # # output_feats = self.output_mlp(modulated_feats) # [N, 8, C] # # new_coords = coords.clone() # # new_coords[:, 1:] *= 2 # # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # # return SparseTensor( # # output_feats.reshape(N * 8, self.in_channels), # # expanded_coords.reshape(N * 8, 4), # # input.shape # # ) # ############################## 16 * 3 embedding ############################## # # 730 class SparseSubdivide_attn(nn.Module): def __init__(self, in_channels: int, num_heads: int = 4, resolution: int=128): super().__init__() self.in_channels = in_channels self.num_heads = num_heads self.embed_dim = in_channels self.relative_coords = resolution // 64 self.coord_embed_x = nn.Embedding(self.relative_coords, self.embed_dim) self.coord_embed_y = nn.Embedding(self.relative_coords, self.embed_dim) self.coord_embed_z = nn.Embedding(self.relative_coords, self.embed_dim) self.embed_proj = nn.Linear(in_channels * 3, in_channels) def forward(self, input): DIM = input.coords.shape[-1] - 1 device = input.device feats = input.feats n_cube = torch.ones([2]*DIM, device=device, dtype=torch.int) n_coords = torch.nonzero(n_cube) n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) new_coords = input.coords.clone() new_coords[:, 1:] *= 2 new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) # [N, 8, 4] abs_coords = new_coords[:, :, 1:] # [N, 8, 3] mod_coords = abs_coords % self.relative_coords x_embed = self.coord_embed_x(mod_coords[..., 0].long()) y_embed = self.coord_embed_y(mod_coords[..., 1].long()) z_embed = self.coord_embed_z(mod_coords[..., 2].long()) pos_embed = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, 8, 3C] pos_embed = self.embed_proj(pos_embed) # [N, 8, C] feats = feats.unsqueeze(1).expand(-1, 8, -1) # [N, 8, C] new_feats = feats + pos_embed # [N, 8, C] out = SparseTensor( new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape ) out._scale = input._scale * 2 out._spatial_cache = input._spatial_cache return out class SparseSubdivide_attn(nn.Module): """ Upsample with sparse cross-attention between parent features and position embeddings """ def __init__(self, in_channels: int, num_heads: int = 4): super().__init__() self.in_channels = in_channels self.num_heads = num_heads self.pos_embed = nn.Embedding(8, in_channels) # [8, C] def forward(self, input: SparseTensor) -> SparseTensor: DIM = input.coords.shape[-1] - 1 device = input.device # upsample scale=2^DIM n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) n_coords = torch.nonzero(n_cube) n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) factor = n_coords.shape[0] assert factor == 2 ** DIM # print(n_coords.shape) new_coords = input.coords.clone() new_coords[:, 1:] *= 2 new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) base_feats = input.feats.unsqueeze(1).expand(-1, factor, -1) # [N,8,C] child_ids = torch.arange(8, device=device) # [8] pos_feats = self.pos_embed(child_ids).unsqueeze(0) # [1,8,C] query_feats = base_feats + pos_feats # [N,8,C] final_feats = query_feats.flatten(0, 1) out = SparseTensor( feats=final_feats, coords=new_coords.flatten(0, 1), shape=input.shape ) out._scale = input._scale * 2 return out