File size: 498 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | import pytest
import torch
# This fixture ensures the torch defaults don't get left in modified states between
# tests (e.g., when a test fails before restoring the original value), which
# can cause subsequent tests to fail.
@pytest.fixture(autouse=True)
def reset_torch_defaults():
orig_default_device = torch.get_default_device()
orig_default_dtype = torch.get_default_dtype()
yield
torch.set_default_dtype(orig_default_dtype)
torch.set_default_device(orig_default_device)
|