| """ Origin code from https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py """ |
|
|
| """Returns points that minimizes the maximum distance of any point to a center. |
| |
| Implements the k-Center-Greedy method in |
| Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for |
| Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 |
| |
| Distance metric defaults to l2 distance. Features used to calculate distance |
| are either raw features or if a model has transform method then uses the output |
| of model.transform(X). |
| |
| Can be extended to a robust k centers algorithm that ignores a certain number of |
| outlier datapoints. Resulting centers are solution to multiple integer program. |
| """ |
|
|
|
|
| import numpy as np |
| from sklearn.metrics import pairwise_distances |
| import torch |
| from tqdm import tqdm |
| import abc |
| import argparse |
|
|
| class SamplingMethod(object): |
| __metaclass__ = abc.ABCMeta |
|
|
| @abc.abstractmethod |
| def __init__(self, X, y, seed, **kwargs): |
| self.X = X |
| self.y = y |
| self.seed = seed |
|
|
| def flatten_X(self): |
| shape = self.X.shape |
| flat_X = self.X |
| if len(shape) > 2: |
| flat_X = np.reshape(self.X, (shape[0], np.product(shape[1:]))) |
| return flat_X |
|
|
| @abc.abstractmethod |
| def select_batch_(self): |
| return |
|
|
| def select_batch(self, **kwargs): |
| return self.select_batch_(**kwargs) |
|
|
| def to_dict(self): |
| return None |
|
|
| class kCenterGreedy(SamplingMethod): |
|
|
| def __init__(self, X, y, seed, metric="euclidean"): |
| self.X = X |
| self.y = y |
| self.flat_X = self.flatten_X() |
| self.name = "kcenter" |
| self.features = self.flat_X |
| self.metric = metric |
| self.min_distances = None |
| self.n_obs = self.X.shape[0] |
| self.already_selected = [] |
|
|
| def update_distances(self, cluster_centers, only_new=True, reset_dist=False): |
| """Update min distances given cluster centers. |
| |
| Args: |
| cluster_centers: indices of cluster centers |
| only_new: only calculate distance for newly selected points and update |
| min_distances. |
| rest_dist: whether to reset min_distances. |
| """ |
|
|
| if reset_dist: |
| self.min_distances = None |
| if only_new: |
| cluster_centers = [ |
| d for d in cluster_centers if d not in self.already_selected |
| ] |
| if cluster_centers: |
| |
| x = self.features[cluster_centers] |
| dist = pairwise_distances(self.features, x, metric=self.metric) |
|
|
| if self.min_distances is None: |
| self.min_distances = np.min(dist, axis=1).reshape(-1, 1) |
| else: |
| self.min_distances = np.minimum(self.min_distances, dist) |
|
|
| def select_batch_(self, model, already_selected, N, **kwargs): |
| """ |
| Diversity promoting active learning method that greedily forms a batch |
| to minimize the maximum distance to a cluster center among all unlabeled |
| datapoints. |
| |
| Args: |
| model: model with scikit-like API with decision_function implemented |
| already_selected: index of datapoints already selected |
| N: batch size |
| |
| Returns: |
| indices of points selected to minimize distance to cluster centers |
| """ |
|
|
| try: |
| |
| |
| print("Getting transformed features...") |
| self.features = model.transform(self.X) |
| print("Calculating distances...") |
| self.update_distances(already_selected, only_new=False, reset_dist=True) |
| except: |
| print("Using flat_X as features.") |
| self.update_distances(already_selected, only_new=True, reset_dist=False) |
|
|
| new_batch = [] |
|
|
| for _ in tqdm(range(N)): |
| if self.already_selected is None: |
| |
| ind = np.random.choice(np.arange(self.n_obs)) |
| else: |
| ind = np.argmax(self.min_distances) |
| |
| |
| assert ind not in already_selected |
|
|
| self.update_distances([ind], only_new=True, reset_dist=False) |
| new_batch.append(ind) |
| print( |
| "Maximum distance from cluster centers is %0.2f" % max(self.min_distances) |
| ) |
|
|
| self.already_selected = already_selected |
|
|
| return new_batch |
| if __name__ == '__main__': |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--start', type=int) |
| parser.add_argument("--end", type=int) |
| args = parser.parse_args() |
| embeddings = torch.load("/home/aiscuser/fhw/embeddings/qwq_ins_embeddings.pt") |
| f = open("/home/aiscuser/fhw/data/qwq_python_selected.json", "r+") |
| fw = open(f"/home/aiscuser/fhw/data/qwq_python_diverse_{args.start}_{args.end}.json", "w+") |
| lines = f.readlines()[args.start:args.end] |
| selected_nums = 10000 |
| nums = embeddings.shape[0] |
| kcg = kCenterGreedy(X=embeddings[args.start:args.end], y=None, seed=42) |
| batch = kcg.select_batch_(model=None, already_selected=[], N=selected_nums) |
| for idx in batch: |
| fw.write(lines[idx]) |
|
|