|
|
|
|
|
|
|
|
#include "python/src/mlx_func.h" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct gc_func { |
|
|
PyObject_HEAD |
|
|
|
|
|
PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*); |
|
|
|
|
|
PyObject* func; |
|
|
|
|
|
|
|
|
PyObject* orig_func; |
|
|
|
|
|
std::vector<PyObject*> deps; |
|
|
}; |
|
|
|
|
|
int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { |
|
|
Py_VISIT(Py_TYPE(self)); |
|
|
gc_func* w = (gc_func*)self; |
|
|
Py_VISIT(w->func); |
|
|
for (auto d : w->deps) { |
|
|
Py_VISIT(d); |
|
|
} |
|
|
return 0; |
|
|
}; |
|
|
|
|
|
int gc_func_tp_clear(PyObject* self) { |
|
|
gc_func* w = (gc_func*)self; |
|
|
Py_CLEAR(w->func); |
|
|
return 0; |
|
|
} |
|
|
|
|
|
PyObject* gc_func_get_doc(PyObject* self, void*) { |
|
|
return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__"); |
|
|
} |
|
|
|
|
|
PyObject* gc_func_get_sig(PyObject* self, void*) { |
|
|
return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__"); |
|
|
} |
|
|
|
|
|
PyObject* gc_func_vectorcall( |
|
|
PyObject* self, |
|
|
PyObject* const* args, |
|
|
size_t nargs, |
|
|
PyObject* kwnames) { |
|
|
return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames); |
|
|
} |
|
|
|
|
|
void gc_func_dealloc(PyObject* self) { |
|
|
PyObject_GC_UnTrack(self); |
|
|
Py_XDECREF(((gc_func*)self)->func); |
|
|
PyObject_GC_Del(self); |
|
|
} |
|
|
|
|
|
static PyMemberDef gc_func_members[] = { |
|
|
{"__vectorcalloffset__", |
|
|
T_PYSSIZET, |
|
|
(Py_ssize_t)offsetof(gc_func, vectorcall), |
|
|
READONLY, |
|
|
nullptr}, |
|
|
{nullptr, 0, 0, 0, nullptr}}; |
|
|
|
|
|
static PyGetSetDef gc_func_getset[] = { |
|
|
{"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr}, |
|
|
{"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr}, |
|
|
{nullptr, nullptr, nullptr, nullptr, nullptr}}; |
|
|
|
|
|
static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { |
|
|
gc_func* w = (gc_func*)self; |
|
|
return PyObject_GenericGetAttr(w->orig_func, name_); |
|
|
} |
|
|
|
|
|
|
|
|
PyType_Slot gc_func_slots[] = { |
|
|
{Py_tp_traverse, (void*)gc_func_tp_traverse}, |
|
|
{Py_tp_clear, (void*)gc_func_tp_clear}, |
|
|
{Py_tp_getset, (void*)gc_func_getset}, |
|
|
{Py_tp_getattro, (void*)gc_func_getattro}, |
|
|
{Py_tp_members, (void*)gc_func_members}, |
|
|
{Py_tp_call, (void*)PyVectorcall_Call}, |
|
|
{Py_tp_dealloc, (void*)gc_func_dealloc}, |
|
|
{0, 0}}; |
|
|
|
|
|
static PyType_Spec gc_func_spec = { |
|
|
"mlx.gc_func", |
|
|
(int)sizeof(gc_func), |
|
|
0, |
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, |
|
|
gc_func_slots}; |
|
|
|
|
|
static PyTypeObject* gc_func_tp = nullptr; |
|
|
|
|
|
nb::callable mlx_func( |
|
|
nb::object func, |
|
|
const nb::callable& orig_func, |
|
|
std::vector<PyObject*> deps) { |
|
|
gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0); |
|
|
r->func = func.inc_ref().ptr(); |
|
|
r->orig_func = orig_func.ptr(); |
|
|
deps.push_back(r->orig_func); |
|
|
r->deps = std::move(deps); |
|
|
r->vectorcall = gc_func_vectorcall; |
|
|
return nb::steal<nb::callable>((PyObject*)r); |
|
|
} |
|
|
|
|
|
void init_mlx_func(nb::module_& m) { |
|
|
gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec); |
|
|
if (!gc_func_tp) { |
|
|
nb::raise("Could not register MLX function type."); |
|
|
} |
|
|
} |
|
|
|