|
|
|
|
|
|
|
|
#include "python/src/trees.h" |
|
|
|
|
|
template <typename T, typename U, typename V> |
|
|
void validate_subtrees(const std::vector<nb::object>& subtrees) { |
|
|
int len = nb::cast<T>(subtrees[0]).size(); |
|
|
for (auto& subtree : subtrees) { |
|
|
if ((nb::isinstance<T>(subtree) && nb::cast<T>(subtree).size() != len) || |
|
|
nb::isinstance<U>(subtree) || nb::isinstance<V>(subtree)) { |
|
|
throw std::invalid_argument( |
|
|
"[tree_map] Additional input tree is not a valid prefix of the first tree."); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
nb::object tree_map( |
|
|
const std::vector<nb::object>& trees, |
|
|
std::function<nb::object(const std::vector<nb::object>&)> transform) { |
|
|
std::function<nb::object(const std::vector<nb::object>&)> recurse; |
|
|
|
|
|
recurse = [&](const std::vector<nb::object>& subtrees) { |
|
|
if (nb::isinstance<nb::list>(subtrees[0])) { |
|
|
nb::list l; |
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); |
|
|
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::list>(subtrees[j])) { |
|
|
items[j] = nb::cast<nb::list>(subtrees[j])[i]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
l.append(recurse(items)); |
|
|
} |
|
|
return nb::cast<nb::object>(l); |
|
|
} else if (nb::isinstance<nb::tuple>(subtrees[0])) { |
|
|
|
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
int len = nb::cast<nb::tuple>(subtrees[0]).size(); |
|
|
nb::list l; |
|
|
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); |
|
|
for (int i = 0; i < len; ++i) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::tuple>(subtrees[j])) { |
|
|
items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
l.append(recurse(items)); |
|
|
} |
|
|
return nb::cast<nb::object>(nb::tuple(l)); |
|
|
} else if (nb::isinstance<nb::dict>(subtrees[0])) { |
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); |
|
|
nb::dict d; |
|
|
for (auto item : nb::cast<nb::dict>(subtrees[0])) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::dict>(subtrees[j])) { |
|
|
auto subdict = nb::cast<nb::dict>(subtrees[j]); |
|
|
if (!subdict.contains(item.first)) { |
|
|
throw std::invalid_argument( |
|
|
"[tree_map] Tree is not a valid prefix tree of the first tree."); |
|
|
} |
|
|
items[j] = subdict[item.first]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
d[item.first] = recurse(items); |
|
|
} |
|
|
return nb::cast<nb::object>(d); |
|
|
} else { |
|
|
return transform(subtrees); |
|
|
} |
|
|
}; |
|
|
return recurse(trees); |
|
|
} |
|
|
|
|
|
nb::object tree_map( |
|
|
nb::object tree, |
|
|
std::function<nb::object(nb::handle)> transform) { |
|
|
return tree_map({tree}, [&](std::vector<nb::object> inputs) { |
|
|
return transform(inputs[0]); |
|
|
}); |
|
|
} |
|
|
|
|
|
void tree_visit( |
|
|
const std::vector<nb::object>& trees, |
|
|
std::function<void(const std::vector<nb::object>&)> visitor) { |
|
|
std::function<void(const std::vector<nb::object>&)> recurse; |
|
|
|
|
|
recurse = [&](const std::vector<nb::object>& subtrees) { |
|
|
if (nb::isinstance<nb::list>(subtrees[0])) { |
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
validate_subtrees<nb::list, nb::tuple, nb::dict>(subtrees); |
|
|
for (int i = 0; i < nb::cast<nb::list>(subtrees[0]).size(); ++i) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::list>(subtrees[j])) { |
|
|
items[j] = nb::cast<nb::list>(subtrees[j])[i]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
recurse(items); |
|
|
} |
|
|
} else if (nb::isinstance<nb::tuple>(subtrees[0])) { |
|
|
|
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
int len = nb::cast<nb::tuple>(subtrees[0]).size(); |
|
|
validate_subtrees<nb::tuple, nb::list, nb::dict>(subtrees); |
|
|
for (int i = 0; i < len; ++i) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::tuple>(subtrees[j])) { |
|
|
items[j] = nb::cast<nb::tuple>(subtrees[j])[i]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
recurse(items); |
|
|
} |
|
|
} else if (nb::isinstance<nb::dict>(subtrees[0])) { |
|
|
std::vector<nb::object> items(subtrees.size()); |
|
|
validate_subtrees<nb::dict, nb::list, nb::tuple>(subtrees); |
|
|
for (auto item : nb::cast<nb::dict>(subtrees[0])) { |
|
|
for (int j = 0; j < subtrees.size(); ++j) { |
|
|
if (nb::isinstance<nb::dict>(subtrees[j])) { |
|
|
auto subdict = nb::cast<nb::dict>(subtrees[j]); |
|
|
if (!subdict.contains(item.first)) { |
|
|
throw std::invalid_argument( |
|
|
"[tree_visit] Tree is not a valid prefix tree of the first tree."); |
|
|
} |
|
|
items[j] = subdict[item.first]; |
|
|
} else { |
|
|
items[j] = subtrees[j]; |
|
|
} |
|
|
} |
|
|
recurse(items); |
|
|
} |
|
|
} else { |
|
|
visitor(subtrees); |
|
|
} |
|
|
}; |
|
|
return recurse(trees); |
|
|
} |
|
|
|
|
|
void tree_visit(nb::handle tree, std::function<void(nb::handle)> visitor) { |
|
|
std::function<void(nb::handle)> recurse; |
|
|
recurse = [&](nb::handle subtree) { |
|
|
if (nb::isinstance<nb::list>(subtree) || |
|
|
nb::isinstance<nb::tuple>(subtree)) { |
|
|
for (auto item : subtree) { |
|
|
recurse(item); |
|
|
} |
|
|
} else if (nb::isinstance<nb::dict>(subtree)) { |
|
|
for (auto item : nb::cast<nb::dict>(subtree)) { |
|
|
recurse(item.second); |
|
|
} |
|
|
} else { |
|
|
visitor(subtree); |
|
|
} |
|
|
}; |
|
|
|
|
|
recurse(tree); |
|
|
} |
|
|
|
|
|
void tree_visit_update( |
|
|
nb::object tree, |
|
|
std::function<nb::object(nb::handle)> visitor) { |
|
|
std::function<nb::object(nb::handle)> recurse; |
|
|
recurse = [&](nb::handle subtree) { |
|
|
if (nb::isinstance<nb::list>(subtree)) { |
|
|
auto l = nb::cast<nb::list>(subtree); |
|
|
for (int i = 0; i < l.size(); ++i) { |
|
|
l[i] = recurse(l[i]); |
|
|
} |
|
|
return nb::cast<nb::object>(l); |
|
|
} else if (nb::isinstance<nb::tuple>(subtree)) { |
|
|
nb::list l(subtree); |
|
|
for (int i = 0; i < l.size(); ++i) { |
|
|
l[i] = recurse(l[i]); |
|
|
} |
|
|
return nb::cast<nb::object>(nb::tuple(l)); |
|
|
} else if (nb::isinstance<nb::dict>(subtree)) { |
|
|
auto d = nb::cast<nb::dict>(subtree); |
|
|
for (auto item : d) { |
|
|
d[item.first] = recurse(item.second); |
|
|
} |
|
|
return nb::cast<nb::object>(d); |
|
|
} else if (nb::isinstance<mx::array>(subtree)) { |
|
|
return visitor(subtree); |
|
|
} else { |
|
|
return nb::cast<nb::object>(subtree); |
|
|
} |
|
|
}; |
|
|
recurse(tree); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void tree_fill(nb::object& tree, const std::vector<mx::array>& values) { |
|
|
size_t index = 0; |
|
|
tree_visit_update( |
|
|
tree, [&](nb::handle node) { return nb::cast(values[index++]); }); |
|
|
} |
|
|
|
|
|
|
|
|
void tree_replace( |
|
|
nb::object& tree, |
|
|
const std::vector<mx::array>& src, |
|
|
const std::vector<mx::array>& dst) { |
|
|
std::unordered_map<uintptr_t, mx::array> src_to_dst; |
|
|
for (int i = 0; i < src.size(); ++i) { |
|
|
src_to_dst.insert({src[i].id(), dst[i]}); |
|
|
} |
|
|
tree_visit_update(tree, [&](nb::handle node) { |
|
|
auto arr = nb::cast<mx::array>(node); |
|
|
if (auto it = src_to_dst.find(arr.id()); it != src_to_dst.end()) { |
|
|
return nb::cast(it->second); |
|
|
} |
|
|
return nb::cast(arr); |
|
|
}); |
|
|
} |
|
|
|
|
|
std::vector<mx::array> tree_flatten(nb::handle tree, bool strict ) { |
|
|
std::vector<mx::array> flat_tree; |
|
|
|
|
|
tree_visit(tree, [&](nb::handle obj) { |
|
|
if (nb::isinstance<mx::array>(obj)) { |
|
|
flat_tree.push_back(nb::cast<mx::array>(obj)); |
|
|
} else if (strict) { |
|
|
throw std::invalid_argument( |
|
|
"[tree_flatten] The argument should contain only arrays"); |
|
|
} |
|
|
}); |
|
|
|
|
|
return flat_tree; |
|
|
} |
|
|
|
|
|
nb::object tree_unflatten( |
|
|
nb::object tree, |
|
|
const std::vector<mx::array>& values, |
|
|
int index ) { |
|
|
return tree_map(tree, [&](nb::handle obj) { |
|
|
if (nb::isinstance<mx::array>(obj)) { |
|
|
return nb::cast(values[index++]); |
|
|
} else { |
|
|
return nb::cast<nb::object>(obj); |
|
|
} |
|
|
}); |
|
|
} |
|
|
|
|
|
nb::object structure_sentinel() { |
|
|
static nb::object sentinel; |
|
|
|
|
|
if (sentinel.ptr() == nullptr) { |
|
|
sentinel = nb::capsule(&sentinel); |
|
|
|
|
|
|
|
|
sentinel.inc_ref(); |
|
|
} |
|
|
|
|
|
return sentinel; |
|
|
} |
|
|
|
|
|
std::pair<std::vector<mx::array>, nb::object> tree_flatten_with_structure( |
|
|
nb::object tree, |
|
|
bool strict ) { |
|
|
auto sentinel = structure_sentinel(); |
|
|
std::vector<mx::array> flat_tree; |
|
|
auto structure = tree_map( |
|
|
tree, |
|
|
[&flat_tree, sentinel = std::move(sentinel), strict](nb::handle obj) { |
|
|
if (nb::isinstance<mx::array>(obj)) { |
|
|
flat_tree.push_back(nb::cast<mx::array>(obj)); |
|
|
return sentinel; |
|
|
} else if (!strict) { |
|
|
return nb::cast<nb::object>(obj); |
|
|
} else { |
|
|
throw std::invalid_argument( |
|
|
"[tree_flatten] The argument should contain only arrays"); |
|
|
} |
|
|
}); |
|
|
|
|
|
return {flat_tree, structure}; |
|
|
} |
|
|
|
|
|
nb::object tree_unflatten_from_structure( |
|
|
nb::object structure, |
|
|
const std::vector<mx::array>& values, |
|
|
int index ) { |
|
|
auto sentinel = structure_sentinel(); |
|
|
return tree_map(structure, [&](nb::handle obj) { |
|
|
if (obj.is(sentinel)) { |
|
|
return nb::cast(values[index++]); |
|
|
} else { |
|
|
return nb::cast<nb::object>(obj); |
|
|
} |
|
|
}); |
|
|
} |
|
|
|