""" Origin code from https://github.com/google/active-learning/blob/master/sampling_methods/kcenter_greedy.py """ """Returns points that minimizes the maximum distance of any point to a center. Implements the k-Center-Greedy method in Ozan Sener and Silvio Savarese. A Geometric Approach to Active Learning for Convolutional Neural Networks. https://arxiv.org/abs/1708.00489 2017 Distance metric defaults to l2 distance. Features used to calculate distance are either raw features or if a model has transform method then uses the output of model.transform(X). Can be extended to a robust k centers algorithm that ignores a certain number of outlier datapoints. Resulting centers are solution to multiple integer program. """ import numpy as np from sklearn.metrics import pairwise_distances import torch from tqdm import tqdm import abc import argparse class SamplingMethod(object): __metaclass__ = abc.ABCMeta @abc.abstractmethod def __init__(self, X, y, seed, **kwargs): self.X = X self.y = y self.seed = seed def flatten_X(self): shape = self.X.shape flat_X = self.X if len(shape) > 2: flat_X = np.reshape(self.X, (shape[0], np.product(shape[1:]))) return flat_X @abc.abstractmethod def select_batch_(self): return def select_batch(self, **kwargs): return self.select_batch_(**kwargs) def to_dict(self): return None class kCenterGreedy(SamplingMethod): def __init__(self, X, y, seed, metric="euclidean"): self.X = X self.y = y self.flat_X = self.flatten_X() self.name = "kcenter" self.features = self.flat_X self.metric = metric self.min_distances = None self.n_obs = self.X.shape[0] self.already_selected = [] def update_distances(self, cluster_centers, only_new=True, reset_dist=False): """Update min distances given cluster centers. Args: cluster_centers: indices of cluster centers only_new: only calculate distance for newly selected points and update min_distances. rest_dist: whether to reset min_distances. """ if reset_dist: self.min_distances = None if only_new: cluster_centers = [ d for d in cluster_centers if d not in self.already_selected ] if cluster_centers: # Update min_distances for all examples given new cluster center. x = self.features[cluster_centers] dist = pairwise_distances(self.features, x, metric=self.metric) if self.min_distances is None: self.min_distances = np.min(dist, axis=1).reshape(-1, 1) else: self.min_distances = np.minimum(self.min_distances, dist) def select_batch_(self, model, already_selected, N, **kwargs): """ Diversity promoting active learning method that greedily forms a batch to minimize the maximum distance to a cluster center among all unlabeled datapoints. Args: model: model with scikit-like API with decision_function implemented already_selected: index of datapoints already selected N: batch size Returns: indices of points selected to minimize distance to cluster centers """ try: # Assumes that the transform function takes in original data and not # flattened data. print("Getting transformed features...") self.features = model.transform(self.X) print("Calculating distances...") self.update_distances(already_selected, only_new=False, reset_dist=True) except: print("Using flat_X as features.") self.update_distances(already_selected, only_new=True, reset_dist=False) new_batch = [] for _ in tqdm(range(N)): if self.already_selected is None: # Initialize centers with a randomly selected datapoint ind = np.random.choice(np.arange(self.n_obs)) else: ind = np.argmax(self.min_distances) # New examples should not be in already selected since those points # should have min_distance of zero to a cluster center. assert ind not in already_selected self.update_distances([ind], only_new=True, reset_dist=False) new_batch.append(ind) print( "Maximum distance from cluster centers is %0.2f" % max(self.min_distances) ) self.already_selected = already_selected return new_batch if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--start', type=int) parser.add_argument("--end", type=int) args = parser.parse_args() embeddings = torch.load("/home/aiscuser/fhw/embeddings/qwq_ins_embeddings.pt") f = open("/home/aiscuser/fhw/data/qwq_python_selected.json", "r+") fw = open(f"/home/aiscuser/fhw/data/qwq_python_diverse_{args.start}_{args.end}.json", "w+") lines = f.readlines()[args.start:args.end] selected_nums = 10000 nums = embeddings.shape[0] kcg = kCenterGreedy(X=embeddings[args.start:args.end], y=None, seed=42) batch = kcg.select_batch_(model=None, already_selected=[], N=selected_nums) for idx in batch: fw.write(lines[idx])