File size: 997 Bytes
ccef021 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 | import torch
_is_low_precision_mode_stack = []
class LowPrecisionMode:
def __init__(self, enabled: bool = True):
self.enabled = enabled
def __enter__(self):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.append(self.enabled)
def __exit__(self, exc_type, exc_value, traceback):
global _is_low_precision_mode_stack
_is_low_precision_mode_stack.pop()
def is_low_precision_mode() -> bool:
global _is_low_precision_mode_stack
if len(_is_low_precision_mode_stack) == 0:
return False
return _is_low_precision_mode_stack[-1]
def optional_cast_to_bf16_and_cast_back(tensor: torch.Tensor) -> torch.Tensor:
assert tensor.dtype == torch.float32, "Input tensor must be of dtype torch.float32 for optional casting."
if is_low_precision_mode():
tensor_bf16 = tensor.to(torch.bfloat16)
tensor_fp32 = tensor_bf16.to(torch.float32)
return tensor_fp32
else:
return tensor
|