| """ |
| AUTORAY - backend agnostic array operations. |
| |
| |
| Copyright 2019-2023 Johnnie Gray |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| """ |
|
|
| import contextlib |
| import functools |
| import importlib |
| import itertools |
| import math |
| import threading |
| from collections import OrderedDict, defaultdict |
| from inspect import signature |
|
|
|
|
| def do(fn, *args, like=None, **kwargs): |
| """Do function named ``fn`` on ``(*args, **kwargs)``, peforming single |
| dispatch to retrieve ``fn`` based on whichever library defines the class of |
| the ``args[0]``, or the ``like`` keyword argument if specified. |
| |
| Examples |
| -------- |
| |
| Works on numpy arrays: |
| |
| >>> import numpy as np |
| >>> x_np = np.random.uniform(size=[5]) |
| >>> y_np = do('sqrt', x_np) |
| >>> y_np |
| array([0.32464973, 0.90379787, 0.85037325, 0.88729814, 0.46768083]) |
| |
| >>> type(y_np) |
| numpy.ndarray |
| |
| Works on cupy arrays: |
| |
| >>> import cupy as cp |
| >>> x_cp = cp.random.uniform(size=[5]) |
| >>> y_cp = do('sqrt', x_cp) |
| >>> y_cp |
| array([0.44541656, 0.88713113, 0.92626237, 0.64080557, 0.69620767]) |
| |
| >>> type(y_cp) |
| cupy.core.core.ndarray |
| |
| Works on tensorflow arrays: |
| |
| >>> import tensorflow as tf |
| >>> x_tf = tf.random.uniform(shape=[5]) |
| >>> y_tf = do('sqrt', x_tf) |
| >>> y_tf |
| <tf.Tensor 'Sqrt_1:0' shape=(5,) dtype=float32> |
| |
| >>> type(y_tf) |
| tensorflow.python.framework.ops.Tensor |
| |
| You get the idea. |
| |
| For functions that don't dispatch on the first argument you can use the |
| ``like`` keyword: |
| |
| >>> do('eye', 3, like=x_tf) |
| <tf.Tensor: id=91, shape=(3, 3), dtype=float32> |
| """ |
| backend = _choose_backend(fn, args, kwargs, like=like) |
| func = get_lib_fn(backend, fn) |
| return func(*args, **kwargs) |
|
|
|
|
| |
|
|
|
|
| def _default_infer_from_sig(fn, *args, **kwargs): |
| """This is the default backend dispatcher, used if no global backend has |
| been set. Hot swapping this function out as below avoids having to check |
| manually for a global backend or worse, a thread aware global backend, on |
| every call to ``do``. |
| """ |
| return _DISPATCHERS[fn](*args, **kwargs) |
|
|
|
|
| _global_backend = None |
| _inferrer_global = _default_infer_from_sig |
|
|
| |
| |
| |
| _infer_auto = _inferrer_global |
|
|
| |
| |
| |
| _global_backends_threadaware = {} |
| _inferrers_threadaware = {} |
| _importing_thrid = threading.get_ident() |
| _backend_lock = threading.Lock() |
|
|
|
|
| def _default_infer_from_sig_threadaware(fn, args, kwargs): |
| |
| thrid = threading.get_ident() |
| return _inferrers_threadaware.get(thrid, _inferrer_global)( |
| fn, args, kwargs |
| ) |
|
|
|
|
| def _always_the_same(fn, args, kwargs, backend): |
| return backend |
|
|
|
|
| def get_backend(get_globally="auto"): |
| """Return the universally set backend, if any. |
| |
| Parameters |
| ---------- |
| get_globally : {"auto", False, True}, optional |
| Which backend to return: |
| |
| - True: return the globally set backend, if any. |
| - False: return the backend set for the current thread, if any. |
| - "auto": return the globally set backend, if this thread is the thread |
| that imported autoray. Otherwise return the backend set for the |
| current thread, if any. |
| |
| Returns |
| ------- |
| backend : str or None |
| The name of the backend, or None if no backend is set. |
| """ |
| if get_globally == "auto": |
| get_globally = threading.get_ident() == _importing_thrid |
|
|
| if get_globally: |
| backend = _global_backend |
| else: |
| thrid = threading.get_ident() |
| backend = _global_backends_threadaware.get(thrid, None) |
|
|
| return backend |
|
|
|
|
| def set_backend(like, set_globally="auto"): |
| """Set a default global backend. The argument ``like`` can be an explicit |
| backend name or an ``array``. |
| |
| Parameters |
| ---------- |
| like : str or array |
| The backend to set. If an array, the backend of the array's class will |
| be set. |
| set_globally : {"auto", False, True}, optional |
| Whether to set the backend globally or for the current thread: |
| |
| - True: set the backend globally. |
| - False: set the backend for the current thread. |
| - "auto": set the backend globally if this thread is the thread that |
| imported autoray. Otherwise set the backend for the current thread. |
| |
| Only one thread should ever call this function with |
| ``set_globally=True``, (by default this is importing thread). |
| """ |
| global _global_backend |
| global _infer_auto |
| global _inferrer_global |
|
|
| if like is None: |
| backend = None |
| inferrer = _default_infer_from_sig |
| elif isinstance(like, str): |
| backend = like |
| inferrer = functools.partial(_always_the_same, backend=backend) |
| else: |
| backend = _infer_class_backend_cached(like.__class__) |
| inferrer = functools.partial(_always_the_same, backend=backend) |
|
|
| if set_globally == "auto": |
| set_globally = threading.get_ident() == _importing_thrid |
|
|
| if set_globally: |
| _global_backend = backend |
| _inferrer_global = inferrer |
| if not _inferrers_threadaware: |
| |
| _infer_auto = inferrer |
| else: |
| thrid = threading.get_ident() |
| _backend_lock.acquire() |
| if backend is None: |
| _global_backends_threadaware.pop(thrid) |
| _inferrers_threadaware.pop(thrid) |
| else: |
| _global_backends_threadaware[thrid] = backend |
| _inferrers_threadaware[thrid] = inferrer |
|
|
| if _inferrers_threadaware: |
| |
| _infer_auto = _default_infer_from_sig_threadaware |
| else: |
| |
| |
| _infer_auto = _inferrer_global |
| _backend_lock.release() |
|
|
|
|
| @contextlib.contextmanager |
| def backend_like(like, set_globally="auto"): |
| """Context manager for setting a default backend. The argument ``like`` can |
| be an explicit backend name or an ``array`` to infer it from. |
| |
| Parameters |
| ---------- |
| like : str or array |
| The backend to set. If an array, the backend of the array's class will |
| be set. |
| set_globally : {"auto", False, True}, optional |
| Whether to set the backend globally or for the current thread: |
| |
| - True: set the backend globally. |
| - False: set the backend for the current thread. |
| - "auto": set the backend globally if this thread is the thread that |
| imported autoray. Otherwise set the backend for the current thread. |
| |
| Only one thread should ever call this function with |
| ``set_globally=True``, (by default this is importing thread). |
| """ |
| if set_globally == "auto": |
| set_globally = threading.get_ident() == _importing_thrid |
|
|
| old_backend = get_backend(get_globally=set_globally) |
| try: |
| set_backend(like, set_globally) |
| yield |
| finally: |
| set_backend(old_backend, set_globally) |
|
|
|
|
| _CUSTOM_BACKENDS = {} |
|
|
|
|
| def register_backend(cls, name): |
| """Register the name (and by default the module or submodule) of a custom |
| array class. |
| |
| Parameters |
| ---------- |
| cls : type |
| The array class itself. |
| name : str |
| The name of the backend that should be used for this class. By default |
| this wil be assumed to be the location of the relevant functions for |
| this class, but this can be overridden. |
| """ |
| if not isinstance(cls, type): |
| raise TypeError("The array class itself should be supplied.") |
|
|
| global _CUSTOM_BACKENDS |
| _CUSTOM_BACKENDS[cls] = name |
|
|
|
|
| @functools.lru_cache(None) |
| def _infer_class_backend_cached(cls): |
| try: |
| import numpy as _numpy |
|
|
| if issubclass(cls, _numpy.ndarray): |
| return "numpy" |
| except ImportError: |
| |
| pass |
|
|
| if cls in _CUSTOM_BACKENDS: |
| return _CUSTOM_BACKENDS[cls] |
|
|
| lib = cls.__module__.split(".")[0] |
|
|
| |
| backend = _BACKEND_ALIASES.get(lib, lib) |
|
|
| return backend |
|
|
|
|
| def infer_backend(array): |
| """Get the name of the library that defined the class of ``array`` - unless |
| ``array`` is directly a subclass of ``numpy.ndarray``, in which case assume |
| ``numpy`` is the desired backend. |
| """ |
| return _infer_class_backend_cached(array.__class__) |
|
|
|
|
| multi_class_priorities = { |
| "builtins": -2, |
| "numpy": -1, |
| "autoray.lazy": 1, |
| } |
|
|
|
|
| @functools.lru_cache(None) |
| def _infer_class_backend_multi_cached(classes): |
| return max( |
| map(_infer_class_backend_cached, classes), |
| key=lambda n: multi_class_priorities.get(n, 0), |
| ) |
|
|
|
|
| def infer_backend_multi(*arrays): |
| """Infer which backend should be used for a function that takes multiple |
| arguments. This assigns a priority to each backend, and returns the backend |
| with the highest priority. By default, the priority is: |
| |
| - ``builtins``: -2 |
| - ``numpy``: -1 |
| - other backends: 0 |
| - ``autoray.lazy``: 1 |
| |
| I.e. when mixing with ``numpy``, other array libraries are preferred, when |
| mixing with ``autoray.lazy``, ``autoray.lazy`` is preferred. This has quite |
| low overhead due to caching. |
| """ |
| return _infer_class_backend_multi_cached( |
| tuple(array.__class__ for array in arrays) |
| ) |
|
|
|
|
| |
| |
| _CREATION_ROUTINES = { |
| "empty", |
| "eye", |
| "full", |
| "identity", |
| "ones", |
| "zeros", |
| |
| |
| |
| |
| |
| } |
|
|
| |
| _CREATION_INJECT = {} |
|
|
|
|
| def register_creation_routine( |
| backend, fn, inject_dtype=True, inject_device=False |
| ): |
| """Register a function that creates a new array, with `dtype` and possibly |
| `device` kwargs, that should be inferred from the like argument. This is |
| not necessary for array creation routines that don't accept either. |
| |
| Parameters |
| ---------- |
| backend : str |
| The backend to register the function for. |
| fn : str |
| The name of the function to register. |
| inject_dtype : bool, optional |
| Whether to inject a `dtype` argument based on the `like` argument. |
| inject_device : bool, optional |
| Whether to inject a `device` argument based on the `like` argument. |
| """ |
| _CREATION_INJECT[backend, fn] = (inject_dtype, inject_device) |
|
|
|
|
| def _choose_backend(fn, args, kwargs, like=None): |
| """Private function to choose a backend based on function name and |
| signature, which passes args and kwargs by reference for performance and |
| also to allow injection of dtype and device arguments for array creation |
| routines. |
| """ |
| if like is None: |
| |
| return _infer_auto(fn, args, kwargs) |
| elif isinstance(like, str): |
| |
| return like |
| else: |
| |
| backend = _infer_class_backend_cached(like.__class__) |
|
|
| |
| if fn in _CREATION_ROUTINES: |
| try: |
| inject_dtype, inject_device = _CREATION_INJECT[backend, fn] |
| except KeyError: |
| |
| inject_dtype = True |
| inject_device = False |
| _CREATION_INJECT[backend, fn] = (inject_dtype, inject_device) |
|
|
| if inject_dtype: |
| kwargs.setdefault("dtype", getattr(like, "dtype", type(like))) |
| if inject_device: |
| kwargs.setdefault("device", like.device) |
|
|
| return backend |
|
|
|
|
| def choose_backend(fn, *args, like=None, **kwargs): |
| """Choose a backend based on function name, arguments, and the ``like`` |
| keyword argument. The default, if ``like`` is not specified, is to infer |
| the backend from the function call, the default of which is simply to use |
| the first argument, if no custom dispatcher is found. Otherwise the |
| backend is chosen based on the ``like`` argument - which can be an explicit |
| backend name or an arbitrary object. |
| """ |
| return _choose_backend(fn, args, kwargs, like=like) |
|
|
|
|
| |
|
|
| |
| _BACKEND_ALIASES = {} |
|
|
| |
| _MODULE_ALIASES = {} |
|
|
| |
| _SUBMODULE_ALIASES = {} |
|
|
| |
| _FUNC_ALIASES = {} |
|
|
| |
| |
| _CUSTOM_WRAPPERS = {} |
|
|
| |
| |
| _FUNCS = {} |
|
|
| |
| |
| _COMPOSED_FUNCTION_GENERATORS = {} |
|
|
|
|
| def import_lib_fn(backend, fn): |
| |
| |
| |
| if fn in _COMPOSED_FUNCTION_GENERATORS: |
| return _COMPOSED_FUNCTION_GENERATORS[fn](backend) |
|
|
| try: |
| |
| |
| try: |
| full_location = _SUBMODULE_ALIASES[backend, fn] |
|
|
| |
| |
| only_fn = fn.split(".")[-1] |
|
|
| except KeyError: |
| full_location = backend |
|
|
| |
| |
| split_fn = fn.split(".") |
| full_location = ".".join([full_location] + split_fn[:-1]) |
| only_fn = split_fn[-1] |
|
|
| |
| |
| |
| |
| |
| |
| for k, v in _MODULE_ALIASES.items(): |
| if full_location[: len(k)] == k: |
| full_location = full_location.replace(k, v, 1) |
| break |
|
|
| |
| |
| fn_name = _FUNC_ALIASES.get((backend, fn), only_fn) |
|
|
| |
| try: |
| lib = importlib.import_module(full_location) |
| except ImportError: |
| if "." in full_location: |
| |
| mod, *submods = full_location.split(".") |
| lib = importlib.import_module(mod) |
| |
| for submod in submods: |
| lib = getattr(lib, submod) |
| else: |
| |
| raise AttributeError |
|
|
| |
| wrapper = _CUSTOM_WRAPPERS.get((backend, fn), lambda fn: fn) |
|
|
| |
| lib_fn = _FUNCS[backend, fn] = wrapper(getattr(lib, fn_name)) |
|
|
| except AttributeError: |
| |
| backend_alt = backend + "[alt]" |
| if backend_alt in _MODULE_ALIASES: |
| return import_lib_fn(backend_alt, fn) |
|
|
| raise ImportError( |
| f"autoray couldn't find function '{fn}' for " |
| f"backend '{backend.replace('[alt]', '')}'." |
| ) |
|
|
| return lib_fn |
|
|
|
|
| def get_lib_fn(backend, fn): |
| """Cached retrieval of correct function for backend, all the logic for |
| finding the correct funtion only runs the first time. |
| |
| Parameters |
| ---------- |
| backend : str |
| The module defining the array class to dispatch on. |
| fn : str |
| The function to retrieve. |
| |
| Returns |
| ------- |
| callable |
| """ |
|
|
| try: |
| lib_fn = _FUNCS[backend, fn] |
| except KeyError: |
| lib_fn = import_lib_fn(backend, fn) |
| return lib_fn |
|
|
|
|
| |
|
|
|
|
| def register_function(backend, name, fn, wrap=False): |
| """Directly provide your own function. |
| |
| Parameters |
| ---------- |
| backend : str |
| The name of the backend to register the function for. |
| name : str |
| Name of the function, e.g. `'sum'` or `'linalg.svd'`. |
| fn : callable |
| The function to register. |
| wrap : bool, optional |
| Whether to wrap the old function like ``fn(old_fn)`` rather than |
| directly supply the entire new function. |
| """ |
| if wrap: |
| old = get_lib_fn(backend, name) |
| _FUNCS[backend, name] = fn(old) |
| else: |
| _FUNCS[backend, name] = fn |
|
|
|
|
| |
|
|
| TREE_MAP_REGISTRY = {} |
| TREE_APPLY_REGISTRY = {} |
| TREE_ITER_REGISTRY = {} |
|
|
|
|
| def tree_register_container(cls, mapper, iterator, applier): |
| """Register a new container type for use with ``tree_map`` and |
| ``tree_apply``. |
| |
| Parameters |
| ---------- |
| cls : type |
| The container type to register. |
| mapper : callable |
| A function that takes ``f``, ``tree`` and ``is_leaf`` and returns a new |
| tree of type ``cls`` with ``f`` applied to all leaves. |
| applier : callable |
| A function that takes ``f``, ``tree`` and ``is_leaf`` and applies ``f`` |
| to all leaves in ``tree``. |
| """ |
| TREE_MAP_REGISTRY[cls] = mapper |
| TREE_ITER_REGISTRY[cls] = iterator |
| TREE_APPLY_REGISTRY[cls] = applier |
|
|
|
|
| IS_CONTAINER_CACHE = {} |
|
|
|
|
| def is_not_container(x): |
| """The default function to determine if an object is a leaf. This simply |
| checks if the object is an instance of any of the registered container |
| types. |
| """ |
| try: |
| return IS_CONTAINER_CACHE[x.__class__] |
| except KeyError: |
| isleaf = not any(isinstance(x, cls) for cls in TREE_MAP_REGISTRY) |
| IS_CONTAINER_CACHE[x.__class__] = isleaf |
| return isleaf |
|
|
|
|
| def is_array(x): |
| """An alternative leaf tester for addressing only arrays within trees.""" |
| return hasattr(x, "shape") |
|
|
|
|
| def identity(f, tree, is_leaf): |
| return tree |
|
|
|
|
| TREE_MAPPER_CACHE = {} |
|
|
|
|
| def tree_map(f, tree, is_leaf=is_not_container): |
| """Map ``f`` over all leaves in ``tree``, returning a new pytree. |
| |
| Parameters |
| ---------- |
| f : callable |
| A function to apply to all leaves in ``tree``. |
| tree : pytree |
| A nested sequence of tuples, lists, dicts and other objects. |
| is_leaf : callable |
| A function to determine if an object is a leaf, ``f`` is only applied |
| to objects for which ``is_leaf(x)`` returns ``True``. |
| |
| Returns |
| ------- |
| pytree |
| """ |
| if is_leaf(tree): |
| return f(tree) |
|
|
| try: |
| return TREE_MAPPER_CACHE[tree.__class__](f, tree, is_leaf) |
| except KeyError: |
| |
| for cls, mapper in reversed(TREE_MAP_REGISTRY.items()): |
| if isinstance(tree, cls): |
| break |
| else: |
| |
| mapper = identity |
| TREE_MAPPER_CACHE[tree.__class__] = mapper |
| return mapper(f, tree, is_leaf) |
|
|
|
|
| def empty(tree, is_leaf): |
| return iter(()) |
|
|
|
|
| TREE_ITER_CACHE = {} |
|
|
|
|
| def tree_iter(tree, is_leaf=is_not_container): |
| """Iterate over all leaves in ``tree``. |
| |
| Parameters |
| ---------- |
| f : callable |
| A function to apply to all leaves in ``tree``. |
| tree : pytree |
| A nested sequence of tuples, lists, dicts and other objects. |
| is_leaf : callable |
| A function to determine if an object is a leaf, ``f`` is only applied |
| to objects for which ``is_leaf(x)`` returns ``True``. |
| """ |
| if is_leaf(tree): |
| yield tree |
| return |
|
|
| try: |
| yield from TREE_ITER_CACHE[tree.__class__](tree, is_leaf) |
| except KeyError: |
| |
| for cls, iterator in reversed(TREE_ITER_REGISTRY.items()): |
| if isinstance(tree, cls): |
| break |
| else: |
| |
| iterator = empty |
| TREE_ITER_CACHE[tree.__class__] = iterator |
| yield from iterator(tree, is_leaf) |
|
|
|
|
| def nothing(f, tree, is_leaf): |
| pass |
|
|
|
|
| TREE_APPLIER_CACHE = {} |
|
|
|
|
| def tree_apply(f, tree, is_leaf=is_not_container): |
| """Apply ``f`` to all leaves in ``tree``, no new pytree is built. |
| |
| Parameters |
| ---------- |
| f : callable |
| A function to apply to all leaves in ``tree``. |
| tree : pytree |
| A nested sequence of tuples, lists, dicts and other objects. |
| is_leaf : callable |
| A function to determine if an object is a leaf, ``f`` is only applied |
| to objects for which ``is_leaf(x)`` returns ``True``. |
| """ |
| if is_leaf(tree): |
| f(tree) |
| return |
|
|
| try: |
| TREE_APPLIER_CACHE[tree.__class__](f, tree, is_leaf) |
| except KeyError: |
| |
| for cls, applier in reversed(TREE_APPLY_REGISTRY.items()): |
| if isinstance(tree, cls): |
| break |
| else: |
| |
| applier = nothing |
| TREE_APPLIER_CACHE[tree.__class__] = applier |
| applier(f, tree, is_leaf) |
|
|
|
|
| class Leaf: |
| """A singleton object to use as a placeholder in a pytree, for |
| unflattening. |
| """ |
|
|
| __slots__ = () |
|
|
| def __repr__(self): |
| return "Leaf" |
|
|
|
|
| LEAF = Leaf() |
|
|
|
|
| def is_leaf_placeholder(x): |
| |
| return x.__class__ is Leaf |
|
|
|
|
| def tree_flatten(tree, is_leaf=is_not_container, get_ref=False): |
| """Flatten ``tree`` into a list of leaves. |
| |
| Parameters |
| ---------- |
| tree : pytree |
| A nested sequence of tuples, lists, dicts and other objects. |
| is_leaf : callable |
| A function to determine if an object is a leaf, only objects for which |
| ``is_leaf(x)`` returns ``True`` are returned in the flattened list. |
| get_ref : bool |
| If ``True``, a reference tree is also returned which can be used to |
| reconstruct the original tree from a flattened list. |
| |
| Returns |
| ------- |
| objs : list |
| The flattened list of leaf objects. |
| (ref_tree) : pytree |
| If ``get_ref`` is ``True``, a reference tree, with leaves of ``Leaf``, |
| is returned which can be used to reconstruct the original tree. |
| """ |
| objs = [] |
| if get_ref: |
| |
|
|
| def f(x): |
| objs.append(x) |
| return LEAF |
|
|
| ref_tree = tree_map(f, tree, is_leaf) |
| return objs, ref_tree |
| else: |
| tree_apply(objs.append, tree, is_leaf) |
| return objs |
|
|
|
|
| def tree_unflatten(objs, tree, is_leaf=is_leaf_placeholder): |
| """Unflatten ``objs`` into a pytree of the same structure as ``tree``. |
| |
| Parameters |
| ---------- |
| objs : sequence |
| A sequence of objects to be unflattened into a pytree. |
| tree : pytree |
| A nested sequence of tuples, lists, dicts and other objects, the objs |
| will be inserted into a new pytree of the same structure. |
| is_leaf : callable |
| A function to determine if an object is a leaf, only objects for which |
| ``is_leaf(x)`` returns ``True`` will have the next item from ``objs`` |
| inserted. By default checks for the ``Leaf`` object inserted by |
| ``tree_flatten(..., get_ref=True)``. |
| |
| Returns |
| ------- |
| pytree |
| """ |
| objs = iter(objs) |
| return tree_map(lambda _: next(objs), tree, is_leaf) |
|
|
|
|
| def tree_map_tuple(f, tree, is_leaf): |
| return tuple(tree_map(f, x, is_leaf) for x in tree) |
|
|
|
|
| def tree_iter_tuple(tree, is_leaf): |
| for x in tree: |
| yield from tree_iter(x, is_leaf) |
|
|
|
|
| def tree_apply_tuple(f, tree, is_leaf): |
| for x in tree: |
| tree_apply(f, x, is_leaf) |
|
|
|
|
| tree_register_container( |
| tuple, tree_map_tuple, tree_iter_tuple, tree_apply_tuple |
| ) |
|
|
|
|
| def tree_map_list(f, tree, is_leaf): |
| return [tree_map(f, x, is_leaf) for x in tree] |
|
|
|
|
| def tree_iter_list(tree, is_leaf): |
| for x in tree: |
| yield from tree_iter(x, is_leaf) |
|
|
|
|
| def tree_apply_list(f, tree, is_leaf): |
| for x in tree: |
| tree_apply(f, x, is_leaf) |
|
|
|
|
| tree_register_container(list, tree_map_list, tree_iter_list, tree_apply_list) |
|
|
|
|
| def tree_map_dict(f, tree, is_leaf): |
| return {k: tree_map(f, v, is_leaf) for k, v in tree.items()} |
|
|
|
|
| def tree_iter_dict(tree, is_leaf): |
| for v in tree.values(): |
| yield from tree_iter(v, is_leaf) |
|
|
|
|
| def tree_apply_dict(f, tree, is_leaf): |
| for v in tree.values(): |
| tree_apply(f, v, is_leaf) |
|
|
|
|
| tree_register_container(dict, tree_map_dict, tree_iter_dict, tree_apply_dict) |
|
|
|
|
| |
|
|
|
|
| class Composed: |
| """Compose an ``autoray.do`` using function. See the main wrapper |
| ``compose``. |
| """ |
|
|
| def __init__(self, fn, name=None): |
| self._default_fn = fn |
| if name is None: |
| name = fn.__name__ |
| self._name = name |
| self._supply_backend = "backend" in signature(fn).parameters |
|
|
| |
| |
| |
| _COMPOSED_FUNCTION_GENERATORS[self._name] = self.make_function |
|
|
| def register(self, backend, fn=None): |
| """Register a different implementation for ``backend``.""" |
| if fn is not None: |
| register_function(backend, self._name, fn) |
| else: |
| |
| def wrapper(fn): |
| register_function(backend, self._name, fn) |
| return fn |
|
|
| return wrapper |
|
|
| def make_function(self, backend): |
| """Make a new function for the specific ``backend``.""" |
| if self._supply_backend: |
| |
| fn = functools.wraps(self._default_fn)( |
| functools.partial(self._default_fn, backend=backend) |
| ) |
| else: |
| fn = self._default_fn |
| self.register(backend, fn) |
| return fn |
|
|
| def __call__(self, *args, like=None, **kwargs): |
| backend = _choose_backend(self._name, args, kwargs, like=like) |
| |
| fn = get_lib_fn(backend, self._name) |
| return fn(*args, **kwargs) |
|
|
| def __repr__(self): |
| return f"Composed('{self._name}')" |
|
|
|
|
| def compose(fn, *, name=None): |
| """Take a function consisting of multiple ``autoray.do`` calls and compose |
| it into a new, single, named function, registered with ``autoray.do``. |
| |
| This creates a default implementation of this function for each new backend |
| encountered without explicitly having to write each out, but also allows |
| for specific implementations to be overridden for specific backends. |
| |
| If the function takes a ``backend`` argument, it will be supplied with the |
| backend name, to save having to re-choose the backend. |
| |
| Specific implementations can be provided by calling the ``register`` method |
| of the composed function, or it can itself be used like a decorator:: |
| |
| @compose |
| def foo(x): |
| ... |
| |
| @foo.register("numpy") |
| @numba.njit |
| def foo_numba(x): |
| ... |
| |
| Parameters |
| ---------- |
| fn : callable |
| The funtion to compose, and its default implementation. |
| name : str, optional |
| The name of the composed function. If not provided, the name of the |
| function will be used. |
| """ |
| if fn is None: |
| return functools.partial(compose, name=name) |
| return functools.wraps(fn)(Composed(fn, name)) |
|
|
|
|
| |
|
|
|
|
| @compose |
| def shape(x): |
| """Get the shape of an array as a tuple of int. This should be preferred |
| to calling `x.shape` directly, as it: |
| |
| 1. Allows customization (e.g. for torch and aesara which return |
| different types for shape - use `@shape.register(backend)` to |
| customize the behavior from this default implementation). |
| 2. Can be used on nested lists and tuples, without calling numpy. |
| |
| Parameters |
| ---------- |
| x : array_like |
| The array to get the shape of. It can be an arbitrary nested list or |
| tuple of arrays and scalars, but is assumed not to be ragged. |
| |
| Returns |
| ------- |
| shape : tuple of int |
| The size of each dimension of the array. |
| """ |
| try: |
| return x.shape |
| except AttributeError: |
| |
| if isinstance(x, (list, tuple)): |
| d = len(x) |
| if d != 0: |
| |
| |
| return (d,) + shape(x[0]) |
| return (d,) |
| return () |
|
|
|
|
| @compose |
| def ndim(x): |
| """Get the number of dimensions of an array. This should be preferred to |
| calling `x.ndim`, since not all backends implement that, and it can also be |
| called on nested lists and tuples. |
| |
| Parameters |
| ---------- |
| x : array_like |
| The array to get the number of dimensions of. It can be an arbitrary |
| nested list or tuple of arrays and scalars. |
| |
| Returns |
| ------- |
| ndim : int |
| """ |
| try: |
| return x.ndim |
| except AttributeError: |
| return len(shape(x)) |
|
|
|
|
| @compose |
| def size(x): |
| """Get the size, or number of elements, of an array. This should be |
| preferred to calling `x.size`, since not all backends implement that, and |
| it can also be called on nested lists and tuples. |
| |
| Parameters |
| ---------- |
| x : array_like |
| The array to get the size of. It can be an arbitrary nested list or |
| tuple of arrays and scalars. |
| |
| Returns |
| ------- |
| size : int |
| """ |
| try: |
| return x.size |
| except AttributeError: |
| return math.prod(shape(x)) |
|
|
|
|
| def conj(x): |
| """Array conjugate.""" |
| return do("conj", x) |
|
|
|
|
| def transpose(x, *args): |
| """Array transpose.""" |
| return do("transpose", x, *args) |
|
|
|
|
| def dag(x): |
| """Array Hermitian transpose.""" |
| try: |
| return x.H |
| except AttributeError: |
| backend = _infer_class_backend_cached(x.__class__) |
| return do("conj", do("transpose", x, like=backend), like=backend) |
|
|
|
|
| def real(x): |
| """Array real part.""" |
| return do("real", x) |
|
|
|
|
| def imag(x): |
| """Array imaginary part.""" |
| return do("imag", x) |
|
|
|
|
| def reshape(x, shape): |
| """Array reshaped.""" |
| try: |
| return x.reshape(shape) |
| except AttributeError: |
| return do("reshape", x, shape) |
|
|
|
|
| def to_backend_dtype(dtype_name, like): |
| """Turn string specifier ``dtype_name`` into dtype of backend ``like``.""" |
| if not isinstance(like, str): |
| like = _infer_class_backend_cached(like.__class__) |
|
|
| try: |
| return get_lib_fn(like, dtype_name) |
| except ImportError: |
| return dtype_name |
|
|
|
|
| @compose |
| def get_dtype_name(x): |
| """Find string specifier ``dtype_name`` of array ``x``.""" |
| dtype = x.dtype |
| try: |
| return dtype.name |
| except AttributeError: |
| return str(dtype) |
|
|
|
|
| _COMPLEX_DTYPES = {"complex64", "complex128"} |
| _DOUBLE_DTYPES = {"float64", "complex128"} |
| _DTYPE_MAP = { |
| (False, False): "float32", |
| (False, True): "float64", |
| (True, False): "complex64", |
| (True, True): "complex128", |
| } |
|
|
|
|
| def get_common_dtype(*arrays): |
| """Compute the minimal dtype sufficient for ``arrays``.""" |
| dtypes = set(map(get_dtype_name, arrays)) |
| has_complex = not _COMPLEX_DTYPES.isdisjoint(dtypes) |
| has_double = not _DOUBLE_DTYPES.isdisjoint(dtypes) |
| return _DTYPE_MAP[has_complex, has_double] |
|
|
|
|
| def astype(x, dtype_name, **kwargs): |
| """Cast array as type ``dtype_name`` - tries ``x.astype`` first.""" |
| dtype = to_backend_dtype(dtype_name, like=x) |
| try: |
| return x.astype(dtype, **kwargs) |
| except AttributeError: |
| return do("astype", x, dtype, **kwargs) |
|
|
|
|
| def to_numpy(x): |
| """Get a numpy version of array ``x``.""" |
| return do("to_numpy", x) |
|
|
|
|
| |
|
|
|
|
| def svd_not_full_matrices_wrapper(fn): |
| @functools.wraps(fn) |
| def default_not_full_matrices(*args, **kwargs): |
| kwargs.setdefault("full_matrices", False) |
| return fn(*args, **kwargs) |
|
|
| return default_not_full_matrices |
|
|
|
|
| def svd_sUV_to_UsVH_wrapper(fn): |
| @functools.wraps(fn) |
| def numpy_like(*args, **kwargs): |
| s, U, V = fn(*args, **kwargs) |
| return U, s, dag(V) |
|
|
| return numpy_like |
|
|
|
|
| def svd_UsV_to_UsVH_wrapper(fn): |
| @functools.wraps(fn) |
| def numpy_like(*args, **kwargs): |
| U, s, V = fn(*args, **kwargs) |
| return U, s, dag(V) |
|
|
| return numpy_like |
|
|
|
|
| def svd_manual_full_matrices_kwarg(fn): |
| @functools.wraps(fn) |
| def numpy_like(*args, full_matrices=False, **kwargs): |
| U, s, VH = fn(*args, **kwargs) |
|
|
| if not full_matrices: |
| U, VH = U[:, : s.size], VH[: s.size, :] |
|
|
| return U, s, VH |
|
|
| return numpy_like |
|
|
|
|
| def qr_allow_fat(fn): |
| @functools.wraps(fn) |
| def numpy_like(a, **kwargs): |
| m, n = shape(a) |
|
|
| if m >= n: |
| |
| return fn(a, **kwargs) |
|
|
| Q, R_sq = fn(a[:, :m]) |
| R_r = dag(Q) @ a[:, m:] |
| R = do("concatenate", (R_sq, R_r), axis=1, like=a) |
|
|
| return Q, R |
|
|
| return numpy_like |
|
|
|
|
| def tril_to_band_part(fn): |
| @functools.wraps(fn) |
| def numpy_like(x, k=0): |
| if k < 0: |
| raise ValueError( |
| "'k' must be positive to recreate 'numpy.tril' " |
| "behaviour with 'tensorflow.matrix_band_part'." |
| ) |
|
|
| return fn(x, -1, k) |
|
|
| return numpy_like |
|
|
|
|
| def triu_to_band_part(fn): |
| @functools.wraps(fn) |
| def numpy_like(x, k=0): |
| if k > 0: |
| raise ValueError( |
| "'k' must be negative to recreate 'numpy.triu' " |
| "behaviour with 'tensorflow.matrix_band_part'." |
| ) |
|
|
| return fn(x, -k, -1) |
|
|
| return numpy_like |
|
|
|
|
| def cholesky_lower(fn): |
| @functools.wraps(fn) |
| def cholesky_numpy_like(a): |
| return fn(a, lower=True) |
|
|
| return cholesky_numpy_like |
|
|
|
|
| def binary_allow_1d_rhs_wrap(fn): |
| @functools.wraps(fn) |
| def allow_1d_rhs(a, b): |
| need_to_convert = ndim(a) != ndim(b) |
| if need_to_convert: |
| b = reshape(b, (*shape(b), 1)) |
| x = fn(a, b) |
| if need_to_convert: |
| x = reshape(x, shape(x)[:-1]) |
| return x |
|
|
| return allow_1d_rhs |
|
|
|
|
| def scale_random_uniform_manually(fn): |
| @functools.wraps(fn) |
| def numpy_like(low=0.0, high=1.0, size=None, dtype=None, **kwargs): |
| if size is None: |
| size = () |
|
|
| x = fn(size, **kwargs) |
|
|
| if (low != 0.0) or (high != 1.0): |
| x = (high - low) * x + low |
|
|
| if (dtype is not None) and get_dtype_name(x) != dtype: |
| x = astype(x, dtype) |
| return x |
|
|
| return numpy_like |
|
|
|
|
| def scale_random_normal_manually(fn): |
| @functools.wraps(fn) |
| def numpy_like(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): |
| if size is None: |
| size = () |
|
|
| x = fn(size, **kwargs) |
|
|
| if (loc != 0.0) or (scale != 1.0): |
| x = scale * x + loc |
|
|
| if (dtype is not None) and get_dtype_name(x) != dtype: |
| x = astype(x, dtype) |
| return x |
|
|
| return numpy_like |
|
|
|
|
| def with_dtype_wrapper(fn): |
| """Add ability to handle `dtype` keyword. |
| If not None, `dtype` should be specified as a string, otherwise conversion |
| will happen regardless. |
| """ |
|
|
| @functools.wraps(fn) |
| def with_dtype(*args, dtype=None, **kwargs): |
| A = fn(*args, **kwargs) |
| if (dtype is not None) and (dtype != get_dtype_name(A)): |
| A = astype(A, dtype) |
| return A |
|
|
| return with_dtype |
|
|
|
|
| def translate_wrapper(fn, translator): |
| """Wrap a function to match the api of another according to a translation. |
| The ``translator`` entries in the form of an ordered dict should have |
| entries like: |
| |
| (desired_kwarg: (backend_kwarg, default_value)) |
| |
| with the order defining the args of the function. |
| """ |
|
|
| @functools.wraps(fn) |
| def translated_function(*args, **kwargs): |
| new_kwargs = {} |
| translation = translator.copy() |
|
|
| |
| for arg_value in args: |
| new_arg_name = translation.popitem(last=False)[1][0] |
| new_kwargs[new_arg_name] = arg_value |
|
|
| |
| for key, value in kwargs.items(): |
| try: |
| new_kwargs[translation.pop(key)[0]] = value |
| except KeyError: |
| new_kwargs[key] = value |
|
|
| |
| for opt in translation.values(): |
| if len(opt) == 2: |
| |
| new_kwargs[opt[0]] = opt[1] |
| |
|
|
| return fn(**new_kwargs) |
|
|
| return translated_function |
|
|
|
|
| def make_translator(t): |
| return functools.partial(translate_wrapper, translator=OrderedDict(t)) |
|
|
|
|
| def complex_add_re_im(re, im): |
| return re + 1j * im |
|
|
|
|
| def allclose(x, y, rtol=1e-05, atol=1e-08): |
| return do("all", do("abs", x - y) <= atol + rtol * do("abs", y)) |
|
|
|
|
| def _handle_size_to_shape(size=None): |
| if size is None: |
| return () |
| try: |
| return tuple(size) |
| except TypeError: |
| return (size,) |
|
|
|
|
| |
|
|
|
|
| def wrap_args_kwargs_from_raw(fn): |
| """Take a function with signature ``(*args, **kwargs)`` and wrap it to |
| accept a single tuple of args and a dict of kwargs. |
| """ |
|
|
| @functools.wraps(fn) |
| def wrapped(args, kwargs): |
| return fn(*args, **kwargs) |
|
|
| return wrapped |
|
|
|
|
| def register_dispatch(fun, dispatcher, raw_signature=True): |
| """Register a new dispatcher, a function that takes the arguments and |
| keyword arguments of a function and returns the backend to use, when the |
| backend is not explicitly given. |
| |
| This is useful in case the backend to be used by a function cannot be |
| inferred from the first argument. |
| |
| Parameters |
| ---------- |
| fun : str |
| The name of the function to register the dispatcher for. |
| dispatcher : callable |
| The dispatcher function to use. This should take the arguments and |
| keyword arguments of the function and return the backend to use. |
| raw_signature : bool, optional |
| The ``dispatcher`` has signature ``(*args, **kwargs)`` if ``True``, |
| otherwise it has signature ``(args, kwargs)``. |
| """ |
| if raw_signature: |
| dispatcher = wrap_args_kwargs_from_raw(dispatcher) |
|
|
| _DISPATCHERS[fun] = dispatcher |
|
|
|
|
| def default_dispatcher(args, kwargs): |
| """Try to infer backend from first argument passed to function.""" |
| try: |
| return _infer_class_backend_cached(args[0].__class__) |
| except IndexError: |
| raise TypeError("No args to infer backend from.") |
|
|
|
|
| |
| |
| _DISPATCHERS = defaultdict(lambda: default_dispatcher) |
|
|
|
|
| def join_array_dispatcher(args, kwargs): |
| """Dispatcher for functions where first argument is a sequence.""" |
| try: |
| return _infer_class_backend_cached(args[0][0].__class__) |
| except (TypeError, ValueError): |
| |
| |
| return _infer_class_backend_cached(args[0].__class__) |
|
|
|
|
| |
| register_dispatch("concatenate", join_array_dispatcher, raw_signature=False) |
| register_dispatch("stack", join_array_dispatcher, raw_signature=False) |
| register_dispatch("block", join_array_dispatcher, raw_signature=False) |
| register_dispatch("vstack", join_array_dispatcher, raw_signature=False) |
| register_dispatch("hstack", join_array_dispatcher, raw_signature=False) |
| register_dispatch("dstack", join_array_dispatcher, raw_signature=False) |
| register_dispatch("column_stack", join_array_dispatcher, raw_signature=False) |
| register_dispatch("row_stack", join_array_dispatcher, raw_signature=False) |
|
|
|
|
| def einsum_dispatcher(args, kwargs): |
| """Dispatcher for handling einsum. |
| |
| einsum can be called with a str equation as the first argument, or with |
| 'interleaved' inputs. This dispatcher handles both cases and also takes |
| into account all arrays. |
| """ |
| return infer_backend_multi(*args) |
|
|
|
|
| register_dispatch("einsum", einsum_dispatcher, raw_signature=False) |
|
|
|
|
| def binary_dispatcher(args, kwargs): |
| """There are cases when we want to take into account both backends of two |
| arguments, e.g. a lazy variable and a constant array. |
| """ |
| return infer_backend_multi(*args[:2]) |
|
|
|
|
| register_dispatch("tensordot", binary_dispatcher, raw_signature=False) |
| register_dispatch("matmul", binary_dispatcher, raw_signature=False) |
| register_dispatch("multiply", binary_dispatcher, raw_signature=False) |
| register_dispatch("divide", binary_dispatcher, raw_signature=False) |
| register_dispatch("true_divide", binary_dispatcher, raw_signature=False) |
| register_dispatch("add", binary_dispatcher, raw_signature=False) |
| register_dispatch("subtract", binary_dispatcher, raw_signature=False) |
|
|
| |
|
|
| |
|
|
|
|
| class InjectDtypeDevice: |
| """Wrapper that injects defaultdtype and device arguments into a function""" |
|
|
| __slots__ = ("_fn", "_device", "_dtype") |
|
|
| def __init__(self, fn, device=None, dtype=None): |
| self._fn = fn |
| self._device = device |
| self._dtype = dtype |
|
|
| def __call__(self, *args, **kwargs): |
| if self._device is not None and "device" not in kwargs: |
| kwargs["device"] = self._device |
| if self._dtype is not None and "dtype" not in kwargs: |
| kwargs["dtype"] = self._dtype |
| return self._fn(*args, **kwargs) |
|
|
| def __repr__(self): |
| return ( |
| f"InjectDtypeDevice(fn={self._fn}, " |
| f"device={self._device}, " |
| f"dtype={self._dtype})" |
| ) |
|
|
|
|
| class AutoNamespace: |
| """Mimics a namespace, optionally for a specific backend, device, and |
| dtype, caching the lookup of functions, and injecting default device and |
| dtype arguments for certain creation routines. |
| |
| Parameters |
| ---------- |
| like : array_like, str, or None |
| The backend to use, or an object to infer the backend from. If None, |
| the default behavior is to use `autoray.do` and auto dispatch backend |
| at function call time. If given, the functions are cached at first |
| call. |
| device : str, optional |
| The device to use for the backend. If None, it will be inferred from |
| the `like` paramater is that is array-like or set to None. |
| dtype : str, optional |
| The dtype to use for the backend. If None, it will be inferred from |
| the `like` parameter if that is array-like or set to None. |
| submodule : str, optional |
| This is used internally when nesting attribute lookups, e.g. |
| `xp.random.normal`, `xp.linalg.eigh`. |
| """ |
|
|
| def __init__( |
| self, |
| like=None, |
| device=None, |
| dtype=None, |
| submodule=None, |
| ): |
| if like is None: |
| |
| self._backend = None |
| elif isinstance(like, str): |
| |
| self._backend = like |
| else: |
| |
| self._backend = _infer_class_backend_cached(like.__class__) |
|
|
| if device is None: |
| if like is not None: |
| if hasattr(like, "device"): |
| device = like.device |
| self._device = device |
|
|
| if dtype is None: |
| if like is not None: |
| if hasattr(like, "dtype"): |
| dtype = like.dtype |
| self._dtype = dtype |
| self._submodule = submodule |
|
|
| def _get_submodule(self, name): |
| new = object.__new__(type(self)) |
| new._backend = self._backend |
| new._device = self._device |
| new._dtype = self._dtype |
| new._submodule = name |
| return new |
|
|
| def _get_fn(self, name): |
| if self._submodule is not None: |
| |
| name = f"{self._submodule}.{name}" |
|
|
| if name in ("random", "linalg"): |
| |
| |
| return self._get_submodule(name) |
|
|
| if self._backend is None: |
| |
| return functools.partial(do, name) |
|
|
| fn = get_lib_fn(self._backend, name) |
|
|
| |
| if name in _CREATION_ROUTINES: |
| inject_dtype, inject_device = _CREATION_INJECT.get( |
| (self._backend, fn), (True, False) |
| ) |
|
|
| if not inject_dtype: |
| |
| dtype_to_inject = None |
| else: |
| dtype_to_inject = self._dtype |
|
|
| if not inject_device: |
| |
| device_to_inject = None |
| else: |
| device_to_inject = self._device |
|
|
| if (dtype_to_inject is not None) or (device_to_inject is not None): |
| |
| fn = InjectDtypeDevice(fn, device_to_inject, dtype_to_inject) |
|
|
| return fn |
|
|
| def __getattribute__(self, name): |
| try: |
| return object.__getattribute__(self, name) |
| except AttributeError: |
| x = self._get_fn(name) |
| super().__getattribute__("__dict__")[name] = x |
| return x |
|
|
| def __repr__(self): |
| return ( |
| f"{self.__class__.__name__}(" |
| f"backend={self._backend}, " |
| f"device={self._device}, " |
| f"dtype={self._dtype}, " |
| f"submodule={self._submodule}" |
| ")" |
| ) |
|
|
|
|
| numpy = AutoNamespace() |
|
|
|
|
| def get_namespace(like=None, device=None, dtype=None, submodule=None): |
| """Get an automatic namespace object. |
| |
| If `like` is None, the namespace essentially provides an alternative syntax |
| to `do`, dispatching each function at calltime, and allowing the backend |
| and function implementations to be dynamically updated. |
| |
| If `like` is supplied however, the backend is eagerly dispatched and |
| functions are loaded and cached specifically for that backend. In this |
| case, default `device` and `dtype` can also be specified for various array |
| creation routines, or if `like` is an array, inferred from that. |
| |
| Parameters |
| ---------- |
| like : array-like, str or None, optional |
| An array-like object to dispatch on, an explicit backend name, or None. |
| device : str or None, optional |
| The device to use for array creation, or None to infer from `like`. |
| dtype : str or None, optional |
| The data type to use for array creation, or None to infer from `like`. |
| |
| Returns |
| ------- |
| AutoNamespace |
| An automatic namespace object. |
| """ |
| return AutoNamespace( |
| like=like, |
| device=device, |
| dtype=dtype, |
| submodule=submodule, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| _MODULE_ALIASES["decimal"] = "math" |
| _MODULE_ALIASES["builtins"] = "numpy" |
|
|
|
|
| _builtin_dtype_lookup = { |
| int: "int64", |
| float: "float64", |
| complex: "complex128", |
| } |
|
|
|
|
| @get_dtype_name.register("builtins") |
| def builtins_get_dtype_name(x): |
| return _builtin_dtype_lookup[x.__class__] |
|
|
|
|
| _FUNCS["builtins", "complex"] = complex |
|
|
| |
|
|
|
|
| def numpy_to_numpy(x): |
| return do("asarray", x, like="numpy") |
|
|
|
|
| _MODULE_ALIASES["numpy.scipy"] = "scipy" |
| _FUNCS["numpy", "to_numpy"] = numpy_to_numpy |
| _FUNCS["numpy", "complex"] = complex_add_re_im |
| _FUNCS["builtins", "to_numpy"] = numpy_to_numpy |
| _SUBMODULE_ALIASES["numpy", "linalg.lu"] = "scipy.linalg" |
| _SUBMODULE_ALIASES["numpy", "linalg.expm"] = "scipy.linalg" |
| _CUSTOM_WRAPPERS["numpy", "linalg.svd"] = svd_not_full_matrices_wrapper |
| _CUSTOM_WRAPPERS["numpy", "random.normal"] = with_dtype_wrapper |
| _CUSTOM_WRAPPERS["numpy", "random.uniform"] = with_dtype_wrapper |
|
|
| |
|
|
|
|
| def cupy_to_numpy(x): |
| return x.get() |
|
|
|
|
| _MODULE_ALIASES["cupy.scipy"] = "cupyx.scipy" |
| _FUNCS["cupy", "to_numpy"] = cupy_to_numpy |
| _FUNCS["cupy", "complex"] = complex_add_re_im |
| _CUSTOM_WRAPPERS["cupy", "linalg.svd"] = svd_not_full_matrices_wrapper |
|
|
| |
|
|
|
|
| def jax_to_numpy(x): |
| return do("asarray", x, like="numpy") |
|
|
|
|
| class JaxDefaultRNG: |
| """Stateful but deterministic random number generator for JAX following |
| numpy's Generator API, compatible with `jax.jit`. |
| """ |
|
|
| def __init__(self, seed, **kwargs): |
| import jax |
|
|
| self.jax = jax |
| self.key = jax.random.key(seed, **kwargs) |
|
|
| def binomial(self, n, p, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| return self.jax.random.binomial( |
| subkey, n=n, p=p, shape=shape, **kwargs |
| ) |
|
|
| def choice(self, a, size=None, replace=True, p=None, axis=0, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| return self.jax.random.choice( |
| subkey, |
| a, |
| shape=shape, |
| replace=replace, |
| p=p, |
| axis=axis, |
| **kwargs, |
| ) |
|
|
| def exponential(self, scale=1.0, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| x = self.jax.random.exponential(subkey, shape=shape, **kwargs) |
| if scale != 1.0: |
| x *= scale |
| return x |
|
|
| def gumbel(self, loc=0.0, scale=1.0, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| x = self.jax.random.gumbel(subkey, shape=shape, **kwargs) |
| if scale != 1.0: |
| x *= scale |
| if loc != 0.0: |
| x += loc |
| return x |
|
|
| def integers(self, low, high=None, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| if high is None: |
| high = low |
| low = 0 |
| return self.jax.random.randint( |
| subkey, |
| shape=shape, |
| minval=low, |
| maxval=high, |
| **kwargs, |
| ) |
|
|
| def normal(self, loc=0.0, scale=1.0, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| x = self.jax.random.normal(subkey, shape=shape, **kwargs) |
| if scale != 1.0: |
| x *= scale |
| if loc != 0.0: |
| x += loc |
| return x |
|
|
| def permutation(self, x, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| return self.jax.random.permutation(subkey, x, **kwargs) |
|
|
| def poisson(self, lam=1.0, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| x = self.jax.random.poisson(subkey, lam, shape=shape, **kwargs) |
| return x |
|
|
| def random(self, size=None, **kwargs): |
| return self.uniform(size=size, **kwargs) |
|
|
| def uniform(self, low=0.0, high=1.0, size=None, **kwargs): |
| self.key, subkey = self.jax.random.split(self.key) |
| shape = _handle_size_to_shape(size) |
| return self.jax.random.uniform( |
| subkey, shape=shape, minval=low, maxval=high, **kwargs |
| ) |
|
|
|
|
| def jax_default_rng(seed, **kwargs): |
| if isinstance(seed, JaxDefaultRNG): |
| return seed |
| return JaxDefaultRNG(seed, **kwargs) |
|
|
|
|
| register_function("jax", "random.default_rng", jax_default_rng) |
| register_backend(JaxDefaultRNG, "jax") |
|
|
|
|
| _JAX_RANDOM_KEY = None |
|
|
|
|
| def jax_random_seed(seed=None): |
| from jax.random import PRNGKey |
|
|
| global _JAX_RANDOM_KEY |
| if seed is None: |
| from random import SystemRandom |
|
|
| seed = SystemRandom().randint(-(2**63), 2**63 - 1) |
| _JAX_RANDOM_KEY = PRNGKey(seed) |
|
|
|
|
| def jax_random_get_key(): |
| from jax.random import split |
|
|
| global _JAX_RANDOM_KEY |
| if _JAX_RANDOM_KEY is None: |
| jax_random_seed() |
| _JAX_RANDOM_KEY, subkey = split(_JAX_RANDOM_KEY) |
| return subkey |
|
|
|
|
| def jax_random_uniform(low=0.0, high=1.0, size=None, **kwargs): |
| from jax.random import uniform |
|
|
| if size is None: |
| size = () |
| return uniform( |
| jax_random_get_key(), shape=size, minval=low, maxval=high, **kwargs |
| ) |
|
|
|
|
| def jax_random_normal(loc=0.0, scale=1.0, size=None, **kwargs): |
| from jax.random import normal |
|
|
| if size is None: |
| size = () |
| x = normal(jax_random_get_key(), shape=size, **kwargs) |
| if scale != 1.0: |
| x *= scale |
| if loc != 0.0: |
| x += loc |
| return x |
|
|
|
|
| _BACKEND_ALIASES["jaxlib"] = "jax" |
| _MODULE_ALIASES["jax.scipy"] = "jax.scipy" |
| _MODULE_ALIASES["jax"] = "jax.numpy" |
| _SUBMODULE_ALIASES["jax", "complex"] = "jax.lax" |
| _SUBMODULE_ALIASES["jax", "linalg.expm"] = "jax.scipy.linalg" |
| _SUBMODULE_ALIASES["jax", "linalg.householder_product"] = "jax.lax.linalg" |
| |
| _CUSTOM_WRAPPERS["jax", "linalg.qr"] = qr_allow_fat |
| _CUSTOM_WRAPPERS["jax", "linalg.svd"] = svd_not_full_matrices_wrapper |
| _FUNCS["jax", "to_numpy"] = jax_to_numpy |
| _FUNCS["jax", "random.seed"] = jax_random_seed |
| _FUNCS["jax", "random.uniform"] = jax_random_uniform |
| _FUNCS["jax", "random.normal"] = jax_random_normal |
|
|
|
|
| |
|
|
|
|
| @shape.register("aesara") |
| def aesara_shape(x): |
| return x.type.shape |
|
|
|
|
| _MODULE_ALIASES["aesara"] = "aesara.tensor" |
| _FUNCS["aesara", "shape"] = aesara_shape |
|
|
|
|
| |
|
|
| _MODULE_ALIASES["autograd"] = "autograd.numpy" |
| _CUSTOM_WRAPPERS["autograd", "linalg.svd"] = svd_not_full_matrices_wrapper |
| _FUNCS["autograd", "complex"] = complex_add_re_im |
|
|
|
|
| |
|
|
|
|
| def dask_to_numpy(x): |
| return x.compute() |
|
|
|
|
| def dask_eye_wrapper(fn): |
| |
| @functools.wraps(fn) |
| def numpy_like(N, M=None, **kwargs): |
| return fn(N, M=M, **kwargs) |
|
|
| return numpy_like |
|
|
|
|
| _FUNCS["dask", "to_numpy"] = dask_to_numpy |
| _FUNCS["dask", "complex"] = complex_add_re_im |
| _FUNC_ALIASES["dask", "abs"] = "absolute" |
| _FUNC_ALIASES["dask", "identity"] = "eye" |
| _MODULE_ALIASES["dask"] = "dask.array" |
| _CUSTOM_WRAPPERS["dask", "linalg.svd"] = svd_manual_full_matrices_kwarg |
| _CUSTOM_WRAPPERS["dask", "linalg.cholesky"] = cholesky_lower |
| _CUSTOM_WRAPPERS["dask", "random.normal"] = with_dtype_wrapper |
| _CUSTOM_WRAPPERS["dask", "random.uniform"] = with_dtype_wrapper |
| _CUSTOM_WRAPPERS["dask", "eye"] = dask_eye_wrapper |
|
|
| |
|
|
|
|
| def mars_to_numpy(x): |
| return x.to_numpy() |
|
|
|
|
| _FUNCS["mars", "to_numpy"] = mars_to_numpy |
| _FUNCS["mars", "complex"] = complex_add_re_im |
| _MODULE_ALIASES["mars"] = "mars.tensor" |
| _CUSTOM_WRAPPERS["mars", "linalg.cholesky"] = cholesky_lower |
|
|
|
|
| |
|
|
|
|
| def ctf_array(x): |
| return do("astensor", x, like="ctf") |
|
|
|
|
| def ctf_to_numpy(x): |
| return x.to_nparray() |
|
|
|
|
| def ctf_count_nonzero(x): |
| return (x != 0).astype(int).sum() |
|
|
|
|
| @get_dtype_name.register("ctf") |
| def ctf_get_dtype_name(x): |
| return x.dtype.__name__ |
|
|
|
|
| _FUNCS["ctf", "array"] = ctf_array |
| _FUNCS["ctf", "complex"] = complex_add_re_im |
| _FUNCS["ctf", "allclose"] = allclose |
| _FUNCS["ctf", "to_numpy"] = ctf_to_numpy |
| _FUNCS["ctf", "count_nonzero"] = ctf_count_nonzero |
|
|
| _SUBMODULE_ALIASES["ctf", "float32"] = "numpy" |
| _SUBMODULE_ALIASES["ctf", "float64"] = "numpy" |
| _SUBMODULE_ALIASES["ctf", "complex64"] = "numpy" |
| _SUBMODULE_ALIASES["ctf", "complex128"] = "numpy" |
| _SUBMODULE_ALIASES["ctf", "linalg.svd"] = "ctf" |
| _SUBMODULE_ALIASES["ctf", "linalg.eigh"] = "ctf" |
| _SUBMODULE_ALIASES["ctf", "linalg.qr"] = "ctf" |
| _SUBMODULE_ALIASES["ctf", "linalg.norm"] = "ctf" |
|
|
| _FUNC_ALIASES["ctf", "random.uniform"] = "random" |
|
|
| _CUSTOM_WRAPPERS["ctf", "random.uniform"] = scale_random_uniform_manually |
|
|
|
|
| |
|
|
|
|
| def sparse_array(x): |
| return do("COO.from_numpy", x, like="sparse") |
|
|
|
|
| def sparse_to_numpy(x): |
| return x.todense() |
|
|
|
|
| def sparse_transpose(x, axes=None): |
| return x.transpose(axes) |
|
|
|
|
| def sparse_reshape(x, shape): |
| return x.reshape(shape) |
|
|
|
|
| def sparse_sum(x, axis=None, keepdims=False, dtype=None, out=None): |
| return x.sum(axis=axis, keepdims=keepdims, dtype=dtype, out=out) |
|
|
|
|
| def sparse_prod(x, axis=None, keepdims=False, dtype=None, out=None): |
| return x.prod(axis=axis, keepdims=keepdims, dtype=dtype, out=out) |
|
|
|
|
| def sparse_conj(x): |
| return x.conj() |
|
|
|
|
| def sparse_real(x): |
| return x.real |
|
|
|
|
| def sparse_imag(x): |
| return x.imag |
|
|
|
|
| def sparse_count_nonzero(x): |
| return x.nnz |
|
|
|
|
| def sparse_random_uniform(low=0.0, high=1.0, size=None, dtype=None, **kwargs): |
| def rvs(nnz): |
| return do( |
| "random.uniform", low, high, (nnz,), dtype=dtype, like="numpy" |
| ) |
|
|
| return do("random", size, data_rvs=rvs, **kwargs, like="sparse") |
|
|
|
|
| def sparse_random_normal(loc=0.0, scale=1.0, size=None, dtype=None, **kwargs): |
| def rvs(nnz): |
| return do( |
| "random.normal", loc, scale, (nnz,), dtype=dtype, like="numpy" |
| ) |
|
|
| return do("random", size, data_rvs=rvs, **kwargs, like="sparse") |
|
|
|
|
| _FUNCS["sparse", "array"] = sparse_array |
| _FUNCS["sparse", "to_numpy"] = sparse_to_numpy |
| _FUNCS["sparse", "transpose"] = sparse_transpose |
| _FUNCS["sparse", "reshape"] = sparse_reshape |
| _FUNCS["sparse", "sum"] = sparse_sum |
| _FUNCS["sparse", "prod"] = sparse_prod |
| _FUNCS["sparse", "conj"] = sparse_conj |
| _FUNCS["sparse", "real"] = sparse_real |
| _FUNCS["sparse", "real"] = sparse_real |
| _FUNCS["sparse", "imag"] = sparse_imag |
| _FUNCS["sparse", "complex"] = complex_add_re_im |
| _FUNCS["sparse", "count_nonzero"] = sparse_count_nonzero |
| _FUNCS["sparse", "random.uniform"] = sparse_random_uniform |
| _FUNCS["sparse", "random.normal"] = sparse_random_normal |
|
|
| _FUNC_ALIASES["sparse", "identity"] = "eye" |
|
|
| |
| for f in ( |
| "log", |
| "log2", |
| "log10", |
| "exp", |
| "sqrt", |
| "sign", |
| "sin", |
| "cos", |
| "tan", |
| "arcsin", |
| "arccos", |
| "arctan", |
| "sinh", |
| "cosh", |
| "tanh", |
| "arcsinh", |
| "arccosh", |
| "arctanh", |
| "tensordot", |
| |
| |
| |
| ): |
| _SUBMODULE_ALIASES["sparse", f] = "numpy" |
|
|
|
|
| |
|
|
|
|
| def tensorflow_to_numpy(x): |
| return x.numpy() |
|
|
|
|
| def tensorflow_indices(dimensions): |
| _meshgrid = get_lib_fn("tensorflow", "meshgrid") |
| _arange = get_lib_fn("tensorflow", "arange") |
| return _meshgrid(*map(_arange, dimensions), indexing="ij") |
|
|
|
|
| _MODULE_ALIASES["tensorflow.linalg"] = "tensorflow.linalg" |
| _MODULE_ALIASES["tensorflow.random"] = "tensorflow.random" |
| _MODULE_ALIASES["tensorflow"] = "tensorflow.experimental.numpy" |
|
|
| _FUNCS["tensorflow", "to_numpy"] = tensorflow_to_numpy |
| _FUNCS["tensorflow", "indices"] = tensorflow_indices |
|
|
| _FUNC_ALIASES["tensorflow", "astype"] = "cast" |
| _SUBMODULE_ALIASES["tensorflow", "cast"] = "tensorflow" |
| _SUBMODULE_ALIASES["tensorflow", "astype"] = "tensorflow" |
| _SUBMODULE_ALIASES["tensorflow", "complex"] = "tensorflow" |
|
|
| _CUSTOM_WRAPPERS["tensorflow", "linalg.svd"] = svd_sUV_to_UsVH_wrapper |
| _CUSTOM_WRAPPERS["tensorflow", "linalg.solve"] = binary_allow_1d_rhs_wrap |
| _CUSTOM_WRAPPERS["tensorflow", "random.uniform"] = make_translator( |
| [ |
| ("low", ("minval", 0.0)), |
| ("high", ("maxval", 1.0)), |
| ("size", ("shape", ())), |
| ] |
| ) |
| _CUSTOM_WRAPPERS["tensorflow", "random.normal"] = make_translator( |
| [ |
| ("loc", ("mean", 0.0)), |
| ("scale", ("stddev", 1.0)), |
| ("size", ("shape", ())), |
| ] |
| ) |
|
|
|
|
| def tensorflow_pad_wrap(tf_pad): |
| def numpy_like(array, pad_width, mode="constant", constant_values=0): |
| if mode != "constant": |
| raise NotImplementedError |
|
|
| try: |
| if len(pad_width) == 1: |
| pad_width = pad_width * ndim(array) |
| except TypeError: |
| pad_width = ((pad_width, pad_width),) * ndim(array) |
|
|
| return tf_pad( |
| array, pad_width, mode="CONSTANT", constant_values=constant_values |
| ) |
|
|
| return numpy_like |
|
|
|
|
| _CUSTOM_WRAPPERS["tensorflow", "pad"] = tensorflow_pad_wrap |
| _SUBMODULE_ALIASES["tensorflow", "pad"] = "tensorflow" |
|
|
|
|
| register_creation_routine("tensorflow", "linspace", inject_dtype=False) |
|
|
|
|
| |
|
|
|
|
| @shape.register("torch") |
| def torch_shape(x): |
| |
| return tuple(map(int, x.shape)) |
|
|
|
|
| @size.register("torch") |
| def torch_size(x): |
| return x.numel() |
|
|
|
|
| def torch_to_numpy(x): |
| return x.detach().cpu().numpy() |
|
|
|
|
| def torch_copy(x): |
| return x.detach().clone() |
|
|
|
|
| def torch_transpose(x, axes=None): |
| if axes is None: |
| axes = reversed(range(0, x.ndimension())) |
| return x.permute(*axes) |
|
|
|
|
| def torch_count_nonzero(x): |
| return do("sum", x != 0, like="torch") |
|
|
|
|
| def torch_astype(x, dtype): |
| return x.to(dtype=to_backend_dtype(dtype, like=x)) |
|
|
|
|
| @functools.lru_cache(None) |
| def _torch_get_dtype_name(dtype): |
| return str(dtype).split(".")[-1] |
|
|
|
|
| @get_dtype_name.register("torch") |
| def torch_get_dtype_name(x): |
| return _torch_get_dtype_name(x.dtype) |
|
|
|
|
| def torch_real(x): |
| |
| try: |
| if x.is_complex(): |
| return x.real |
| except AttributeError: |
| pass |
| return x |
|
|
|
|
| def torch_imag(x): |
| |
| try: |
| if x.is_complex(): |
| return x.imag |
| except AttributeError: |
| pass |
| return do("zeros_like", x) |
|
|
|
|
| def torch_linalg_solve_wrap(fn): |
| @binary_allow_1d_rhs_wrap |
| def numpy_like(a, b): |
| return fn(b, a)[0] |
|
|
| return numpy_like |
|
|
|
|
| def torch_linalg_eigh(x): |
| return tuple(do("symeig", x, eigenvectors=True, like="torch")) |
|
|
|
|
| def torch_linalg_eigvalsh(x): |
| return do("symeig", x, eigenvectors=False, like="torch")[0] |
|
|
|
|
| def torch_tensordot_wrap(fn): |
| @functools.wraps(fn) |
| def numpy_like(a, b, axes=2): |
| return fn(a, b, dims=axes) |
|
|
| return numpy_like |
|
|
|
|
| def torch_pad(array, pad_width, mode="constant", constant_values=0): |
| if mode != "constant": |
| raise NotImplementedError |
|
|
| try: |
| |
| |
| pad = tuple(itertools.chain.from_iterable(pad_width))[::-1] |
|
|
| |
| if len(pad) == 2: |
| pad = pad * array.ndimension() |
|
|
| except TypeError: |
| |
| pad = (pad_width,) * 2 * array.ndimension() |
|
|
| return do( |
| "nn.functional.pad", |
| array, |
| pad=pad, |
| mode=mode, |
| value=constant_values, |
| like="torch", |
| ) |
|
|
|
|
| def torch_split_wrap(fn): |
| |
| |
| @functools.wraps(fn) |
| def numpy_like(ary, indices_or_sections, axis=0, **kwargs): |
| if isinstance(indices_or_sections, int): |
| split_size = shape(ary)[axis] // indices_or_sections |
| return fn(ary, split_size, dim=axis, **kwargs) |
| else: |
| |
| if len(indices_or_sections) == 0: |
| return (ary,) |
|
|
| diff = do( |
| "diff", |
| indices_or_sections, |
| prepend=0, |
| append=shape(ary)[axis], |
| like="numpy", |
| ) |
| diff = list(diff) |
| return fn(ary, diff, dim=axis) |
|
|
| return numpy_like |
|
|
|
|
| def torch_zeros_ones_wrap(fn): |
| @functools.wraps(fn) |
| def numpy_like(shape, dtype=None, **kwargs): |
| if dtype is not None: |
| dtype = to_backend_dtype(dtype, like="torch") |
| return fn(shape, dtype=dtype, **kwargs) |
|
|
| return numpy_like |
|
|
|
|
| def torch_eye_wrap(fn): |
| @functools.wraps(fn) |
| def numpy_like(N, M=None, dtype=None, **kwargs): |
| if dtype is not None: |
| dtype = to_backend_dtype(dtype, like="torch") |
| if M is not None: |
| return fn(N, m=M, dtype=dtype, **kwargs) |
| else: |
| return fn(N, dtype=dtype, **kwargs) |
|
|
| return numpy_like |
|
|
|
|
| def torch_sort_wrap(fn): |
| @functools.wraps(fn) |
| def numpy_like(a, axis=-1): |
| return fn(a, dim=axis)[0] |
|
|
| return numpy_like |
|
|
|
|
| def torch_indices(dimensions): |
| _meshgrid = get_lib_fn("torch", "meshgrid") |
| _arange = get_lib_fn("torch", "arange") |
| return _meshgrid(*map(_arange, dimensions), indexing="ij") |
|
|
|
|
| def torch_flip_wrap(torch_flip): |
| def numpy_like(x, axis=None): |
| if axis is None: |
| dims = tuple(range(x.ndimension())) |
| elif isinstance(axis, int): |
| dims = (axis,) |
| else: |
| |
| dims = axis |
| return torch_flip(x, dims) |
|
|
| return numpy_like |
|
|
|
|
| class TorchDefaultRNG: |
| def __init__(self, seed, device=None): |
| import torch |
|
|
| self._torch = torch |
| self._generator = torch.Generator(device=device) |
| self._generator.manual_seed(seed) |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| def integers(self, low, high=None, size=None, **kwargs): |
| if high is None: |
| high = low |
| low = 0 |
| size = _handle_size_to_shape(size) |
| return self._torch.randint( |
| low, high, size, generator=self._generator, **kwargs |
| ) |
|
|
| |
| |
|
|
| def normal(self, loc=0.0, scale=1.0, size=None, **kwargs): |
| size = _handle_size_to_shape(size) |
| x = self._torch.randn(size, generator=self._generator, **kwargs) |
| if scale != 1.0: |
| x = x * scale |
| if loc != 0.0: |
| x = x + loc |
| return x |
|
|
| def random(self, size=None, **kwargs): |
| size = _handle_size_to_shape(size) |
| return self._torch.rand(size, generator=self._generator, **kwargs) |
|
|
| def permutation(self, x, **kwargs): |
| if isinstance(x, int): |
| return self._torch.randperm(x, generator=self._generator, **kwargs) |
|
|
| axis = kwargs.get("axis", 0) |
| n = x.shape[axis] |
| perm_indices = self._torch.randperm( |
| n, generator=self._generator, device=x.device |
| ) |
| return self._torch.index_select(x, axis, perm_indices) |
|
|
| def uniform(self, low=0.0, high=1.0, size=None, **kwargs): |
| size = _handle_size_to_shape(size) |
| x = self._torch.rand(size, generator=self._generator, **kwargs) |
| if low != 0.0 or high != 1.0: |
| x = x * (high - low) + low |
| return x |
|
|
|
|
| def torch_default_rng(seed, **kwargs): |
| if isinstance(seed, TorchDefaultRNG): |
| return seed |
| return TorchDefaultRNG(seed, **kwargs) |
|
|
|
|
| register_function("torch", "random.default_rng", torch_default_rng) |
| register_backend(TorchDefaultRNG, "torch") |
|
|
|
|
| _FUNCS["torch", "pad"] = torch_pad |
| _FUNCS["torch", "real"] = torch_real |
| _FUNCS["torch", "imag"] = torch_imag |
| _FUNCS["torch", "astype"] = torch_astype |
| _FUNCS["torch", "copy"] = torch_copy |
| _FUNCS["torch", "to_numpy"] = torch_to_numpy |
| _FUNCS["torch", "complex"] = complex_add_re_im |
| _FUNCS["torch", "transpose"] = torch_transpose |
| _FUNCS["torch", "count_nonzero"] = torch_count_nonzero |
| _FUNCS["torch", "indices"] = torch_indices |
|
|
| _FUNC_ALIASES["torch", "array"] = "tensor" |
| _FUNC_ALIASES["torch", "asarray"] = "as_tensor" |
| _FUNC_ALIASES["torch", "clip"] = "clamp" |
| _FUNC_ALIASES["torch", "concatenate"] = "cat" |
| _FUNC_ALIASES["torch", "conjugate"] = "conj" |
| _FUNC_ALIASES["torch", "equal"] = "eq" |
| _FUNC_ALIASES["torch", "expand_dims"] = "unsqueeze" |
| _FUNC_ALIASES["torch", "identity"] = "eye" |
| _FUNC_ALIASES["torch", "linalg.expm"] = "matrix_exp" |
| _FUNC_ALIASES["torch", "max"] = "amax" |
| _FUNC_ALIASES["torch", "min"] = "amin" |
| _FUNC_ALIASES["torch", "power"] = "pow" |
| _FUNC_ALIASES["torch", "random.normal"] = "randn" |
| _FUNC_ALIASES["torch", "random.uniform"] = "rand" |
| _FUNC_ALIASES["torch", "scipy.linalg.expm"] = "matrix_exp" |
| _FUNC_ALIASES["torch", "split"] = "tensor_split" |
| _FUNC_ALIASES["torch", "take"] = "index_select" |
| _FUNC_ALIASES["torch", "take_along_axis"] = "take_along_dim" |
|
|
| _SUBMODULE_ALIASES["torch", "linalg.expm"] = "torch" |
| _SUBMODULE_ALIASES["torch", "scipy.linalg.expm"] = "torch" |
| _SUBMODULE_ALIASES["torch", "random.normal"] = "torch" |
| _SUBMODULE_ALIASES["torch", "random.uniform"] = "torch" |
|
|
| _CUSTOM_WRAPPERS["torch", "clip"] = make_translator( |
| [("a", ("input",)), ("a_min", ("min",)), ("a_max", ("max",))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "concatenate"] = make_translator( |
| [("arrays", ("tensors",)), ("axis", ("dim", 0))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "diagonal"] = make_translator( |
| [("a", ("input",)), ("axis1", ("dim1", 0)), ("axis2", ("dim2", 1))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "empty"] = make_translator([("shape", ("size",))]) |
| _CUSTOM_WRAPPERS["torch", "expand_dims"] = make_translator( |
| [("a", ("input",)), ("axis", ("dim",))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "eye"] = torch_eye_wrap |
| _CUSTOM_WRAPPERS["torch", "flip"] = torch_flip_wrap |
| _CUSTOM_WRAPPERS["torch", "linalg.svd"] = svd_not_full_matrices_wrapper |
| _CUSTOM_WRAPPERS["torch", "ones"] = torch_zeros_ones_wrap |
| _CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually |
| _CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually |
| _CUSTOM_WRAPPERS["torch", "sort"] = torch_sort_wrap |
| _CUSTOM_WRAPPERS["torch", "stack"] = make_translator( |
| [("arrays", ("tensors",)), ("axis", ("dim", 0))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "take"] = make_translator( |
| [("a", ("input",)), ("indices", ("index",)), ("axis", ("dim",))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "tensordot"] = torch_tensordot_wrap |
| _CUSTOM_WRAPPERS["torch", "tril"] = make_translator( |
| [("m", ("input",)), ("k", ("diagonal", 0))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "triu"] = make_translator( |
| [("m", ("input",)), ("k", ("diagonal", 0))] |
| ) |
| _CUSTOM_WRAPPERS["torch", "zeros"] = torch_zeros_ones_wrap |
| _CUSTOM_WRAPPERS["torch", "take_along_axis"] = make_translator( |
| [("arr", ("input",)), ("indices", ("indices",)), ("axis", ("dim", -1))] |
| ) |
|
|
| _torch_reduce_translation = [ |
| ("a", ("input",)), |
| ("axis", ("dim",)), |
| ("keepdims", ("keepdim",)), |
| ] |
| for f in ("sum", "max", "min", "prod", "mean", "median", "std", "var"): |
| |
| _CUSTOM_WRAPPERS["torch", f] = make_translator(_torch_reduce_translation) |
|
|
| |
| _MODULE_ALIASES["torch[alt]"] = "torch" |
|
|
| _FUNCS["torch[alt]", "linalg.eigh"] = torch_linalg_eigh |
| _FUNCS["torch[alt]", "linalg.eigvalsh"] = torch_linalg_eigvalsh |
|
|
| _SUBMODULE_ALIASES["torch[alt]", "linalg.norm"] = "torch" |
| _SUBMODULE_ALIASES["torch[alt]", "linalg.qr"] = "torch" |
| _SUBMODULE_ALIASES["torch[alt]", "linalg.solve"] = "torch" |
| _SUBMODULE_ALIASES["torch[alt]", "linalg.svd"] = "torch" |
|
|
| _CUSTOM_WRAPPERS["torch[alt]", "linalg.qr"] = qr_allow_fat |
| _CUSTOM_WRAPPERS["torch[alt]", "linalg.solve"] = torch_linalg_solve_wrap |
| _CUSTOM_WRAPPERS["torch[alt]", "linalg.svd"] = svd_UsV_to_UsVH_wrapper |
| _CUSTOM_WRAPPERS["torch[alt]", "split"] = torch_split_wrap |
|
|
| for f in _CREATION_ROUTINES: |
| register_creation_routine("torch", f, inject_device=True) |
|
|
| |
|
|
|
|
| def mxnet_to_numpy(x): |
| return x.asnumpy() |
|
|
|
|
| _MODULE_ALIASES["mxnet"] = "mxnet.numpy" |
| _FUNCS["mxnet", "to_numpy"] = mxnet_to_numpy |
|
|
|
|
| |
|
|
| _paddle_dtype_name_conversion = { |
| "BOOL": "bool", |
| "INT8": "int8", |
| "INT16": "int16", |
| "INT32": "int32", |
| "INT64": "int64", |
| "FP16": "float16", |
| "FP32": "float32", |
| "FP64": "float64", |
| "COMPLEX64": "complex64", |
| "COMPLEX128": "complex128", |
| } |
|
|
|
|
| @get_dtype_name.register("paddle") |
| def paddle_get_dtype_name(x): |
| return _paddle_dtype_name_conversion[x.dtype.name] |
|
|
|
|
| @shape.register("paddle") |
| def paddle_shape(x): |
| |
| return tuple(x.shape) |
|
|
|
|
| def paddle_to_numpy(x): |
| return x.numpy() |
|
|
|
|
| def paddle_transpose(a, axes=None): |
| if axes is None: |
| axes = tuple(range(a.ndim - 1, -1, -1)) |
| return a.transpose(perm=axes) |
|
|
|
|
| def paddle_real(x): |
| |
| try: |
| if x.is_complex(): |
| return x.real() |
| except AttributeError: |
| pass |
| return x |
|
|
|
|
| def paddle_imag(x): |
| |
| try: |
| if x.is_complex(): |
| return x.imag() |
| except AttributeError: |
| pass |
| return do("zeros_like", x) |
|
|
|
|
| def paddle_indices(dimensions): |
| _meshgrid = get_lib_fn("paddle", "meshgrid") |
| _arange = get_lib_fn("paddle", "arange") |
| return _meshgrid(*map(_arange, dimensions), indexing="ij") |
|
|
|
|
| def paddle_ravel(x): |
| return x.reshape((-1,)) |
|
|
|
|
| def paddle_pad(array, pad_width, mode="constant", constant_values=0): |
| if mode != "constant": |
| raise NotImplementedError |
|
|
| try: |
| |
| |
| pad = tuple(itertools.chain.from_iterable(pad_width)) |
|
|
| |
| if len(pad) == 2: |
| pad = pad * array.ndim |
|
|
| except TypeError: |
| |
| pad = (pad_width,) * 2 * array.ndim |
|
|
| return do( |
| "nn.functional.pad", |
| array, |
| pad=pad, |
| mode=mode, |
| value=constant_values, |
| like="paddle", |
| ) |
|
|
|
|
| def paddle_wrap_reduction(fn): |
| def numpy_like(*args, **kwargs): |
| keepdims = kwargs.pop("keepdims", None) |
| if keepdims is not None: |
| kwargs["keepdim"] = keepdims |
| return fn(*args, **kwargs) |
|
|
| return numpy_like |
|
|
|
|
| def paddle_split_wrap(fn): |
| |
|
|
| @functools.wraps(fn) |
| def numpy_like(ary, indices_or_sections, axis=0, **kwargs): |
| if isinstance(indices_or_sections, int): |
| return fn(ary, indices_or_sections, axis=axis, **kwargs) |
| else: |
| diff = do( |
| "diff", |
| indices_or_sections, |
| prepend=0, |
| append=shape(ary)[axis], |
| like="numpy", |
| ) |
| diff = list(diff) |
| return fn(ary, diff, axis=axis) |
|
|
| return numpy_like |
|
|
|
|
| _MODULE_ALIASES["paddle[alt]"] = "paddle" |
|
|
| _FUNCS["paddle", "to_numpy"] = paddle_to_numpy |
| _FUNCS["paddle", "transpose"] = paddle_transpose |
| _FUNCS["paddle", "real"] = paddle_real |
| _FUNCS["paddle", "imag"] = paddle_imag |
| _FUNCS["paddle", "indices"] = paddle_indices |
| _FUNCS["paddle", "ravel"] = paddle_ravel |
| _FUNCS["paddle", "pad"] = paddle_pad |
|
|
| _FUNC_ALIASES["paddle", "random.normal"] = "randn" |
| _FUNC_ALIASES["paddle", "random.uniform"] = "rand" |
| _FUNC_ALIASES["paddle", "asarray"] = "to_tensor" |
| _FUNC_ALIASES["paddle", "concatenate"] = "concat" |
| _FUNC_ALIASES["paddle", "power"] = "pow" |
| _FUNC_ALIASES["paddle", "identity"] = "eye" |
| _FUNC_ALIASES["paddle", "split"] = "tensor_split" |
|
|
| _SUBMODULE_ALIASES["paddle", "random.normal"] = "paddle" |
| _SUBMODULE_ALIASES["paddle", "random.uniform"] = "paddle" |
|
|
| _CUSTOM_WRAPPERS["paddle", "random.normal"] = scale_random_normal_manually |
| _CUSTOM_WRAPPERS["paddle", "random.uniform"] = scale_random_uniform_manually |
| _CUSTOM_WRAPPERS["paddle[alt]", "split"] = paddle_split_wrap |
| _CUSTOM_WRAPPERS["paddle", "tril"] = make_translator( |
| [("m", ("x",)), ("k", ("diagonal", 0))] |
| ) |
| _CUSTOM_WRAPPERS["paddle", "triu"] = make_translator( |
| [("m", ("x",)), ("k", ("diagonal", 0))] |
| ) |
| for f in ("sum", "max", "min", "prod", "mean", "std", "var"): |
| _CUSTOM_WRAPPERS["paddle", f] = paddle_wrap_reduction |
|
|
|
|
| |
|
|
|
|
| @shape.register("pytensor") |
| def pytensor_shape(x): |
| return x.type.shape |
|
|
|
|
| def pytensor_wrap_qr_with_shapes(fn): |
| import pytensor.tensor as pt |
|
|
| @functools.wraps(fn) |
| def qr_shaped(x, **kwargs): |
| *b, m, n = x.type.shape |
| k = min(m, n) |
| q, r = fn(x, mode="economic", **kwargs) |
| q = pt.specify_shape(q, (*b, m, k)) |
| r = pt.specify_shape(r, (*b, k, n)) |
| return q, r |
|
|
| return qr_shaped |
|
|
|
|
| def pytensor_wrap_svd_with_shapes(fn): |
| import pytensor.tensor as pt |
|
|
| @functools.wraps(fn) |
| def svd_shaped(x, full_matrices=False, **kwargs): |
| *b, m, n = x.type.shape |
|
|
| k = min(m, n) |
| if full_matrices: |
| u_shape = (*b, m, m) |
| vh_shape = (*b, n, n) |
| else: |
| u_shape = (*b, m, k) |
| vh_shape = (*b, k, n) |
|
|
| u, s, vh = fn(x, full_matrices=full_matrices, **kwargs) |
| u = pt.specify_shape(u, u_shape) |
| s = pt.specify_shape(s, (*b, k)) |
| vh = pt.specify_shape(vh, vh_shape) |
| return u, s, vh |
|
|
| return svd_shaped |
|
|
|
|
| _MODULE_ALIASES["pytensor"] = "pytensor.tensor" |
| _CUSTOM_WRAPPERS["pytensor", "linalg.qr"] = pytensor_wrap_qr_with_shapes |
| _CUSTOM_WRAPPERS["pytensor", "linalg.svd"] = pytensor_wrap_svd_with_shapes |
|
|