|
|
import numpy as np |
|
|
from sklearn.metrics import pairwise_distances |
|
|
from tqdm import tqdm |
|
|
|
|
|
from .sampling_def import SamplingMethod |
|
|
|
|
|
|
|
|
class kCenterGreedy(SamplingMethod): |
|
|
def __init__(self, X: np.array): |
|
|
self.X = X |
|
|
self.flat_X = self.flatten_X() |
|
|
self.name = "kcenter" |
|
|
self.features = self.flat_X |
|
|
self.min_distances = None |
|
|
self.n_obs = self.X.shape[0] |
|
|
self.already_selected = None |
|
|
|
|
|
def update_distances(self, cluster_centers, only_new=True, reset_dist=False): |
|
|
"""Update min distances given cluster centers. |
|
|
|
|
|
Args: |
|
|
cluster_centers: indices of cluster centers |
|
|
only_new: only calculate distance for newly selected points and update |
|
|
min_distances. |
|
|
rest_dist: whether to reset min_distances. |
|
|
""" |
|
|
|
|
|
if reset_dist: |
|
|
self.min_distances = None |
|
|
if only_new: |
|
|
cluster_centers = [d for d in cluster_centers if d not in self.already_selected] |
|
|
if cluster_centers: |
|
|
|
|
|
x = self.features[cluster_centers] |
|
|
dist = pairwise_distances(self.features, x, metric="euclidean") |
|
|
|
|
|
if self.min_distances is None: |
|
|
self.min_distances = np.min(dist, axis=1).reshape(-1, 1) |
|
|
else: |
|
|
self.min_distances = np.minimum(self.min_distances, dist) |
|
|
|
|
|
def select_batch(self, N): |
|
|
""" |
|
|
Diversity promoting active learning method that greedily forms a batch |
|
|
to minimize the maximum distance to a cluster center among all unlabeled |
|
|
datapoints. |
|
|
|
|
|
Args: |
|
|
model: model with scikit-like API with decision_function implemented |
|
|
already_selected: index of datapoints already selected |
|
|
N: batch size |
|
|
|
|
|
Returns: |
|
|
indices of points selected to minimize distance to cluster centers |
|
|
""" |
|
|
|
|
|
print("Using flat_X as features.") |
|
|
|
|
|
new_batch = [] |
|
|
|
|
|
for _ in tqdm(range(N), desc="K-Center Greedy"): |
|
|
if self.already_selected is None: |
|
|
|
|
|
|
|
|
ind = 0 |
|
|
self.already_selected = [] |
|
|
else: |
|
|
ind = np.argmax(self.min_distances) |
|
|
|
|
|
|
|
|
assert ind not in self.already_selected |
|
|
|
|
|
self.update_distances([ind], only_new=True, reset_dist=False) |
|
|
new_batch.append(ind) |
|
|
self.already_selected.append(ind) |
|
|
print("Maximum distance from cluster centers is %0.2f" % max(self.min_distances)) |
|
|
|
|
|
new_batch = np.array(new_batch) |
|
|
return new_batch |
|
|
|