| import torch | |
| from typing import Callable | |
| def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool): | |
| def decorator(*args, **kwargs): | |
| if cuda_amp_deprecated: | |
| kwargs["device_type"] = "cuda" | |
| return dec(*args, **kwargs) | |
| return decorator | |
| if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined] | |
| deprecated = True | |
| from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined] | |
| else: | |
| deprecated = False | |
| from torch.cuda.amp import custom_fwd, custom_bwd | |
| custom_fwd = custom_amp_decorator(custom_fwd, deprecated) | |
| custom_bwd = custom_amp_decorator(custom_bwd, deprecated) | |