| import numpy as np | |
| import torch | |
| def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True): | |
| ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature | |
| batch_size, ref_size = np.shape(ref_feature)[:2] | |
| feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True) | |
| ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True) | |
| sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature) | |
| sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight) | |
| return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim) | |
| def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42): | |
| data = np.array(data) if not isinstance(data, np.ndarray) else data | |
| n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample | |
| n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0) | |
| weight = 1 if weight is None else weight | |
| weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight) | |
| random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random) | |
| if n_sample == len(data): | |
| indices = np.arange(n_sample) | |
| else: | |
| indices = [] | |
| approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)] | |
| dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False) | |
| dist = np.mean(dist, axis = 1, keepdims = True) | |
| for i in range(n_sample): | |
| sample_index = np.argmax(dist) | |
| indices.append(sample_index) | |
| sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False) | |
| dist = np.minimum(dist, sample_dist) | |
| dist[sample_index] = -np.inf | |
| return indices | |