| """Adapted from: |
| |
| https://github.com/PyTorchLightning/pytorch-lightning/blob/master/tests/helpers/runif.py |
| """ |
|
|
| import sys |
| from typing import Any, Dict, Optional |
|
|
| import pytest |
| import torch |
| from packaging.version import Version |
| from pkg_resources import get_distribution |
| from pytest import MarkDecorator |
|
|
| from tests.helpers.package_available import ( |
| _COMET_AVAILABLE, |
| _DEEPSPEED_AVAILABLE, |
| _FAIRSCALE_AVAILABLE, |
| _IS_WINDOWS, |
| _MLFLOW_AVAILABLE, |
| _NEPTUNE_AVAILABLE, |
| _SH_AVAILABLE, |
| _TPU_AVAILABLE, |
| _WANDB_AVAILABLE, |
| ) |
|
|
|
|
| class RunIf: |
| """RunIf wrapper for conditional skipping of tests. |
| |
| Fully compatible with `@pytest.mark`. |
| |
| Example: |
| |
| ```python |
| @RunIf(min_torch="1.8") |
| @pytest.mark.parametrize("arg1", [1.0, 2.0]) |
| def test_wrapper(arg1): |
| assert arg1 > 0 |
| ``` |
| """ |
|
|
| def __new__( |
| cls, |
| min_gpus: int = 0, |
| min_torch: Optional[str] = None, |
| max_torch: Optional[str] = None, |
| min_python: Optional[str] = None, |
| skip_windows: bool = False, |
| sh: bool = False, |
| tpu: bool = False, |
| fairscale: bool = False, |
| deepspeed: bool = False, |
| wandb: bool = False, |
| neptune: bool = False, |
| comet: bool = False, |
| mlflow: bool = False, |
| **kwargs: Dict[Any, Any], |
| ) -> MarkDecorator: |
| """Creates a new `@RunIf` `MarkDecorator` decorator. |
| |
| :param min_gpus: Min number of GPUs required to run test. |
| :param min_torch: Minimum pytorch version to run test. |
| :param max_torch: Maximum pytorch version to run test. |
| :param min_python: Minimum python version required to run test. |
| :param skip_windows: Skip test for Windows platform. |
| :param tpu: If TPU is available. |
| :param sh: If `sh` module is required to run the test. |
| :param fairscale: If `fairscale` module is required to run the test. |
| :param deepspeed: If `deepspeed` module is required to run the test. |
| :param wandb: If `wandb` module is required to run the test. |
| :param neptune: If `neptune` module is required to run the test. |
| :param comet: If `comet` module is required to run the test. |
| :param mlflow: If `mlflow` module is required to run the test. |
| :param kwargs: Native `pytest.mark.skipif` keyword arguments. |
| """ |
| conditions = [] |
| reasons = [] |
|
|
| if min_gpus: |
| conditions.append(torch.cuda.device_count() < min_gpus) |
| reasons.append(f"GPUs>={min_gpus}") |
|
|
| if min_torch: |
| torch_version = get_distribution("torch").version |
| conditions.append(Version(torch_version) < Version(min_torch)) |
| reasons.append(f"torch>={min_torch}") |
|
|
| if max_torch: |
| torch_version = get_distribution("torch").version |
| conditions.append(Version(torch_version) >= Version(max_torch)) |
| reasons.append(f"torch<{max_torch}") |
|
|
| if min_python: |
| py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" |
| conditions.append(Version(py_version) < Version(min_python)) |
| reasons.append(f"python>={min_python}") |
|
|
| if skip_windows: |
| conditions.append(_IS_WINDOWS) |
| reasons.append("does not run on Windows") |
|
|
| if tpu: |
| conditions.append(not _TPU_AVAILABLE) |
| reasons.append("TPU") |
|
|
| if sh: |
| conditions.append(not _SH_AVAILABLE) |
| reasons.append("sh") |
|
|
| if fairscale: |
| conditions.append(not _FAIRSCALE_AVAILABLE) |
| reasons.append("fairscale") |
|
|
| if deepspeed: |
| conditions.append(not _DEEPSPEED_AVAILABLE) |
| reasons.append("deepspeed") |
|
|
| if wandb: |
| conditions.append(not _WANDB_AVAILABLE) |
| reasons.append("wandb") |
|
|
| if neptune: |
| conditions.append(not _NEPTUNE_AVAILABLE) |
| reasons.append("neptune") |
|
|
| if comet: |
| conditions.append(not _COMET_AVAILABLE) |
| reasons.append("comet") |
|
|
| if mlflow: |
| conditions.append(not _MLFLOW_AVAILABLE) |
| reasons.append("mlflow") |
|
|
| reasons = [rs for cond, rs in zip(conditions, reasons) if cond] |
| return pytest.mark.skipif( |
| condition=any(conditions), |
| reason=f"Requires: [{' + '.join(reasons)}]", |
| **kwargs, |
| ) |
|
|