| import pytest | |
| import torch | |
| # This fixture ensures the torch defaults don't get left in modified states between | |
| # tests (e.g., when a test fails before restoring the original value), which | |
| # can cause subsequent tests to fail. | |
| def reset_torch_defaults(): | |
| orig_default_device = torch.get_default_device() | |
| orig_default_dtype = torch.get_default_dtype() | |
| yield | |
| torch.set_default_dtype(orig_default_dtype) | |
| torch.set_default_device(orig_default_device) | |