Rignet / RigNet /utils /cluster_utils.py
ckc99u's picture
Upload 116 files
3b8bcb1 verified
#-------------------------------------------------------------------------------
# Name: cluster_utils.py
# Purpose: utilize functions for clustering
# RigNet Copyright 2020 University of Massachusetts
# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
#-------------------------------------------------------------------------------
import sys
sys.path.append("./")
import numpy as np
def meanshift_cluster(pts_in, bandwidth, weights=None, max_iter=20):
"""
Meanshift clustering
:param pts_in: input points
:param bandwidth: bandwidth
:param weights: weights per pts indicting its importance in the clustering
:return: points after clustering
"""
diff = 1e10
num_iter = 1
while diff > 1e-3 and num_iter < max_iter:
Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
K = np.maximum(bandwidth**2 - Y, np.zeros(Y.shape))
if weights is not None:
K = K * weights
row_sums = K.sum(axis=0, keepdims=True)
P = K / (row_sums + 1e-10)
P = P.transpose()
pts_in_prim = 0.3 * (np.matmul(P, pts_in) - pts_in) + pts_in
diff = np.sqrt(np.sum((pts_in_prim - pts_in)**2))
pts_in = pts_in_prim
num_iter += 1
return pts_in
def nms_meanshift(pts_in, density, bandwidth):
"""
NMS to extract modes after meanshift. Code refers to sci-kit-learn.
:param pts_in: input points
:param density: density at each point
:param bandwidth: bandwidth used in meanshift. Used here as neighbor region for NMS
:return: extracted clusters.
"""
Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
sorted_ids = np.argsort(density)[::-1]
unique = np.ones(len(sorted_ids), dtype=np.bool)
dist = np.sqrt(Y)
for i in sorted_ids:
if unique[i]:
neighbor_idxs = np.argwhere(dist[:, i] <= bandwidth)
unique[neighbor_idxs.squeeze()] = 0
unique[i] = 1 # leave the current point as unique
pts_in = pts_in[unique]
return pts_in