infinity1096
initial commit
c8b42eb
"""
Model Utils Config
"""
import os
import warnings
import torch
__all__ = ["use_fused_attn", "set_fused_attn"]
# Use torch.scaled_dot_product_attention where possible
_HAS_FUSED_ATTN = hasattr(torch.nn.functional, "scaled_dot_product_attention")
if "UNICEPTION_FUSED_ATTN" in os.environ:
_USE_FUSED_ATTN = int(os.environ["UNICEPTION_FUSED_ATTN"])
else:
_USE_FUSED_ATTN = 1 # 0 == off, 1 == on
def use_fused_attn() -> bool:
"Return whether to use torch.nn.functional.scaled_dot_product_attention"
return _USE_FUSED_ATTN > 0
def set_fused_attn(enable: bool = True):
"Set whether to use torch.nn.functional.scaled_dot_product_attention"
global _USE_FUSED_ATTN
if not _HAS_FUSED_ATTN:
warnings.warn("This version of pytorch does not have F.scaled_dot_product_attention, fused_attn flag ignored.")
return
if enable:
_USE_FUSED_ATTN = 1
else:
_USE_FUSED_ATTN = 0