| """ |
| Common utils for testing. |
| These functions allow testing only some frameworks, not all. |
| """ |
|
|
| import logging |
| import os |
| import warnings |
| from functools import lru_cache |
| from typing import List, Tuple |
|
|
| from einops import _backends |
|
|
| __author__ = "Alex Rogozhnikov" |
|
|
|
|
| |
| logging.getLogger("tensorflow").disabled = True |
| logging.getLogger("matplotlib").disabled = True |
|
|
| FLOAT_REDUCTIONS = ("min", "max", "sum", "mean", "prod") |
|
|
|
|
| def find_names_of_all_frameworks() -> List[str]: |
| backend_subclasses = [] |
| backends = _backends.AbstractBackend.__subclasses__() |
| while backends: |
| backend = backends.pop() |
| backends += backend.__subclasses__() |
| backend_subclasses.append(backend) |
| return [b.framework_name for b in backend_subclasses] |
|
|
|
|
| ENVVAR_NAME = "EINOPS_TEST_BACKENDS" |
|
|
|
|
| def unparse_backends(backend_names: List[str]) -> Tuple[str, str]: |
| _known_backends = find_names_of_all_frameworks() |
| for backend_name in backend_names: |
| if backend_name not in _known_backends: |
| raise RuntimeError(f"Unknown framework: {backend_name}") |
| return ENVVAR_NAME, ",".join(backend_names) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def parse_backends_to_test() -> List[str]: |
| if ENVVAR_NAME not in os.environ: |
| raise RuntimeError(f"Testing frameworks were not specified, env var {ENVVAR_NAME} not set") |
| parsed_backends = os.environ[ENVVAR_NAME].split(",") |
| _known_backends = find_names_of_all_frameworks() |
| for backend_name in parsed_backends: |
| if backend_name not in _known_backends: |
| raise RuntimeError(f"Unknown framework: {backend_name}") |
|
|
| return parsed_backends |
|
|
|
|
| def is_backend_tested(backend: str) -> bool: |
| """Used to skip test if corresponding backend is not tested""" |
| if backend not in find_names_of_all_frameworks(): |
| raise RuntimeError(f"Unknown framework {backend}") |
| return backend in parse_backends_to_test() |
|
|
|
|
| def collect_test_backends(symbolic=False, layers=False) -> List[_backends.AbstractBackend]: |
| """ |
| :param symbolic: symbolic or imperative frameworks? |
| :param layers: layers or operations? |
| :return: list of backends satisfying set conditions |
| """ |
| if not symbolic: |
| if not layers: |
| backend_types = [ |
| _backends.NumpyBackend, |
| _backends.JaxBackend, |
| _backends.TorchBackend, |
| _backends.TensorflowBackend, |
| _backends.OneFlowBackend, |
| _backends.PaddleBackend, |
| _backends.CupyBackend, |
| ] |
| else: |
| backend_types = [ |
| _backends.TorchBackend, |
| _backends.OneFlowBackend, |
| _backends.PaddleBackend, |
| ] |
| else: |
| if not layers: |
| backend_types = [ |
| _backends.PyTensorBackend, |
| ] |
| else: |
| backend_types = [ |
| _backends.TFKerasBackend, |
| ] |
|
|
| backend_names_to_test = parse_backends_to_test() |
| result = [] |
| for backend_type in backend_types: |
| if backend_type.framework_name not in backend_names_to_test: |
| continue |
| try: |
| result.append(backend_type()) |
| except ImportError: |
| |
| |
| warnings.warn(f"backend could not be initialized for tests: {backend_type}", stacklevel=1) |
| return result |
|
|