File size: 1,071 Bytes
77b7c2e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import torch
import functools
def get_default_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def safe_autocast_decorator(enabled=True):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = get_default_device()
if device.type in ["cuda", "cpu"]:
with torch.amp.autocast(device_type=device.type, enabled=enabled):
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return decorator
import contextlib
@contextlib.contextmanager
def safe_autocast(enabled=True):
device = get_default_device()
if device.type in ["cuda", "cpu"]:
with torch.amp.autocast(device_type=device.type, enabled=enabled):
yield
else:
yield # MPS or other unsupported backends skip autocast
|