| | |
| |
|
| | #include "python/src/mlx_func.h" |
| |
|
| | |
| | |
| |
|
| | struct gc_func { |
| | PyObject_HEAD |
| | |
| | PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*); |
| | |
| | PyObject* func; |
| |
|
| | |
| | PyObject* orig_func; |
| | |
| | std::vector<PyObject*> deps; |
| | }; |
| |
|
| | int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) { |
| | Py_VISIT(Py_TYPE(self)); |
| | gc_func* w = (gc_func*)self; |
| | Py_VISIT(w->func); |
| | for (auto d : w->deps) { |
| | Py_VISIT(d); |
| | } |
| | return 0; |
| | }; |
| |
|
| | int gc_func_tp_clear(PyObject* self) { |
| | gc_func* w = (gc_func*)self; |
| | Py_CLEAR(w->func); |
| | return 0; |
| | } |
| |
|
| | PyObject* gc_func_get_doc(PyObject* self, void*) { |
| | return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__"); |
| | } |
| |
|
| | PyObject* gc_func_get_sig(PyObject* self, void*) { |
| | return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__"); |
| | } |
| |
|
| | PyObject* gc_func_vectorcall( |
| | PyObject* self, |
| | PyObject* const* args, |
| | size_t nargs, |
| | PyObject* kwnames) { |
| | return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames); |
| | } |
| |
|
| | void gc_func_dealloc(PyObject* self) { |
| | PyObject_GC_UnTrack(self); |
| | Py_XDECREF(((gc_func*)self)->func); |
| | PyObject_GC_Del(self); |
| | } |
| |
|
| | static PyMemberDef gc_func_members[] = { |
| | {"__vectorcalloffset__", |
| | T_PYSSIZET, |
| | (Py_ssize_t)offsetof(gc_func, vectorcall), |
| | READONLY, |
| | nullptr}, |
| | {nullptr, 0, 0, 0, nullptr}}; |
| |
|
| | static PyGetSetDef gc_func_getset[] = { |
| | {"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr}, |
| | {"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr}, |
| | {nullptr, nullptr, nullptr, nullptr, nullptr}}; |
| |
|
| | static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) { |
| | gc_func* w = (gc_func*)self; |
| | return PyObject_GenericGetAttr(w->orig_func, name_); |
| | } |
| |
|
| | |
| | PyType_Slot gc_func_slots[] = { |
| | {Py_tp_traverse, (void*)gc_func_tp_traverse}, |
| | {Py_tp_clear, (void*)gc_func_tp_clear}, |
| | {Py_tp_getset, (void*)gc_func_getset}, |
| | {Py_tp_getattro, (void*)gc_func_getattro}, |
| | {Py_tp_members, (void*)gc_func_members}, |
| | {Py_tp_call, (void*)PyVectorcall_Call}, |
| | {Py_tp_dealloc, (void*)gc_func_dealloc}, |
| | {0, 0}}; |
| |
|
| | static PyType_Spec gc_func_spec = { |
| | "mlx.gc_func", |
| | (int)sizeof(gc_func), |
| | 0, |
| | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL, |
| | gc_func_slots}; |
| |
|
| | static PyTypeObject* gc_func_tp = nullptr; |
| |
|
| | nb::callable mlx_func( |
| | nb::object func, |
| | const nb::callable& orig_func, |
| | std::vector<PyObject*> deps) { |
| | gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0); |
| | r->func = func.inc_ref().ptr(); |
| | r->orig_func = orig_func.ptr(); |
| | deps.push_back(r->orig_func); |
| | r->deps = std::move(deps); |
| | r->vectorcall = gc_func_vectorcall; |
| | return nb::steal<nb::callable>((PyObject*)r); |
| | } |
| |
|
| | void init_mlx_func(nb::module_& m) { |
| | gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec); |
| | if (!gc_func_tp) { |
| | nb::raise("Could not register MLX function type."); |
| | } |
| | } |
| |
|