|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility functions for OpTree.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
from collections.abc import Iterable, Sequence |
|
|
from typing import TYPE_CHECKING, Any, Callable, overload |
|
|
|
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from optree.typing import S, T, U |
|
|
|
|
|
|
|
|
def total_order_sorted( |
|
|
iterable: Iterable[T], |
|
|
/, |
|
|
*, |
|
|
key: Callable[[T], Any] | None = None, |
|
|
reverse: bool = False, |
|
|
) -> list[T]: |
|
|
"""Sort an iterable in a total order. |
|
|
|
|
|
This is useful for sorting objects that are not comparable, e.g., dictionaries with different |
|
|
types of keys. |
|
|
""" |
|
|
sequence = list(iterable) |
|
|
|
|
|
try: |
|
|
|
|
|
return sorted(sequence, key=key, reverse=reverse) |
|
|
except TypeError: |
|
|
if key is None: |
|
|
|
|
|
def key_fn(x: T) -> tuple[str, Any]: |
|
|
return (f'{x.__class__.__module__}.{x.__class__.__qualname__}', x) |
|
|
|
|
|
else: |
|
|
|
|
|
def key_fn(x: T) -> tuple[str, Any]: |
|
|
y = key(x) |
|
|
return (f'{y.__class__.__module__}.{y.__class__.__qualname__}', y) |
|
|
|
|
|
try: |
|
|
|
|
|
|
|
|
return sorted(sequence, key=key_fn, reverse=reverse) |
|
|
except TypeError: |
|
|
return sequence |
|
|
|
|
|
|
|
|
@overload |
|
|
def safe_zip( |
|
|
iter1: Iterable[T], |
|
|
/, |
|
|
) -> zip[tuple[T]]: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def safe_zip( |
|
|
iter1: Iterable[T], |
|
|
iter2: Iterable[S], |
|
|
/, |
|
|
) -> zip[tuple[T, S]]: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def safe_zip( |
|
|
iter1: Iterable[T], |
|
|
iter2: Iterable[S], |
|
|
iter3: Iterable[U], |
|
|
/, |
|
|
) -> zip[tuple[T, S, U]]: ... |
|
|
|
|
|
|
|
|
@overload |
|
|
def safe_zip( |
|
|
iter1: Iterable[Any], |
|
|
iter2: Iterable[Any], |
|
|
iter3: Iterable[Any], |
|
|
iter4: Iterable[Any], |
|
|
/, |
|
|
*iters: Iterable[Any], |
|
|
) -> zip[tuple[Any, ...]]: ... |
|
|
|
|
|
|
|
|
def safe_zip(*args: Iterable[Any]) -> zip[tuple[Any, ...]]: |
|
|
"""Strict zip that requires all arguments to be the same length.""" |
|
|
seqs = [arg if isinstance(arg, Sequence) else list(arg) for arg in args] |
|
|
if len(set(map(len, seqs))) > 1: |
|
|
raise ValueError(f'length mismatch: {list(map(len, seqs))}') |
|
|
return zip(*seqs) |
|
|
|
|
|
|
|
|
def unzip2(xys: Iterable[tuple[T, S]], /) -> tuple[tuple[T, ...], tuple[S, ...]]: |
|
|
"""Unzip sequence of length-2 tuples into two tuples.""" |
|
|
|
|
|
|
|
|
|
|
|
xs = [] |
|
|
ys = [] |
|
|
for x, y in xys: |
|
|
xs.append(x) |
|
|
ys.append(y) |
|
|
return tuple(xs), tuple(ys) |
|
|
|