|
|
|
|
|
|
|
|
"""Version utils for testing.""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import warnings |
|
|
from typing import Callable, Sequence |
|
|
|
|
|
import packaging.version |
|
|
|
|
|
|
|
|
def onnx_older_than(version: str) -> bool: |
|
|
"""Returns True if the ONNX version is older than the given version.""" |
|
|
import onnx |
|
|
|
|
|
return ( |
|
|
packaging.version.parse(onnx.__version__).release |
|
|
< packaging.version.parse(version).release |
|
|
) |
|
|
|
|
|
|
|
|
def torch_older_than(version: str) -> bool: |
|
|
"""Returns True if the torch version is older than the given version.""" |
|
|
import torch |
|
|
|
|
|
return ( |
|
|
packaging.version.parse(torch.__version__).release |
|
|
< packaging.version.parse(version).release |
|
|
) |
|
|
|
|
|
|
|
|
def transformers_older_than(version: str) -> bool | None: |
|
|
"""Returns True if the transformers version is older than the given version.""" |
|
|
try: |
|
|
import transformers |
|
|
except ImportError: |
|
|
return None |
|
|
|
|
|
return ( |
|
|
packaging.version.parse(transformers.__version__).release |
|
|
< packaging.version.parse(version).release |
|
|
) |
|
|
|
|
|
|
|
|
def onnxruntime_older_than(version: str) -> bool: |
|
|
"""Returns True if the onnxruntime version is older than the given version.""" |
|
|
import onnxruntime |
|
|
|
|
|
return ( |
|
|
packaging.version.parse(onnxruntime.__version__).release |
|
|
< packaging.version.parse(version).release |
|
|
) |
|
|
|
|
|
|
|
|
def numpy_older_than(version: str) -> bool: |
|
|
"""Returns True if the numpy version is older than the given version.""" |
|
|
import numpy |
|
|
|
|
|
return ( |
|
|
packaging.version.parse(numpy.__version__).release |
|
|
< packaging.version.parse(version).release |
|
|
) |
|
|
|
|
|
|
|
|
def has_transformers(): |
|
|
"""Tells if transformers is installed.""" |
|
|
try: |
|
|
import transformers |
|
|
|
|
|
assert transformers |
|
|
return True |
|
|
except ImportError: |
|
|
return False |
|
|
|
|
|
|
|
|
def ignore_warnings(warns: Warning | Sequence[Warning]) -> Callable: |
|
|
"""Catches warnings. |
|
|
|
|
|
Args: |
|
|
warns: warnings to ignore |
|
|
|
|
|
Returns: |
|
|
decorated function |
|
|
""" |
|
|
|
|
|
def wrapper(fct): |
|
|
if warns is None: |
|
|
raise AssertionError(f"warns cannot be None for '{fct}'.") |
|
|
|
|
|
def call_f(self): |
|
|
with warnings.catch_warnings(): |
|
|
warnings.simplefilter("ignore", warns) |
|
|
return fct(self) |
|
|
|
|
|
return call_f |
|
|
|
|
|
return wrapper |
|
|
|