| |
| |
| |
| |
| |
|
|
| import torch |
|
|
|
|
| class TorchAutocast: |
| """TorchAutocast utility class. |
| Allows you to enable and disable autocast. This is specially useful |
| when dealing with different architectures and clusters with different |
| levels of support. |
| |
| Args: |
| enabled (bool): Whether to enable torch.autocast or not. |
| args: Additional args for torch.autocast. |
| kwargs: Additional kwargs for torch.autocast |
| """ |
| def __init__(self, enabled: bool, *args, **kwargs): |
| self.autocast = torch.autocast(*args, **kwargs) if enabled else None |
|
|
| def __enter__(self): |
| if self.autocast is None: |
| return |
| try: |
| self.autocast.__enter__() |
| except RuntimeError: |
| device = self.autocast.device |
| dtype = self.autocast.fast_dtype |
| raise RuntimeError( |
| f"There was an error autocasting with dtype={dtype} device={device}\n" |
| "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16" |
| ) |
|
|
| def __exit__(self, *args, **kwargs): |
| if self.autocast is None: |
| return |
| self.autocast.__exit__(*args, **kwargs) |
|
|