|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from functools import partial
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from tqdm import tqdm
|
|
|
import math, random
|
|
|
|
|
|
|
|
|
|
|
|
def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
|
|
N,C,H,W = data_vecs.shape
|
|
|
assert N == 1, 'only support singe image tensor'
|
|
|
|
|
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
|
|
|
|
|
data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
|
|
|
km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
|
|
|
pred = km.fit_predict(data_vecs_np)
|
|
|
cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
|
|
|
id_maps = cluster_ids_x.reshape(1,1,H,W).long()
|
|
|
if need_layer_masks:
|
|
|
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
|
|
cluster_mask = one_hot_labels.permute(0,3,1,2)
|
|
|
return cluster_mask
|
|
|
return id_maps
|
|
|
|
|
|
|
|
|
def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
|
|
|
N,C,H,W = data_vecs.shape
|
|
|
assert N == 1, 'only support singe image tensor'
|
|
|
|
|
|
|
|
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
|
|
|
|
|
|
|
|
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
|
|
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
|
|
id_maps = cluster_ids_x.reshape(1,1,H,W)
|
|
|
if need_layer_masks:
|
|
|
one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
|
|
|
cluster_mask = one_hot_labels.permute(0,3,1,2)
|
|
|
return cluster_mask
|
|
|
return id_maps
|
|
|
|
|
|
|
|
|
def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
|
|
|
N,C,H,W = data_vecs.shape
|
|
|
sample_list = []
|
|
|
for idx in range(N):
|
|
|
if use_sklearn_kmeans:
|
|
|
cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
|
|
else:
|
|
|
cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
|
|
|
sample_list.append(cluster_mask)
|
|
|
return torch.cat(sample_list, dim=0)
|
|
|
|
|
|
|
|
|
def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
|
|
|
N,C,H,W = data_vecs.shape
|
|
|
data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
|
|
|
cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
|
|
|
tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
|
|
|
return cluster_centers
|
|
|
|
|
|
|
|
|
def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
|
|
|
N,C,H,W = data_tensor.shape
|
|
|
centroid_list = []
|
|
|
for idx in range(N):
|
|
|
cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
|
|
|
centroid_list.append(cluster_centers)
|
|
|
|
|
|
batch_centroids = torch.stack(centroid_list, dim=0)
|
|
|
data_vecs = data_tensor.flatten(2)
|
|
|
|
|
|
AtB = torch.matmul(batch_centroids, data_vecs)
|
|
|
AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
|
|
|
BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
|
|
|
diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
|
|
|
diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
|
|
|
A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
|
|
|
B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
|
|
|
distance_map = A2 - 2*AtB + B2
|
|
|
values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
|
|
|
cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
|
|
|
cluster_mask = cluster_mask.view(N,n_clusters,H,W)
|
|
|
return cluster_mask
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
resource from github: https://github.com/subhadarship/kmeans_pytorch
|
|
|
'''
|
|
|
|
|
|
|
|
|
def initialize(X, num_clusters):
|
|
|
"""
|
|
|
initialize cluster centers
|
|
|
:param X: (torch.tensor) matrix
|
|
|
:param num_clusters: (int) number of clusters
|
|
|
:return: (np.array) initial state
|
|
|
"""
|
|
|
np.random.seed(1)
|
|
|
num_samples = len(X)
|
|
|
indices = np.random.choice(num_samples, num_clusters, replace=False)
|
|
|
initial_state = X[indices]
|
|
|
return initial_state
|
|
|
|
|
|
|
|
|
def kmeans(
|
|
|
X,
|
|
|
num_clusters,
|
|
|
distance='euclidean',
|
|
|
cluster_centers=[],
|
|
|
tol=1e-4,
|
|
|
tqdm_flag=True,
|
|
|
iter_limit=0,
|
|
|
device=torch.device('cpu'),
|
|
|
gamma_for_soft_dtw=0.001
|
|
|
):
|
|
|
"""
|
|
|
perform kmeans
|
|
|
:param X: (torch.tensor) matrix
|
|
|
:param num_clusters: (int) number of clusters
|
|
|
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
|
|
:param tol: (float) threshold [default: 0.0001]
|
|
|
:param device: (torch.device) device [default: cpu]
|
|
|
:param tqdm_flag: Allows to turn logs on and off
|
|
|
:param iter_limit: hard limit for max number of iterations
|
|
|
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
|
|
:return: (torch.tensor, torch.tensor) cluster ids, cluster centers
|
|
|
"""
|
|
|
if tqdm_flag:
|
|
|
print(f'running k-means on {device}..')
|
|
|
|
|
|
if distance == 'euclidean':
|
|
|
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
|
|
elif distance == 'cosine':
|
|
|
pairwise_distance_function = partial(pairwise_cosine, device=device)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
X = X.float()
|
|
|
|
|
|
|
|
|
X = X.to(device)
|
|
|
|
|
|
|
|
|
if type(cluster_centers) == list:
|
|
|
initial_state = initialize(X, num_clusters)
|
|
|
else:
|
|
|
if tqdm_flag:
|
|
|
print('resuming')
|
|
|
|
|
|
initial_state = cluster_centers
|
|
|
dis = pairwise_distance_function(X, initial_state)
|
|
|
choice_points = torch.argmin(dis, dim=0)
|
|
|
initial_state = X[choice_points]
|
|
|
initial_state = initial_state.to(device)
|
|
|
|
|
|
iteration = 0
|
|
|
if tqdm_flag:
|
|
|
tqdm_meter = tqdm(desc='[running kmeans]')
|
|
|
while True:
|
|
|
|
|
|
dis = pairwise_distance_function(X, initial_state)
|
|
|
|
|
|
choice_cluster = torch.argmin(dis, dim=1)
|
|
|
|
|
|
initial_state_pre = initial_state.clone()
|
|
|
|
|
|
for index in range(num_clusters):
|
|
|
selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
|
|
|
|
|
|
selected = torch.index_select(X, 0, selected)
|
|
|
|
|
|
|
|
|
if selected.shape[0] == 0:
|
|
|
selected = X[torch.randint(len(X), (1,))]
|
|
|
|
|
|
initial_state[index] = selected.mean(dim=0)
|
|
|
|
|
|
center_shift = torch.sum(
|
|
|
torch.sqrt(
|
|
|
torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
|
|
|
))
|
|
|
|
|
|
|
|
|
iteration = iteration + 1
|
|
|
|
|
|
|
|
|
if tqdm_flag:
|
|
|
tqdm_meter.set_postfix(
|
|
|
iteration=f'{iteration}',
|
|
|
center_shift=f'{center_shift ** 2:0.6f}',
|
|
|
tol=f'{tol:0.6f}'
|
|
|
)
|
|
|
tqdm_meter.update()
|
|
|
if center_shift ** 2 < tol:
|
|
|
break
|
|
|
if iter_limit != 0 and iteration >= iter_limit:
|
|
|
|
|
|
break
|
|
|
|
|
|
return choice_cluster.to(device), initial_state.to(device)
|
|
|
|
|
|
|
|
|
def kmeans_predict(
|
|
|
X,
|
|
|
cluster_centers,
|
|
|
distance='euclidean',
|
|
|
device=torch.device('cpu'),
|
|
|
gamma_for_soft_dtw=0.001,
|
|
|
tqdm_flag=True
|
|
|
):
|
|
|
"""
|
|
|
predict using cluster centers
|
|
|
:param X: (torch.tensor) matrix
|
|
|
:param cluster_centers: (torch.tensor) cluster centers
|
|
|
:param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
|
|
|
:param device: (torch.device) device [default: 'cpu']
|
|
|
:param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
|
|
|
:return: (torch.tensor) cluster ids
|
|
|
"""
|
|
|
if tqdm_flag:
|
|
|
print(f'predicting on {device}..')
|
|
|
|
|
|
if distance == 'euclidean':
|
|
|
pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
|
|
|
elif distance == 'cosine':
|
|
|
pairwise_distance_function = partial(pairwise_cosine, device=device)
|
|
|
elif distance == 'soft_dtw':
|
|
|
sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
|
|
|
pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
|
|
|
else:
|
|
|
raise NotImplementedError
|
|
|
|
|
|
|
|
|
X = X.float()
|
|
|
|
|
|
|
|
|
X = X.to(device)
|
|
|
|
|
|
dis = pairwise_distance_function(X, cluster_centers)
|
|
|
choice_cluster = torch.argmin(dis, dim=1)
|
|
|
|
|
|
return choice_cluster.cpu()
|
|
|
|
|
|
|
|
|
def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
|
|
|
if tqdm_flag:
|
|
|
print(f'device is :{device}')
|
|
|
|
|
|
|
|
|
data1, data2 = data1.to(device), data2.to(device)
|
|
|
|
|
|
|
|
|
A = data1.unsqueeze(dim=1)
|
|
|
|
|
|
|
|
|
B = data2.unsqueeze(dim=0)
|
|
|
|
|
|
dis = (A - B) ** 2.0
|
|
|
|
|
|
dis = dis.sum(dim=-1).squeeze()
|
|
|
return dis
|
|
|
|
|
|
|
|
|
def pairwise_cosine(data1, data2, device=torch.device('cpu')):
|
|
|
|
|
|
data1, data2 = data1.to(device), data2.to(device)
|
|
|
|
|
|
|
|
|
A = data1.unsqueeze(dim=1)
|
|
|
|
|
|
|
|
|
B = data2.unsqueeze(dim=0)
|
|
|
|
|
|
|
|
|
A_normalized = A / A.norm(dim=-1, keepdim=True)
|
|
|
B_normalized = B / B.norm(dim=-1, keepdim=True)
|
|
|
|
|
|
cosine = A_normalized * B_normalized
|
|
|
|
|
|
|
|
|
cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
|
|
|
return cosine_dis |