File size: 4,748 Bytes
8d5039c
 
 
 
5a60eac
 
 
 
8d5039c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
# import pc_util
import sys
os.environ['LD_LIBRARY_PATH'] = '/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64:' + os.environ.get('LD_LIBRARY_PATH', '')
sys.path.append('pc_util-1.0-py3.10-linux-x86_64.egg')
import pc_util
from torch.autograd import Function, Variable


class Conv2ds(nn.Sequential):
    def __init__(self, cns):
        super().__init__()
        for i in range(len(cns) - 1):
            in_cn, out_cn = cns[i], cns[i + 1]
            self.add_module('conv%d' % (i + 1), Conv2dBN(in_cn, out_cn))


class Conv2dBN(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.bn = nn.BatchNorm2d(out_channel)
        self.conv = nn.Conv2d(in_channel, out_channel, 1)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x), inplace=True))


class Conv1ds(nn.Sequential):
    def __init__(self, cns):
        super().__init__()
        for i in range(len(cns) - 1):
            in_cn, out_cn = cns[i], cns[i + 1]
            self.add_module('conv%d' % (i + 1), Conv1dBN(in_cn, out_cn))


class Conv1dBN(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.bn = nn.BatchNorm1d(out_channel)
        self.conv = nn.Conv1d(in_channel, out_channel, 1)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x), inplace=True))


class Linears(nn.Sequential):
    def __init__(self, cns):
        super().__init__()
        for i in range(len(cns) - 1):
            in_cn, out_cn = cns[i], cns[i + 1]
            self.add_module('linear%d' % (i + 1), LinearBN(in_cn, out_cn))


class LinearBN(nn.Module):
    def __init__(self, in_channel, out_channel):
        super().__init__()
        self.bn = nn.BatchNorm1d(out_channel)
        self.conv = nn.Linear(in_channel, out_channel)

    def forward(self, x):
        return self.bn(F.relu(self.conv(x), inplace=True))


def load_params_with_optimizer(net, filename, to_cpu=False, optimizer=None, logger=None):
    if not os.path.isfile(filename):
        raise FileNotFoundError

    logger.info('==> Loading parameters from checkpoint')
    checkpoint = torch.load(filename)
    epoch = checkpoint.get('epoch', -1)
    it = checkpoint.get('it', 0.0)

    net.load_state_dict(checkpoint['model_state'])

    if optimizer is not None:
        logger.info('==> Loading optimizer parameters from checkpoint')
        optimizer.load_state_dict(checkpoint['optimizer_state'])

    logger.info('==> Done')

    return it, epoch



def load_params(net, filename, logger=None):
    if not os.path.isfile(filename):
        raise FileNotFoundError
    if logger is not None:
        logger.info('==> Loading parameters from checkpoint')
    checkpoint = torch.load(filename)

    net.load_state_dict(checkpoint['model_state'])
    if logger is not None:
        logger.info('==> Done')




class DBSCANCluster(Function):

    @staticmethod
    def forward(ctx, eps: float, min_pts: int, point: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param eps: float, dbscan eps
        :param min_pts: int, dbscan core point threshold
        :param point: (B, N, 3) xyz coordinates of the points
        :return:
            idx: (B, N) cluster idx
        """
        point = point.contiguous()

        B, N, _ = point.size()
        idx = torch.cuda.IntTensor(B, N).zero_() - 1

        pc_util.dbscan_wrapper(B, N, eps, min_pts, point, idx)
        ctx.mark_non_differentiable(idx)
        return idx

    @staticmethod
    def backward(ctx, grad_out):
        return ()


dbscan_cluster = DBSCANCluster.apply


class GetClusterPts(Function):

    @staticmethod
    def forward(ctx, point: torch.Tensor, cluster_idx: torch.Tensor) -> torch.Tensor:
        """
        :param ctx:
        :param point: (B, N, 3) xyz coordinates of the points
        :param cluster_idx: (B, N) cluster idx
        :return:
            key_pts: (B, M, 3) cluster center pts, M is max_num_cluster_class
            num_cluster: (B, M) cluster num, num of pts in each cluster class
        """
        cluster_idx = cluster_idx.contiguous()

        B, N = cluster_idx.size()
        M = torch.max(cluster_idx) +1
        key_pts = torch.cuda.FloatTensor(B, M, 3).zero_()
        num_cluster = torch.cuda.IntTensor(B, M).zero_()
        pc_util.cluster_pts_wrapper(B, N, M, point, cluster_idx, key_pts, num_cluster)
        key_pts[key_pts * 1e4 == 0] = -1e1
        ctx.mark_non_differentiable(key_pts)
        ctx.mark_non_differentiable(num_cluster)
        return key_pts, num_cluster

    @staticmethod
    def backward(ctx, grad_out):
        return ()


get_cluster_pts = GetClusterPts.apply