File size: 11,261 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
// Copyright © 2024 Apple Inc.
#include <nanobind/nanobind.h>
#include <nanobind/stl/optional.h>
#include <nanobind/stl/string.h>
#include <nanobind/stl/unordered_map.h>
#include <nanobind/stl/vector.h>

#include <fstream>

#include "mlx/array.h"
#include "mlx/export.h"
#include "mlx/graph_utils.h"
#include "python/src/small_vector.h"
#include "python/src/trees.h"

namespace mx = mlx::core;
namespace nb = nanobind;
using namespace nb::literals;

std::pair<mx::Args, mx::Kwargs> validate_and_extract_inputs(
    const nb::args& args,
    const nb::kwargs& kwargs,
    const std::string& prefix) {
  auto maybe_throw = [&prefix](bool valid) {
    if (!valid) {
      throw std::invalid_argument(
          prefix +
          " Inputs can either be a variable "
          "number of positional and keyword arrays or a single tuple "
          "and/or dictionary of arrays.");
    }
  };
  mx::Args args_;
  mx::Kwargs kwargs_;
  if (args.size() == 0) {
    // No args so kwargs must be keyword arrays
    maybe_throw(nb::try_cast(kwargs, kwargs_));
  } else if (args.size() > 0 && nb::isinstance<mx::array>(args[0])) {
    // Args are positional arrays and kwargs are keyword arrays
    maybe_throw(nb::try_cast(args, args_));
    maybe_throw(nb::try_cast(kwargs, kwargs_));
  } else if (args.size() == 1) {
    // - args[0] can be a tuple or list or arrays or a dict
    //   with string keys and array values
    // - kwargs should be empty
    maybe_throw(kwargs.size() == 0);
    if (!nb::try_cast(args[0], args_)) {
      maybe_throw(nb::try_cast(args[0], kwargs_));
    }
  } else if (args.size() == 2) {
    // - args[0] can be a tuple or list of arrays
    // - args[1] can be a dict of string keys with array values.
    // - kwargs should be empty
    maybe_throw(kwargs.size() == 0);
    maybe_throw(nb::try_cast(args[0], args_));
    maybe_throw(nb::try_cast(args[1], kwargs_));
  } else {
    maybe_throw(false);
  }
  return {args_, kwargs_};
}

int py_function_exporter_tp_traverse(
    PyObject* self,
    visitproc visit,
    void* arg);

class PyFunctionExporter {
 public:
  PyFunctionExporter(mx::FunctionExporter exporter, nb::handle dep)
      : exporter_(std::move(exporter)), dep_(dep) {}
  ~PyFunctionExporter() {
    nb::gil_scoped_acquire gil;
  }
  PyFunctionExporter(const PyFunctionExporter&) = delete;
  PyFunctionExporter& operator=(const PyFunctionExporter&) = delete;
  PyFunctionExporter& operator=(const PyFunctionExporter&&) = delete;
  PyFunctionExporter(PyFunctionExporter&& other)
      : exporter_(std::move(other.exporter_)), dep_(std::move(other.dep_)) {}

  void close() {
    exporter_.close();
  }
  void operator()(const mx::Args& args, const mx::Kwargs& kwargs) {
    exporter_(args, kwargs);
  }

  friend int py_function_exporter_tp_traverse(PyObject*, visitproc, void*);

 private:
  mx::FunctionExporter exporter_;
  nb::handle dep_;
};

int py_function_exporter_tp_traverse(
    PyObject* self,
    visitproc visit,
    void* arg) {
  Py_VISIT(Py_TYPE(self));
  if (!nb::inst_ready(self)) {
    return 0;
  }
  auto* p = nb::inst_ptr<PyFunctionExporter>(self);
  Py_VISIT(p->dep_.ptr());
  return 0;
}

PyType_Slot py_function_exporter_slots[] = {
    {Py_tp_traverse, (void*)py_function_exporter_tp_traverse},
    {0, 0}};

auto wrap_export_function(nb::callable fun) {
  return
      [fun = std::move(fun)](const mx::Args& args_, const mx::Kwargs& kwargs_) {
        auto kwargs = nb::dict();
        kwargs.update(nb::cast(kwargs_));
        auto args = nb::tuple(nb::cast(args_));
        auto outputs = fun(*args, **kwargs);
        std::vector<mx::array> outputs_;
        if (nb::isinstance<mx::array>(outputs)) {
          outputs_.push_back(nb::cast<mx::array>(outputs));
        } else if (!nb::try_cast(outputs, outputs_)) {
          throw std::invalid_argument(
              "[export_function] Outputs can be either a single array "
              "a tuple or list of arrays.");
        }
        return outputs_;
      };
}

