| # MIT License | |
| # Copyright (c) Microsoft Corporation. | |
| # Copyright (c) 2025 VAST-AI-Research and contributors. | |
| # Permission is hereby granted, free of charge, to any person obtaining a copy | |
| # of this software and associated documentation files (the "Software"), to deal | |
| # in the Software without restriction, including without limitation the rights | |
| # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
| # copies of the Software, and to permit persons to whom the Software is | |
| # furnished to do so, subject to the following conditions: | |
| # The above copyright notice and this permission notice shall be included in all | |
| # copies or substantial portions of the Software. | |
| # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
| # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
| # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
| # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
| # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
| # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
| # SOFTWARE | |
| from typing import * | |
| import torch | |
| import torch.nn as nn | |
| from . import SparseTensor | |
| import torch.nn.functional as F | |
| import spconv.pytorch as spconv | |
| from typing import Optional | |
| from query_point import PE_NeRF | |
| from ...modules.sparse.transformer.blocks import SparseTransformerCrossBlock | |
| __all__ = [ | |
| 'SparseDownsample', | |
| 'SparseUpsample', | |
| 'SparseSubdivide', | |
| 'SparseSubdivide_attn' | |
| ] | |
| class SparseDownsample(nn.Module): | |
| """ | |
| Downsample a sparse tensor by a factor of `factor`. | |
| Implemented as average pooling. | |
| """ | |
| def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): | |
| super(SparseDownsample, self).__init__() | |
| self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| DIM = input.coords.shape[-1] - 1 | |
| factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM | |
| assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' | |
| coord = list(input.coords.unbind(dim=-1)) | |
| for i, f in enumerate(factor): | |
| coord[i+1] = coord[i+1] // f | |
| MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] | |
| OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] | |
| code = sum([c * o for c, o in zip(coord, OFFSET)]) | |
| code, idx = code.unique(return_inverse=True) | |
| new_feats = torch.scatter_reduce( | |
| torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), | |
| dim=0, | |
| index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), | |
| src=input.feats, | |
| # reduce='mean' | |
| reduce='amax', | |
| ) | |
| new_coords = torch.stack( | |
| [code // OFFSET[0]] + | |
| [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], | |
| dim=-1 | |
| ) | |
| out = SparseTensor(new_feats, new_coords, input.shape,) | |
| out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) | |
| out._spatial_cache = input._spatial_cache | |
| out.register_spatial_cache(f'upsample_{factor}_coords', input.coords) | |
| out.register_spatial_cache(f'upsample_{factor}_layout', input.layout) | |
| out.register_spatial_cache(f'upsample_{factor}_idx', idx) | |
| return out | |
| # class SparseDownsample(nn.Module): | |
| # """ | |
| # Downsample a sparse tensor by a factor of `factor`. | |
| # Implemented as average pooling. | |
| # """ | |
| # def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]): | |
| # super(SparseDownsample, self).__init__() | |
| # self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor | |
| # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # DIM = input.coords.shape[-1] - 1 | |
| # factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM | |
| # assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.' | |
| # coord = list(input.coords.unbind(dim=-1)) | |
| # for i, f in enumerate(factor): | |
| # coord[i+1] = coord[i+1] // f | |
| # MAX = [coord[i+1].max().item() + 1 for i in range(DIM)] | |
| # OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1] | |
| # code = sum([c * o for c, o in zip(coord, OFFSET)]) | |
| # code, idx = code.unique(return_inverse=True) | |
| # new_feats = torch.scatter_reduce( | |
| # torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype), | |
| # dim=0, | |
| # index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]), | |
| # src=input.feats, | |
| # reduce='mean' | |
| # ) | |
| # new_coords = torch.stack( | |
| # [code // OFFSET[0]] + | |
| # [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)], | |
| # dim=-1 | |
| # ) | |
| # out = SparseTensor(new_feats, new_coords, input.shape,) | |
| # out._scale = tuple([s // f for s, f in zip(input._scale, factor)]) | |
| # out._spatial_cache = input._spatial_cache | |
| # if out.get_spatial_cache(f'upsample_{factor}_coords') is not None: | |
| # out.register_spatial_cache(f'upsample_{factor}_coords', [*out.get_spatial_cache(f'upsample_{factor}_coords'), input.coords]) | |
| # out.register_spatial_cache(f'upsample_{factor}_layout', [*out.get_spatial_cache(f'upsample_{factor}_layout'), input.layout]) | |
| # out.register_spatial_cache(f'upsample_{factor}_idx', [*out.get_spatial_cache(f'upsample_{factor}_idx'), idx]) | |
| # else: | |
| # out.register_spatial_cache(f'upsample_{factor}_coords', [input.coords]) | |
| # out.register_spatial_cache(f'upsample_{factor}_layout', [input.layout]) | |
| # out.register_spatial_cache(f'upsample_{factor}_idx', [idx]) | |
| # return out | |
| class SparseUpsample(nn.Module): | |
| """ | |
| Upsample a sparse tensor by a factor of `factor`. | |
| Implemented as nearest neighbor interpolation. | |
| """ | |
| def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]): | |
| super(SparseUpsample, self).__init__() | |
| self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| DIM = input.coords.shape[-1] - 1 | |
| factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM | |
| assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.' | |
| new_coords = input.get_spatial_cache(f'upsample_{factor}_coords') | |
| new_layout = input.get_spatial_cache(f'upsample_{factor}_layout') | |
| idx = input.get_spatial_cache(f'upsample_{factor}_idx') | |
| # print(len(new_coords)) | |
| new_coords = new_coords.pop(-1) | |
| new_layout = new_layout.pop(-1) | |
| idx = idx.pop(-1) | |
| if any([x is None for x in [new_coords, new_layout, idx]]): | |
| raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.') | |
| new_feats = input.feats[idx] | |
| out = SparseTensor(new_feats, new_coords, input.shape, new_layout) | |
| out._scale = tuple([s * f for s, f in zip(input._scale, factor)]) | |
| out._spatial_cache = input._spatial_cache | |
| return out | |
| class SparseSubdivide(nn.Module): | |
| """ | |
| Upsample a sparse tensor by a factor of `factor`. | |
| Implemented as nearest neighbor interpolation. | |
| """ | |
| def __init__(self): | |
| super(SparseSubdivide, self).__init__() | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| DIM = input.coords.shape[-1] - 1 | |
| # upsample scale=2^DIM | |
| n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) | |
| n_coords = torch.nonzero(n_cube) | |
| n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| factor = n_coords.shape[0] | |
| assert factor == 2 ** DIM | |
| # print(n_coords.shape) | |
| new_coords = input.coords.clone() | |
| new_coords[:, 1:] *= 2 | |
| new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) | |
| new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:]) | |
| out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape) | |
| out._scale = input._scale * 2 | |
| out._spatial_cache = input._spatial_cache | |
| return out | |
| #################### new ######################## | |
| # 730 add ca, | |
| class SparseSubdivide_attn(nn.Module): | |
| """ | |
| Attention-based upsampling: Compute child voxel features with multi-head attention | |
| Enhanced with residual connections, layer normalization, and position encoding | |
| Improvements to overcome training plateau: | |
| 1. Position encoding using relative offsets instead of indices | |
| 2. Feature enhancement before attention | |
| 3. Output normalization and projection | |
| 4. Careful residual connections | |
| """ | |
| def __init__(self, in_channels: int, num_heads: int = 4,): | |
| super().__init__() | |
| assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" | |
| self.in_channels = in_channels | |
| self.head_dim = in_channels // num_heads | |
| self.num_heads = num_heads | |
| self.scale = self.head_dim ** -0.5 | |
| # Enhanced position encoding (continuous offsets instead of discrete indices) | |
| # self.pos_embed = nn.Sequential( | |
| # nn.Linear(3, 64), # Process actual spatial offsets | |
| # nn.LayerNorm(64), | |
| # nn.GELU(), | |
| # nn.Linear(64, in_channels) # Map to feature dimension | |
| # ) | |
| self.pos_embed = nn.Sequential( | |
| PE_NeRF(out_channels=in_channels, multires=10), # Process actual spatial offsets | |
| nn.LayerNorm(in_channels * 3), | |
| nn.GELU(), | |
| nn.Linear(in_channels * 3, in_channels) # Map to feature dimension | |
| ) | |
| # Feature enhancement before attention | |
| self.feat_enhance = nn.Sequential( | |
| nn.Linear(in_channels, in_channels * 2), | |
| nn.LayerNorm(in_channels * 2), | |
| nn.GELU(), | |
| nn.Linear(in_channels * 2, in_channels), | |
| nn.LayerNorm(in_channels) | |
| ) | |
| # Attention projections | |
| self.q_proj = nn.Linear(in_channels, in_channels) # Query from position | |
| self.k_proj = nn.Linear(in_channels, in_channels) # Key from content | |
| self.v_proj = nn.Linear(in_channels, in_channels) # Value from content | |
| # Output processing with residual | |
| self.output_norm = nn.LayerNorm(in_channels) | |
| self.output_proj = nn.Sequential( | |
| nn.Linear(in_channels, in_channels * 2), | |
| nn.GELU(), | |
| nn.Linear(in_channels * 2, in_channels) | |
| ) | |
| # Initialize for stable training | |
| self._initialize_weights() | |
| def _initialize_weights(self): | |
| nn.init.xavier_uniform_(self.q_proj.weight) | |
| nn.init.xavier_uniform_(self.k_proj.weight) | |
| nn.init.xavier_uniform_(self.v_proj.weight) | |
| nn.init.zeros_(self.q_proj.bias) | |
| nn.init.zeros_(self.k_proj.bias) | |
| nn.init.zeros_(self.v_proj.bias) | |
| nn.init.zeros_(self.output_proj[-1].weight) | |
| nn.init.zeros_(self.output_proj[-1].bias) | |
| nn.init.uniform_(self.output_proj[-1].weight, -1e-5, 1e-5) | |
| nn.init.constant_(self.output_proj[-1].bias, 0) | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| DIM = input.coords.shape[-1] - 1 # Spatial dimensions (3 for 3D) | |
| device = input.device | |
| batch_coords = input.coords | |
| feats = input.feats | |
| # Generate child positions (identical to original) | |
| n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| n_coords = torch.nonzero(n_cube) # [8, 3] for 3D | |
| n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| # Calculate actual spatial offsets (normalized to [-0.5, 0.5]) | |
| spatial_offsets = (n_coords[:, 1:].float() - 0.5) # Centered at origin | |
| pos_emb = self.pos_embed(spatial_offsets) # [8, C] | |
| # Enhance original features before attention | |
| enhanced_feats = self.feat_enhance(feats) | |
| # Compute new coordinates (same as original) | |
| new_coords = batch_coords.clone() | |
| new_coords[:, 1:] *= 2 | |
| expanded_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) | |
| # Prepare attention inputs | |
| N = feats.shape[0] # Number of parent voxels | |
| num_children = n_coords.shape[0] # Always 8 for 3D | |
| # Project features to K, V | |
| K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) | |
| V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) | |
| # Project position embeddings to Q | |
| Q = self.q_proj(pos_emb).view(1, num_children, self.num_heads, self.head_dim) | |
| ######################################### | |
| # # Expand tensors for attention | |
| # K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D] | |
| # V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D] | |
| # Q = Q.expand(N, -1, -1, -1) # [N, 8, H, D] | |
| # attn_out = F.scaled_dot_product_attention( | |
| # Q.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), | |
| # K.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), | |
| # V.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim), | |
| # dropout_p=0.0, | |
| # ) | |
| # # Reshape attention output | |
| # attn_out = attn_out.view(N, num_children, self.in_channels) | |
| K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D] | |
| V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D] | |
| Q = Q.expand(N, num_children, -1, -1) # [N, 8, H, D] | |
| # === 手动 scaled dot-product attention === | |
| Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] | |
| K_ = K.permute(0, 2, 1, 3) # [N, H, 8, D] | |
| V_ = V.permute(0, 2, 1, 3) # [N, H, 8, D] | |
| scale = self.head_dim ** -0.5 | |
| attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * scale # [N, H, 8, 8] | |
| # 稳定 softmax:减去最大值 | |
| attn_logits = attn_logits - attn_logits.amax(dim=-1, keepdim=True) | |
| attn_weights = torch.softmax(attn_logits, dim=-1) | |
| attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) | |
| attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] | |
| # 拼回原始形状 [N, 8, C] | |
| attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, num_children, self.in_channels) | |
| # Position injection and output processing | |
| modulated = attn_out + pos_emb.unsqueeze(0) # Inject position info | |
| transformed = self.output_proj(self.output_norm(modulated)) | |
| # Residual connection: Combine with expanded parent features | |
| base_features = enhanced_feats.unsqueeze(1).expand(-1, num_children, -1) | |
| child_feats = base_features + transformed # Preserve original information | |
| # Create new sparse tensor | |
| out = SparseTensor( | |
| child_feats.reshape(N * num_children, -1), | |
| expanded_coords.flatten(0, 1), | |
| input.shape | |
| ) | |
| out._scale = input._scale * 2 | |
| out._spatial_cache = input._spatial_cache | |
| return out | |
| # ######################## relative linear ############################# | |
| # class SparseSubdivide_attn(nn.Module): | |
| # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): | |
| # super().__init__() | |
| # assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" | |
| # self.in_channels = in_channels | |
| # self.head_dim = in_channels // num_heads | |
| # self.num_heads = num_heads | |
| # self.scale = self.head_dim ** -0.5 | |
| # self.pos_embed = nn.Sequential( | |
| # nn.Linear(3, in_channels), | |
| # nn.LayerNorm(in_channels), | |
| # nn.GELU(), | |
| # nn.Linear(in_channels, in_channels) | |
| # ) | |
| # self.feat_enhance = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.LayerNorm(in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Linear(in_channels * 2, in_channels), | |
| # nn.LayerNorm(in_channels) | |
| # ) | |
| # self.q_proj = nn.Linear(in_channels, in_channels) | |
| # self.k_proj = nn.Linear(in_channels, in_channels) | |
| # self.v_proj = nn.Linear(in_channels, in_channels) | |
| # self.output_norm = nn.LayerNorm(in_channels) | |
| # self.output_proj = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Dropout(dropout), | |
| # nn.Linear(in_channels * 2, in_channels) | |
| # ) | |
| # # self._initialize_weights() | |
| # def _initialize_weights(self): | |
| # nn.init.xavier_uniform_(self.q_proj.weight) | |
| # nn.init.xavier_uniform_(self.k_proj.weight) | |
| # nn.init.xavier_uniform_(self.v_proj.weight) | |
| # nn.init.zeros_(self.q_proj.bias) | |
| # nn.init.zeros_(self.k_proj.bias) | |
| # nn.init.zeros_(self.v_proj.bias) | |
| # nn.init.zeros_(self.output_proj[-1].weight) | |
| # nn.init.constant_(self.output_proj[-1].bias, 0) | |
| # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # DIM = input.coords.shape[-1] - 1 | |
| # device = input.device | |
| # coords = input.coords | |
| # feats = input.feats | |
| # N = feats.shape[0] | |
| # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| # n_coords = torch.nonzero(n_cube) # [8, 3] | |
| # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] | |
| # spatial_offsets = n_coords[:, 1:].float() # [8, 3], 用于位置编码 | |
| # pos_emb = self.pos_embed(spatial_offsets.to(device=device, dtype=feats.dtype)) # [8, C] | |
| # Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D] | |
| # enhanced_feats = self.feat_enhance(feats) # [N, C] | |
| # K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] | |
| # V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] | |
| # Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] | |
| # K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D] | |
| # V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D] | |
| # attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1] | |
| # attn_weights = torch.softmax(attn_logits, dim=2) # over children | |
| # attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) | |
| # attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] | |
| # attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C] | |
| # modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C] | |
| # transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C] | |
| # base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1) | |
| # child_feats = base_features + transformed # [N, 8, C] | |
| # new_coords = coords.clone() | |
| # new_coords[:, 1:] *= 2 | |
| # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] | |
| # return SparseTensor( | |
| # child_feats.reshape(N * 8, self.in_channels), | |
| # expanded_coords.reshape(N * 8, 4), | |
| # input.shape | |
| # ) | |
| # ######################## relative embedding ############################# | |
| # class SparseSubdivide_attn(nn.Module): | |
| # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): | |
| # super().__init__() | |
| # assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads" | |
| # self.in_channels = in_channels | |
| # self.head_dim = in_channels // num_heads | |
| # self.num_heads = num_heads | |
| # self.scale = self.head_dim ** -0.5 | |
| # self.pos_index_embed = nn.Embedding(8, in_channels) | |
| # self.feat_enhance = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.LayerNorm(in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Linear(in_channels * 2, in_channels), | |
| # nn.LayerNorm(in_channels) | |
| # ) | |
| # self.q_proj = nn.Linear(in_channels, in_channels) | |
| # self.k_proj = nn.Linear(in_channels, in_channels) | |
| # self.v_proj = nn.Linear(in_channels, in_channels) | |
| # self.output_norm = nn.LayerNorm(in_channels) | |
| # self.output_proj = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Dropout(dropout), | |
| # nn.Linear(in_channels * 2, in_channels) | |
| # ) | |
| # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # DIM = input.coords.shape[-1] - 1 | |
| # device = input.device | |
| # coords = input.coords | |
| # feats = input.feats | |
| # N = feats.shape[0] | |
| # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| # n_coords = torch.nonzero(n_cube) | |
| # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] | |
| # pos_indices = torch.arange(8, device=device) | |
| # pos_emb = self.pos_index_embed(pos_indices) # [8, C] | |
| # Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D] | |
| # enhanced_feats = self.feat_enhance(feats) # [N, C] | |
| # K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] | |
| # V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D] | |
| # # === Cross Attention === | |
| # Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D] | |
| # K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D] | |
| # V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D] | |
| # attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1] | |
| # attn_weights = torch.softmax(attn_logits, dim=2) | |
| # attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0) | |
| # attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D] | |
| # attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C] | |
| # modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C] | |
| # transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C] | |
| # base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1) | |
| # child_feats = base_features + transformed # [N, 8, C] | |
| # new_coords = coords.clone() | |
| # new_coords[:, 1:] *= 2 | |
| # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] | |
| # return SparseTensor( | |
| # child_feats.reshape(N * 8, self.in_channels), | |
| # expanded_coords.reshape(N * 8, 4), | |
| # input.shape | |
| # ) | |
| # ############################## Position-Specific Filters ##################################### | |
| # class SparseSubdivide_attn(nn.Module): | |
| # def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05): | |
| # super().__init__() | |
| # self.in_channels = in_channels | |
| # self.num_heads = num_heads | |
| # self.head_dim = in_channels // num_heads | |
| # self.scale = self.head_dim ** -0.5 | |
| # # Position-aware modulation components | |
| # self.pos_index_embed = nn.Embedding(8, in_channels) | |
| # self.pos_filters = nn.Embedding(8, in_channels * in_channels) | |
| # # Feature enhancement | |
| # self.feat_enhance = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.LayerNorm(in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Linear(in_channels * 2, in_channels), | |
| # nn.LayerNorm(in_channels) | |
| # ) | |
| # # Output transformation | |
| # self.output_norm = nn.LayerNorm(in_channels) | |
| # self.output_proj = nn.Sequential( | |
| # nn.Linear(in_channels, in_channels * 2), | |
| # nn.GELU(), | |
| # nn.Dropout(dropout), | |
| # nn.Linear(in_channels * 2, in_channels) | |
| # ) | |
| # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # DIM = input.coords.shape[-1] - 1 | |
| # device = input.device | |
| # coords = input.coords | |
| # feats = input.feats | |
| # N = feats.shape[0] | |
| # # Generate subdivision coordinates (8 children per voxel) | |
| # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| # n_coords = torch.nonzero(n_cube) | |
| # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4] | |
| # pos_indices = torch.arange(8, device=device) | |
| # # Position-aware feature modulation | |
| # pos_emb = self.pos_index_embed(pos_indices) # [8, C] | |
| # pos_filters = self.pos_filters(pos_indices) # [8, C*C] | |
| # pos_filters = pos_filters.view(8, self.in_channels, self.in_channels) # [8, C, C] | |
| # # Enhance parent features | |
| # enhanced_feats = self.feat_enhance(feats) # [N, C] | |
| # # Apply position-specific transformation | |
| # modulated_feats = torch.einsum('pci,nc->npc', pos_filters, enhanced_feats) # [N, 8, C] | |
| # modulated_feats = modulated_feats + pos_emb.unsqueeze(0) # Add positional encoding | |
| # # Final transformation | |
| # transformed = self.output_proj(self.output_norm(modulated_feats)) # [N, 8, C] | |
| # child_feats = enhanced_feats.unsqueeze(1) + transformed # Residual connection | |
| # # Compute new coordinates | |
| # new_coords = coords.clone() | |
| # new_coords[:, 1:] *= 2 | |
| # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4] | |
| # return SparseTensor( | |
| # child_feats.reshape(N * 8, self.in_channels), | |
| # expanded_coords.reshape(N * 8, 4), | |
| # input.shape | |
| # ) | |
| # # ################################# nn.Parameter ####################################### | |
| # # class SparseSubdivideCrossAttn(nn.Module): | |
| # # def __init__(self, in_channels: int, num_heads: int = 4, mlp_ratio: int = 4): | |
| # # super().__init__() | |
| # # self.in_channels = in_channels | |
| # # self.num_heads = num_heads | |
| # # self.head_dim = in_channels // num_heads | |
| # # self.scale = self.head_dim ** -0.5 | |
| # # self.mlp_ratio = mlp_ratio | |
| # # self.pos_embed = nn.Parameter(torch.randn(8, in_channels)) | |
| # # self.q_proj = nn.Linear(in_channels, in_channels) | |
| # # self.kv_proj = nn.Linear(in_channels, in_channels * 2) | |
| # # self.proj = nn.Linear(in_channels, in_channels) | |
| # # self.norm1 = nn.LayerNorm(in_channels) | |
| # # self.norm2 = nn.LayerNorm(in_channels) | |
| # # self.mlp = nn.Sequential( | |
| # # nn.Linear(in_channels, in_channels * mlp_ratio), | |
| # # nn.GELU(), | |
| # # nn.Linear(in_channels * mlp_ratio, in_channels) | |
| # # ) | |
| # # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # # DIM = input.coords.shape[-1] - 1 | |
| # # device = input.device | |
| # # coords = input.coords | |
| # # feats = input.feats | |
| # # N = feats.shape[0] | |
| # # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| # # n_coords = torch.nonzero(n_cube) | |
| # # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| # # q = self.q_proj(self.pos_embed) | |
| # # q = q.reshape(8, self.num_heads, self.head_dim).permute(1, 0, 2) # [num_heads, 8, head_dim] | |
| # # kv = self.kv_proj(feats) | |
| # # kv = kv.reshape(N, 2, self.num_heads, self.head_dim).permute(2, 0, 1, 3) # [num_heads, N, 2, head_dim] | |
| # # k, v = kv[:, :, 0, :], kv[:, :, 1, :] # [num_heads, N, head_dim] for both | |
| # # attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [num_heads, 8, N] | |
| # # attn = torch.softmax(attn, dim=-1) | |
| # # out = torch.matmul(attn, v) # [num_heads, 8, head_dim] | |
| # # out = out.permute(1, 0, 2).reshape(8, N, self.in_channels) # [8, N, in_channels] | |
| # # out = out.permute(1, 0, 2) # [N, 8, in_channels] | |
| # # x = self.proj(out) + feats.unsqueeze(1) | |
| # # x = self.norm1(x) | |
| # # x = x + self.mlp(self.norm2(x)) | |
| # # new_coords = coords.clone() | |
| # # new_coords[:, 1:] *= 2 | |
| # # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) | |
| # # return SparseTensor( | |
| # # x.reshape(N * 8, self.in_channels), | |
| # # expanded_coords.reshape(N * 8, 4), | |
| # # input.shape | |
| # # ) | |
| # # ################################ Modulation #################################### | |
| # # class SparseSubdivideModulation(nn.Module): | |
| # # def __init__(self, in_channels: int): | |
| # # super().__init__() | |
| # # self.in_channels = in_channels | |
| # # self.position_emb = nn.Embedding(8, in_channels) | |
| # # self.modulation_vectors = nn.Embedding(8, in_channels) | |
| # # self.feature_transformer = nn.Sequential( | |
| # # nn.Linear(in_channels, in_channels * 2), | |
| # # nn.LayerNorm(in_channels * 2), | |
| # # nn.GELU(), | |
| # # nn.Linear(in_channels * 2, in_channels) | |
| # # ) | |
| # # self.output_mlp = nn.Sequential( | |
| # # nn.Linear(in_channels, in_channels * 2), | |
| # # nn.GELU(), | |
| # # nn.Linear(in_channels * 2, in_channels) | |
| # # ) | |
| # # def forward(self, input: SparseTensor) -> SparseTensor: | |
| # # DIM = input.coords.shape[-1] - 1 | |
| # # device = input.device | |
| # # coords = input.coords | |
| # # feats = input.feats | |
| # # N = feats.shape[0] | |
| # # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int) | |
| # # n_coords = torch.nonzero(n_cube) | |
| # # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| # # pos_ids = torch.arange(8, device=device) | |
| # # trans_feats = self.feature_transformer(feats) # [N, C] | |
| # # mod_vectors = self.modulation_vectors(pos_ids) # [8, C] | |
| # # pos_emb = self.position_emb(pos_ids) # [8, C] | |
| # # modulated_feats = torch.einsum('nc,pc->npc', trans_feats, mod_vectors) + pos_emb | |
| # # output_feats = self.output_mlp(modulated_feats) # [N, 8, C] | |
| # # new_coords = coords.clone() | |
| # # new_coords[:, 1:] *= 2 | |
| # # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) | |
| # # return SparseTensor( | |
| # # output_feats.reshape(N * 8, self.in_channels), | |
| # # expanded_coords.reshape(N * 8, 4), | |
| # # input.shape | |
| # # ) | |
| # ############################## 16 * 3 embedding ############################## | |
| # # 730 | |
| class SparseSubdivide_attn(nn.Module): | |
| def __init__(self, in_channels: int, num_heads: int = 4, resolution: int=128): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.num_heads = num_heads | |
| self.embed_dim = in_channels | |
| self.relative_coords = resolution // 64 | |
| self.coord_embed_x = nn.Embedding(self.relative_coords, self.embed_dim) | |
| self.coord_embed_y = nn.Embedding(self.relative_coords, self.embed_dim) | |
| self.coord_embed_z = nn.Embedding(self.relative_coords, self.embed_dim) | |
| self.embed_proj = nn.Linear(in_channels * 3, in_channels) | |
| def forward(self, input): | |
| DIM = input.coords.shape[-1] - 1 | |
| device = input.device | |
| feats = input.feats | |
| n_cube = torch.ones([2]*DIM, device=device, dtype=torch.int) | |
| n_coords = torch.nonzero(n_cube) | |
| n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| new_coords = input.coords.clone() | |
| new_coords[:, 1:] *= 2 | |
| new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) # [N, 8, 4] | |
| abs_coords = new_coords[:, :, 1:] # [N, 8, 3] | |
| mod_coords = abs_coords % self.relative_coords | |
| x_embed = self.coord_embed_x(mod_coords[..., 0].long()) | |
| y_embed = self.coord_embed_y(mod_coords[..., 1].long()) | |
| z_embed = self.coord_embed_z(mod_coords[..., 2].long()) | |
| pos_embed = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, 8, 3C] | |
| pos_embed = self.embed_proj(pos_embed) # [N, 8, C] | |
| feats = feats.unsqueeze(1).expand(-1, 8, -1) # [N, 8, C] | |
| new_feats = feats + pos_embed # [N, 8, C] | |
| out = SparseTensor( | |
| new_feats.flatten(0, 1), | |
| new_coords.flatten(0, 1), | |
| input.shape | |
| ) | |
| out._scale = input._scale * 2 | |
| out._spatial_cache = input._spatial_cache | |
| return out | |
| class SparseSubdivide_attn(nn.Module): | |
| """ | |
| Upsample with sparse cross-attention between parent features and position embeddings | |
| """ | |
| def __init__(self, in_channels: int, num_heads: int = 4): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.num_heads = num_heads | |
| self.pos_embed = nn.Embedding(8, in_channels) # [8, C] | |
| def forward(self, input: SparseTensor) -> SparseTensor: | |
| DIM = input.coords.shape[-1] - 1 | |
| device = input.device | |
| # upsample scale=2^DIM | |
| n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int) | |
| n_coords = torch.nonzero(n_cube) | |
| n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) | |
| factor = n_coords.shape[0] | |
| assert factor == 2 ** DIM | |
| # print(n_coords.shape) | |
| new_coords = input.coords.clone() | |
| new_coords[:, 1:] *= 2 | |
| new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) | |
| base_feats = input.feats.unsqueeze(1).expand(-1, factor, -1) # [N,8,C] | |
| child_ids = torch.arange(8, device=device) # [8] | |
| pos_feats = self.pos_embed(child_ids).unsqueeze(0) # [1,8,C] | |
| query_feats = base_feats + pos_feats # [N,8,C] | |
| final_feats = query_feats.flatten(0, 1) | |
| out = SparseTensor( | |
| feats=final_feats, | |
| coords=new_coords.flatten(0, 1), | |
| shape=input.shape | |
| ) | |
| out._scale = input._scale * 2 | |
| return out |