File size: 3,341 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
// Copyright © 2025 Apple Inc.

#include "python/src/mlx_func.h"

// A garbage collected function which wraps nb::cpp_function
// See https://github.com/wjakob/nanobind/discussions/919

struct gc_func {
  PyObject_HEAD
      // Vector call implementation that forwards calls to nanobind
      PyObject* (*vectorcall)(PyObject*, PyObject* const*, size_t, PyObject*);
  // The nanobind wrapper func
  PyObject* func;

  // The original wrapped func
  PyObject* orig_func;
  // A non-owning reference to dependencies owned by 'func'
  std::vector<PyObject*> deps;
};

int gc_func_tp_traverse(PyObject* self, visitproc visit, void* arg) {
  Py_VISIT(Py_TYPE(self));
  gc_func* w = (gc_func*)self;
  Py_VISIT(w->func);
  for (auto d : w->deps) {
    Py_VISIT(d);
  }
  return 0;
};

int gc_func_tp_clear(PyObject* self) {
  gc_func* w = (gc_func*)self;
  Py_CLEAR(w->func);
  return 0;
}

PyObject* gc_func_get_doc(PyObject* self, void*) {
  return PyObject_GetAttrString(((gc_func*)self)->func, "__doc__");
}

PyObject* gc_func_get_sig(PyObject* self, void*) {
  return PyObject_GetAttrString(((gc_func*)self)->func, "__nb_signature__");
}

PyObject* gc_func_vectorcall(
    PyObject* self,
    PyObject* const* args,
    size_t nargs,
    PyObject* kwnames) {
  return PyObject_Vectorcall(((gc_func*)self)->func, args, nargs, kwnames);
}

void gc_func_dealloc(PyObject* self) {
  PyObject_GC_UnTrack(self);
  Py_XDECREF(((gc_func*)self)->func);
  PyObject_GC_Del(self);
}

static PyMemberDef gc_func_members[] = {
    {"__vectorcalloffset__",
     T_PYSSIZET,
     (Py_ssize_t)offsetof(gc_func, vectorcall),
     READONLY,
     nullptr},
    {nullptr, 0, 0, 0, nullptr}};

static PyGetSetDef gc_func_getset[] = {
    {"__doc__", gc_func_get_doc, nullptr, nullptr, nullptr},
    {"__nb_signature__", gc_func_get_sig, nullptr, nullptr, nullptr},
    {nullptr, nullptr, nullptr, nullptr, nullptr}};

static PyObject* gc_func_getattro(PyObject* self, PyObject* name_) {
  gc_func* w = (gc_func*)self;
  return PyObject_GenericGetAttr(w->orig_func, name_);
}

// Table of custom type slots we want to install
PyType_Slot gc_func_slots[] = {
    {Py_tp_traverse, (void*)gc_func_tp_traverse},
    {Py_tp_clear, (void*)gc_func_tp_clear},
    {Py_tp_getset, (void*)gc_func_getset},
    {Py_tp_getattro, (void*)gc_func_getattro},
    {Py_tp_members, (void*)gc_func_members},
    {Py_tp_call, (void*)PyVectorcall_Call},
    {Py_tp_dealloc, (void*)gc_func_dealloc},
    {0, 0}};

static PyType_Spec gc_func_spec = {
    /* .name = */ "mlx.gc_func",
    /* .basicsize = */ (int)sizeof(gc_func),
    /* .itemsize = */ 0,
    /* .flags = */ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC | NB_HAVE_VECTORCALL,
    /* .slots = */ gc_func_slots};

static PyTypeObject* gc_func_tp = nullptr;

nb::callable mlx_func(
    nb::object func,
    const nb::callable& orig_func,
    std::vector<PyObject*> deps) {
  gc_func* r = (gc_func*)PyType_GenericAlloc(gc_func_tp, 0);
  r->func = func.inc_ref().ptr();
  r->orig_func = orig_func.ptr();
  deps.push_back(r->orig_func);
  r->deps = std::move(deps);
  r->vectorcall = gc_func_vectorcall;
  return nb::steal<nb::callable>((PyObject*)r);
}

void init_mlx_func(nb::module_& m) {
  gc_func_tp = (PyTypeObject*)PyType_FromSpec(&gc_func_spec);
  if (!gc_func_tp) {
    nb::raise("Could not register MLX function type.");
  }
}