| """ Model / Layer Config singleton state |
| Borrowed from https://github.com/huggingface/pytorch-image-models/blob/main/timm/layers/config.py#L130 |
| """ |
| import os |
| import warnings |
| from typing import Any, Optional |
|
|
| import torch |
|
|
| __all__ = [ |
| 'is_exportable', 'is_scriptable', 'is_no_jit', 'use_fused_attn', |
| 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config', 'set_fused_attn' |
| ] |
|
|
| |
| _NO_JIT = False |
|
|
| |
| |
| |
| _NO_ACTIVATION_JIT = False |
|
|
| |
| _EXPORTABLE = False |
|
|
| |
| _SCRIPTABLE = False |
|
|
|
|
| |
| _HAS_FUSED_ATTN = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
| if 'TIMM_FUSED_ATTN' in os.environ: |
| _USE_FUSED_ATTN = int(os.environ['TIMM_FUSED_ATTN']) |
| else: |
| _USE_FUSED_ATTN = 1 |
|
|
|
|
| def is_no_jit(): |
| return _NO_JIT |
|
|
|
|
| class set_no_jit: |
| def __init__(self, mode: bool) -> None: |
| global _NO_JIT |
| self.prev = _NO_JIT |
| _NO_JIT = mode |
|
|
| def __enter__(self) -> None: |
| pass |
|
|
| def __exit__(self, *args: Any) -> bool: |
| global _NO_JIT |
| _NO_JIT = self.prev |
| return False |
|
|
|
|
| def is_exportable(): |
| return _EXPORTABLE |
|
|
|
|
| class set_exportable: |
| def __init__(self, mode: bool) -> None: |
| global _EXPORTABLE |
| self.prev = _EXPORTABLE |
| _EXPORTABLE = mode |
|
|
| def __enter__(self) -> None: |
| pass |
|
|
| def __exit__(self, *args: Any) -> bool: |
| global _EXPORTABLE |
| _EXPORTABLE = self.prev |
| return False |
|
|
|
|
| def is_scriptable(): |
| return _SCRIPTABLE |
|
|
|
|
| class set_scriptable: |
| def __init__(self, mode: bool) -> None: |
| global _SCRIPTABLE |
| self.prev = _SCRIPTABLE |
| _SCRIPTABLE = mode |
|
|
| def __enter__(self) -> None: |
| pass |
|
|
| def __exit__(self, *args: Any) -> bool: |
| global _SCRIPTABLE |
| _SCRIPTABLE = self.prev |
| return False |
|
|
|
|
| class set_layer_config: |
| """ Layer config context manager that allows setting all layer config flags at once. |
| If a flag arg is None, it will not change the current value. |
| """ |
| def __init__( |
| self, |
| scriptable: Optional[bool] = None, |
| exportable: Optional[bool] = None, |
| no_jit: Optional[bool] = None, |
| no_activation_jit: Optional[bool] = None): |
| global _SCRIPTABLE |
| global _EXPORTABLE |
| global _NO_JIT |
| global _NO_ACTIVATION_JIT |
| self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT |
| if scriptable is not None: |
| _SCRIPTABLE = scriptable |
| if exportable is not None: |
| _EXPORTABLE = exportable |
| if no_jit is not None: |
| _NO_JIT = no_jit |
| if no_activation_jit is not None: |
| _NO_ACTIVATION_JIT = no_activation_jit |
|
|
| def __enter__(self) -> None: |
| pass |
|
|
| def __exit__(self, *args: Any) -> bool: |
| global _SCRIPTABLE |
| global _EXPORTABLE |
| global _NO_JIT |
| global _NO_ACTIVATION_JIT |
| _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev |
| return False |
|
|
|
|
| def use_fused_attn(experimental: bool = False) -> bool: |
| |
| if not _HAS_FUSED_ATTN or _EXPORTABLE: |
| return False |
| if experimental: |
| return _USE_FUSED_ATTN > 1 |
| return _USE_FUSED_ATTN > 0 |
|
|
|
|
| def set_fused_attn(enable: bool = True, experimental: bool = False): |
| global _USE_FUSED_ATTN |
| if not _HAS_FUSED_ATTN: |
| warnings.warn('This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.') |
| return |
| if experimental and enable: |
| _USE_FUSED_ATTN = 2 |
| elif enable: |
| _USE_FUSED_ATTN = 1 |
| else: |
| _USE_FUSED_ATTN = 0 |