udbbdh's picture
Upload folder using huggingface_hub
7340df2 verified
# MIT License
# Copyright (c) 2020 Songyou Peng, Michael Niemeyer, Lars Mescheder, Marc Pollefeys, Andreas Geiger.
# Copyright (c) 2025 VAST-AI-Research and contributors.
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE
# modified from https://github.com/autonomousvision/convolutional_occupancy_networks/blob/master/src/encoder/pointnet.py
import torch
import torch.nn as nn
import copy
from torch import Tensor
from torch_scatter import scatter_mean
def scale_tensor(
dat, inp_scale=None, tgt_scale=None
):
if inp_scale is None:
inp_scale = (-0.5, 0.5)
if tgt_scale is None:
tgt_scale = (0, 1)
assert tgt_scale[1] > tgt_scale[0] and inp_scale[1] > inp_scale[0]
if isinstance(tgt_scale, Tensor):
assert dat.shape[-1] == tgt_scale.shape[-1]
dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0])
dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0]
return dat.clamp(tgt_scale[0] + 1e-6, tgt_scale[1] - 1e-6)
# Resnet Blocks for pointnet
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.GELU(approximate="tanh")
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.xavier_uniform_(self.fc_0.weight)
if self.fc_0.bias is not None:
nn.init.constant_(self.fc_0.bias, 0)
if self.shortcut is not None:
nn.init.xavier_uniform_(self.shortcut.weight)
if self.shortcut.bias is not None:
nn.init.constant_(self.shortcut.bias, 0)
nn.init.xavier_uniform_(self.fc_1.weight)
if self.fc_1.bias is not None:
nn.init.constant_(self.fc_1.bias, 0)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
class LocalPoolPointnet(nn.Module):
def __init__(self, in_channels=3, out_channels=128, hidden_dim=128, scatter_type='mean', n_blocks=5):
super().__init__()
self.scatter_type = scatter_type
self.in_channels = in_channels
self.hidden_dim = hidden_dim
self.out_channels = out_channels
self.fc_pos = nn.Linear(in_channels, 2*hidden_dim)
self.blocks = nn.ModuleList([
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
])
self.fc_c = nn.Linear(hidden_dim, out_channels)
self.in_channels = in_channels
if self.scatter_type == 'mean':
self.scatter = scatter_mean
else:
raise ValueError('Incorrect scatter type')
self.initialize_weights()
def initialize_weights(self):
nn.init.xavier_uniform_(self.fc_pos.weight)
if self.fc_pos.bias is not None:
nn.init.constant_(self.fc_pos.bias, 0)
nn.init.xavier_uniform_(self.fc_c.weight)
if self.fc_c.bias is not None:
nn.init.constant_(self.fc_c.bias, 0)
def convert_to_sparse_feats(self, c, sparse_coords):
'''
Input:
sparse_coords: Tensor [Nx, 4], point to sparse indices
c: Tensor [B, res, C], input feats of each grid
Output:
c_out: Tensor [B, Np, C], aggregated grid feats of each point
'''
feats_new = torch.zeros((sparse_coords.shape[0], c.shape[-1]), device=c.device, dtype=c.dtype)
offsets = 0
batch_nums = copy.deepcopy(sparse_coords[..., 0])
for i in range(len(c)):
coords_num_i = (batch_nums == i).sum()
feats_new[offsets: offsets + coords_num_i] = c[i, : coords_num_i]
offsets += coords_num_i
return feats_new
def generate_sparse_grid_features(self, index, c, max_coord_num):
# scatter grid features from points
bs, fea_dim = c.size(0), c.size(2)
res = max_coord_num
c_out = c.new_zeros(bs, self.out_channels, res)
c_out = scatter_mean(c.permute(0, 2, 1), index, out=c_out).permute(0, 2, 1) # B x res X C
return c_out
def pool_sparse_local(self, index, c, max_coord_num):
'''
Input:
index: Tensor [B, 1, Np], sparse indices of each point
c: Tensor [B, Np, C], input feats of each point
Output:
c_out: Tensor [B, Np, C], aggregated grid feats of each point
'''
bs, fea_dim = c.size(0), c.size(2)
res = max_coord_num
c_out = c.new_zeros(bs, fea_dim, res)
c_out = self.scatter(c.permute(0, 2, 1), index, out=c_out)
# gather feature back to points
c_out = c_out.gather(dim=2, index=index.expand(-1, fea_dim, -1))
return c_out.permute(0, 2, 1)
@torch.no_grad()
def coordinate2sparseindex(self, x, sparse_coords, res):
'''
Input:
x: Tensor [B, Np, 3], points scaled at ([0, 1] * res)
sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z])
res: Int, resolution of the grid index
Output:
sparse_index: Tensor [B, 1, Np], sparse indices of each point
'''
B = x.shape[0]
sparse_index = torch.zeros((B, x.shape[1]), device=x.device, dtype=torch.int64)
index = (x[..., 0] * res + x[..., 1]) * res + x[..., 2]
sparse_indices = copy.deepcopy(sparse_coords)
sparse_indices[..., 1] = (sparse_indices[..., 1] * res + sparse_indices[..., 2]) * res + sparse_indices[..., 3]
sparse_indices = sparse_indices[..., :2]
for i in range(B):
mask_i = sparse_indices[..., 0] == i
coords_i = sparse_indices[mask_i, 1]
coords_num_i = len(coords_i)
sparse_index[i] = torch.searchsorted(coords_i, index[i])
return sparse_index[:, None, :]
def forward(self, p, sparse_coords, res=64, bbox_size=(-0.5, 0.5)):
'''
Input:
p : Tensor [B, Np(819_200), 3]
sparse_coords: Tensor [Nx, 4] ([batch_number, x, y, z])
Output:
sparse_pc_feats: [Nx, self.out_channels]
'''
batch_size, T, D = p.size()
# print('batch_size, T, D', batch_size, T, D)
max_coord_num = 0
for i in range(batch_size):
max_coord_num = max(max_coord_num, (sparse_coords[..., 0] == i).sum().item() + 5)
if D == self.in_channels:
p, normals = p[..., :3], p[..., 3:]
coord = (scale_tensor(p, inp_scale=bbox_size) * res)
p = 2 * (coord - (coord.floor() + 0.5)) # dist to the centrios, [-1., 1.]
index = self.coordinate2sparseindex(coord.long(), sparse_coords, res)
if D == self.in_channels:
p = torch.cat((p, normals), dim=-1)
net = self.fc_pos(p)
net = self.blocks[0](net)
for block in self.blocks[1:]:
pooled = self.pool_sparse_local(index, net, max_coord_num=max_coord_num)
net = torch.cat([net, pooled], dim=2)
net = block(net)
c = self.fc_c(net)
feats = self.generate_sparse_grid_features(index, c, max_coord_num=max_coord_num)
feats = self.convert_to_sparse_feats(feats, sparse_coords)
# torch.cuda.empty_cache()
return feats