vecmini-engine / src /bindings.cpp
levanel
vecmini1
e87a50a
#include <pybind11/pybind11.h>
#include <pybind11/numpy.h>
#include "IndexIVF.h"
#include "IndexIVFPQ.h"
#include "iostream"
#include <pybind11/stl.h>
#include <vector>
namespace py = pybind11;
// "vecmini" is the name of the module you will type in python-> 'import vecmini'
PYBIND11_MODULE(vecmini, m) {
m.doc() = "Vecmini: A mini custom IVF Vector Database with Metadata Filtering";
py::class_<IndexIVF>(m, "IndexIVF")
.def(py::init<int, int>(), py::arg("d"), py::arg("nbucket"))
.def("train", [](IndexIVF &self, int n, py::array_t<float, py::array::c_style> x) {
py::buffer_info buf = x.request();
self.train(n, (const float *)buf.ptr);
}, py::arg("n"), py::arg("x").noconvert())
.def("add", [](IndexIVF &self, int n,
py::array_t<float, py::array::c_style | py::array::forcecast> x,
py::array_t<uint64_t, py::array::c_style | py::array::forcecast> xids) {
py::buffer_info buf_x = x.request();
py::buffer_info buf_xids = xids.request();
self.add(n, (const float *)buf_x.ptr, (const uint64_t *)buf_xids.ptr);
}, py::arg("n"), py::arg("x"), py::arg("xids"))
// Expose search() - UPDATED FOR NPROBE AND BITMASK
.def("search", [](IndexIVF &self, int n,
py::array_t<float, py::array::c_style | py::array::forcecast> x,
int k, int nprobe, py::object bitmask) {
py::buffer_info buf_x = x.request();
// Empty arrays to hold the answers for Python
py::array_t<float> distances({n, k});
py::array_t<int> labels({n, k});
const uint8_t* bitmask_ptr = nullptr;
py::array_t<uint8_t> bitmask_arr;
if (!bitmask.is_none()) {
bitmask_arr = bitmask.cast<py::array_t<uint8_t, py::array::c_style | py::array::forcecast>>();
bitmask_ptr = (const uint8_t*)bitmask_arr.request().ptr;
std::cout<<"recieved bitmask , *pointer address->" <<(void*)bitmask_ptr<<"\n";
} else {
std::cout<<"recieved NONE\n";
}
self.search(n, (const float *)buf_x.ptr, k, nprobe, bitmask_ptr,
distances.mutable_data(), labels.mutable_data());
return py::make_tuple(distances, labels);
}, py::arg("n"), py::arg("x"), py::arg("k"), py::arg("nprobe"), py::arg("bitmask"));
py::class_<IndexIVFPQ>(m, "IndexIVFPQ")
.def(py::init<int, int, int>(),
py::arg("d"),
py::arg("nbucket"),
py::arg("m"))
.def("train", [](IndexIVFPQ &self, int n, py::array_t<float, py::array::c_style> x, bool subsampling, bool seed) {
py::buffer_info buf = x.request();
self.train(n, static_cast<const float *>(buf.ptr), subsampling, seed);
}, py::arg("n"), py::arg("x").noconvert(), py::arg("subsampling"), py::arg("seed"))
.def("add", [](IndexIVFPQ &self,int n, py::array_t<float, py::array::c_style> x, py::array_t<uint64_t, py::array::c_style> xids){
py::buffer_info bufx = x.request();
py::buffer_info bufxids = xids.request();
self.add(n, static_cast<const float *>(bufx.ptr),static_cast<const uint64_t *>(bufxids.ptr));
}, py::arg("n"), py::arg("x").noconvert(), py::arg("xids").noconvert())
.def("search", [](IndexIVFPQ &self, int n,
py::array_t<float, py::array::c_style> query,
int k, int nprobe){
py::buffer_info buf_query = query.request();
py::array_t<float> distances({n,k});
py::array_t<int64_t> labels({n,k});
self.search(n, static_cast<const float *>(buf_query.ptr), k, nprobe, distances.mutable_data(), labels.mutable_data());
return py::make_tuple(distances, labels);
}, py::arg("n"), py::arg("query").noconvert(), py::arg("k"), py::arg("nprobe"));
py::class_<IndexFlatL2>(m, "IndexFlatL2")
.def(py::init<int>(),
py::arg("d"))
.def("add", [](IndexFlatL2 &self,int n, py::array_t<float, py::array::c_style> x){
py::buffer_info bufx = x.request();
self.add(n, static_cast<const float *>(bufx.ptr));
}, py::arg("n"), py::arg("x").noconvert())
.def("search", [](IndexFlatL2 &self, int n,
py::array_t<float, py::array::c_style> x,
int k){
py::buffer_info bufx = x.request();
py::array_t<float> distances({n,k});
py::array_t<int> labels({n,k});
self.search(n, static_cast<const float *>(bufx.ptr), k, distances.mutable_data(), labels.mutable_data());
return py::make_tuple(distances, labels);
}, py::arg("n"), py::arg("x").noconvert(), py::arg("k"));
}