| import torch | |
| import functools | |
| def get_default_device(): | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| elif torch.backends.mps.is_built() and torch.backends.mps.is_available(): | |
| return torch.device("mps") | |
| else: | |
| return torch.device("cpu") | |
| def safe_autocast_decorator(enabled=True): | |
| def decorator(func): | |
| def wrapper(*args, **kwargs): | |
| device = get_default_device() | |
| if device.type in ["cuda", "cpu"]: | |
| with torch.amp.autocast(device_type=device.type, enabled=enabled): | |
| return func(*args, **kwargs) | |
| else: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| return decorator | |
| import contextlib | |
| def safe_autocast(enabled=True): | |
| device = get_default_device() | |
| if device.type in ["cuda", "cpu"]: | |
| with torch.amp.autocast(device_type=device.type, enabled=enabled): | |
| yield | |
| else: | |
| yield # MPS or other unsupported backends skip autocast | |