csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
import numpy as np
from sklearn.metrics import pairwise_distances
from tqdm import tqdm
from .sampling_def import SamplingMethod
class kCenterGreedy(SamplingMethod):
def __init__(self, X: np.array):
self.X = X
self.flat_X = self.flatten_X()
self.name = "kcenter"
self.features = self.flat_X
self.min_distances = None
self.n_obs = self.X.shape[0]
self.already_selected = None
def update_distances(self, cluster_centers, only_new=True, reset_dist=False):
"""Update min distances given cluster centers.
Args:
cluster_centers: indices of cluster centers
only_new: only calculate distance for newly selected points and update
min_distances.
rest_dist: whether to reset min_distances.
"""
if reset_dist:
self.min_distances = None
if only_new:
cluster_centers = [d for d in cluster_centers if d not in self.already_selected]
if cluster_centers:
# Update min_distances for all examples given new cluster center.
x = self.features[cluster_centers]
dist = pairwise_distances(self.features, x, metric="euclidean")
if self.min_distances is None:
self.min_distances = np.min(dist, axis=1).reshape(-1, 1)
else:
self.min_distances = np.minimum(self.min_distances, dist)
def select_batch(self, N):
"""
Diversity promoting active learning method that greedily forms a batch
to minimize the maximum distance to a cluster center among all unlabeled
datapoints.
Args:
model: model with scikit-like API with decision_function implemented
already_selected: index of datapoints already selected
N: batch size
Returns:
indices of points selected to minimize distance to cluster centers
"""
print("Using flat_X as features.")
new_batch = []
for _ in tqdm(range(N), desc="K-Center Greedy"):
if self.already_selected is None:
# Initialize centers with a randomly selected datapoint
# ind = np.random.choice(np.arange(self.n_obs))
ind = 0 # To avoid randomness
self.already_selected = []
else:
ind = np.argmax(self.min_distances)
# New examples should not be in already selected since those points
# should have min_distance of zero to a cluster center.
assert ind not in self.already_selected
self.update_distances([ind], only_new=True, reset_dist=False)
new_batch.append(ind)
self.already_selected.append(ind)
print("Maximum distance from cluster centers is %0.2f" % max(self.min_distances))
new_batch = np.array(new_batch)
return new_batch