File size: 3,653 Bytes
9ed01de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import os
import sys
import argparse
import numpy as np
from tqdm import tqdm
import faiss

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
if BASE_DIR not in sys.path:
    sys.path.insert(0, BASE_DIR)

from extract_embeddings import extract_feats

def resolve_path(*parts):
    return os.path.abspath(os.path.join(BASE_DIR, *parts))


def build_search_index(img_feats, dest_dir, n_train=None, do_pca=False, m=32, nprobe=256, nlist=None, force_flat=False,
                       batch_size=None, use_gpu=False):
    os.makedirs(dest_dir, exist_ok=True)

    if n_train is None:
        train_feats = img_feats
    else:
        train_feats = img_feats[:n_train]
    feat_dim = train_feats.shape[1]

    if len(img_feats) < 100000 or force_flat:
        print('Using flat index')
        index = faiss.IndexFlatIP(feat_dim)
    else:
        print('Using IVFPQ index')
        if nlist is None:
            nlist = int(np.sqrt(train_feats.shape[0]) * 4)
            print('Setting nlist={}'.format(nlist))
        quantizer = faiss.IndexFlatIP(feat_dim)
        index = faiss.IndexIVFPQ(quantizer, feat_dim, nlist, m, 8)
        index.nprobe = nprobe

    if use_gpu:
        res = faiss.StandardGpuResources()
        index = faiss.index_cpu_to_gpu(res, 0, index)

    norm1 = faiss.NormalizationTransform(feat_dim)
    norm1.train(train_feats)
    train_feats = norm1.apply_py(train_feats)
    faiss.write_VectorTransform(norm1, os.path.join(dest_dir, 'norm1.bin'))
    if do_pca:
        pca = faiss.PCAMatrix(d_in=feat_dim, d_out=feat_dim, eigen_power=-0.5)
        pca.train(train_feats)
        train_feats = pca.apply_py(train_feats)
        faiss.write_VectorTransform(pca, os.path.join(dest_dir, 'pca.bin'))
        norm2 = faiss.NormalizationTransform(feat_dim)
        norm2.train(train_feats)
        train_feats = norm2.apply_py(train_feats)
        faiss.write_VectorTransform(norm2, os.path.join(dest_dir, 'norm2.bin'))

    print('Training index...')
    index.train(train_feats)

    def transform(x):
        x = norm1.apply_py(x)
        if do_pca:
            x = pca.apply_py(x)
            x = norm2.apply_py(x)
        return x

    if batch_size is None:
        print('Adding vectors to index...')
        feats = transform(img_feats)
        index.add(feats)
    else:
        print('Adding vectors to index in batches...')
        for i in tqdm(range(0, len(img_feats), batch_size)):
            feats = transform(img_feats[i:i+batch_size])
            index.add(feats)
    print('Done adding vectors to index')

    index_path = os.path.join(dest_dir, 'faiss_IVPQ_PCA.index')
    if use_gpu:
        index = faiss.index_gpu_to_cpu(index)
    faiss.write_index(index, index_path)



index_configs = {
    'CLIP-Huge-Flickr-Flat':
        {
            'a1_config': 'ViT-H-14-quickgelu',
            'weight_path': 'dfn5b',
            'img_dir': resolve_path('..', 'coco_data', 'train2017'),
        },
}

out_dir = resolve_path('..', 'coco_faiss_indexes')

if __name__ == '__main__':

    os.makedirs(out_dir, exist_ok=True)
    model_id = 'CLIP-Huge-Flickr-Flat'

    force_flat = 'Flat' in model_id
    save_feats = True

    index_dir = os.path.join(out_dir, model_id)
    os.makedirs(index_dir, exist_ok=True)

    im_hashes, im_feats = extract_feats(index_configs[model_id])

    if save_feats:
        feats_dir = resolve_path('..', 'coco_feats')
        os.makedirs(feats_dir, exist_ok=True)
        np.save(os.path.join(feats_dir, model_id + '_feats.npy'), im_feats)

    np.save(os.path.join(index_dir, 'img_ids.npy'), im_hashes)

    build_search_index(im_feats, index_dir, force_flat=force_flat, use_gpu=True)