void init_export(nb::module_& m) {
  m.def(
      "export_function",
      [](const std::string& file,
         const nb::callable& fun,
         const nb::args& args,
         bool shapeless,
         const nb::kwargs& kwargs) {
        auto [args_, kwargs_] =
            validate_and_extract_inputs(args, kwargs, "[export_function]");
        mx::export_function(
            file, wrap_export_function(fun), args_, kwargs_, shapeless);
      },
      "file"_a,
      "fun"_a,
      "args"_a,
      nb::kw_only(),
      "shapeless"_a = false,
      "kwargs"_a,
      R"pbdoc(
        Export a function to a file.

        Example input arrays must be provided to export a function. The example
        inputs can be variable ``*args`` and ``**kwargs`` or a tuple of arrays
        and/or dictionary of string keys with array values.

        .. warning::

          This is part of an experimental API which is likely to
          change in future versions of MLX. Functions exported with older
          versions of MLX may not be compatible with future versions.

        Args:
            file (str): File path to export the function to.
            fun (Callable): A function which takes as input zero or more
              :class:`array` and returns one or more :class:`array`.
            *args (array): Example array inputs to the function.
            shapeless (bool, optional): Whether or not the function allows
              inputs with variable shapes. Default: ``False``.
            **kwargs (array): Additional example keyword array inputs to the
              function.

        Example:

          .. code-block:: python

            def fun(x, y):
                return x + y

            x = mx.array(1)
            y = mx.array([1, 2, 3])
            mx.export_function("fun.mlxfn", fun, x, y=y)
      )pbdoc");
  m.def(
      "import_function",
      [](const std::string& file) {
        return nb::cpp_function(
            [fn = mx::import_function(file)](
                const nb::args& args, const nb::kwargs& kwargs) {
              auto [args_, kwargs_] = validate_and_extract_inputs(
                  args, kwargs, "[import_function::call]");
              return nb::tuple(nb::cast(fn(args_, kwargs_)));
            });
      },
      "file"_a,
      nb::sig("def import_function(file: str) -> Callable"),
      R"pbdoc(
        Import a function from a file.

        The imported function can be called either with ``*args`` and
        ``**kwargs`` or with a tuple of arrays and/or dictionary of string
        keys with array values. Imported functions always return a tuple of
        arrays.

        .. warning::

          This is part of an experimental API which is likely to
          change in future versions of MLX. Functions exported with older
          versions of MLX may not be compatible with future versions.

        Args:
            file (str): The file path to import the function from.

        Returns:
            Callable: The imported function.

        Example:
          >>> fn = mx.import_function("function.mlxfn")
          >>> out = fn(a, b, x=x, y=y)[0]
          >>>
          >>> out = fn((a, b), {"x": x, "y": y}[0]
      )pbdoc");

  nb::class_<PyFunctionExporter>(
      m,
      "FunctionExporter",
      nb::type_slots(py_function_exporter_slots),
      R"pbdoc(
       A context managing class for exporting multiple traces of the same
       function to a file.

       Make an instance of this class by calling fun:`mx.exporter`.
      )pbdoc")
      .def("close", &PyFunctionExporter::close)
      .def("__enter__", [](PyFunctionExporter& exporter) { return &exporter; })
      .def(
          "__exit__",
          [](PyFunctionExporter& exporter,
             const std::optional<nb::object>&,
             const std::optional<nb::object>&,
             const std::optional<nb::object>&) { exporter.close(); },
          "exc_type"_a = nb::none(),
          "exc_value"_a = nb::none(),
          "traceback"_a = nb::none())
      .def(
          "__call__",
          [](PyFunctionExporter& exporter,
             const nb::args& args,
             const nb::kwargs& kwargs) {
            auto [args_, kwargs_] =
                validate_and_extract_inputs(args, kwargs, "[export_function]");
            exporter(args_, kwargs_);
          });

  m.def(
      "exporter",
      [](const std::string& file, nb::callable fun, bool shapeless) {
        return PyFunctionExporter{
            mx::exporter(file, wrap_export_function(fun), shapeless), fun};
      },
      "file"_a,
      "fun"_a,
      nb::kw_only(),
      "shapeless"_a = false,
      R"pbdoc(
        Make a callable object to export multiple traces of a function to a file.

        .. warning::

          This is part of an experimental API which is likely to
          change in future versions of MLX. Functions exported with older
          versions of MLX may not be compatible with future versions.

        Args:
            file (str): File path to export the function to.
            shapeless (bool, optional): Whether or not the function allows
              inputs with variable shapes. Default: ``False``.

        Example:

          .. code-block:: python

            def fun(*args):
                return sum(args)

            with mx.exporter("fun.mlxfn", fun) as exporter:
                exporter(mx.array(1))
                exporter(mx.array(1), mx.array(2))
                exporter(mx.array(1), mx.array(2), mx.array(3))
      )pbdoc");
  m.def(
      "export_to_dot",
      [](nb::object file, const nb::args& args, const nb::kwargs& kwargs) {
        std::vector<mx::array> arrays =
            tree_flatten(nb::make_tuple(args, kwargs));
        mx::NodeNamer namer;
        for (const auto& n : kwargs) {
          namer.set_name(
              nb::cast<mx::array>(n.second), nb::cast<std::string>(n.first));
        }
        if (nb::isinstance<nb::str>(file)) {
          std::ofstream out(nb::cast<std::string>(file));
          mx::export_to_dot(out, std::move(namer), arrays);
        } else if (nb::hasattr(file, "write")) {
          std::ostringstream out;
          mx::export_to_dot(out, std::move(namer), arrays);
          auto write = file.attr("write");
          write(out.str());
        } else {
          throw std::invalid_argument(
              "[export_to_dot] Accepts file-like objects or strings "
              "to be used as filenames.");
        }
      },
      "file"_a,
      "args"_a,
      "kwargs"_a,
      R"pbdoc(
        Export a graph to DOT format for visualization.

        A variable number of output arrays can be provided for exporting
        The graph exported will recursively include all unevaluated inputs of
        the provided outputs.

        Args:
            file (str): The file path to export to.
            *args (array): The output arrays.
            **kwargs (dict[str, array]): Provide some names for arrays in the
              graph to make the result easier to parse.

        Example:
          >>> a = mx.array(1) + mx.array(2)
          >>> mx.export_to_dot("graph.dot", a)
          >>> x = mx.array(1)
          >>> y = mx.array(2)
          >>> mx.export_to_dot("graph.dot", x + y, x=x, y=y)
      )pbdoc");
}