File size: 5,069 Bytes
e87a50a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h> 
#include "IndexIVF.h"
#include "IndexIVFPQ.h"
#include "iostream"
#include <pybind11/stl.h>
#include <vector>
namespace py = pybind11;

// "vecmini" is the name of the module you will type in python-> 'import vecmini'
PYBIND11_MODULE(vecmini, m) {
    m.doc() = "Vecmini: A mini custom IVF Vector Database with Metadata Filtering";

    py::class_<IndexIVF>(m, "IndexIVF")
        .def(py::init<int, int>(), py::arg("d"), py::arg("nbucket"))
        
        .def("train", [](IndexIVF &self, int n, py::array_t<float, py::array::c_style> x) {
            py::buffer_info buf = x.request();
            self.train(n, (const float *)buf.ptr);
        }, py::arg("n"), py::arg("x").noconvert())
        

        .def("add", [](IndexIVF &self, int n, 
                       py::array_t<float, py::array::c_style | py::array::forcecast> x, 
                       py::array_t<uint64_t, py::array::c_style | py::array::forcecast> xids) {
            
            py::buffer_info buf_x = x.request();
            py::buffer_info buf_xids = xids.request();
            
            self.add(n, (const float *)buf_x.ptr, (const uint64_t *)buf_xids.ptr);
        }, py::arg("n"), py::arg("x"), py::arg("xids"))

        // Expose search() - UPDATED FOR NPROBE AND BITMASK
        .def("search", [](IndexIVF &self, int n, 
                          py::array_t<float, py::array::c_style | py::array::forcecast> x, 
                          int k, int nprobe, py::object bitmask) {
            
            py::buffer_info buf_x = x.request();
            
            // Empty arrays to hold the answers for Python
            py::array_t<float> distances({n, k});
            py::array_t<int> labels({n, k});

            const uint8_t* bitmask_ptr = nullptr;
            py::array_t<uint8_t> bitmask_arr; 
            
            if (!bitmask.is_none()) {
                bitmask_arr = bitmask.cast<py::array_t<uint8_t, py::array::c_style | py::array::forcecast>>();
                bitmask_ptr = (const uint8_t*)bitmask_arr.request().ptr;
                std::cout<<"recieved bitmask , *pointer address->" <<(void*)bitmask_ptr<<"\n";
            } else {
                std::cout<<"recieved NONE\n";
            }
            

            self.search(n, (const float *)buf_x.ptr, k, nprobe, bitmask_ptr, 
                        distances.mutable_data(), labels.mutable_data());
            
            return py::make_tuple(distances, labels);
        }, py::arg("n"), py::arg("x"), py::arg("k"), py::arg("nprobe"), py::arg("bitmask"));

    py::class_<IndexIVFPQ>(m, "IndexIVFPQ")
        .def(py::init<int, int, int>(),
        py::arg("d"),
        py::arg("nbucket"),
        py::arg("m"))

        .def("train", [](IndexIVFPQ &self, int n, py::array_t<float, py::array::c_style> x, bool subsampling, bool seed) {
            py::buffer_info buf = x.request();
            self.train(n, static_cast<const float *>(buf.ptr), subsampling, seed);
        }, py::arg("n"), py::arg("x").noconvert(), py::arg("subsampling"), py::arg("seed"))

        .def("add", [](IndexIVFPQ &self,int n, py::array_t<float, py::array::c_style> x, py::array_t<uint64_t, py::array::c_style> xids){
            py::buffer_info bufx = x.request();
            py::buffer_info bufxids = xids.request();

            self.add(n, static_cast<const float *>(bufx.ptr),static_cast<const uint64_t *>(bufxids.ptr));
        }, py::arg("n"), py::arg("x").noconvert(), py::arg("xids").noconvert())

        .def("search", [](IndexIVFPQ &self, int n, 
                        py::array_t<float, py::array::c_style> query, 
                        int k, int nprobe){
            py::buffer_info buf_query = query.request();
            
            py::array_t<float> distances({n,k});
            py::array_t<int64_t> labels({n,k});
            
            self.search(n, static_cast<const float *>(buf_query.ptr), k, nprobe, distances.mutable_data(), labels.mutable_data());

            return py::make_tuple(distances, labels);
        }, py::arg("n"), py::arg("query").noconvert(), py::arg("k"), py::arg("nprobe"));



    py::class_<IndexFlatL2>(m, "IndexFlatL2")
        .def(py::init<int>(),
        py::arg("d"))

        .def("add", [](IndexFlatL2 &self,int n, py::array_t<float, py::array::c_style> x){
            py::buffer_info bufx = x.request();

            self.add(n, static_cast<const float *>(bufx.ptr));
        }, py::arg("n"), py::arg("x").noconvert())

        .def("search", [](IndexFlatL2 &self, int n, 
                        py::array_t<float, py::array::c_style> x, 
                        int k){
            py::buffer_info bufx = x.request();
            
            py::array_t<float> distances({n,k});
            py::array_t<int> labels({n,k});
            
            self.search(n, static_cast<const float *>(bufx.ptr), k, distances.mutable_data(), labels.mutable_data());

            return py::make_tuple(distances, labels);
        }, py::arg("n"), py::arg("x").noconvert(), py::arg("k"));

}