|
|
from typing import Any |
|
|
import warnings |
|
|
import sys |
|
|
from functools import lru_cache as _lru_cache |
|
|
from contextlib import contextmanager |
|
|
from torch.backends import ContextProp, PropModule, __allow_nonbracketed_mutation |
|
|
|
|
|
try: |
|
|
import opt_einsum as _opt_einsum |
|
|
except ImportError: |
|
|
_opt_einsum = None |
|
|
|
|
|
|
|
|
@_lru_cache() |
|
|
def is_available() -> bool: |
|
|
r"""Returns a bool indicating if opt_einsum is currently available.""" |
|
|
return _opt_einsum is not None |
|
|
|
|
|
|
|
|
def get_opt_einsum() -> Any: |
|
|
r"""Returns the opt_einsum package if opt_einsum is currently available, else None.""" |
|
|
return _opt_einsum |
|
|
|
|
|
|
|
|
def _set_enabled(_enabled: bool) -> None: |
|
|
if not is_available() and _enabled: |
|
|
raise ValueError(f'opt_einsum is not available, so setting `enabled` to {_enabled} will not reap ' |
|
|
'the benefits of calculating an optimal path for einsum. torch.einsum will ' |
|
|
'fall back to contracting from left to right. To enable this optimal path ' |
|
|
'calculation, please install opt-einsum.') |
|
|
global enabled |
|
|
enabled = _enabled |
|
|
|
|
|
|
|
|
def _get_enabled() -> bool: |
|
|
return enabled |
|
|
|
|
|
|
|
|
def _set_strategy(_strategy: str) -> None: |
|
|
if not is_available(): |
|
|
raise ValueError(f'opt_einsum is not available, so setting `strategy` to {_strategy} will not be meaningful. ' |
|
|
'torch.einsum will bypass path calculation and simply contract from left to right. ' |
|
|
'Please install opt_einsum or unset `strategy`.') |
|
|
if not enabled: |
|
|
raise ValueError(f'opt_einsum is not enabled, so setting a `strategy` to {_strategy} will not be meaningful. ' |
|
|
'torch.einsum will bypass path calculation and simply contract from left to right. ' |
|
|
'Please set `enabled` to `True` as well or unset `strategy`.') |
|
|
if _strategy not in ['auto', 'greedy', 'optimal']: |
|
|
raise ValueError(f'`strategy` must be one of the following: [auto, greedy, optimal] but is {_strategy}') |
|
|
global strategy |
|
|
strategy = _strategy |
|
|
|
|
|
|
|
|
def _get_strategy() -> str: |
|
|
return strategy |
|
|
|
|
|
|
|
|
def set_flags(_enabled=None, _strategy=None): |
|
|
orig_flags = (enabled, None if not is_available() else strategy) |
|
|
if _enabled is not None: |
|
|
_set_enabled(_enabled) |
|
|
if _strategy is not None: |
|
|
_set_strategy(_strategy) |
|
|
return orig_flags |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def flags(enabled=None, strategy=None): |
|
|
with __allow_nonbracketed_mutation(): |
|
|
orig_flags = set_flags(enabled, strategy) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
|
|
|
with __allow_nonbracketed_mutation(): |
|
|
set_flags(*orig_flags) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class OptEinsumModule(PropModule): |
|
|
def __init__(self, m, name): |
|
|
super(OptEinsumModule, self).__init__(m, name) |
|
|
|
|
|
global enabled |
|
|
enabled = ContextProp(_get_enabled, _set_enabled) |
|
|
global strategy |
|
|
strategy = None |
|
|
if is_available(): |
|
|
strategy = ContextProp(_get_strategy, _set_strategy) |
|
|
|
|
|
|
|
|
|
|
|
sys.modules[__name__] = OptEinsumModule(sys.modules[__name__], __name__) |
|
|
|
|
|
enabled = True if is_available() else False |
|
|
strategy = 'auto' if is_available() else None |
|
|
|