| | |
| |
|
| | #include "python/src/trees.h" |
| |
|
| | template <typename T, typename U, typename V> |
| | void validate_subtrees(const std::vector<nb::object>& subtrees) { |
| | int len = nb::cast<T>(subtrees[0]).size(); |
| | for (auto& subtree : subtrees) { |
| | if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) || |
| | nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) { |
| | throw std::invalid_argument( |
| | "[tree_map] Additional input tree is not a valid prefix of the first tree."); |
| | } |
| | } |
| | } |
| |
|
| | nb::object tree_map( |
| | const std::vector<nb::object>& trees, |
| | std::function<nb::object(const std::vector<nb::object>&)> transform) { |
| | std::function<nb::object(const std::vector<nb::object>&)> recurse; |
| |
|
| | recurse = [&](const std::vector<nb::object>& subtrees) { |
| | if (nb::isinstance<nb::list>(subtrees[0])) { |
| | nb::list l; |
| | std::vector<nb::object> items(subtrees.size()); |
| | validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); |
| | for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::list>(subtrees[j])) { |
| | items[j] = nb::cast<nb::list>(subtrees[j])[i]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | l.append(recurse(items)); |
| | } |
| | return nb::cast<nb::object>(l); |
| | } else if (nb::isinstance<nb::tuple>(subtrees[0])) { |
| | |
| | std::vector<nb::object> items(subtrees.size()); |
| | int len = nb::cast<nb::tuple>(subtrees[0]).size(); |
| | nb::list l; |
| | validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); |
| | for (int i = 0; i < len; ++i) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::tuple>(subtrees[j])) { |
| | items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | l.append(recurse(items)); |
| | } |
| | return nb::cast<nb::object>(nb::tuple(l)); |
| | } else if (nb::isinstance<nb::dict>(subtrees[0])) { |
| | std::vector<nb::object> items(subtrees.size()); |
| | validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); |
| | nb::dict d; |
| | for (auto item : nb::cast<nb::dict>(subtrees[0])) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::dict>(subtrees[j])) { |
| | auto subdict = nb::cast<nb::dict>(subtrees[j]); |
| | if (!subdict.contains(item.first)) { |
| | throw std::invalid_argument( |
| | "[tree_map] Tree is not a valid prefix tree of the first tree."); |
| | } |
| | items[j] = subdict[item.first]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | d[item.first] = recurse(items); |
| | } |
| | return nb::cast<nb::object>(d); |
| | } else { |
| | return transform(subtrees); |
| | } |
| | }; |
| | return recurse(trees); |
| | } |
| |
|
| | nb::object tree_map( |
| | nb::object tree, |
| | std::function<nb::object(nb::handle)> transform) { |
| | return tree_map({tree}, [&](std::vector<nb::object> inputs) { |
| | return transform(inputs[0]); |
| | }); |
| | } |
| |
|
| | void tree_visit( |
| | const std::vector<nb::object>& trees, |
| | std::function<void(const std::vector<nb::object>&)> visitor) { |
| | std::function<void(const std::vector<nb::object>&)> recurse; |
| |
|
| | recurse = [&](const std::vector<nb::object>& subtrees) { |
| | if (nb::isinstance<nb::list>(subtrees[0])) { |
| | std::vector<nb::object> items(subtrees.size()); |
| | validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); |
| | for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::list>(subtrees[j])) { |
| | items[j] = nb::cast<nb::list>(subtrees[j])[i]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | recurse(items); |
| | } |
| | } else if (nb::isinstance<nb::tuple>(subtrees[0])) { |
| | |
| | std::vector<nb::object> items(subtrees.size()); |
| | int len = nb::cast<nb::tuple>(subtrees[0]).size(); |
| | validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); |
| | for (int i = 0; i < len; ++i) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::tuple>(subtrees[j])) { |
| | items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | recurse(items); |
| | } |
| | } else if (nb::isinstance<nb::dict>(subtrees[0])) { |
| | std::vector<nb::object> items(subtrees.size()); |
| | validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); |
| | for (auto item : nb::cast<nb::dict>(subtrees[0])) { |
| | for (int j = 0; j < subtrees.size(); ++j) { |
| | if (nb::isinstance<nb::dict>(subtrees[j])) { |
| | auto subdict = nb::cast<nb::dict>(subtrees[j]); |
| | if (!subdict.contains(item.first)) { |
| | throw std::invalid_argument( |
| | "[tree_visit] Tree is not a valid prefix tree of the first tree."); |
| | } |
| | items[j] = subdict[item.first]; |
| | } else { |
| | items[j] = subtrees[j]; |
| | } |
| | } |
| | recurse(items); |
| | } |
| | } else { |
| | visitor(subtrees); |
| | } |
| | }; |
| | return recurse(trees); |
| | } |
| |
|
| | void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) { |
| | std::function<void(nb::handle)> recurse; |
| | recurse = [&](nb::handle subtree) { |
| | if (nb::isinstance<nb::list>(subtree) || |
| | nb::isinstance<nb::tuple>(subtree)) { |
| | for (auto item : subtree) { |
| | recurse(item); |
| | } |
| | } else if (nb::isinstance<nb::dict>(subtree)) { |
| | for (auto item : nb::cast<nb::dict>(subtree)) { |
| | recurse(item.second); |
| | } |
| | } else { |
| | visitor(subtree); |
| | } |
| | }; |
| |
|
| | recurse(tree); |
| | } |
| |
|
| | void tree_visit_update( |
| | nb::object tree, |
| | std::function<nb::object(nb::handle)> visitor) { |
| | std::function<nb::object(nb::handle)> recurse; |
| | recurse = [&](nb::handle subtree) { |
| | if (nb::isinstance<nb::list>(subtree)) { |
| | auto l = nb::cast<nb::list>(subtree); |
| | for (int i = 0; i < l.size(); ++i) { |
| | l[i] = recurse(l[i]); |
| | } |
| | return nb::cast<nb::object>(l); |
| | } else if (nb::isinstance<nb::tuple>(subtree)) { |
| | nb::list l(subtree); |
| | for (int i = 0; i < l.size(); ++i) { |
| | l[i] = recurse(l[i]); |
| | } |
| | return nb::cast<nb::object>(nb::tuple(l)); |
| | } else if (nb::isinstance<nb::dict>(subtree)) { |
| | auto d = nb::cast<nb::dict>(subtree); |
| | for (auto item : d) { |
| | d[item.first] = recurse(item.second); |
| | } |
| | return nb::cast<nb::object>(d); |
| | } else if (nb::isinstance<mx::array>(subtree)) { |
| | return visitor(subtree); |
| | } else { |
| | return nb::cast<nb::object>(subtree); |
| | } |
| | }; |
| | recurse(tree); |
| | } |
| |
|
| | |
| | |
| | |
| | void tree_fill(nb::object& tree, const std::vector<mx::array>& values) { |
| | size_t index = 0; |
| | tree_visit_update( |
| | tree, [&](nb::handle node) { return nb::cast(values[index++]); }); |
| | } |
| |
|
| | |
| | void tree_replace( |
| | nb::object& tree, |
| | const std::vector<mx::array>& src, |
| | const std::vector<mx::array>& dst) { |
| | std::unordered_map<uintptr_t, mx::array> src_to_dst; |
| | for (int i = 0; i < src.size(); ++i) { |
| | src_to_dst.insert({src[i].id(), dst[i]}); |
| | } |
| | tree_visit_update(tree, [&](nb::handle node) { |
| | auto arr = nb::cast<mx::array>(node); |
| | if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { |
| | return nb::cast(it->second); |
| | } |
| | return nb::cast(arr); |
| | }); |
| | } |
| |
|
| | std::vector<mx::array> tree_flatten(nb::handle tree, bool strict ) { |
| | std::vector<mx::array> flat_tree; |
| |
|
| | tree_visit(tree, [&](nb::handle obj) { |
| | if (nb::isinstance<mx::array>(obj)) { |
| | flat_tree.push_back(nb::cast<mx::array>(obj)); |
| | } else if (strict) { |
| | throw std::invalid_argument( |
| | "[tree_flatten] The argument should contain only arrays"); |
| | } |
| | }); |
| |
|
| | return flat_tree; |
| | } |
| |
|
| | nb::object tree_unflatten( |
| | nb::object tree, |
| | const std::vector<mx::array>& values, |
| | int index ) { |
| | return tree_map(tree, [&](nb::handle obj) { |
| | if (nb::isinstance<mx::array>(obj)) { |
| | return nb::cast(values[index++]); |
| | } else { |
| | return nb::cast<nb::object>(obj); |
| | } |
| | }); |
| | } |
| |
|
| | nb::object structure_sentinel() { |
| | static nb::object sentinel; |
| |
|
| | if (sentinel.ptr() == nullptr) { |
| | sentinel = nb::capsule(&sentinel); |
| | |
| | |
| | sentinel.inc_ref(); |
| | } |
| |
|
| | return sentinel; |
| | } |
| |
|
| | std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure( |
| | nb::object tree, |
| | bool strict ) { |
| | auto sentinel = structure_sentinel(); |
| | std::vector<mx::array> flat_tree; |
| | auto structure = tree_map( |
| | tree, |
| | [&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { |
| | if (nb::isinstance<mx::array>(obj)) { |
| | flat_tree.push_back(nb::cast<mx::array>(obj)); |
| | return sentinel; |
| | } else if (!strict) { |
| | return nb::cast<nb::object>(obj); |
| | } else { |
| | throw std::invalid_argument( |
| | "[tree_flatten] The argument should contain only arrays"); |
| | } |
| | }); |
| |
|
| | return {flat_tree, structure}; |
| | } |
| |
|
| | nb::object tree_unflatten_from_structure( |
| | nb::object structure, |
| | const std::vector<mx::array>& values, |
| | int index ) { |
| | auto sentinel = structure_sentinel(); |
| | return tree_map(structure, [&](nb::handle obj) { |
| | if (obj.is(sentinel)) { |
| | return nb::cast(values[index++]); |
| | } else { |
| | return nb::cast<nb::object>(obj); |
| | } |
| | }); |
| | } |
| |
|