| | import pytest |
| | from unittest import mock |
| |
|
| | import torch |
| |
|
| | from mlagents.torch_utils import set_torch_config, default_device |
| | from mlagents.trainers.settings import TorchSettings |
| |
|
| |
|
| | @pytest.mark.parametrize( |
| | "device_str, expected_type, expected_index, expected_tensor_type", |
| | [ |
| | ("cpu", "cpu", None, torch.FloatTensor), |
| | ("cuda", "cuda", None, torch.cuda.FloatTensor), |
| | ("cuda:42", "cuda", 42, torch.cuda.FloatTensor), |
| | ("opengl", "opengl", None, torch.FloatTensor), |
| | ], |
| | ) |
| | @mock.patch.object(torch, "set_default_tensor_type") |
| | def test_set_torch_device( |
| | mock_set_default_tensor_type, |
| | device_str, |
| | expected_type, |
| | expected_index, |
| | expected_tensor_type, |
| | ): |
| | try: |
| | torch_settings = TorchSettings(device=device_str) |
| | set_torch_config(torch_settings) |
| | assert default_device().type == expected_type |
| | if expected_index is None: |
| | assert default_device().index is None |
| | else: |
| | assert default_device().index == expected_index |
| | mock_set_default_tensor_type.assert_called_once_with(expected_tensor_type) |
| | except Exception: |
| | raise |
| | finally: |
| | |
| | torch_settings = TorchSettings(device=None) |
| | set_torch_config(torch_settings) |
| |
|