| | import os |
| | import random |
| | import unittest |
| | from distutils.util import strtobool |
| |
|
| | import torch |
| |
|
| | from packaging import version |
| |
|
| |
|
| | global_rng = random.Random() |
| | torch_device = "cuda" if torch.cuda.is_available() else "cpu" |
| | is_torch_higher_equal_than_1_12 = version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.12") |
| |
|
| | if is_torch_higher_equal_than_1_12: |
| | torch_device = "mps" if torch.backends.mps.is_available() else torch_device |
| |
|
| |
|
| | def parse_flag_from_env(key, default=False): |
| | try: |
| | value = os.environ[key] |
| | except KeyError: |
| | |
| | _value = default |
| | else: |
| | |
| | try: |
| | _value = strtobool(value) |
| | except ValueError: |
| | |
| | raise ValueError(f"If set, {key} must be yes or no.") |
| | return _value |
| |
|
| |
|
| | _run_slow_tests = parse_flag_from_env("RUN_SLOW", default=False) |
| |
|
| |
|
| | def floats_tensor(shape, scale=1.0, rng=None, name=None): |
| | """Creates a random float32 tensor""" |
| | if rng is None: |
| | rng = global_rng |
| |
|
| | total_dims = 1 |
| | for dim in shape: |
| | total_dims *= dim |
| |
|
| | values = [] |
| | for _ in range(total_dims): |
| | values.append(rng.random() * scale) |
| |
|
| | return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() |
| |
|
| |
|
| | def slow(test_case): |
| | """ |
| | Decorator marking a test as slow. |
| | |
| | Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. |
| | |
| | """ |
| | return unittest.skipUnless(_run_slow_tests, "test is slow")(test_case) |
| |
|