| | import os |
| |
|
| | from distutils.version import LooseVersion |
| | import pkg_resources |
| | from mlagents.torch_utils import cpu_utils |
| | from mlagents.trainers.settings import TorchSettings |
| | from mlagents_envs.logging_util import get_logger |
| |
|
| |
|
| | logger = get_logger(__name__) |
| |
|
| |
|
| | def assert_torch_installed(): |
| | |
| | |
| | torch_pkg = None |
| | try: |
| | torch_pkg = pkg_resources.get_distribution("torch") |
| | except pkg_resources.DistributionNotFound: |
| | pass |
| | assert torch_pkg is not None and LooseVersion(torch_pkg.version) >= LooseVersion( |
| | "1.6.0" |
| | ), ( |
| | "A compatible version of PyTorch was not installed. Please visit the PyTorch homepage " |
| | + "(https://pytorch.org/get-started/locally/) and follow the instructions to install. " |
| | + "Version 1.6.0 and later are supported." |
| | ) |
| |
|
| |
|
| | assert_torch_installed() |
| |
|
| | |
| | |
| | import torch |
| |
|
| |
|
| | torch.set_num_threads(cpu_utils.get_num_threads_to_use()) |
| | os.environ["KMP_BLOCKTIME"] = "0" |
| |
|
| |
|
| | _device = torch.device("cpu") |
| |
|
| |
|
| | def set_torch_config(torch_settings: TorchSettings) -> None: |
| | global _device |
| |
|
| | if torch_settings.device is None: |
| | device_str = "cuda" if torch.cuda.is_available() else "cpu" |
| | else: |
| | device_str = torch_settings.device |
| |
|
| | _device = torch.device(device_str) |
| |
|
| | if _device.type == "cuda": |
| | torch.set_default_tensor_type(torch.cuda.FloatTensor) |
| | else: |
| | torch.set_default_tensor_type(torch.FloatTensor) |
| | logger.debug(f"default Torch device: {_device}") |
| |
|
| |
|
| | |
| | set_torch_config(TorchSettings(device=None)) |
| |
|
| | nn = torch.nn |
| |
|
| |
|
| | def default_device(): |
| | return _device |
| |
|