File size: 2,718 Bytes
6a22ec9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | # Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Version utils for testing."""
from __future__ import annotations
import warnings
from typing import Callable, Sequence
import packaging.version
def onnx_older_than(version: str) -> bool:
"""Returns True if the ONNX version is older than the given version."""
import onnx # pylint: disable=import-outside-toplevel
return (
packaging.version.parse(onnx.__version__).release
< packaging.version.parse(version).release
)
def torch_older_than(version: str) -> bool:
"""Returns True if the torch version is older than the given version."""
import torch # pylint: disable=import-outside-toplevel
return (
packaging.version.parse(torch.__version__).release
< packaging.version.parse(version).release
)
def transformers_older_than(version: str) -> bool | None:
"""Returns True if the transformers version is older than the given version."""
try:
import transformers # pylint: disable=import-outside-toplevel
except ImportError:
return None
return (
packaging.version.parse(transformers.__version__).release
< packaging.version.parse(version).release
)
def onnxruntime_older_than(version: str) -> bool:
"""Returns True if the onnxruntime version is older than the given version."""
import onnxruntime # pylint: disable=import-outside-toplevel
return (
packaging.version.parse(onnxruntime.__version__).release
< packaging.version.parse(version).release
)
def numpy_older_than(version: str) -> bool:
"""Returns True if the numpy version is older than the given version."""
import numpy # pylint: disable=import-outside-toplevel
return (
packaging.version.parse(numpy.__version__).release
< packaging.version.parse(version).release
)
def has_transformers():
"""Tells if transformers is installed."""
try:
import transformers # pylint: disable=import-outside-toplevel
assert transformers
return True # noqa
except ImportError:
return False
def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: # type: ignore[arg-type]
"""Catches warnings.
Args:
warns: warnings to ignore
Returns:
decorated function
"""
def wrapper(fct):
if warns is None:
raise AssertionError(f"warns cannot be None for '{fct}'.")
def call_f(self):
with warnings.catch_warnings():
warnings.simplefilter("ignore", warns) # type: ignore[arg-type]
return fct(self)
return call_f
return wrapper
|