File size: 936 Bytes
c8b42eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""
Model Utils Config
"""

import os
import warnings

import torch

__all__ = ["use_fused_attn", "set_fused_attn"]

# Use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
if "UNICEPTION_FUSED_ATTN" in os.environ:
    _USE_FUSED_ATTN = int(os.environ["UNICEPTION_FUSED_ATTN"])
else:
    _USE_FUSED_ATTN = 1  # 0 == off, 1 == on


def use_fused_attn() -> bool:
    "Return whether to use torch.nn.functional.scaled_dot_product_attention"
    return _USE_FUSED_ATTN > 0


def set_fused_attn(enable: bool = True):
    "Set whether to use torch.nn.functional.scaled_dot_product_attention"
    global _USE_FUSED_ATTN
    if not _HAS_FUSED_ATTN:
        warnings.warn("This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.")
        return
    if enable:
        _USE_FUSED_ATTN = 1
    else:
        _USE_FUSED_ATTN = 0