|
|
|
|
|
#pragma once |
|
|
#include <nanobind/nanobind.h> |
|
|
|
|
|
#include "mlx/array.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
|
|
|
void tree_visit( |
|
|
const std::vector<nb::object>& trees, |
|
|
std::function<void(const std::vector<nb::object>&)> visitor); |
|
|
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor); |
|
|
|
|
|
nb::object tree_map( |
|
|
const std::vector<nb::object>& trees, |
|
|
std::function<nb::object(const std::vector<nb::object>&)> transform); |
|
|
|
|
|
nb::object tree_map( |
|
|
nb::object tree, |
|
|
std::function<nb::object(nb::handle)> transform); |
|
|
|
|
|
void tree_visit_update( |
|
|
nb::object tree, |
|
|
std::function<nb::object(nb::handle)> visitor); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void tree_fill(nb::object& tree, const std::vector<mx::array>& values); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void tree_replace( |
|
|
nb::object& tree, |
|
|
const std::vector<mx::array>& src, |
|
|
const std::vector<mx::array>& dst); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict = true); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
nb::object tree_unflatten( |
|
|
nb::object tree, |
|
|
const std::vector<mx::array>& values, |
|
|
int index = 0); |
|
|
|
|
|
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure( |
|
|
nb::object tree, |
|
|
bool strict = true); |
|
|
|
|
|
nb::object tree_unflatten_from_structure( |
|
|
nb::object structure, |
|
|
const std::vector<mx::array>& values, |
|
|
int index = 0); |
|
|
|