| | import os |
| | from enum import Enum |
| | from .device_id import DeviceId |
| |
|
| | |
| |
|
| | class DeviceException(Exception): |
| | pass |
| |
|
| | class _Device: |
| | def __init__(self): |
| | self.set(DeviceId.CPU) |
| |
|
| | def is_gpu(self): |
| | ''' Returns `True` if the current device is GPU, `False` otherwise. ''' |
| | return self.current() is not DeviceId.CPU |
| | |
| | def current(self): |
| | return self._current_device |
| |
|
| | def set(self, device:DeviceId): |
| | if device == DeviceId.CPU: |
| | os.environ['CUDA_VISIBLE_DEVICES']='' |
| | else: |
| | os.environ['CUDA_VISIBLE_DEVICES']=str(device.value) |
| | import torch |
| | torch.backends.cudnn.benchmark=False |
| | |
| | self._current_device = device |
| | return device |