File size: 2,155 Bytes
541e9bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import numpy as np
import torch

def clip_score(feature, ref_feature, logit_scale = 100.0, weight = 1, reduce = True):
    ref_feature = np.expand_dims(ref_feature, axis = 0) if np.ndim(ref_feature) == 2 else ref_feature
    batch_size, ref_size = np.shape(ref_feature)[:2]
    feature = feature / np.linalg.norm(feature, axis = -1, keepdims=True)
    ref_feature = ref_feature / np.linalg.norm(ref_feature, axis = -1, keepdims=True)
    sim = logit_scale * np.einsum("bf,btf->bt", feature, ref_feature)
    sim = sim * (np.expand_dims(weight, axis = 0) if np.ndim(weight) == 1 else weight)
    return sim.mean(axis = 1) if reduce else (sim[..., 0] if ref_size == 1 else sim)
    
def coreset_sampling(data, n_sample = 0.1, weight = 1, n_approximate = 10, logit_scale = 100, seed = 42):
    data = np.array(data) if not isinstance(data, np.ndarray) else data
    n_sample = round(len(data) * n_sample) if isinstance(n_sample, float) or (isinstance(n_sample, int) and n_sample < 1) else n_sample
    n_sample = max(min(n_sample, len(data)), 1 if len(data) != 0 else 0)
    weight = 1 if weight is None else weight
    weight = np.transpose(weight) if np.ndim(weight) == 2 else (np.expand_dims(weight, axis = -1) if np.ndim(weight) == 1 else weight)
    
    random = ((np.random.RandomState(seed) if isinstance(seed, int) else seed) if seed is not None else np.random)
    if n_sample == len(data):
        indices = np.arange(n_sample)
    else:
        indices = []
        approx_data = data[random.choice(len(data), min(round(len(data) * n_approximate) if isinstance(n_approximate, float) else n_approximate, len(data)), replace = False)]
        dist = clip_score(data, approx_data, weight = weight, logit_scale = logit_scale, reduce = False)
        dist = np.mean(dist, axis = 1, keepdims = True)
        for i in range(n_sample):
            sample_index = np.argmax(dist)
            indices.append(sample_index)
            sample_dist = clip_score(data, data[[sample_index]], weight = weight, logit_scale = logit_scale, reduce = False)
            dist = np.minimum(dist, sample_dist)
            dist[sample_index] = -np.inf
    return indices