|
|
from typing import NamedTuple, Callable, Any, Tuple, List, Dict, Type, cast, Optional, TypeVar, overload, Union |
|
|
import functools |
|
|
from collections import namedtuple, OrderedDict |
|
|
|
|
|
T = TypeVar('T') |
|
|
S = TypeVar('S') |
|
|
R = TypeVar('R') |
|
|
|
|
|
""" |
|
|
Contains utility functions for working with nested python data structures. |
|
|
|
|
|
A *pytree* is Python nested data structure. It is a tree in the sense that |
|
|
nodes are Python collections (e.g., list, tuple, dict) and the leaves are |
|
|
Python values. Furthermore, a pytree should not contain reference cycles. |
|
|
|
|
|
pytrees are useful for working with nested collections of Tensors. For example, |
|
|
one can use `tree_map` to map a function over all Tensors inside some nested |
|
|
collection of Tensors and `tree_unflatten` to get a flat list of all Tensors |
|
|
inside some nested collection. pytrees are helpful for implementing nested |
|
|
collection support for PyTorch APIs. |
|
|
|
|
|
This pytree implementation is not very performant due to Python overhead |
|
|
To improve the performance we can move parts of the implementation to C++. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Context = Any |
|
|
PyTree = Any |
|
|
FlattenFunc = Callable[[PyTree], Tuple[List, Context]] |
|
|
UnflattenFunc = Callable[[List, Context], PyTree] |
|
|
|
|
|
class NodeDef(NamedTuple): |
|
|
flatten_fn: FlattenFunc |
|
|
unflatten_fn: UnflattenFunc |
|
|
|
|
|
SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} |
|
|
|
|
|
def _register_pytree_node(typ: Any, flatten_fn: FlattenFunc, unflatten_fn: UnflattenFunc) -> None: |
|
|
SUPPORTED_NODES[typ] = NodeDef(flatten_fn, unflatten_fn) |
|
|
|
|
|
def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: |
|
|
return list(d.values()), list(d.keys()) |
|
|
|
|
|
def _dict_unflatten(values: List[Any], context: Context) -> Dict[Any, Any]: |
|
|
return {key: value for key, value in zip(context, values)} |
|
|
|
|
|
def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: |
|
|
return d, None |
|
|
|
|
|
def _list_unflatten(values: List[Any], context: Context) -> List[Any]: |
|
|
return list(values) |
|
|
|
|
|
def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: |
|
|
return list(d), None |
|
|
|
|
|
def _tuple_unflatten(values: List[Any], context: Context) -> Tuple[Any, ...]: |
|
|
return tuple(values) |
|
|
|
|
|
def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: |
|
|
return list(d), type(d) |
|
|
|
|
|
def _namedtuple_unflatten(values: List[Any], context: Context) -> NamedTuple: |
|
|
return cast(NamedTuple, context(*values)) |
|
|
|
|
|
def _odict_flatten(d: 'OrderedDict[Any, Any]') -> Tuple[List[Any], Context]: |
|
|
return list(d.values()), list(d.keys()) |
|
|
|
|
|
def _odict_unflatten(values: List[Any], context: Context) -> 'OrderedDict[Any, Any]': |
|
|
return OrderedDict((key, value) for key, value in zip(context, values)) |
|
|
|
|
|
|
|
|
_register_pytree_node(dict, _dict_flatten, _dict_unflatten) |
|
|
_register_pytree_node(list, _list_flatten, _list_unflatten) |
|
|
_register_pytree_node(tuple, _tuple_flatten, _tuple_unflatten) |
|
|
_register_pytree_node(namedtuple, _namedtuple_flatten, _namedtuple_unflatten) |
|
|
_register_pytree_node(OrderedDict, _odict_flatten, _odict_unflatten) |
|
|
|
|
|
|
|
|
|
|
|
def _is_namedtuple_instance(pytree: Any) -> bool: |
|
|
typ = type(pytree) |
|
|
bases = typ.__bases__ |
|
|
if len(bases) != 1 or bases[0] != tuple: |
|
|
return False |
|
|
fields = getattr(typ, '_fields', None) |
|
|
if not isinstance(fields, tuple): |
|
|
return False |
|
|
return all(type(entry) == str for entry in fields) |
|
|
|
|
|
def _get_node_type(pytree: Any) -> Any: |
|
|
if _is_namedtuple_instance(pytree): |
|
|
return namedtuple |
|
|
return type(pytree) |
|
|
|
|
|
|
|
|
def _is_leaf(pytree: PyTree) -> bool: |
|
|
return _get_node_type(pytree) not in SUPPORTED_NODES.keys() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TreeSpec: |
|
|
def __init__(self, typ: Any, context: Context, children_specs: List['TreeSpec']) -> None: |
|
|
self.type = typ |
|
|
self.context = context |
|
|
self.children_specs = children_specs |
|
|
self.num_leaves: int = sum([spec.num_leaves for spec in children_specs]) |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return f'TreeSpec({self.type.__name__}, {self.context}, {self.children_specs})' |
|
|
|
|
|
def __eq__(self, other: Any) -> bool: |
|
|
result = self.type == other.type and self.context == other.context \ |
|
|
and self.children_specs == other.children_specs \ |
|
|
and self.num_leaves == other.num_leaves |
|
|
|
|
|
return cast(bool, result) |
|
|
|
|
|
def __ne__(self, other: Any) -> bool: |
|
|
return not self.__eq__(other) |
|
|
|
|
|
class LeafSpec(TreeSpec): |
|
|
def __init__(self) -> None: |
|
|
super().__init__(None, None, []) |
|
|
self.num_leaves = 1 |
|
|
|
|
|
def __repr__(self) -> str: |
|
|
return '*' |
|
|
|
|
|
def tree_flatten(pytree: PyTree) -> Tuple[List[Any], TreeSpec]: |
|
|
"""Flattens a pytree into a list of values and a TreeSpec that can be used |
|
|
to reconstruct the pytree. |
|
|
""" |
|
|
if _is_leaf(pytree): |
|
|
return [pytree], LeafSpec() |
|
|
|
|
|
node_type = _get_node_type(pytree) |
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn |
|
|
child_pytrees, context = flatten_fn(pytree) |
|
|
|
|
|
|
|
|
result : List[Any] = [] |
|
|
children_specs : List['TreeSpec'] = [] |
|
|
for child in child_pytrees: |
|
|
flat, child_spec = tree_flatten(child) |
|
|
result += flat |
|
|
children_specs.append(child_spec) |
|
|
|
|
|
return result, TreeSpec(node_type, context, children_specs) |
|
|
|
|
|
|
|
|
def tree_unflatten(values: List[Any], spec: TreeSpec) -> PyTree: |
|
|
"""Given a list of values and a TreeSpec, builds a pytree. |
|
|
This is the inverse operation of `tree_flatten`. |
|
|
""" |
|
|
if not isinstance(spec, TreeSpec): |
|
|
raise ValueError( |
|
|
f'tree_unflatten(values, spec): Expected `spec` to be instance of ' |
|
|
f'TreeSpec but got item of type {type(spec)}.') |
|
|
if len(values) != spec.num_leaves: |
|
|
raise ValueError( |
|
|
f'tree_unflatten(values, spec): `values` has length {len(values)} ' |
|
|
f'but the spec refers to a pytree that holds {spec.num_leaves} ' |
|
|
f'items ({spec}).') |
|
|
if isinstance(spec, LeafSpec): |
|
|
return values[0] |
|
|
|
|
|
unflatten_fn = SUPPORTED_NODES[spec.type].unflatten_fn |
|
|
|
|
|
|
|
|
start = 0 |
|
|
end = 0 |
|
|
child_pytrees = [] |
|
|
for child_spec in spec.children_specs: |
|
|
end += child_spec.num_leaves |
|
|
child_pytrees.append(tree_unflatten(values[start:end], child_spec)) |
|
|
start = end |
|
|
|
|
|
return unflatten_fn(child_pytrees, spec.context) |
|
|
|
|
|
def tree_map(fn: Any, pytree: PyTree) -> PyTree: |
|
|
flat_args, spec = tree_flatten(pytree) |
|
|
return tree_unflatten([fn(i) for i in flat_args], spec) |
|
|
|
|
|
Type2 = Tuple[Type[T], Type[S]] |
|
|
TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] |
|
|
|
|
|
Fn2 = Callable[[Union[T, S]], R] |
|
|
Fn = Callable[[T], R] |
|
|
FnAny = Callable[[Any], R] |
|
|
|
|
|
MapOnlyFn = Callable[[T], Callable[[Any], Any]] |
|
|
|
|
|
|
|
|
|
|
|
@overload |
|
|
def map_only(ty: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: |
|
|
... |
|
|
|
|
|
@overload |
|
|
def map_only(ty: Type[T]) -> MapOnlyFn[Fn[T, Any]]: |
|
|
... |
|
|
|
|
|
|
|
|
@overload |
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: |
|
|
... |
|
|
|
|
|
def map_only(ty: TypeAny) -> MapOnlyFn[FnAny[Any]]: |
|
|
""" |
|
|
Suppose you are writing a tree_map over tensors, leaving everything |
|
|
else unchanged. Ordinarily you would have to write: |
|
|
|
|
|
def go(t): |
|
|
if isinstance(t, Tensor): |
|
|
return ... |
|
|
else: |
|
|
return t |
|
|
|
|
|
With this function, you only need to write: |
|
|
|
|
|
@map_only(Tensor) |
|
|
def go(t): |
|
|
return ... |
|
|
|
|
|
You can also directly use 'tree_map_only' |
|
|
""" |
|
|
def deco(f: Callable[[T], Any]) -> Callable[[Any], Any]: |
|
|
@functools.wraps(f) |
|
|
def inner(x: T) -> Any: |
|
|
if isinstance(x, ty): |
|
|
return f(x) |
|
|
else: |
|
|
return x |
|
|
return inner |
|
|
return deco |
|
|
|
|
|
@overload |
|
|
def tree_map_only(ty: Type[T], fn: Fn[T, Any], pytree: PyTree) -> PyTree: |
|
|
... |
|
|
|
|
|
@overload |
|
|
def tree_map_only(ty: Type2[T, S], fn: Fn2[T, S, Any], pytree: PyTree) -> PyTree: |
|
|
... |
|
|
|
|
|
def tree_map_only(ty: TypeAny, fn: FnAny[Any], pytree: PyTree) -> PyTree: |
|
|
return tree_map(map_only(ty)(fn), pytree) |
|
|
|
|
|
def tree_all(pred: Callable[[Any], bool], pytree: PyTree) -> bool: |
|
|
flat_args, _ = tree_flatten(pytree) |
|
|
return all(map(pred, flat_args)) |
|
|
|
|
|
def tree_any(pred: Callable[[Any], bool], pytree: PyTree) -> bool: |
|
|
flat_args, _ = tree_flatten(pytree) |
|
|
return any(map(pred, flat_args)) |
|
|
|
|
|
@overload |
|
|
def tree_all_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: |
|
|
... |
|
|
|
|
|
@overload |
|
|
def tree_all_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: |
|
|
... |
|
|
|
|
|
def tree_all_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: |
|
|
flat_args, _ = tree_flatten(pytree) |
|
|
return all(pred(x) for x in flat_args if isinstance(x, ty)) |
|
|
|
|
|
@overload |
|
|
def tree_any_only(ty: Type[T], pred: Fn[T, bool], pytree: PyTree) -> bool: |
|
|
... |
|
|
|
|
|
@overload |
|
|
def tree_any_only(ty: Type2[T, S], pred: Fn2[T, S, bool], pytree: PyTree) -> bool: |
|
|
... |
|
|
|
|
|
def tree_any_only(ty: TypeAny, pred: FnAny[bool], pytree: PyTree) -> bool: |
|
|
flat_args, _ = tree_flatten(pytree) |
|
|
return any(pred(x) for x in flat_args if isinstance(x, ty)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _broadcast_to_and_flatten(pytree: PyTree, spec: TreeSpec) -> Optional[List[Any]]: |
|
|
assert isinstance(spec, TreeSpec) |
|
|
|
|
|
if _is_leaf(pytree): |
|
|
return [pytree] * spec.num_leaves |
|
|
if isinstance(spec, LeafSpec): |
|
|
return None |
|
|
node_type = _get_node_type(pytree) |
|
|
if node_type != spec.type: |
|
|
return None |
|
|
|
|
|
flatten_fn = SUPPORTED_NODES[node_type].flatten_fn |
|
|
child_pytrees, ctx = flatten_fn(pytree) |
|
|
|
|
|
|
|
|
if len(child_pytrees) != len(spec.children_specs) or ctx != spec.context: |
|
|
return None |
|
|
|
|
|
|
|
|
result : List[Any] = [] |
|
|
for child, child_spec in zip(child_pytrees, spec.children_specs): |
|
|
flat = _broadcast_to_and_flatten(child, child_spec) |
|
|
if flat is not None: |
|
|
result += flat |
|
|
else: |
|
|
return None |
|
|
|
|
|
return result |
|
|
|