S23DR-P2R / model /model_utils.py
colin1842's picture
Upload 28 files
5a60eac verified
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# import pc_util
import sys
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
sys.path.append('pc_util-1.0-py3.10-linux-x86_64.egg')
import pc_util
from torch.autograd import Function, Variable
class Conv2ds(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('conv%d' % (i + 1), Conv2dBN(in_cn, out_cn))
class Conv2dBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm2d(out_channel)
self.conv = nn.Conv2d(in_channel, out_channel, 1)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
class Conv1ds(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('conv%d' % (i + 1), Conv1dBN(in_cn, out_cn))
class Conv1dBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm1d(out_channel)
self.conv = nn.Conv1d(in_channel, out_channel, 1)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
class Linears(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('linear%d' % (i + 1), LinearBN(in_cn, out_cn))
class LinearBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm1d(out_channel)
self.conv = nn.Linear(in_channel, out_channel)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
def load_params_with_optimizer(net, filename, to_cpu=False, optimizer=None, logger=None):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint')
checkpoint = torch.load(filename)
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
net.load_state_dict(checkpoint['model_state'])
if optimizer is not None:
logger.info('==> Loading optimizer parameters from checkpoint')
optimizer.load_state_dict(checkpoint['optimizer_state'])
logger.info('==> Done')
return it, epoch
def load_params(net, filename, logger=None):
if not os.path.isfile(filename):
raise FileNotFoundError
if logger is not None:
logger.info('==> Loading parameters from checkpoint')
checkpoint = torch.load(filename)
net.load_state_dict(checkpoint['model_state'])
if logger is not None:
logger.info('==> Done')
class DBSCANCluster(Function):
@staticmethod
def forward(ctx, eps: float, min_pts: int, point: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param eps: float, dbscan eps
:param min_pts: int, dbscan core point threshold
:param point: (B, N, 3) xyz coordinates of the points
:return:
idx: (B, N) cluster idx
"""
point = point.contiguous()
B, N, _ = point.size()
idx = torch.cuda.IntTensor(B, N).zero_() - 1
pc_util.dbscan_wrapper(B, N, eps, min_pts, point, idx)
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, grad_out):
return ()
dbscan_cluster = DBSCANCluster.apply
class GetClusterPts(Function):
@staticmethod
def forward(ctx, point: torch.Tensor, cluster_idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param point: (B, N, 3) xyz coordinates of the points
:param cluster_idx: (B, N) cluster idx
:return:
key_pts: (B, M, 3) cluster center pts, M is max_num_cluster_class
num_cluster: (B, M) cluster num, num of pts in each cluster class
"""
cluster_idx = cluster_idx.contiguous()
B, N = cluster_idx.size()
M = torch.max(cluster_idx) +1
key_pts = torch.cuda.FloatTensor(B, M, 3).zero_()
num_cluster = torch.cuda.IntTensor(B, M).zero_()
pc_util.cluster_pts_wrapper(B, N, M, point, cluster_idx, key_pts, num_cluster)
key_pts[key_pts * 1e4 == 0] = -1e1
ctx.mark_non_differentiable(key_pts)
ctx.mark_non_differentiable(num_cluster)
return key_pts, num_cluster
@staticmethod
def backward(ctx, grad_out):
return ()
get_cluster_pts = GetClusterPts.apply