udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
# MIT License
# Copyright (c) Microsoft Corporation.
# Copyright (c) 2025 VAST-AI-Research and contributors.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
from typing import *
import torch
import torch.nn as nn
from . import SparseTensor
import torch.nn.functional as F
import spconv.pytorch as spconv
from typing import Optional
from query_point import PE_NeRF
from ...modules.sparse.transformer.blocks import SparseTransformerCrossBlock
__all__ = [
'SparseDownsample',
'SparseUpsample',
'SparseSubdivide',
'SparseSubdivide_attn'
]
class SparseDownsample(nn.Module):
"""
Downsample a sparse tensor by a factor of `factor`.
Implemented as average pooling.
"""
def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
super(SparseDownsample, self).__init__()
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
def forward(self, input: SparseTensor) -> SparseTensor:
DIM = input.coords.shape[-1] - 1
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
coord = list(input.coords.unbind(dim=-1))
for i, f in enumerate(factor):
coord[i+1] = coord[i+1] // f
MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
code = sum([c * o for c, o in zip(coord, OFFSET)])
code, idx = code.unique(return_inverse=True)
new_feats = torch.scatter_reduce(
torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
dim=0,
index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
src=input.feats,
# reduce='mean'
reduce='amax',
)
new_coords = torch.stack(
[code // OFFSET[0]] +
[(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
dim=-1
)
out = SparseTensor(new_feats, new_coords, input.shape,)
out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
out._spatial_cache = input._spatial_cache
out.register_spatial_cache(f'upsample_{factor}_coords', input.coords)
out.register_spatial_cache(f'upsample_{factor}_layout', input.layout)
out.register_spatial_cache(f'upsample_{factor}_idx', idx)
return out
# class SparseDownsample(nn.Module):
# """
# Downsample a sparse tensor by a factor of `factor`.
# Implemented as average pooling.
# """
# def __init__(self, factor: Union[int, Tuple[int, ...], List[int]]):
# super(SparseDownsample, self).__init__()
# self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
# def forward(self, input: SparseTensor) -> SparseTensor:
# DIM = input.coords.shape[-1] - 1
# factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
# assert DIM == len(factor), 'Input coordinates must have the same dimension as the downsample factor.'
# coord = list(input.coords.unbind(dim=-1))
# for i, f in enumerate(factor):
# coord[i+1] = coord[i+1] // f
# MAX = [coord[i+1].max().item() + 1 for i in range(DIM)]
# OFFSET = torch.cumprod(torch.tensor(MAX[::-1]), 0).tolist()[::-1] + [1]
# code = sum([c * o for c, o in zip(coord, OFFSET)])
# code, idx = code.unique(return_inverse=True)
# new_feats = torch.scatter_reduce(
# torch.zeros(code.shape[0], input.feats.shape[1], device=input.feats.device, dtype=input.feats.dtype),
# dim=0,
# index=idx.unsqueeze(1).expand(-1, input.feats.shape[1]),
# src=input.feats,
# reduce='mean'
# )
# new_coords = torch.stack(
# [code // OFFSET[0]] +
# [(code // OFFSET[i+1]) % MAX[i] for i in range(DIM)],
# dim=-1
# )
# out = SparseTensor(new_feats, new_coords, input.shape,)
# out._scale = tuple([s // f for s, f in zip(input._scale, factor)])
# out._spatial_cache = input._spatial_cache
# if out.get_spatial_cache(f'upsample_{factor}_coords') is not None:
# out.register_spatial_cache(f'upsample_{factor}_coords', [*out.get_spatial_cache(f'upsample_{factor}_coords'), input.coords])
# out.register_spatial_cache(f'upsample_{factor}_layout', [*out.get_spatial_cache(f'upsample_{factor}_layout'), input.layout])
# out.register_spatial_cache(f'upsample_{factor}_idx', [*out.get_spatial_cache(f'upsample_{factor}_idx'), idx])
# else:
# out.register_spatial_cache(f'upsample_{factor}_coords', [input.coords])
# out.register_spatial_cache(f'upsample_{factor}_layout', [input.layout])
# out.register_spatial_cache(f'upsample_{factor}_idx', [idx])
# return out
class SparseUpsample(nn.Module):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as nearest neighbor interpolation.
"""
def __init__(self, factor: Union[int, Tuple[int, int, int], List[int]]):
super(SparseUpsample, self).__init__()
self.factor = tuple(factor) if isinstance(factor, (list, tuple)) else factor
def forward(self, input: SparseTensor) -> SparseTensor:
DIM = input.coords.shape[-1] - 1
factor = self.factor if isinstance(self.factor, tuple) else (self.factor,) * DIM
assert DIM == len(factor), 'Input coordinates must have the same dimension as the upsample factor.'
new_coords = input.get_spatial_cache(f'upsample_{factor}_coords')
new_layout = input.get_spatial_cache(f'upsample_{factor}_layout')
idx = input.get_spatial_cache(f'upsample_{factor}_idx')
# print(len(new_coords))
new_coords = new_coords.pop(-1)
new_layout = new_layout.pop(-1)
idx = idx.pop(-1)
if any([x is None for x in [new_coords, new_layout, idx]]):
raise ValueError('Upsample cache not found. SparseUpsample must be paired with SparseDownsample.')
new_feats = input.feats[idx]
out = SparseTensor(new_feats, new_coords, input.shape, new_layout)
out._scale = tuple([s * f for s, f in zip(input._scale, factor)])
out._spatial_cache = input._spatial_cache
return out
class SparseSubdivide(nn.Module):
"""
Upsample a sparse tensor by a factor of `factor`.
Implemented as nearest neighbor interpolation.
"""
def __init__(self):
super(SparseSubdivide, self).__init__()
def forward(self, input: SparseTensor) -> SparseTensor:
DIM = input.coords.shape[-1] - 1
# upsample scale=2^DIM
n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
n_coords = torch.nonzero(n_cube)
n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
factor = n_coords.shape[0]
assert factor == 2 ** DIM
# print(n_coords.shape)
new_coords = input.coords.clone()
new_coords[:, 1:] *= 2
new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
new_feats = input.feats.unsqueeze(1).expand(input.feats.shape[0], factor, *input.feats.shape[1:])
out = SparseTensor(new_feats.flatten(0, 1), new_coords.flatten(0, 1), input.shape)
out._scale = input._scale * 2
out._spatial_cache = input._spatial_cache
return out
#################### new ########################
# 730 add ca,
class SparseSubdivide_attn(nn.Module):
"""
Attention-based upsampling: Compute child voxel features with multi-head attention
Enhanced with residual connections, layer normalization, and position encoding
Improvements to overcome training plateau:
1. Position encoding using relative offsets instead of indices
2. Feature enhancement before attention
3. Output normalization and projection
4. Careful residual connections
"""
def __init__(self, in_channels: int, num_heads: int = 4,):
super().__init__()
assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads"
self.in_channels = in_channels
self.head_dim = in_channels // num_heads
self.num_heads = num_heads
self.scale = self.head_dim ** -0.5
# Enhanced position encoding (continuous offsets instead of discrete indices)
# self.pos_embed = nn.Sequential(
# nn.Linear(3, 64), # Process actual spatial offsets
# nn.LayerNorm(64),
# nn.GELU(),
# nn.Linear(64, in_channels) # Map to feature dimension
# )
self.pos_embed = nn.Sequential(
PE_NeRF(out_channels=in_channels, multires=10), # Process actual spatial offsets
nn.LayerNorm(in_channels * 3),
nn.GELU(),
nn.Linear(in_channels * 3, in_channels) # Map to feature dimension
)
# Feature enhancement before attention
self.feat_enhance = nn.Sequential(
nn.Linear(in_channels, in_channels * 2),
nn.LayerNorm(in_channels * 2),
nn.GELU(),
nn.Linear(in_channels * 2, in_channels),
nn.LayerNorm(in_channels)
)
# Attention projections
self.q_proj = nn.Linear(in_channels, in_channels) # Query from position
self.k_proj = nn.Linear(in_channels, in_channels) # Key from content
self.v_proj = nn.Linear(in_channels, in_channels) # Value from content
# Output processing with residual
self.output_norm = nn.LayerNorm(in_channels)
self.output_proj = nn.Sequential(
nn.Linear(in_channels, in_channels * 2),
nn.GELU(),
nn.Linear(in_channels * 2, in_channels)
)
# Initialize for stable training
self._initialize_weights()
def _initialize_weights(self):
nn.init.xavier_uniform_(self.q_proj.weight)
nn.init.xavier_uniform_(self.k_proj.weight)
nn.init.xavier_uniform_(self.v_proj.weight)
nn.init.zeros_(self.q_proj.bias)
nn.init.zeros_(self.k_proj.bias)
nn.init.zeros_(self.v_proj.bias)
nn.init.zeros_(self.output_proj[-1].weight)
nn.init.zeros_(self.output_proj[-1].bias)
nn.init.uniform_(self.output_proj[-1].weight, -1e-5, 1e-5)
nn.init.constant_(self.output_proj[-1].bias, 0)
def forward(self, input: SparseTensor) -> SparseTensor:
DIM = input.coords.shape[-1] - 1 # Spatial dimensions (3 for 3D)
device = input.device
batch_coords = input.coords
feats = input.feats
# Generate child positions (identical to original)
n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
n_coords = torch.nonzero(n_cube) # [8, 3] for 3D
n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
# Calculate actual spatial offsets (normalized to [-0.5, 0.5])
spatial_offsets = (n_coords[:, 1:].float() - 0.5) # Centered at origin
pos_emb = self.pos_embed(spatial_offsets) # [8, C]
# Enhance original features before attention
enhanced_feats = self.feat_enhance(feats)
# Compute new coordinates (same as original)
new_coords = batch_coords.clone()
new_coords[:, 1:] *= 2
expanded_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
# Prepare attention inputs
N = feats.shape[0] # Number of parent voxels
num_children = n_coords.shape[0] # Always 8 for 3D
# Project features to K, V
K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim)
V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim)
# Project position embeddings to Q
Q = self.q_proj(pos_emb).view(1, num_children, self.num_heads, self.head_dim)
#########################################
# # Expand tensors for attention
# K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D]
# V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D]
# Q = Q.expand(N, -1, -1, -1) # [N, 8, H, D]
# attn_out = F.scaled_dot_product_attention(
# Q.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim),
# K.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim),
# V.permute(0, 2, 1, 3).reshape(N * num_children, self.num_heads, self.head_dim),
# dropout_p=0.0,
# )
# # Reshape attention output
# attn_out = attn_out.view(N, num_children, self.in_channels)
K = K.expand(-1, num_children, -1, -1) # [N, 8, H, D]
V = V.expand(-1, num_children, -1, -1) # [N, 8, H, D]
Q = Q.expand(N, num_children, -1, -1) # [N, 8, H, D]
# === 手动 scaled dot-product attention ===
Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D]
K_ = K.permute(0, 2, 1, 3) # [N, H, 8, D]
V_ = V.permute(0, 2, 1, 3) # [N, H, 8, D]
scale = self.head_dim ** -0.5
attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * scale # [N, H, 8, 8]
# 稳定 softmax:减去最大值
attn_logits = attn_logits - attn_logits.amax(dim=-1, keepdim=True)
attn_weights = torch.softmax(attn_logits, dim=-1)
attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0)
attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D]
# 拼回原始形状 [N, 8, C]
attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, num_children, self.in_channels)
# Position injection and output processing
modulated = attn_out + pos_emb.unsqueeze(0) # Inject position info
transformed = self.output_proj(self.output_norm(modulated))
# Residual connection: Combine with expanded parent features
base_features = enhanced_feats.unsqueeze(1).expand(-1, num_children, -1)
child_feats = base_features + transformed # Preserve original information
# Create new sparse tensor
out = SparseTensor(
child_feats.reshape(N * num_children, -1),
expanded_coords.flatten(0, 1),
input.shape
)
out._scale = input._scale * 2
out._spatial_cache = input._spatial_cache
return out
# ######################## relative linear #############################
# class SparseSubdivide_attn(nn.Module):
# def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05):
# super().__init__()
# assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads"
# self.in_channels = in_channels
# self.head_dim = in_channels // num_heads
# self.num_heads = num_heads
# self.scale = self.head_dim ** -0.5
# self.pos_embed = nn.Sequential(
# nn.Linear(3, in_channels),
# nn.LayerNorm(in_channels),
# nn.GELU(),
# nn.Linear(in_channels, in_channels)
# )
# self.feat_enhance = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.LayerNorm(in_channels * 2),
# nn.GELU(),
# nn.Linear(in_channels * 2, in_channels),
# nn.LayerNorm(in_channels)
# )
# self.q_proj = nn.Linear(in_channels, in_channels)
# self.k_proj = nn.Linear(in_channels, in_channels)
# self.v_proj = nn.Linear(in_channels, in_channels)
# self.output_norm = nn.LayerNorm(in_channels)
# self.output_proj = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(in_channels * 2, in_channels)
# )
# # self._initialize_weights()
# def _initialize_weights(self):
# nn.init.xavier_uniform_(self.q_proj.weight)
# nn.init.xavier_uniform_(self.k_proj.weight)
# nn.init.xavier_uniform_(self.v_proj.weight)
# nn.init.zeros_(self.q_proj.bias)
# nn.init.zeros_(self.k_proj.bias)
# nn.init.zeros_(self.v_proj.bias)
# nn.init.zeros_(self.output_proj[-1].weight)
# nn.init.constant_(self.output_proj[-1].bias, 0)
# def forward(self, input: SparseTensor) -> SparseTensor:
# DIM = input.coords.shape[-1] - 1
# device = input.device
# coords = input.coords
# feats = input.feats
# N = feats.shape[0]
# n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
# n_coords = torch.nonzero(n_cube) # [8, 3]
# n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4]
# spatial_offsets = n_coords[:, 1:].float() # [8, 3], 用于位置编码
# pos_emb = self.pos_embed(spatial_offsets.to(device=device, dtype=feats.dtype)) # [8, C]
# Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D]
# enhanced_feats = self.feat_enhance(feats) # [N, C]
# K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D]
# V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D]
# Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D]
# K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D]
# V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D]
# attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1]
# attn_weights = torch.softmax(attn_logits, dim=2) # over children
# attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0)
# attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D]
# attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C]
# modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C]
# transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C]
# base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1)
# child_feats = base_features + transformed # [N, 8, C]
# new_coords = coords.clone()
# new_coords[:, 1:] *= 2
# expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4]
# return SparseTensor(
# child_feats.reshape(N * 8, self.in_channels),
# expanded_coords.reshape(N * 8, 4),
# input.shape
# )
# ######################## relative embedding #############################
# class SparseSubdivide_attn(nn.Module):
# def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05):
# super().__init__()
# assert in_channels % num_heads == 0, "in_channels must be divisible by num_heads"
# self.in_channels = in_channels
# self.head_dim = in_channels // num_heads
# self.num_heads = num_heads
# self.scale = self.head_dim ** -0.5
# self.pos_index_embed = nn.Embedding(8, in_channels)
# self.feat_enhance = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.LayerNorm(in_channels * 2),
# nn.GELU(),
# nn.Linear(in_channels * 2, in_channels),
# nn.LayerNorm(in_channels)
# )
# self.q_proj = nn.Linear(in_channels, in_channels)
# self.k_proj = nn.Linear(in_channels, in_channels)
# self.v_proj = nn.Linear(in_channels, in_channels)
# self.output_norm = nn.LayerNorm(in_channels)
# self.output_proj = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(in_channels * 2, in_channels)
# )
# def forward(self, input: SparseTensor) -> SparseTensor:
# DIM = input.coords.shape[-1] - 1
# device = input.device
# coords = input.coords
# feats = input.feats
# N = feats.shape[0]
# n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
# n_coords = torch.nonzero(n_cube)
# n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4]
# pos_indices = torch.arange(8, device=device)
# pos_emb = self.pos_index_embed(pos_indices) # [8, C]
# Q = self.q_proj(pos_emb).view(1, 8, self.num_heads, self.head_dim).expand(N, -1, -1, -1) # [N, 8, H, D]
# enhanced_feats = self.feat_enhance(feats) # [N, C]
# K = self.k_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D]
# V = self.v_proj(enhanced_feats).view(N, 1, self.num_heads, self.head_dim) # [N, 1, H, D]
# # === Cross Attention ===
# Q_ = Q.permute(0, 2, 1, 3) # [N, H, 8, D]
# K_ = K.permute(0, 2, 1, 3) # [N, H, 1, D]
# V_ = V.permute(0, 2, 1, 3) # [N, H, 1, D]
# attn_logits = torch.matmul(Q_, K_.transpose(-2, -1)) * self.scale # [N, H, 8, 1]
# attn_weights = torch.softmax(attn_logits, dim=2)
# attn_weights = torch.nan_to_num(attn_weights, nan=0.0, posinf=0.0, neginf=0.0)
# attn_output = torch.matmul(attn_weights, V_) # [N, H, 8, D]
# attn_out = attn_output.permute(0, 2, 1, 3).reshape(N, 8, self.in_channels) # [N, 8, C]
# modulated = attn_out + pos_emb.unsqueeze(0) # [N, 8, C]
# transformed = self.output_proj(self.output_norm(modulated)) # [N, 8, C]
# base_features = enhanced_feats.unsqueeze(1).expand(-1, 8, -1)
# child_feats = base_features + transformed # [N, 8, C]
# new_coords = coords.clone()
# new_coords[:, 1:] *= 2
# expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4]
# return SparseTensor(
# child_feats.reshape(N * 8, self.in_channels),
# expanded_coords.reshape(N * 8, 4),
# input.shape
# )
# ############################## Position-Specific Filters #####################################
# class SparseSubdivide_attn(nn.Module):
# def __init__(self, in_channels: int, num_heads: int = 4, dropout: float = 0.05):
# super().__init__()
# self.in_channels = in_channels
# self.num_heads = num_heads
# self.head_dim = in_channels // num_heads
# self.scale = self.head_dim ** -0.5
# # Position-aware modulation components
# self.pos_index_embed = nn.Embedding(8, in_channels)
# self.pos_filters = nn.Embedding(8, in_channels * in_channels)
# # Feature enhancement
# self.feat_enhance = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.LayerNorm(in_channels * 2),
# nn.GELU(),
# nn.Linear(in_channels * 2, in_channels),
# nn.LayerNorm(in_channels)
# )
# # Output transformation
# self.output_norm = nn.LayerNorm(in_channels)
# self.output_proj = nn.Sequential(
# nn.Linear(in_channels, in_channels * 2),
# nn.GELU(),
# nn.Dropout(dropout),
# nn.Linear(in_channels * 2, in_channels)
# )
# def forward(self, input: SparseTensor) -> SparseTensor:
# DIM = input.coords.shape[-1] - 1
# device = input.device
# coords = input.coords
# feats = input.feats
# N = feats.shape[0]
# # Generate subdivision coordinates (8 children per voxel)
# n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
# n_coords = torch.nonzero(n_cube)
# n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1) # [8, 4]
# pos_indices = torch.arange(8, device=device)
# # Position-aware feature modulation
# pos_emb = self.pos_index_embed(pos_indices) # [8, C]
# pos_filters = self.pos_filters(pos_indices) # [8, C*C]
# pos_filters = pos_filters.view(8, self.in_channels, self.in_channels) # [8, C, C]
# # Enhance parent features
# enhanced_feats = self.feat_enhance(feats) # [N, C]
# # Apply position-specific transformation
# modulated_feats = torch.einsum('pci,nc->npc', pos_filters, enhanced_feats) # [N, 8, C]
# modulated_feats = modulated_feats + pos_emb.unsqueeze(0) # Add positional encoding
# # Final transformation
# transformed = self.output_proj(self.output_norm(modulated_feats)) # [N, 8, C]
# child_feats = enhanced_feats.unsqueeze(1) + transformed # Residual connection
# # Compute new coordinates
# new_coords = coords.clone()
# new_coords[:, 1:] *= 2
# expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0) # [N, 8, 4]
# return SparseTensor(
# child_feats.reshape(N * 8, self.in_channels),
# expanded_coords.reshape(N * 8, 4),
# input.shape
# )
# # ################################# nn.Parameter #######################################
# # class SparseSubdivideCrossAttn(nn.Module):
# # def __init__(self, in_channels: int, num_heads: int = 4, mlp_ratio: int = 4):
# # super().__init__()
# # self.in_channels = in_channels
# # self.num_heads = num_heads
# # self.head_dim = in_channels // num_heads
# # self.scale = self.head_dim ** -0.5
# # self.mlp_ratio = mlp_ratio
# # self.pos_embed = nn.Parameter(torch.randn(8, in_channels))
# # self.q_proj = nn.Linear(in_channels, in_channels)
# # self.kv_proj = nn.Linear(in_channels, in_channels * 2)
# # self.proj = nn.Linear(in_channels, in_channels)
# # self.norm1 = nn.LayerNorm(in_channels)
# # self.norm2 = nn.LayerNorm(in_channels)
# # self.mlp = nn.Sequential(
# # nn.Linear(in_channels, in_channels * mlp_ratio),
# # nn.GELU(),
# # nn.Linear(in_channels * mlp_ratio, in_channels)
# # )
# # def forward(self, input: SparseTensor) -> SparseTensor:
# # DIM = input.coords.shape[-1] - 1
# # device = input.device
# # coords = input.coords
# # feats = input.feats
# # N = feats.shape[0]
# # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
# # n_coords = torch.nonzero(n_cube)
# # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
# # q = self.q_proj(self.pos_embed)
# # q = q.reshape(8, self.num_heads, self.head_dim).permute(1, 0, 2) # [num_heads, 8, head_dim]
# # kv = self.kv_proj(feats)
# # kv = kv.reshape(N, 2, self.num_heads, self.head_dim).permute(2, 0, 1, 3) # [num_heads, N, 2, head_dim]
# # k, v = kv[:, :, 0, :], kv[:, :, 1, :] # [num_heads, N, head_dim] for both
# # attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale # [num_heads, 8, N]
# # attn = torch.softmax(attn, dim=-1)
# # out = torch.matmul(attn, v) # [num_heads, 8, head_dim]
# # out = out.permute(1, 0, 2).reshape(8, N, self.in_channels) # [8, N, in_channels]
# # out = out.permute(1, 0, 2) # [N, 8, in_channels]
# # x = self.proj(out) + feats.unsqueeze(1)
# # x = self.norm1(x)
# # x = x + self.mlp(self.norm2(x))
# # new_coords = coords.clone()
# # new_coords[:, 1:] *= 2
# # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0)
# # return SparseTensor(
# # x.reshape(N * 8, self.in_channels),
# # expanded_coords.reshape(N * 8, 4),
# # input.shape
# # )
# # ################################ Modulation ####################################
# # class SparseSubdivideModulation(nn.Module):
# # def __init__(self, in_channels: int):
# # super().__init__()
# # self.in_channels = in_channels
# # self.position_emb = nn.Embedding(8, in_channels)
# # self.modulation_vectors = nn.Embedding(8, in_channels)
# # self.feature_transformer = nn.Sequential(
# # nn.Linear(in_channels, in_channels * 2),
# # nn.LayerNorm(in_channels * 2),
# # nn.GELU(),
# # nn.Linear(in_channels * 2, in_channels)
# # )
# # self.output_mlp = nn.Sequential(
# # nn.Linear(in_channels, in_channels * 2),
# # nn.GELU(),
# # nn.Linear(in_channels * 2, in_channels)
# # )
# # def forward(self, input: SparseTensor) -> SparseTensor:
# # DIM = input.coords.shape[-1] - 1
# # device = input.device
# # coords = input.coords
# # feats = input.feats
# # N = feats.shape[0]
# # n_cube = torch.ones([2] * DIM, device=device, dtype=torch.int)
# # n_coords = torch.nonzero(n_cube)
# # n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
# # pos_ids = torch.arange(8, device=device)
# # trans_feats = self.feature_transformer(feats) # [N, C]
# # mod_vectors = self.modulation_vectors(pos_ids) # [8, C]
# # pos_emb = self.position_emb(pos_ids) # [8, C]
# # modulated_feats = torch.einsum('nc,pc->npc', trans_feats, mod_vectors) + pos_emb
# # output_feats = self.output_mlp(modulated_feats) # [N, 8, C]
# # new_coords = coords.clone()
# # new_coords[:, 1:] *= 2
# # expanded_coords = new_coords.unsqueeze(1) + n_coords.to(dtype=coords.dtype).unsqueeze(0)
# # return SparseTensor(
# # output_feats.reshape(N * 8, self.in_channels),
# # expanded_coords.reshape(N * 8, 4),
# # input.shape
# # )
# ############################## 16 * 3 embedding ##############################
# # 730
class SparseSubdivide_attn(nn.Module):
def __init__(self, in_channels: int, num_heads: int = 4, resolution: int=128):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.embed_dim = in_channels
self.relative_coords = resolution // 64
self.coord_embed_x = nn.Embedding(self.relative_coords, self.embed_dim)
self.coord_embed_y = nn.Embedding(self.relative_coords, self.embed_dim)
self.coord_embed_z = nn.Embedding(self.relative_coords, self.embed_dim)
self.embed_proj = nn.Linear(in_channels * 3, in_channels)
def forward(self, input):
DIM = input.coords.shape[-1] - 1
device = input.device
feats = input.feats
n_cube = torch.ones([2]*DIM, device=device, dtype=torch.int)
n_coords = torch.nonzero(n_cube)
n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
new_coords = input.coords.clone()
new_coords[:, 1:] *= 2
new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype) # [N, 8, 4]
abs_coords = new_coords[:, :, 1:] # [N, 8, 3]
mod_coords = abs_coords % self.relative_coords
x_embed = self.coord_embed_x(mod_coords[..., 0].long())
y_embed = self.coord_embed_y(mod_coords[..., 1].long())
z_embed = self.coord_embed_z(mod_coords[..., 2].long())
pos_embed = torch.cat([x_embed, y_embed, z_embed], dim=-1) # [N, 8, 3C]
pos_embed = self.embed_proj(pos_embed) # [N, 8, C]
feats = feats.unsqueeze(1).expand(-1, 8, -1) # [N, 8, C]
new_feats = feats + pos_embed # [N, 8, C]
out = SparseTensor(
new_feats.flatten(0, 1),
new_coords.flatten(0, 1),
input.shape
)
out._scale = input._scale * 2
out._spatial_cache = input._spatial_cache
return out
class SparseSubdivide_attn(nn.Module):
"""
Upsample with sparse cross-attention between parent features and position embeddings
"""
def __init__(self, in_channels: int, num_heads: int = 4):
super().__init__()
self.in_channels = in_channels
self.num_heads = num_heads
self.pos_embed = nn.Embedding(8, in_channels) # [8, C]
def forward(self, input: SparseTensor) -> SparseTensor:
DIM = input.coords.shape[-1] - 1
device = input.device
# upsample scale=2^DIM
n_cube = torch.ones([2] * DIM, device=input.device, dtype=torch.int)
n_coords = torch.nonzero(n_cube)
n_coords = torch.cat([torch.zeros_like(n_coords[:, :1]), n_coords], dim=-1)
factor = n_coords.shape[0]
assert factor == 2 ** DIM
# print(n_coords.shape)
new_coords = input.coords.clone()
new_coords[:, 1:] *= 2
new_coords = new_coords.unsqueeze(1) + n_coords.unsqueeze(0).to(new_coords.dtype)
base_feats = input.feats.unsqueeze(1).expand(-1, factor, -1) # [N,8,C]
child_ids = torch.arange(8, device=device) # [8]
pos_feats = self.pos_embed(child_ids).unsqueeze(0) # [1,8,C]
query_feats = base_feats + pos_feats # [N,8,C]
final_feats = query_feats.flatten(0, 1)
out = SparseTensor(
feats=final_feats,
coords=new_coords.flatten(0, 1),
shape=input.shape
)
out._scale = input._scale * 2
return out