|
|
|
|
|
|
|
|
#include <algorithm> |
|
|
#include <numeric> |
|
|
#include <sstream> |
|
|
#include <unordered_set> |
|
|
|
|
|
#include <nanobind/nanobind.h> |
|
|
#include <nanobind/stl/optional.h> |
|
|
#include <nanobind/stl/pair.h> |
|
|
#include <nanobind/stl/string.h> |
|
|
#include <nanobind/stl/unordered_set.h> |
|
|
#include <nanobind/stl/variant.h> |
|
|
#include <nanobind/stl/vector.h> |
|
|
|
|
|
#include "mlx/array.h" |
|
|
#include "mlx/compile.h" |
|
|
#include "mlx/compile_impl.h" |
|
|
#include "mlx/transforms.h" |
|
|
#include "mlx/transforms_impl.h" |
|
|
#include "mlx/utils.h" |
|
|
#include "python/src/mlx_func.h" |
|
|
#include "python/src/small_vector.h" |
|
|
#include "python/src/trees.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
|
|
|
using mx::operator<<; |
|
|
|
|
|
using IntOrVec = std::variant<int, std::vector<int>>; |
|
|
using StrOrSet = std::variant<std::string, std::unordered_set<std::string>>; |
|
|
|
|
|
inline std::string type_name_str(const nb::handle& o) { |
|
|
return nb::cast<std::string>(nb::type_name(o.type())); |
|
|
} |
|
|
|
|
|
auto validate_argnums_argnames( |
|
|
const std::optional<IntOrVec>& argnums, |
|
|
const StrOrSet& argnames) { |
|
|
std::unordered_set<std::string> setnames; |
|
|
if (auto pv = std::get_if<std::string>(&argnames); pv) { |
|
|
setnames = {*pv}; |
|
|
} else { |
|
|
setnames = std::get<std::unordered_set<std::string>>(argnames); |
|
|
} |
|
|
|
|
|
if (!argnums.has_value()) { |
|
|
|
|
|
if (setnames.empty()) { |
|
|
return std::make_pair(std::vector<int>{0}, setnames); |
|
|
} else { |
|
|
return std::make_pair(std::vector<int>{}, setnames); |
|
|
} |
|
|
} |
|
|
|
|
|
std::vector<int> vecnums; |
|
|
if (auto pv = std::get_if<int>(&(*argnums)); pv) { |
|
|
vecnums = {*pv}; |
|
|
} else { |
|
|
vecnums = std::get<std::vector<int>>(*argnums); |
|
|
} |
|
|
|
|
|
return std::make_pair(vecnums, setnames); |
|
|
} |
|
|
|
|
|
auto py_value_and_grad( |
|
|
const nb::callable& fun, |
|
|
std::vector<int> argnums, |
|
|
std::unordered_set<std::string> argnames, |
|
|
const std::string& error_msg_tag, |
|
|
bool scalar_func_only) { |
|
|
|
|
|
if (argnums.size() == 0 && argnames.size() == 0) { |
|
|
throw std::invalid_argument( |
|
|
error_msg_tag + " Gradient wrt no argument requested"); |
|
|
} |
|
|
for (auto arg : argnums) { |
|
|
std::sort(argnums.begin(), argnums.end()); |
|
|
if (argnums[0] < 0) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag |
|
|
<< " Can't compute the gradient of negative argument index " |
|
|
<< argnums[0]; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
for (int i = 1; i < argnums.size(); ++i) { |
|
|
if (argnums[i] == argnums[i - 1]) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " Duplicate argument index " << argnums[0] |
|
|
<< " is not allowed."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
return [fun, argnums, argnames, error_msg_tag, scalar_func_only]( |
|
|
nb::args& args, nb::kwargs& kwargs) { |
|
|
|
|
|
if (argnums.size() > 0 && argnums.back() >= args.size()) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " Can't compute the gradient of argument index " |
|
|
<< argnums.back() << " because the function is called with only " |
|
|
<< args.size() << " positional arguments."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
|
|
|
for (auto& key : argnames) { |
|
|
if (!kwargs.contains(key)) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag |
|
|
<< " Can't compute the gradient of keyword argument '" << key |
|
|
<< "' because the function is called with the " |
|
|
<< "following keyword arguments {"; |
|
|
for (auto item : kwargs) { |
|
|
msg << nb::cast<std::string>(item.first) << ","; |
|
|
} |
|
|
msg << "}"; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<mx::array> arrays; |
|
|
std::vector<int> counts(1, 0); |
|
|
std::vector<int> gradient_indices; |
|
|
for (int i = 0, j = 0; i < args.size(); ++i) { |
|
|
bool needs_grad = (j < argnums.size() && argnums[j] == i); |
|
|
auto argsi = tree_flatten(args[i], needs_grad); |
|
|
if (needs_grad) { |
|
|
auto old_size = gradient_indices.size(); |
|
|
gradient_indices.resize(old_size + argsi.size()); |
|
|
std::iota( |
|
|
gradient_indices.begin() + old_size, |
|
|
gradient_indices.end(), |
|
|
arrays.size()); |
|
|
j++; |
|
|
counts.push_back(argsi.size()); |
|
|
} |
|
|
arrays.insert(arrays.end(), argsi.begin(), argsi.end()); |
|
|
} |
|
|
for (auto item : kwargs) { |
|
|
bool needs_grad = |
|
|
(argnames.find(nb::cast<std::string>(item.first)) != argnames.end()); |
|
|
auto argsk = tree_flatten(item.second, needs_grad); |
|
|
if (needs_grad) { |
|
|
auto old_size = gradient_indices.size(); |
|
|
gradient_indices.resize(old_size + argsk.size()); |
|
|
std::iota( |
|
|
gradient_indices.begin() + old_size, |
|
|
gradient_indices.end(), |
|
|
arrays.size()); |
|
|
counts.push_back(argsk.size()); |
|
|
} |
|
|
arrays.insert(arrays.end(), argsk.begin(), argsk.end()); |
|
|
} |
|
|
std::partial_sum(counts.cbegin(), counts.cend(), counts.begin()); |
|
|
|
|
|
|
|
|
|
|
|
nb::object py_value_out; |
|
|
auto value_and_grads = mx::value_and_grad( |
|
|
[&fun, |
|
|
&arrays, |
|
|
&args, |
|
|
&kwargs, |
|
|
&py_value_out, |
|
|
&error_msg_tag, |
|
|
scalar_func_only](const std::vector<mx::array>& a) { |
|
|
nb::list tree; |
|
|
tree.append(args); |
|
|
tree.append(kwargs); |
|
|
tree_fill(tree, a); |
|
|
|
|
|
|
|
|
py_value_out = fun(*tree[0], **tree[1]); |
|
|
|
|
|
|
|
|
|
|
|
int index = 0; |
|
|
tree_visit_update(tree, [&](nb::handle node) { |
|
|
auto replace_arr = nb::cast<mx::array>(node); |
|
|
if (replace_arr.id() == a[index].id()) { |
|
|
return nb::cast(arrays[index++]); |
|
|
} else { |
|
|
return nb::cast(replace_arr); |
|
|
} |
|
|
}); |
|
|
|
|
|
|
|
|
if (!nb::isinstance<mx::array>(py_value_out)) { |
|
|
if (scalar_func_only) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " The return value of the function " |
|
|
<< "whose gradient we want to compute should be a " |
|
|
<< "scalar array; but " << type_name_str(py_value_out) |
|
|
<< " was returned."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
if (!nb::isinstance<nb::tuple>(py_value_out)) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " The return value of the function " |
|
|
<< "whose gradient we want to compute should be either a " |
|
|
<< "scalar array or a tuple with the first value being a " |
|
|
<< "scalar array (Union[array, tuple[array, Any, ...]]); but " |
|
|
<< type_name_str(py_value_out) << " was returned."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
nb::tuple ret = nb::cast<nb::tuple>(py_value_out); |
|
|
if (ret.size() == 0) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " The return value of the function " |
|
|
<< "whose gradient we want to compute should be either a " |
|
|
<< "scalar array or a non-empty tuple. The first value should be a " |
|
|
<< "scalar array and the rest can be anything. Instead, " |
|
|
<< "we got an empty tuple."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
if (!nb::isinstance<mx::array>(ret[0])) { |
|
|
std::ostringstream msg; |
|
|
msg << error_msg_tag << " The return value of the function " |
|
|
<< "whose gradient we want to compute should be either a " |
|
|
<< "scalar array or a tuple with the first value being a " |
|
|
<< "scalar array (Union[array, tuple[array, Any, ...]]); but it " |
|
|
<< "was a tuple with the first value being of type " |
|
|
<< type_name_str(ret[0]) << " ."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
} |
|
|
|
|
|
return tree_flatten(py_value_out, false); |
|
|
}, |
|
|
gradient_indices)(arrays); |
|
|
|
|
|
auto value = value_and_grads.first; |
|
|
auto gradients = value_and_grads.second; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nb::object positional_grads; |
|
|
nb::object keyword_grads; |
|
|
nb::object py_grads; |
|
|
|
|
|
|
|
|
if (argnums.size() == 1) { |
|
|
positional_grads = tree_unflatten(args[argnums[0]], gradients, counts[0]); |
|
|
} else if (argnums.size() > 1) { |
|
|
nb::list grads_; |
|
|
for (int i = 0; i < argnums.size(); i++) { |
|
|
grads_.append(tree_unflatten(args[argnums[i]], gradients, counts[i])); |
|
|
} |
|
|
positional_grads = nb::tuple(grads_); |
|
|
} else { |
|
|
positional_grads = nb::none(); |
|
|
} |
|
|
|
|
|
|
|
|
if (argnames.size() == 0) { |
|
|
py_grads = positional_grads; |
|
|
} else { |
|
|
nb::dict grads_; |
|
|
int i = 0; |
|
|
for (auto item : kwargs) { |
|
|
auto k = nb::cast<std::string>(item.first); |
|
|
if (argnames.find(k) != argnames.end()) { |
|
|
grads_[k.c_str()] = tree_unflatten( |
|
|
nb::borrow(item.second), gradients, counts[i++ + argnums.size()]); |
|
|
} |
|
|
} |
|
|
keyword_grads = grads_; |
|
|
|
|
|
py_grads = nb::make_tuple(positional_grads, keyword_grads); |
|
|
} |
|
|
|
|
|
|
|
|
nb::object return_value = tree_unflatten(py_value_out, value); |
|
|
return std::make_pair(return_value, py_grads); |
|
|
}; |
|
|
} |
|
|
|
|
|
auto py_vmap( |
|
|
const nb::callable& fun, |
|
|
const nb::object& in_axes, |
|
|
const nb::object& out_axes) { |
|
|
return [fun, in_axes, out_axes](const nb::args& args) { |
|
|
auto axes_to_flat_tree = [](const nb::object& tree, |
|
|
const nb::object& axes, |
|
|
bool output_axes) { |
|
|
std::vector<int> flat_axes; |
|
|
bool encountered_tuple = false; |
|
|
tree_visit( |
|
|
{tree, axes}, |
|
|
[&flat_axes, &encountered_tuple, output_axes]( |
|
|
const std::vector<nb::object>& inputs) { |
|
|
if (nb::isinstance<mx::array>(inputs[0])) { |
|
|
if (inputs[1].is_none()) { |
|
|
flat_axes.push_back(-1); |
|
|
} else if (nb::isinstance<nb::int_>(inputs[1])) { |
|
|
int axis = nb::cast<int>(nb::cast<nb::int_>(inputs[1])); |
|
|
const mx::array& x = nb::cast<mx::array>(inputs[0]); |
|
|
if (axis < 0) { |
|
|
axis += x.ndim() + output_axes; |
|
|
} |
|
|
if (axis < 0 || axis >= (x.ndim() + output_axes)) { |
|
|
std::ostringstream msg; |
|
|
msg << "[vmap] Invalid" << (output_axes ? " output " : " ") |
|
|
<< "vectorization axis " << axis |
|
|
<< " for array with shape " << x.shape(); |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
flat_axes.push_back(axis); |
|
|
} else if (nb::isinstance<nb::tuple>(inputs[1])) { |
|
|
encountered_tuple = true; |
|
|
auto l = nb::cast<nb::tuple>(inputs[1]); |
|
|
if (l.size() == 1 && nb::isinstance<nb::int_>(l[0])) { |
|
|
int axis = nb::cast<int>(nb::cast<nb::int_>(l[0])); |
|
|
const mx::array& x = nb::cast<mx::array>(inputs[0]); |
|
|
if (axis < 0) { |
|
|
axis += x.ndim() + output_axes; |
|
|
} |
|
|
if (axis < 0 || axis >= (x.ndim() + output_axes)) { |
|
|
std::ostringstream msg; |
|
|
msg << "[vmap] Invalid" << (output_axes ? " output " : " ") |
|
|
<< "vectorization axis " << axis |
|
|
<< " for array with shape " << x.shape(); |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
flat_axes.push_back(axis); |
|
|
} else if (l.size() == 1 && l[0].is_none()) { |
|
|
flat_axes.push_back(-1); |
|
|
} else { |
|
|
throw std::invalid_argument( |
|
|
"[vmap] axis must be int or None."); |
|
|
} |
|
|
} else { |
|
|
throw std::invalid_argument("[vmap] axis must be int or None."); |
|
|
} |
|
|
} else { |
|
|
throw std::invalid_argument( |
|
|
"[vmap] The arguments should contain only arrays"); |
|
|
} |
|
|
}); |
|
|
if (encountered_tuple && !nb::isinstance<mx::array>(tree)) { |
|
|
throw std::invalid_argument("[vmap] axis must be int or None."); |
|
|
} |
|
|
return flat_axes; |
|
|
}; |
|
|
|
|
|
|
|
|
auto inputs = tree_flatten(args, true); |
|
|
auto flat_in_axes = |
|
|
axes_to_flat_tree((args.size() == 1) ? args[0] : args, in_axes, false); |
|
|
|
|
|
|
|
|
|
|
|
nb::object py_outputs; |
|
|
|
|
|
auto vmap_fn = |
|
|
[&fun, &args, &inputs, &py_outputs](const std::vector<mx::array>& a) { |
|
|
|
|
|
py_outputs = fun(*tree_unflatten(args, a)); |
|
|
|
|
|
|
|
|
return tree_flatten(py_outputs, true); |
|
|
}; |
|
|
|
|
|
auto [trace_inputs, trace_outputs] = |
|
|
mx::detail::vmap_trace(vmap_fn, inputs, flat_in_axes); |
|
|
|
|
|
auto flat_out_axes = axes_to_flat_tree(py_outputs, out_axes, true); |
|
|
|
|
|
|
|
|
auto outputs = mx::detail::vmap_replace( |
|
|
inputs, trace_inputs, trace_outputs, flat_in_axes, flat_out_axes); |
|
|
|
|
|
|
|
|
return tree_unflatten(py_outputs, outputs); |
|
|
}; |
|
|
} |
|
|
|
|
|
struct PyCompiledFun { |
|
|
nb::callable fun; |
|
|
std::uintptr_t fun_id; |
|
|
nb::object captured_inputs; |
|
|
nb::object captured_outputs; |
|
|
bool shapeless; |
|
|
|
|
|
|
|
|
|
|
|
struct AttachedData { |
|
|
nb::object output_structure; |
|
|
int num_outputs; |
|
|
|
|
|
AttachedData(nb::object output_structure_, int num_outputs_) |
|
|
: output_structure(output_structure_), num_outputs(num_outputs_) {} |
|
|
}; |
|
|
|
|
|
PyCompiledFun( |
|
|
const nb::callable& fun, |
|
|
nb::object inputs, |
|
|
nb::object outputs, |
|
|
bool shapeless) |
|
|
: fun(fun), |
|
|
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())), |
|
|
captured_inputs(inputs), |
|
|
captured_outputs(outputs), |
|
|
shapeless(shapeless) {} |
|
|
|
|
|
PyCompiledFun(const PyCompiledFun&) = delete; |
|
|
PyCompiledFun& operator=(const PyCompiledFun&) = delete; |
|
|
PyCompiledFun& operator=(PyCompiledFun&& other) = delete; |
|
|
PyCompiledFun(PyCompiledFun&& other) |
|
|
: fun(std::move(other.fun)), |
|
|
fun_id(reinterpret_cast<std::uintptr_t>(fun.ptr())) { |
|
|
other.fun_id = 0; |
|
|
captured_inputs = std::move(other.captured_inputs); |
|
|
captured_outputs = std::move(other.captured_outputs); |
|
|
shapeless = other.shapeless; |
|
|
}; |
|
|
|
|
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { |
|
|
|
|
|
std::vector<mx::array> inputs; |
|
|
|
|
|
|
|
|
std::vector<uint64_t> constants; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
constexpr uint64_t array_identifier = 18446744073709551557UL; |
|
|
constexpr uint64_t list_identifier = 18446744073709551533UL; |
|
|
constexpr uint64_t dict_identifier = 18446744073709551521UL; |
|
|
constexpr uint64_t none_identifier = 10239356951478402889UL; |
|
|
|
|
|
|
|
|
std::function<void(nb::handle)> recurse; |
|
|
recurse = [&](nb::handle obj) { |
|
|
if (nb::isinstance<nb::list>(obj)) { |
|
|
auto l = nb::cast<nb::list>(obj); |
|
|
constants.push_back(list_identifier); |
|
|
for (int i = 0; i < l.size(); ++i) { |
|
|
recurse(l[i]); |
|
|
} |
|
|
} else if (nb::isinstance<nb::tuple>(obj)) { |
|
|
auto l = nb::cast<nb::tuple>(obj); |
|
|
constants.push_back(list_identifier); |
|
|
for (auto item : obj) { |
|
|
recurse(item); |
|
|
} |
|
|
} else if (nb::isinstance<nb::dict>(obj)) { |
|
|
auto d = nb::cast<nb::dict>(obj); |
|
|
constants.push_back(dict_identifier); |
|
|
for (auto item : d) { |
|
|
auto r = item.first.attr("__hash__")(); |
|
|
constants.push_back(nb::cast<int64_t>(r)); |
|
|
recurse(item.second); |
|
|
} |
|
|
} else if (nb::isinstance<mx::array>(obj)) { |
|
|
inputs.push_back(nb::cast<mx::array>(obj)); |
|
|
constants.push_back(array_identifier); |
|
|
} else if (nb::isinstance<nb::str>(obj)) { |
|
|
auto r = obj.attr("__hash__")(); |
|
|
constants.push_back(nb::cast<int64_t>(r)); |
|
|
} else if (nb::isinstance<nb::int_>(obj)) { |
|
|
constants.push_back(nb::cast<int64_t>(obj)); |
|
|
} else if (nb::isinstance<nb::float_>(obj)) { |
|
|
auto r = nb::cast<double>(obj); |
|
|
constants.push_back(*reinterpret_cast<uint64_t*>(&r)); |
|
|
} else if (obj.is_none()) { |
|
|
constants.push_back(none_identifier); |
|
|
} else { |
|
|
std::ostringstream msg; |
|
|
msg << "[compile] Function arguments must be trees of arrays " |
|
|
<< "or constants (floats, ints, strings, or None), but received " |
|
|
<< "type " << type_name_str(obj) << "."; |
|
|
throw std::invalid_argument(msg.str()); |
|
|
} |
|
|
}; |
|
|
|
|
|
recurse(args); |
|
|
int num_args = inputs.size(); |
|
|
recurse(kwargs); |
|
|
auto compile_fun = [this, &args, &kwargs, num_args]( |
|
|
const std::vector<mx::array>& a) { |
|
|
|
|
|
std::vector<mx::array> flat_in_captures; |
|
|
std::vector<mx::array> trace_captures; |
|
|
if (!captured_inputs.is_none()) { |
|
|
flat_in_captures = tree_flatten(captured_inputs, false); |
|
|
trace_captures.insert( |
|
|
trace_captures.end(), a.end() - flat_in_captures.size(), a.end()); |
|
|
tree_fill(captured_inputs, trace_captures); |
|
|
} |
|
|
|
|
|
auto tree_outputs = |
|
|
fun(*tree_unflatten(args, a), **tree_unflatten(kwargs, a, num_args)); |
|
|
auto [outputs, py_outputs] = |
|
|
tree_flatten_with_structure(std::move(tree_outputs), false); |
|
|
|
|
|
std::shared_ptr<void> extra_data = |
|
|
std::make_shared<AttachedData>(py_outputs, outputs.size()); |
|
|
|
|
|
if (!captured_outputs.is_none()) { |
|
|
auto flat_out_captures = tree_flatten(captured_outputs, false); |
|
|
outputs.insert( |
|
|
outputs.end(), |
|
|
std::make_move_iterator(flat_out_captures.begin()), |
|
|
std::make_move_iterator(flat_out_captures.end())); |
|
|
} |
|
|
|
|
|
|
|
|
if (!captured_inputs.is_none()) { |
|
|
tree_replace(captured_inputs, trace_captures, flat_in_captures); |
|
|
} |
|
|
return mx::detail::ArraysAndExtra{outputs, extra_data}; |
|
|
}; |
|
|
|
|
|
if (!captured_inputs.is_none()) { |
|
|
auto flat_in_captures = tree_flatten(captured_inputs, false); |
|
|
inputs.insert( |
|
|
inputs.end(), |
|
|
std::make_move_iterator(flat_in_captures.begin()), |
|
|
std::make_move_iterator(flat_in_captures.end())); |
|
|
} |
|
|
|
|
|
|
|
|
auto [outputs, extra_data] = |
|
|
mx::detail::compile(compile_fun, fun_id, shapeless, constants)(inputs); |
|
|
|
|
|
int num_outputs = |
|
|
reinterpret_cast<AttachedData*>(extra_data.get())->num_outputs; |
|
|
nb::object py_outputs = |
|
|
reinterpret_cast<AttachedData*>(extra_data.get())->output_structure; |
|
|
|
|
|
if (!captured_outputs.is_none()) { |
|
|
std::vector<mx::array> captures( |
|
|
std::make_move_iterator(outputs.begin() + num_outputs), |
|
|
std::make_move_iterator(outputs.end())); |
|
|
tree_fill(captured_outputs, captures); |
|
|
} |
|
|
|
|
|
|
|
|
return tree_unflatten_from_structure(std::move(py_outputs), outputs); |
|
|
} |
|
|
|
|
|
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { |
|
|
return const_cast<PyCompiledFun*>(this)->call_impl(args, kwargs); |
|
|
}; |
|
|
|
|
|
~PyCompiledFun() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
mx::detail::compile_erase(fun_id); |
|
|
fun.reset(); |
|
|
captured_inputs.reset(); |
|
|
captured_outputs.reset(); |
|
|
} |
|
|
}; |
|
|
|
|
|
class PyCheckpointedFun { |
|
|
public: |
|
|
PyCheckpointedFun(nb::callable fun) : fun_(std::move(fun)) {} |
|
|
~PyCheckpointedFun() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
fun_.reset(); |
|
|
} |
|
|
|
|
|
struct InnerFunction { |
|
|
nb::object fun_; |
|
|
nb::object args_structure_; |
|
|
std::weak_ptr<nb::object> output_structure_; |
|
|
|
|
|
InnerFunction( |
|
|
nb::object fun, |
|
|
nb::object args_structure, |
|
|
std::weak_ptr<nb::object> output_structure) |
|
|
: fun_(std::move(fun)), |
|
|
args_structure_(std::move(args_structure)), |
|
|
output_structure_(output_structure) {} |
|
|
~InnerFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
fun_.reset(); |
|
|
args_structure_.reset(); |
|
|
} |
|
|
|
|
|
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) { |
|
|
auto args = nb::cast<nb::tuple>( |
|
|
tree_unflatten_from_structure(args_structure_, inputs)); |
|
|
auto [outputs, output_structure] = |
|
|
tree_flatten_with_structure(fun_(*args[0], **args[1]), false); |
|
|
if (auto s = output_structure_.lock()) { |
|
|
*s = output_structure; |
|
|
} |
|
|
return outputs; |
|
|
} |
|
|
}; |
|
|
|
|
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { |
|
|
auto output_structure = std::make_shared<nb::object>(); |
|
|
auto full_args = nb::make_tuple(args, kwargs); |
|
|
auto [inputs, args_structure] = |
|
|
tree_flatten_with_structure(full_args, false); |
|
|
|
|
|
auto outputs = mx::checkpoint( |
|
|
InnerFunction(fun_, args_structure, output_structure))(inputs); |
|
|
|
|
|
return tree_unflatten_from_structure(*output_structure, outputs); |
|
|
} |
|
|
|
|
|
nb::object operator()(const nb::args& args, const nb::kwargs& kwargs) const { |
|
|
return const_cast<PyCheckpointedFun*>(this)->call_impl(args, kwargs); |
|
|
} |
|
|
|
|
|
private: |
|
|
nb::callable fun_; |
|
|
}; |
|
|
|
|
|
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg); |
|
|
|
|
|
int py_custom_function_tp_clear(PyObject* self); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyCustomFunction { |
|
|
public: |
|
|
PyCustomFunction(nb::callable fun) : fun_(std::move(fun)) {} |
|
|
~PyCustomFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
reset(); |
|
|
} |
|
|
|
|
|
struct InnerFunction { |
|
|
nb::callable fun_; |
|
|
nb::object input_structure_; |
|
|
std::shared_ptr<nb::object> output_structure_; |
|
|
|
|
|
InnerFunction( |
|
|
nb::callable fun, |
|
|
nb::object input_structure, |
|
|
std::shared_ptr<nb::object> output_structure) |
|
|
: fun_(std::move(fun)), |
|
|
input_structure_(std::move(input_structure)), |
|
|
output_structure_(std::move(output_structure)) {} |
|
|
~InnerFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
fun_.reset(); |
|
|
input_structure_.reset(); |
|
|
if (output_structure_.use_count() == 1) { |
|
|
output_structure_->reset(); |
|
|
} |
|
|
} |
|
|
|
|
|
std::vector<mx::array> operator()(const std::vector<mx::array>& inputs) { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
auto new_inputs = nb::cast<nb::tuple>( |
|
|
tree_unflatten_from_structure(input_structure_, inputs)); |
|
|
std::vector<mx::array> outputs; |
|
|
std::tie(outputs, *output_structure_) = |
|
|
tree_flatten_with_structure(fun_(*new_inputs[0], **new_inputs[1])); |
|
|
return outputs; |
|
|
} |
|
|
}; |
|
|
|
|
|
struct InnerVJPFunction { |
|
|
nb::callable vjp_fun_; |
|
|
nb::object input_structure_; |
|
|
std::shared_ptr<nb::object> output_structure_; |
|
|
|
|
|
InnerVJPFunction( |
|
|
nb::callable vjp_fun, |
|
|
nb::object input_structure, |
|
|
std::shared_ptr<nb::object> output_structure) |
|
|
: vjp_fun_(std::move(vjp_fun)), |
|
|
input_structure_(std::move(input_structure)), |
|
|
output_structure_(std::move(output_structure)) {} |
|
|
~InnerVJPFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
vjp_fun_.reset(); |
|
|
input_structure_.reset(); |
|
|
if (output_structure_.use_count() == 1) { |
|
|
output_structure_->reset(); |
|
|
} |
|
|
} |
|
|
|
|
|
std::vector<mx::array> operator()( |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& cotangents, |
|
|
const std::vector<mx::array>& outputs) { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
auto new_inputs = nb::cast<nb::tuple>( |
|
|
tree_unflatten_from_structure(input_structure_, primals)); |
|
|
auto args = nb::cast<nb::tuple>(new_inputs[0]); |
|
|
auto new_cotangents = |
|
|
tree_unflatten_from_structure(*output_structure_, cotangents); |
|
|
auto new_outputs = |
|
|
tree_unflatten_from_structure(*output_structure_, outputs); |
|
|
|
|
|
if (args.size() == 1) { |
|
|
return tree_flatten( |
|
|
vjp_fun_(args[0], new_cotangents, new_outputs, **new_inputs[1]), |
|
|
false); |
|
|
} else { |
|
|
return tree_flatten( |
|
|
vjp_fun_(args, new_cotangents, new_outputs, **new_inputs[1]), |
|
|
false); |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
struct InnerJVPFunction { |
|
|
nb::callable jvp_fun_; |
|
|
nb::object input_structure_; |
|
|
|
|
|
InnerJVPFunction(nb::callable jvp_fun, nb::object input_structure) |
|
|
: jvp_fun_(std::move(jvp_fun)), |
|
|
input_structure_(std::move(input_structure)) {} |
|
|
~InnerJVPFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
jvp_fun_.reset(); |
|
|
input_structure_.reset(); |
|
|
} |
|
|
|
|
|
std::vector<mx::array> operator()( |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& tangents, |
|
|
const std::vector<int>& argnums) { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
auto new_inputs = nb::cast<nb::tuple>( |
|
|
tree_unflatten_from_structure(input_structure_, primals)); |
|
|
auto args = nb::cast<nb::tuple>(new_inputs[0]); |
|
|
auto kwargs = nb::cast<nb::dict>(new_inputs[1]); |
|
|
if (kwargs.size() > 0) { |
|
|
throw std::invalid_argument( |
|
|
"[custom jvp] Function should only accept positional arguments"); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<bool> have_tangents(primals.size(), false); |
|
|
for (auto arg : argnums) { |
|
|
have_tangents[arg] = true; |
|
|
} |
|
|
int array_index = 0; |
|
|
int tangent_index = 0; |
|
|
auto new_tangents = |
|
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) { |
|
|
if (nb::isinstance<mx::array>(element) && |
|
|
have_tangents[array_index++]) { |
|
|
return nb::cast(tangents[tangent_index++]); |
|
|
} else { |
|
|
return nb::none(); |
|
|
} |
|
|
})); |
|
|
|
|
|
if (args.size() == 1) { |
|
|
return tree_flatten(jvp_fun_(args[0], new_tangents[0]), false); |
|
|
} else { |
|
|
return tree_flatten(jvp_fun_(args, new_tangents), false); |
|
|
} |
|
|
} |
|
|
}; |
|
|
|
|
|
struct InnerVmapFunction { |
|
|
nb::callable vmap_fun_; |
|
|
nb::object input_structure_; |
|
|
|
|
|
InnerVmapFunction(nb::callable vmap_fun, nb::object input_structure) |
|
|
: vmap_fun_(std::move(vmap_fun)), |
|
|
input_structure_(std::move(input_structure)) {} |
|
|
~InnerVmapFunction() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
vmap_fun_.reset(); |
|
|
input_structure_.reset(); |
|
|
} |
|
|
|
|
|
std::pair<std::vector<mx::array>, std::vector<int>> operator()( |
|
|
const std::vector<mx::array>& inputs, |
|
|
const std::vector<int>& axes) { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
auto new_inputs = nb::cast<nb::tuple>( |
|
|
tree_unflatten_from_structure(input_structure_, inputs)); |
|
|
auto args = nb::cast<nb::tuple>(new_inputs[0]); |
|
|
auto kwargs = nb::cast<nb::dict>(new_inputs[1]); |
|
|
if (kwargs.size() > 0) { |
|
|
throw std::invalid_argument( |
|
|
"[custom vmap] Function should only accept positional arguments"); |
|
|
} |
|
|
|
|
|
int arr_index = 0; |
|
|
auto new_axes = |
|
|
nb::cast<nb::tuple>(tree_map(args, [&](nb::handle element) { |
|
|
int axis = axes[arr_index++]; |
|
|
if (nb::isinstance<mx::array>(element) && axis >= 0) { |
|
|
return nb::cast(axis); |
|
|
} else { |
|
|
return nb::none(); |
|
|
} |
|
|
})); |
|
|
|
|
|
nb::object result; |
|
|
if (args.size() == 1) { |
|
|
result = vmap_fun_(args[0], new_axes[0]); |
|
|
} else { |
|
|
result = vmap_fun_(args, new_axes); |
|
|
} |
|
|
|
|
|
if (!nb::isinstance<nb::tuple>(result)) { |
|
|
throw std::invalid_argument( |
|
|
"[custom vmap] Vmap function should return a tuple with 2 items."); |
|
|
} |
|
|
nb::tuple result_tuple = nb::cast<nb::tuple>(result); |
|
|
if (result_tuple.size() != 2) { |
|
|
throw std::invalid_argument( |
|
|
"[custom vmap] Vmap function should return a tuple with 2 items."); |
|
|
} |
|
|
|
|
|
std::vector<mx::array> outputs; |
|
|
std::vector<int> output_axes; |
|
|
tree_visit({result_tuple[0], result_tuple[1]}, [&](auto objects) { |
|
|
if (nb::isinstance<mx::array>(objects[0])) { |
|
|
outputs.push_back(nb::cast<mx::array>(objects[0])); |
|
|
output_axes.push_back( |
|
|
objects[1].is_none() ? -1 : nb::cast<int>(objects[1])); |
|
|
} |
|
|
}); |
|
|
|
|
|
return {outputs, output_axes}; |
|
|
} |
|
|
}; |
|
|
|
|
|
nb::object call_impl(const nb::args& args, const nb::kwargs& kwargs) { |
|
|
if (!vjp_fun_.has_value() && !jvp_fun_.has_value() && |
|
|
!vmap_fun_.has_value()) { |
|
|
return fun_(*args, **kwargs); |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<mx::array> input_arrays; |
|
|
nb::object input_structure; |
|
|
auto full_args = nb::make_tuple(args, kwargs); |
|
|
std::tie(input_arrays, input_structure) = |
|
|
tree_flatten_with_structure(full_args, false); |
|
|
|
|
|
|
|
|
|
|
|
auto output_structure = std::make_shared<nb::object>(); |
|
|
|
|
|
|
|
|
|
|
|
auto f = mx::custom_function( |
|
|
InnerFunction(fun_, input_structure, output_structure), |
|
|
make_vjp_function(input_structure, output_structure), |
|
|
make_jvp_function(input_structure), |
|
|
make_vmap_function(input_structure)); |
|
|
|
|
|
auto outputs = f(input_arrays); |
|
|
return tree_unflatten_from_structure(*output_structure, outputs); |
|
|
} |
|
|
|
|
|
PyCustomFunction& set_vjp(nb::callable vjp_fun) { |
|
|
vjp_fun_ = vjp_fun; |
|
|
return *this; |
|
|
} |
|
|
|
|
|
PyCustomFunction& set_jvp(nb::callable jvp_fun) { |
|
|
jvp_fun_ = jvp_fun; |
|
|
return *this; |
|
|
} |
|
|
|
|
|
PyCustomFunction& set_vmap(nb::callable vmap_fun) { |
|
|
vmap_fun_ = vmap_fun; |
|
|
return *this; |
|
|
} |
|
|
void reset() { |
|
|
fun_.reset(); |
|
|
if (vjp_fun_.has_value()) { |
|
|
(*vjp_fun_).reset(); |
|
|
} |
|
|
if (jvp_fun_.has_value()) { |
|
|
(*jvp_fun_).reset(); |
|
|
} |
|
|
if (vmap_fun_.has_value()) { |
|
|
(*vmap_fun_).reset(); |
|
|
} |
|
|
} |
|
|
|
|
|
friend int py_custom_function_tp_traverse(PyObject*, visitproc, void*); |
|
|
|
|
|
private: |
|
|
std::optional<InnerVJPFunction> make_vjp_function( |
|
|
nb::object input_structure, |
|
|
std::shared_ptr<nb::object> output_structure) { |
|
|
if (!vjp_fun_.has_value()) { |
|
|
return std::nullopt; |
|
|
} |
|
|
|
|
|
return InnerVJPFunction(*vjp_fun_, input_structure, output_structure); |
|
|
} |
|
|
|
|
|
std::optional<InnerJVPFunction> make_jvp_function( |
|
|
nb::object input_structure) { |
|
|
if (!jvp_fun_.has_value()) { |
|
|
return std::nullopt; |
|
|
} |
|
|
|
|
|
return InnerJVPFunction(*jvp_fun_, input_structure); |
|
|
} |
|
|
|
|
|
std::optional<InnerVmapFunction> make_vmap_function( |
|
|
nb::object input_structure) { |
|
|
if (!vmap_fun_.has_value()) { |
|
|
return std::nullopt; |
|
|
} |
|
|
|
|
|
return InnerVmapFunction(*vmap_fun_, input_structure); |
|
|
} |
|
|
|
|
|
nb::callable fun_; |
|
|
std::optional<nb::callable> vjp_fun_; |
|
|
std::optional<nb::callable> jvp_fun_; |
|
|
std::optional<nb::callable> vmap_fun_; |
|
|
}; |
|
|
|
|
|
int py_custom_function_tp_traverse(PyObject* self, visitproc visit, void* arg) { |
|
|
Py_VISIT(Py_TYPE(self)); |
|
|
if (!nb::inst_ready(self)) { |
|
|
return 0; |
|
|
} |
|
|
|
|
|
auto* p = nb::inst_ptr<PyCustomFunction>(self); |
|
|
nb::handle v = nb::find(p->fun_); |
|
|
Py_VISIT(v.ptr()); |
|
|
if (p->vjp_fun_.has_value()) { |
|
|
nb::handle v = nb::find(*(p->vjp_fun_)); |
|
|
Py_VISIT(v.ptr()); |
|
|
} |
|
|
if (p->jvp_fun_.has_value()) { |
|
|
nb::handle v = nb::find(*(p->jvp_fun_)); |
|
|
Py_VISIT(v.ptr()); |
|
|
} |
|
|
if (p->vmap_fun_.has_value()) { |
|
|
nb::handle v = nb::find(*(p->vmap_fun_)); |
|
|
Py_VISIT(v.ptr()); |
|
|
} |
|
|
return 0; |
|
|
} |
|
|
int py_custom_function_tp_clear(PyObject* self) { |
|
|
auto* p = nb::inst_ptr<PyCustomFunction>(self); |
|
|
p->reset(); |
|
|
return 0; |
|
|
} |
|
|
PyType_Slot py_custom_function_slots[] = { |
|
|
{Py_tp_traverse, (void*)py_custom_function_tp_traverse}, |
|
|
{Py_tp_clear, (void*)py_custom_function_tp_clear}, |
|
|
{0, 0}}; |
|
|
|
|
|
void init_transforms(nb::module_& m) { |
|
|
nb::class_<PyCustomFunction>( |
|
|
m, |
|
|
"custom_function", |
|
|
nb::type_slots(py_custom_function_slots), |
|
|
R"pbdoc( |
|
|
Set up a function for custom gradient and vmap definitions. |
|
|
|
|
|
This class is meant to be used as a function decorator. Instances are |
|
|
callables that behave identically to the wrapped function. However, when |
|
|
a function transformation is used (e.g. computing gradients using |
|
|
:func:`value_and_grad`) then the functions defined via |
|
|
:meth:`custom_function.vjp`, :meth:`custom_function.jvp` and |
|
|
:meth:`custom_function.vmap` are used instead of the default transformation. |
|
|
|
|
|
Note, all custom transformations are optional. Undefined transformations |
|
|
fall back to the default behaviour. |
|
|
|
|
|
Example: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
@mx.custom_function |
|
|
def f(x, y): |
|
|
return mx.sin(x) * y |
|
|
|
|
|
@f.vjp |
|
|
def f_vjp(primals, cotangent, output): |
|
|
x, y = primals |
|
|
return cotan * mx.cos(x) * y, cotan * mx.sin(x) |
|
|
|
|
|
@f.jvp |
|
|
def f_jvp(primals, tangents): |
|
|
x, y = primals |
|
|
dx, dy = tangents |
|
|
return dx * mx.cos(x) * y + dy * mx.sin(x) |
|
|
|
|
|
@f.vmap |
|
|
def f_vmap(inputs, axes): |
|
|
x, y = inputs |
|
|
ax, ay = axes |
|
|
if ay != ax and ax is not None: |
|
|
y = y.swapaxes(ay, ax) |
|
|
return mx.sin(x) * y, (ax or ay) |
|
|
|
|
|
All ``custom_function`` instances behave as pure functions. Namely, any |
|
|
variables captured will be treated as constants and no gradients will be |
|
|
computed with respect to the captured arrays. For instance: |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
def g(x, y): |
|
|
@mx.custom_function |
|
|
def f(x): |
|
|
return x * y |
|
|
|
|
|
@f.vjp |
|
|
def f_vjp(x, dx, fx): |
|
|
# Note that we have only x, dx and fx and nothing with respect to y |
|
|
raise ValueError("Abort!") |
|
|
|
|
|
return f(x) |
|
|
|
|
|
x = mx.array(2.0) |
|
|
y = mx.array(3.0) |
|
|
print(g(x, y)) # prints 6.0 |
|
|
print(mx.grad(g)(x, y)) # Raises exception |
|
|
print(mx.grad(g, argnums=1)(x, y)) # prints 0.0 |
|
|
)pbdoc") |
|
|
.def( |
|
|
nb::init<nb::callable>(), |
|
|
"f"_a, |
|
|
nb::sig("def __init__(self, f: Callable)")) |
|
|
.def("__call__", &PyCustomFunction::call_impl) |
|
|
.def( |
|
|
"vjp", |
|
|
&PyCustomFunction::set_vjp, |
|
|
"f"_a, |
|
|
nb::sig("def vjp(self, f: Callable)"), |
|
|
R"pbdoc( |
|
|
Define a custom vjp for the wrapped function. |
|
|
|
|
|
The vjp function takes three arguments: |
|
|
|
|
|
- *primals*: A pytree that contains all the positional arguments to |
|
|
the function. It could be a single array, a tuple of arrays or a |
|
|
full blown tuple of dicts of arrays etc. |
|
|
- *cotangents*: A pytree that matches the structure of the output |
|
|
but contains the cotangents (usually the gradients of the loss |
|
|
function with respect to the outputs). |
|
|
- *outputs*: The outputs of the function to be used to avoid |
|
|
recomputing them for the gradient computation. |
|
|
|
|
|
The vjp function should return the same pytree structure as the |
|
|
primals but containing the corresponding computed cotangents. |
|
|
)pbdoc") |
|
|
.def( |
|
|
"jvp", |
|
|
&PyCustomFunction::set_jvp, |
|
|
"f"_a, |
|
|
nb::sig("def jvp(self, f: Callable)"), |
|
|
R"pbdoc( |
|
|
Define a custom jvp for the wrapped function. |
|
|
|
|
|
The jvp function takes two arguments: |
|
|
|
|
|
- *primals*: A pytree that contains all the positional arguments to |
|
|
the function. It could be a single array, a tuple of arrays or a |
|
|
full blown tuple of dicts of arrays etc. |
|
|
- *tangents*: A pytree that matches the structure of the inputs but |
|
|
instead contains the gradients wrt to each input. Tangents could |
|
|
be ``None`` if some inputs don't have an associated gradient. |
|
|
|
|
|
The jvp function should return the same pytree structure as the |
|
|
outputs of the function but containing the tangents. |
|
|
)pbdoc") |
|
|
.def( |
|
|
"vmap", |
|
|
&PyCustomFunction::set_vmap, |
|
|
"f"_a, |
|
|
nb::sig("def vmap(self, f: Callable)"), |
|
|
R"pbdoc( |
|
|
Define a custom vectorization transformation for the wrapped function. |
|
|
|
|
|
The vmap function takes two arguments: |
|
|
|
|
|
- *inputs*: A pytree that contains all the positional arguments to |
|
|
the function. It could be a single array, a tuple of arrays or a |
|
|
full blown tuple of dicts of arrays etc. |
|
|
- *axes*: A pytree that matches the structure of the inputs but |
|
|
instead contains the vectorization axis for each input or |
|
|
``None`` if an input is not vectorized. |
|
|
|
|
|
The vmap function should return the outputs of the original |
|
|
function but vectorized over the provided axes. It should also |
|
|
return a pytree with the vectorization axes of each output. If some |
|
|
outputs are no longer vectorized, then their vectorization axis |
|
|
should be ``None``. |
|
|
)pbdoc"); |
|
|
|
|
|
m.def( |
|
|
"eval", |
|
|
[](const nb::args& args) { |
|
|
std::vector<mx::array> arrays = tree_flatten(args, false); |
|
|
{ |
|
|
nb::gil_scoped_release nogil; |
|
|
eval(arrays); |
|
|
} |
|
|
}, |
|
|
nb::arg(), |
|
|
nb::sig("def eval(*args) -> None"), |
|
|
R"pbdoc( |
|
|
Evaluate an :class:`array` or tree of :class:`array`. |
|
|
|
|
|
Args: |
|
|
*args (arrays or trees of arrays): Each argument can be a single array |
|
|
or a tree of arrays. If a tree is given the nodes can be a Python |
|
|
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not |
|
|
arrays are ignored. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"async_eval", |
|
|
[](const nb::args& args) { |
|
|
std::vector<mx::array> arrays = tree_flatten(args, false); |
|
|
{ |
|
|
nb::gil_scoped_release nogil; |
|
|
async_eval(arrays); |
|
|
} |
|
|
}, |
|
|
nb::arg(), |
|
|
nb::sig("def async_eval(*args)"), |
|
|
R"pbdoc( |
|
|
Asynchronously evaluate an :class:`array` or tree of :class:`array`. |
|
|
|
|
|
.. note:: |
|
|
|
|
|
This is an experimental API and may change in future versions. |
|
|
|
|
|
Args: |
|
|
*args (arrays or trees of arrays): Each argument can be a single array |
|
|
or a tree of arrays. If a tree is given the nodes can be a Python |
|
|
:class:`list`, :class:`tuple` or :class:`dict`. Leaves which are not |
|
|
arrays are ignored. |
|
|
|
|
|
Example: |
|
|
>>> x = mx.array(1.0) |
|
|
>>> y = mx.exp(x) |
|
|
>>> mx.async_eval(y) |
|
|
>>> print(y) |
|
|
>>> |
|
|
>>> y = mx.exp(x) |
|
|
>>> mx.async_eval(y) |
|
|
>>> z = y + 3 |
|
|
>>> mx.async_eval(z) |
|
|
>>> print(z) |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"jvp", |
|
|
[](const nb::callable& fun, |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& tangents) { |
|
|
auto vfun = [&fun](const std::vector<mx::array>& primals) { |
|
|
auto out = fun(*nb::cast(primals)); |
|
|
if (nb::isinstance<mx::array>(out)) { |
|
|
return std::vector<mx::array>{nb::cast<mx::array>(out)}; |
|
|
} else { |
|
|
return nb::cast<std::vector<mx::array>>(out); |
|
|
} |
|
|
}; |
|
|
return jvp(vfun, primals, tangents); |
|
|
}, |
|
|
"fun"_a, |
|
|
"primals"_a, |
|
|
"tangents"_a, |
|
|
nb::sig( |
|
|
"def jvp(fun: Callable, primals: list[array], tangents: list[array]) -> tuple[list[array], list[array]]"), |
|
|
R"pbdoc( |
|
|
Compute the Jacobian-vector product. |
|
|
|
|
|
This computes the product of the Jacobian of a function ``fun`` evaluated |
|
|
at ``primals`` with the ``tangents``. |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of :class:`array` |
|
|
and returns a single :class:`array` or list of :class:`array`. |
|
|
primals (list(array)): A list of :class:`array` at which to |
|
|
evaluate the Jacobian. |
|
|
tangents (list(array)): A list of :class:`array` which are the |
|
|
"vector" in the Jacobian-vector product. The ``tangents`` should be the |
|
|
same in number, shape, and type as the inputs of ``fun`` (i.e. the ``primals``). |
|
|
|
|
|
Returns: |
|
|
list(array): A list of the Jacobian-vector products which |
|
|
is the same in number, shape, and type of the inputs to ``fun``. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"vjp", |
|
|
[](const nb::callable& fun, |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& cotangents) { |
|
|
auto vfun = [&fun](const std::vector<mx::array>& primals) { |
|
|
auto out = fun(*nb::cast(primals)); |
|
|
if (nb::isinstance<mx::array>(out)) { |
|
|
return std::vector<mx::array>{nb::cast<mx::array>(out)}; |
|
|
} else { |
|
|
return nb::cast<std::vector<mx::array>>(out); |
|
|
} |
|
|
}; |
|
|
return vjp(vfun, primals, cotangents); |
|
|
}, |
|
|
"fun"_a, |
|
|
"primals"_a, |
|
|
"cotangents"_a, |
|
|
nb::sig( |
|
|
"def vjp(fun: Callable, primals: list[array], cotangents: list[array]) -> tuple[list[array], list[array]]"), |
|
|
R"pbdoc( |
|
|
Compute the vector-Jacobian product. |
|
|
|
|
|
Computes the product of the ``cotangents`` with the Jacobian of a |
|
|
function ``fun`` evaluated at ``primals``. |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of :class:`array` |
|
|
and returns a single :class:`array` or list of :class:`array`. |
|
|
primals (list(array)): A list of :class:`array` at which to |
|
|
evaluate the Jacobian. |
|
|
cotangents (list(array)): A list of :class:`array` which are the |
|
|
"vector" in the vector-Jacobian product. The ``cotangents`` should be the |
|
|
same in number, shape, and type as the outputs of ``fun``. |
|
|
|
|
|
Returns: |
|
|
list(array): A list of the vector-Jacobian products which |
|
|
is the same in number, shape, and type of the outputs of ``fun``. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"value_and_grad", |
|
|
[](const nb::callable& fun, |
|
|
const std::optional<IntOrVec>& argnums, |
|
|
const StrOrSet& argnames) { |
|
|
auto [argnums_vec, argnames_set] = |
|
|
validate_argnums_argnames(argnums, argnames); |
|
|
return mlx_func( |
|
|
py_value_and_grad( |
|
|
fun, argnums_vec, argnames_set, "[value_and_grad]", false), |
|
|
fun); |
|
|
}, |
|
|
"fun"_a, |
|
|
"argnums"_a = nb::none(), |
|
|
"argnames"_a = std::vector<std::string>{}, |
|
|
nb::sig( |
|
|
"def value_and_grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), |
|
|
R"pbdoc( |
|
|
Returns a function which computes the value and gradient of ``fun``. |
|
|
|
|
|
The function passed to :func:`value_and_grad` should return either |
|
|
a scalar loss or a tuple in which the first element is a scalar |
|
|
loss and the remaining elements can be anything. |
|
|
|
|
|
.. code-block:: python |
|
|
|
|
|
import mlx.core as mx |
|
|
|
|
|
def mse(params, inputs, targets): |
|
|
outputs = forward(params, inputs) |
|
|
lvalue = (outputs - targets).square().mean() |
|
|
return lvalue |
|
|
|
|
|
# Returns lvalue, dlvalue/dparams |
|
|
lvalue, grads = mx.value_and_grad(mse)(params, inputs, targets) |
|
|
|
|
|
def lasso(params, inputs, targets, a=1.0, b=1.0): |
|
|
outputs = forward(params, inputs) |
|
|
mse = (outputs - targets).square().mean() |
|
|
l1 = mx.abs(outputs - targets).mean() |
|
|
|
|
|
loss = a*mse + b*l1 |
|
|
|
|
|
return loss, mse, l1 |
|
|
|
|
|
(loss, mse, l1), grads = mx.value_and_grad(lasso)(params, inputs, targets) |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of |
|
|
:class:`array` or trees of :class:`array` and returns |
|
|
a scalar output :class:`array` or a tuple the first element |
|
|
of which should be a scalar :class:`array`. |
|
|
argnums (int or list(int), optional): Specify the index (or indices) |
|
|
of the positional arguments of ``fun`` to compute the gradient |
|
|
with respect to. If neither ``argnums`` nor ``argnames`` are |
|
|
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first |
|
|
argument. |
|
|
argnames (str or list(str), optional): Specify keyword arguments of |
|
|
``fun`` to compute gradients with respect to. It defaults to [] so |
|
|
no gradients for keyword arguments by default. |
|
|
|
|
|
Returns: |
|
|
Callable: A function which returns a tuple where the first element |
|
|
is the output of `fun` and the second element is the gradients w.r.t. |
|
|
the loss. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"grad", |
|
|
[](const nb::callable& fun, |
|
|
const std::optional<IntOrVec>& argnums, |
|
|
const StrOrSet& argnames) { |
|
|
auto [argnums_vec, argnames_set] = |
|
|
validate_argnums_argnames(argnums, argnames); |
|
|
auto fn = |
|
|
py_value_and_grad(fun, argnums_vec, argnames_set, "[grad]", true); |
|
|
return mlx_func( |
|
|
[fn = std::move(fn)](nb::args& args, nb::kwargs& kwargs) { |
|
|
return fn(args, kwargs).second; |
|
|
}, |
|
|
fun); |
|
|
}, |
|
|
"fun"_a, |
|
|
"argnums"_a = nb::none(), |
|
|
"argnames"_a = std::vector<std::string>{}, |
|
|
nb::sig( |
|
|
"def grad(fun: Callable, argnums: Optional[Union[int, Sequence[int]]] = None, argnames: Union[str, Sequence[str]] = []) -> Callable"), |
|
|
R"pbdoc( |
|
|
Returns a function which computes the gradient of ``fun``. |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of |
|
|
:class:`array` or trees of :class:`array` and returns |
|
|
a scalar output :class:`array`. |
|
|
argnums (int or list(int), optional): Specify the index (or indices) |
|
|
of the positional arguments of ``fun`` to compute the gradient |
|
|
with respect to. If neither ``argnums`` nor ``argnames`` are |
|
|
provided ``argnums`` defaults to ``0`` indicating ``fun``'s first |
|
|
argument. |
|
|
argnames (str or list(str), optional): Specify keyword arguments of |
|
|
``fun`` to compute gradients with respect to. It defaults to [] so |
|
|
no gradients for keyword arguments by default. |
|
|
|
|
|
Returns: |
|
|
Callable: A function which has the same input arguments as ``fun`` and |
|
|
returns the gradient(s). |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"vmap", |
|
|
[](const nb::callable& fun, |
|
|
const nb::object& in_axes, |
|
|
const nb::object& out_axes) { |
|
|
return mlx_func( |
|
|
py_vmap(fun, in_axes, out_axes), fun, in_axes, out_axes); |
|
|
}, |
|
|
"fun"_a, |
|
|
"in_axes"_a = 0, |
|
|
"out_axes"_a = 0, |
|
|
nb::sig( |
|
|
"def vmap(fun: Callable, in_axes: object = 0, out_axes: object = 0) -> Callable"), |
|
|
R"pbdoc( |
|
|
Returns a vectorized version of ``fun``. |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of |
|
|
:class:`array` or a tree of :class:`array` and returns |
|
|
a variable number of :class:`array` or a tree of :class:`array`. |
|
|
in_axes (int, optional): An integer or a valid prefix tree of the |
|
|
inputs to ``fun`` where each node specifies the vmapped axis. If |
|
|
the value is ``None`` then the corresponding input(s) are not vmapped. |
|
|
Defaults to ``0``. |
|
|
out_axes (int, optional): An integer or a valid prefix tree of the |
|
|
outputs of ``fun`` where each node specifies the vmapped axis. If |
|
|
the value is ``None`` then the corresponding outputs(s) are not vmapped. |
|
|
Defaults to ``0``. |
|
|
|
|
|
Returns: |
|
|
Callable: The vectorized function. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"compile", |
|
|
[](const nb::callable& fun, |
|
|
const nb::object& inputs, |
|
|
const nb::object& outputs, |
|
|
bool shapeless) { |
|
|
return mlx_func( |
|
|
nb::cpp_function(PyCompiledFun{fun, inputs, outputs, shapeless}), |
|
|
fun, |
|
|
inputs, |
|
|
outputs); |
|
|
}, |
|
|
"fun"_a, |
|
|
"inputs"_a = nb::none(), |
|
|
"outputs"_a = nb::none(), |
|
|
"shapeless"_a = false, |
|
|
nb::sig( |
|
|
"def compile(fun: Callable, inputs: Optional[object] = None, outputs: Optional[object] = None, shapeless: bool = False) -> Callable"), |
|
|
R"pbdoc( |
|
|
Returns a compiled function which produces the same output as ``fun``. |
|
|
|
|
|
Args: |
|
|
fun (Callable): A function which takes a variable number of |
|
|
:class:`array` or trees of :class:`array` and returns |
|
|
a variable number of :class:`array` or trees of :class:`array`. |
|
|
inputs (list or dict, optional): These inputs will be captured during |
|
|
the function compilation along with the inputs to ``fun``. The ``inputs`` |
|
|
can be a :obj:`list` or a :obj:`dict` containing arbitrarily nested |
|
|
lists, dictionaries, or arrays. Leaf nodes that are not |
|
|
:obj:`array` are ignored. Default: ``None`` |
|
|
outputs (list or dict, optional): These outputs will be captured and |
|
|
updated in a compiled function. The ``outputs`` can be a |
|
|
:obj:`list` or a :obj:`dict` containing arbitrarily nested lists, |
|
|
dictionaries, or arrays. Leaf nodes that are not :obj:`array` are ignored. |
|
|
Default: ``None`` |
|
|
shapeless (bool, optional): A function compiled with the ``shapeless`` |
|
|
option enabled will not be recompiled when the input shape changes. Not all |
|
|
functions can be compiled with ``shapeless`` enabled. Attempting to compile |
|
|
such functions with shapeless enabled will throw. Note, changing the number |
|
|
of dimensions or type of any input will result in a recompilation even with |
|
|
``shapeless`` set to ``True``. Default: ``False`` |
|
|
|
|
|
Returns: |
|
|
Callable: A compiled function which has the same input arguments |
|
|
as ``fun`` and returns the the same output(s). |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"disable_compile", |
|
|
&mx::disable_compile, |
|
|
R"pbdoc( |
|
|
Globally disable compilation. Setting the environment variable |
|
|
``MLX_DISABLE_COMPILE`` can also be used to disable compilation. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"enable_compile", |
|
|
&mx::enable_compile, |
|
|
R"pbdoc( |
|
|
Globally enable compilation. This will override the environment |
|
|
variable ``MLX_DISABLE_COMPILE`` if set. |
|
|
)pbdoc"); |
|
|
m.def( |
|
|
"checkpoint", |
|
|
[](nb::callable fun) { return mlx_func(PyCheckpointedFun{fun}, fun); }, |
|
|
"fun"_a); |
|
|
|
|
|
|
|
|
auto atexit = nb::module_::import_("atexit"); |
|
|
atexit.attr("register")( |
|
|
nb::cpp_function([]() { mx::detail::compile_clear_cache(); })); |
|
|
} |
|
|
|