File size: 4,748 Bytes
8d5039c 5a60eac 8d5039c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# import pc_util
import sys
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
sys.path.append('pc_util-1.0-py3.10-linux-x86_64.egg')
import pc_util
from torch.autograd import Function, Variable
class Conv2ds(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('conv%d' % (i + 1), Conv2dBN(in_cn, out_cn))
class Conv2dBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm2d(out_channel)
self.conv = nn.Conv2d(in_channel, out_channel, 1)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
class Conv1ds(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('conv%d' % (i + 1), Conv1dBN(in_cn, out_cn))
class Conv1dBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm1d(out_channel)
self.conv = nn.Conv1d(in_channel, out_channel, 1)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
class Linears(nn.Sequential):
def __init__(self, cns):
super().__init__()
for i in range(len(cns) - 1):
in_cn, out_cn = cns[i], cns[i + 1]
self.add_module('linear%d' % (i + 1), LinearBN(in_cn, out_cn))
class LinearBN(nn.Module):
def __init__(self, in_channel, out_channel):
super().__init__()
self.bn = nn.BatchNorm1d(out_channel)
self.conv = nn.Linear(in_channel, out_channel)
def forward(self, x):
return self.bn(F.relu(self.conv(x), inplace=True))
def load_params_with_optimizer(net, filename, to_cpu=False, optimizer=None, logger=None):
if not os.path.isfile(filename):
raise FileNotFoundError
logger.info('==> Loading parameters from checkpoint')
checkpoint = torch.load(filename)
epoch = checkpoint.get('epoch', -1)
it = checkpoint.get('it', 0.0)
net.load_state_dict(checkpoint['model_state'])
if optimizer is not None:
logger.info('==> Loading optimizer parameters from checkpoint')
optimizer.load_state_dict(checkpoint['optimizer_state'])
logger.info('==> Done')
return it, epoch
def load_params(net, filename, logger=None):
if not os.path.isfile(filename):
raise FileNotFoundError
if logger is not None:
logger.info('==> Loading parameters from checkpoint')
checkpoint = torch.load(filename)
net.load_state_dict(checkpoint['model_state'])
if logger is not None:
logger.info('==> Done')
class DBSCANCluster(Function):
@staticmethod
def forward(ctx, eps: float, min_pts: int, point: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param eps: float, dbscan eps
:param min_pts: int, dbscan core point threshold
:param point: (B, N, 3) xyz coordinates of the points
:return:
idx: (B, N) cluster idx
"""
point = point.contiguous()
B, N, _ = point.size()
idx = torch.cuda.IntTensor(B, N).zero_() - 1
pc_util.dbscan_wrapper(B, N, eps, min_pts, point, idx)
ctx.mark_non_differentiable(idx)
return idx
@staticmethod
def backward(ctx, grad_out):
return ()
dbscan_cluster = DBSCANCluster.apply
class GetClusterPts(Function):
@staticmethod
def forward(ctx, point: torch.Tensor, cluster_idx: torch.Tensor) -> torch.Tensor:
"""
:param ctx:
:param point: (B, N, 3) xyz coordinates of the points
:param cluster_idx: (B, N) cluster idx
:return:
key_pts: (B, M, 3) cluster center pts, M is max_num_cluster_class
num_cluster: (B, M) cluster num, num of pts in each cluster class
"""
cluster_idx = cluster_idx.contiguous()
B, N = cluster_idx.size()
M = torch.max(cluster_idx) +1
key_pts = torch.cuda.FloatTensor(B, M, 3).zero_()
num_cluster = torch.cuda.IntTensor(B, M).zero_()
pc_util.cluster_pts_wrapper(B, N, M, point, cluster_idx, key_pts, num_cluster)
key_pts[key_pts * 1e4 == 0] = -1e1
ctx.mark_non_differentiable(key_pts)
ctx.mark_non_differentiable(num_cluster)
return key_pts, num_cluster
@staticmethod
def backward(ctx, grad_out):
return ()
get_cluster_pts = GetClusterPts.apply |