File size: 724 Bytes
e87a50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#pragma once
#include <vector>
#include "IndexFlat.h"
#include <cstdint>

class IndexIVF {
private: 
    int d; 
    int nbucket; 
    int ntotal = 0;
    bool trained = false;
    
    IndexFlatL2 router; 
    std::vector<std::vector<float>> memory;
    std::vector<std::vector<uint64_t>> memory_ids;

public: // The interface (Your benchmark script is allowed to use these)
    IndexIVF(int d, int nbucket);    
    void train(int n, const float *x);
    void add(int n, const float *x, const uint64_t*xids);
    void search(int n, const float* x, int k, int nprobe, const uint8_t *bitmask, float *distances, int *labels, const uint8_t *L1_summary = nullptr);
    void inject_centroids(const float* external_centroids);
};