| | import contextlib |
| | import functools as ft |
| | import inspect |
| | from typing import TypeAlias, TypeVar, cast |
| |
|
| | import beartype |
| | import jax |
| | import jax._src.tree_util as private_tree_util |
| | import jax.core |
| | from jaxtyping import Array |
| | from jaxtyping import ArrayLike |
| | from jaxtyping import Bool |
| | from jaxtyping import DTypeLike |
| | from jaxtyping import Float |
| | from jaxtyping import Int |
| | from jaxtyping import Key |
| | from jaxtyping import Num |
| | from jaxtyping import PyTree |
| | from jaxtyping import Real |
| | from jaxtyping import UInt8 |
| | from jaxtyping import config |
| | from jaxtyping import jaxtyped |
| | import jaxtyping._decorator |
| |
|
| | |
| |
|
| | KeyArrayLike: TypeAlias = jax.typing.ArrayLike |
| | Params: TypeAlias = PyTree[Float[ArrayLike, "..."]] |
| |
|
| | T = TypeVar("T") |
| |
|
| |
|
| | |
| | |
| | def typecheck(t: T) -> T: |
| | |
| | return t |
| |
|
| |
|
| | @contextlib.contextmanager |
| | def disable_typechecking(): |
| | initial = config.jaxtyping_disable |
| | config.update("jaxtyping_disable", True) |
| | yield |
| | config.update("jaxtyping_disable", initial) |
| |
|
| |
|
| | def check_pytree_equality(*, expected: PyTree, got: PyTree, check_shapes: bool = False, check_dtypes: bool = False): |
| | """Checks that two PyTrees have the same structure and optionally checks shapes and dtypes. Creates a much nicer |
| | error message than if `jax.tree.map` is naively used on PyTrees with different structures. |
| | """ |
| |
|
| | if errors := list(private_tree_util.equality_errors(expected, got)): |
| | raise ValueError( |
| | "PyTrees have different structure:\n" |
| | + ( |
| | "\n".join( |
| | f" - at keypath '{jax.tree_util.keystr(path)}': expected {thing1}, got {thing2}, so {explanation}.\n" |
| | for path, thing1, thing2, explanation in errors |
| | ) |
| | ) |
| | ) |
| |
|
| | if check_shapes or check_dtypes: |
| |
|
| | def check(kp, x, y): |
| | if check_shapes and x.shape != y.shape: |
| | raise ValueError(f"Shape mismatch at {jax.tree_util.keystr(kp)}: expected {x.shape}, got {y.shape}") |
| |
|
| | if check_dtypes and x.dtype != y.dtype: |
| | raise ValueError(f"Dtype mismatch at {jax.tree_util.keystr(kp)}: expected {x.dtype}, got {y.dtype}") |
| |
|
| | jax.tree_util.tree_map_with_path(check, expected, got) |
| |
|