#include #include #include "IndexIVF.h" #include "IndexIVFPQ.h" #include "iostream" #include #include namespace py = pybind11; // "vecmini" is the name of the module you will type in python-> 'import vecmini' PYBIND11_MODULE(vecmini, m) { m.doc() = "Vecmini: A mini custom IVF Vector Database with Metadata Filtering"; py::class_(m, "IndexIVF") .def(py::init(), py::arg("d"), py::arg("nbucket")) .def("train", [](IndexIVF &self, int n, py::array_t x) { py::buffer_info buf = x.request(); self.train(n, (const float *)buf.ptr); }, py::arg("n"), py::arg("x").noconvert()) .def("add", [](IndexIVF &self, int n, py::array_t x, py::array_t xids) { py::buffer_info buf_x = x.request(); py::buffer_info buf_xids = xids.request(); self.add(n, (const float *)buf_x.ptr, (const uint64_t *)buf_xids.ptr); }, py::arg("n"), py::arg("x"), py::arg("xids")) // Expose search() - UPDATED FOR NPROBE AND BITMASK .def("search", [](IndexIVF &self, int n, py::array_t x, int k, int nprobe, py::object bitmask) { py::buffer_info buf_x = x.request(); // Empty arrays to hold the answers for Python py::array_t distances({n, k}); py::array_t labels({n, k}); const uint8_t* bitmask_ptr = nullptr; py::array_t bitmask_arr; if (!bitmask.is_none()) { bitmask_arr = bitmask.cast>(); bitmask_ptr = (const uint8_t*)bitmask_arr.request().ptr; std::cout<<"recieved bitmask , *pointer address->" <<(void*)bitmask_ptr<<"\n"; } else { std::cout<<"recieved NONE\n"; } self.search(n, (const float *)buf_x.ptr, k, nprobe, bitmask_ptr, distances.mutable_data(), labels.mutable_data()); return py::make_tuple(distances, labels); }, py::arg("n"), py::arg("x"), py::arg("k"), py::arg("nprobe"), py::arg("bitmask")); py::class_(m, "IndexIVFPQ") .def(py::init(), py::arg("d"), py::arg("nbucket"), py::arg("m")) .def("train", [](IndexIVFPQ &self, int n, py::array_t x, bool subsampling, bool seed) { py::buffer_info buf = x.request(); self.train(n, static_cast(buf.ptr), subsampling, seed); }, py::arg("n"), py::arg("x").noconvert(), py::arg("subsampling"), py::arg("seed")) .def("add", [](IndexIVFPQ &self,int n, py::array_t x, py::array_t xids){ py::buffer_info bufx = x.request(); py::buffer_info bufxids = xids.request(); self.add(n, static_cast(bufx.ptr),static_cast(bufxids.ptr)); }, py::arg("n"), py::arg("x").noconvert(), py::arg("xids").noconvert()) .def("search", [](IndexIVFPQ &self, int n, py::array_t query, int k, int nprobe){ py::buffer_info buf_query = query.request(); py::array_t distances({n,k}); py::array_t labels({n,k}); self.search(n, static_cast(buf_query.ptr), k, nprobe, distances.mutable_data(), labels.mutable_data()); return py::make_tuple(distances, labels); }, py::arg("n"), py::arg("query").noconvert(), py::arg("k"), py::arg("nprobe")); py::class_(m, "IndexFlatL2") .def(py::init(), py::arg("d")) .def("add", [](IndexFlatL2 &self,int n, py::array_t x){ py::buffer_info bufx = x.request(); self.add(n, static_cast(bufx.ptr)); }, py::arg("n"), py::arg("x").noconvert()) .def("search", [](IndexFlatL2 &self, int n, py::array_t x, int k){ py::buffer_info bufx = x.request(); py::array_t distances({n,k}); py::array_t labels({n,k}); self.search(n, static_cast(bufx.ptr), k, distances.mutable_data(), labels.mutable_data()); return py::make_tuple(distances, labels); }, py::arg("n"), py::arg("x").noconvert(), py::arg("k")); }