PeiqingYang's picture
integrate MatAnyone 1 & 2
77b7c2e
import torch
import functools
def get_default_device():
if torch.cuda.is_available():
return torch.device("cuda")
elif torch.backends.mps.is_built() and torch.backends.mps.is_available():
return torch.device("mps")
else:
return torch.device("cpu")
def safe_autocast_decorator(enabled=True):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
device = get_default_device()
if device.type in ["cuda", "cpu"]:
with torch.amp.autocast(device_type=device.type, enabled=enabled):
return func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
return decorator
import contextlib
@contextlib.contextmanager
def safe_autocast(enabled=True):
device = get_default_device()
if device.type in ["cuda", "cpu"]:
with torch.amp.autocast(device_type=device.type, enabled=enabled):
yield
else:
yield # MPS or other unsupported backends skip autocast