File size: 8,574 Bytes
ad5f26a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 |
r"""
This package introduces support for the current :ref:`accelerator<accelerators>` in python.
"""
from typing import Optional
from typing_extensions import deprecated
import torch
from ._utils import _device_t, _get_device_index
__all__ = [
"current_accelerator",
"current_device_idx", # deprecated
"current_device_index",
"current_stream",
"device_count",
"device_index",
"is_available",
"set_device_idx", # deprecated
"set_device_index",
"set_stream",
"synchronize",
]
def device_count() -> int:
r"""Return the number of current :ref:`accelerator<accelerators>` available.
Returns:
int: the number of the current :ref:`accelerator<accelerators>` available.
If there is no available accelerators, return 0.
.. note:: This API delegates to the device-specific version of `device_count`.
On CUDA, this API will NOT poison fork if NVML discovery succeeds.
Otherwise, it will. For more details, see :ref:`multiprocessing-poison-fork-note`.
"""
acc = current_accelerator()
if acc is None:
return 0
mod = torch.get_device_module(acc)
return mod.device_count()
def is_available() -> bool:
r"""Check if the current accelerator is available at runtime: it was build, all the
required drivers are available and at least one device is visible.
See :ref:`accelerator<accelerators>` for details.
Returns:
bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.
.. note:: This API delegates to the device-specific version of `is_available`.
On CUDA, when the environment variable ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` is set,
this function will NOT poison fork. Otherwise, it will. For more details, see
:ref:`multiprocessing-poison-fork-note`.
Example::
>>> assert torch.accelerator.is_available() "No available accelerators detected."
"""
# Why not just check "device_count() > 0" like other is_available call?
# Because device like CUDA have a python implementation of is_available that is
# non-poisoning and some features like Dataloader rely on it.
# So we are careful to delegate to the Python version of the accelerator here
acc = current_accelerator()
if acc is None:
return False
mod = torch.get_device_module(acc)
return mod.is_available()
def current_accelerator(check_available: bool = False) -> Optional[torch.device]:
r"""Return the device of the accelerator available at compilation time.
If no accelerator were available at compilation time, returns None.
See :ref:`accelerator<accelerators>` for details.
Args:
check_available (bool, optional): if True, will also do a runtime check to see
if the device :func:`torch.accelerator.is_available` on top of the compile-time
check.
Default: ``False``
Returns:
torch.device: return the current accelerator as :class:`torch.device`.
.. note:: The index of the returned :class:`torch.device` will be ``None``, please use
:func:`torch.accelerator.current_device_index` to know the current index being used.
This API does NOT poison fork. For more details, see :ref:`multiprocessing-poison-fork-note`.
Example::
>>> # xdoctest:
>>> # If an accelerator is available, sent the model to it
>>> model = torch.nn.Linear(2, 2)
>>> if (current_device := current_accelerator(check_available=True)) is not None:
>>> model.to(current_device)
"""
if (acc := torch._C._accelerator_getAccelerator()) is not None:
if (not check_available) or (check_available and is_available()):
return acc
return None
def current_device_index() -> int:
r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.
Returns:
int: the index of a currently selected device.
"""
return torch._C._accelerator_getDeviceIndex()
current_device_idx = deprecated(
"Use `current_device_index` instead.",
category=FutureWarning,
)(current_device_index)
def set_device_index(device: _device_t, /) -> None:
r"""Set the current device index to a given device.
Args:
device (:class:`torch.device`, str, int): a given device that must match the current
:ref:`accelerator<accelerators>` device type.
.. note:: This function is a no-op if this device index is negative.
"""
device_index = _get_device_index(device, optional=False)
torch._C._accelerator_setDeviceIndex(device_index)
set_device_idx = deprecated(
"Use `set_device_index` instead.",
category=FutureWarning,
)(set_device_index)
def current_stream(device: _device_t = None, /) -> torch.Stream:
r"""Return the currently selected stream for a given device.
Args:
device (:class:`torch.device`, str, int, optional): a given device that must match the current
:ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_index` by default.
Returns:
torch.Stream: the currently selected stream for a given device.
"""
device_index = _get_device_index(device, optional=True)
return torch._C._accelerator_getStream(device_index)
def set_stream(stream: torch.Stream) -> None:
r"""Set the current stream to a given stream.
Args:
stream (torch.Stream): a given stream that must match the current :ref:`accelerator<accelerators>` device type.
.. note:: This function will set the current device index to the device index of the given stream.
"""
torch._C._accelerator_setStream(stream)
def synchronize(device: _device_t = None, /) -> None:
r"""Wait for all kernels in all streams on the given device to complete.
Args:
device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match
the current :ref:`accelerator<accelerators>` device type. If not given,
use :func:`torch.accelerator.current_device_index` by default.
.. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.
Example::
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> assert torch.accelerator.is_available() "No available accelerators detected."
>>> start_event = torch.Event(enable_timing=True)
>>> end_event = torch.Event(enable_timing=True)
>>> start_event.record()
>>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator())
>>> sum = torch.sum(tensor)
>>> end_event.record()
>>> torch.accelerator.synchronize()
>>> elapsed_time_ms = start_event.elapsed_time(end_event)
"""
device_index = _get_device_index(device, optional=True)
torch._C._accelerator_synchronizeDevice(device_index)
class device_index:
r"""Context manager to set the current device index for the current :ref:`accelerator<accelerators>`.
Temporarily changes the current device index to the specified value for the duration
of the context, and automatically restores the previous device index when exiting
the context.
Args:
device (Optional[int]): a given device index to temporarily set. If None,
no device index switching occurs.
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # Set device 0 as the current device temporarily
>>> with torch.accelerator.device_index(0):
... # Code here runs with device 0 as the current device
... pass
>>> # Original device is now restored
>>> # No-op when None is passed
>>> with torch.accelerator.device_index(None):
... # No device switching occurs
... pass
"""
def __init__(self, device: Optional[int], /) -> None:
self.idx = device
self.prev_idx = -1
def __enter__(self) -> None:
if self.idx is not None:
self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx)
def __exit__(self, *exc_info: object) -> None:
if self.idx is not None:
torch._C._accelerator_maybeExchangeDevice(self.prev_idx)
|