File size: 8,574 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
r"""

This package introduces support for the current :ref:`accelerator<accelerators>` in python.

"""

from typing import Optional
from typing_extensions import deprecated

import torch

from ._utils import _device_t, _get_device_index


__all__ = [
    "current_accelerator",
    "current_device_idx",  # deprecated
    "current_device_index",
    "current_stream",
    "device_count",
    "device_index",
    "is_available",
    "set_device_idx",  # deprecated
    "set_device_index",
    "set_stream",
    "synchronize",
]


def device_count() -> int:
    r"""Return the number of current :ref:`accelerator<accelerators>` available.



    Returns:

        int: the number of the current :ref:`accelerator<accelerators>` available.

            If there is no available accelerators, return 0.



    .. note:: This API delegates to the device-specific version of `device_count`.

        On CUDA, this API will NOT poison fork if NVML discovery succeeds.

        Otherwise, it will. For more details, see :ref:`multiprocessing-poison-fork-note`.

    """
    acc = current_accelerator()
    if acc is None:
        return 0

    mod = torch.get_device_module(acc)
    return mod.device_count()


def is_available() -> bool:
    r"""Check if the current accelerator is available at runtime: it was build, all the

    required drivers are available and at least one device is visible.

    See :ref:`accelerator<accelerators>` for details.



    Returns:

        bool: A boolean indicating if there is an available :ref:`accelerator<accelerators>`.



    .. note:: This API delegates to the device-specific version of `is_available`.

        On CUDA, when the environment variable ``PYTORCH_NVML_BASED_CUDA_CHECK=1`` is set,

        this function will NOT poison fork. Otherwise, it will. For more details, see

        :ref:`multiprocessing-poison-fork-note`.



    Example::



        >>> assert torch.accelerator.is_available() "No available accelerators detected."

    """
    # Why not just check "device_count() > 0" like other is_available call?
    # Because device like CUDA have a python implementation of is_available that is
    # non-poisoning and some features like Dataloader rely on it.
    # So we are careful to delegate to the Python version of the accelerator here
    acc = current_accelerator()
    if acc is None:
        return False

    mod = torch.get_device_module(acc)
    return mod.is_available()


def current_accelerator(check_available: bool = False) -> Optional[torch.device]:
    r"""Return the device of the accelerator available at compilation time.

    If no accelerator were available at compilation time, returns None.

    See :ref:`accelerator<accelerators>` for details.



    Args:

        check_available (bool, optional): if True, will also do a runtime check to see

            if the device :func:`torch.accelerator.is_available` on top of the compile-time

            check.

            Default: ``False``



    Returns:

        torch.device: return the current accelerator as :class:`torch.device`.



    .. note:: The index of the returned :class:`torch.device` will be ``None``, please use

        :func:`torch.accelerator.current_device_index` to know the current index being used.

        This API does NOT poison fork. For more details, see :ref:`multiprocessing-poison-fork-note`.



    Example::



        >>> # xdoctest:

        >>> # If an accelerator is available, sent the model to it

        >>> model = torch.nn.Linear(2, 2)

        >>> if (current_device := current_accelerator(check_available=True)) is not None:

        >>>     model.to(current_device)

    """
    if (acc := torch._C._accelerator_getAccelerator()) is not None:
        if (not check_available) or (check_available and is_available()):
            return acc
    return None


def current_device_index() -> int:
    r"""Return the index of a currently selected device for the current :ref:`accelerator<accelerators>`.



    Returns:

        int: the index of a currently selected device.

    """
    return torch._C._accelerator_getDeviceIndex()


current_device_idx = deprecated(
    "Use `current_device_index` instead.",
    category=FutureWarning,
)(current_device_index)


def set_device_index(device: _device_t, /) -> None:
    r"""Set the current device index to a given device.



    Args:

        device (:class:`torch.device`, str, int): a given device that must match the current

            :ref:`accelerator<accelerators>` device type.



    .. note:: This function is a no-op if this device index is negative.

    """
    device_index = _get_device_index(device, optional=False)
    torch._C._accelerator_setDeviceIndex(device_index)


set_device_idx = deprecated(
    "Use `set_device_index` instead.",
    category=FutureWarning,
)(set_device_index)


def current_stream(device: _device_t = None, /) -> torch.Stream:
    r"""Return the currently selected stream for a given device.



    Args:

        device (:class:`torch.device`, str, int, optional): a given device that must match the current

            :ref:`accelerator<accelerators>` device type. If not given,

            use :func:`torch.accelerator.current_device_index` by default.



    Returns:

        torch.Stream: the currently selected stream for a given device.

    """
    device_index = _get_device_index(device, optional=True)
    return torch._C._accelerator_getStream(device_index)


def set_stream(stream: torch.Stream) -> None:
    r"""Set the current stream to a given stream.



    Args:

        stream (torch.Stream): a given stream that must match the current :ref:`accelerator<accelerators>` device type.



    .. note:: This function will set the current device index to the device index of the given stream.

    """
    torch._C._accelerator_setStream(stream)


def synchronize(device: _device_t = None, /) -> None:
    r"""Wait for all kernels in all streams on the given device to complete.



    Args:

        device (:class:`torch.device`, str, int, optional): device for which to synchronize. It must match

            the current :ref:`accelerator<accelerators>` device type. If not given,

            use :func:`torch.accelerator.current_device_index` by default.



    .. note:: This function is a no-op if the current :ref:`accelerator<accelerators>` is not initialized.



    Example::



        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)

        >>> assert torch.accelerator.is_available() "No available accelerators detected."

        >>> start_event = torch.Event(enable_timing=True)

        >>> end_event = torch.Event(enable_timing=True)

        >>> start_event.record()

        >>> tensor = torch.randn(100, device=torch.accelerator.current_accelerator())

        >>> sum = torch.sum(tensor)

        >>> end_event.record()

        >>> torch.accelerator.synchronize()

        >>> elapsed_time_ms = start_event.elapsed_time(end_event)

    """
    device_index = _get_device_index(device, optional=True)
    torch._C._accelerator_synchronizeDevice(device_index)


class device_index:
    r"""Context manager to set the current device index for the current :ref:`accelerator<accelerators>`.

    Temporarily changes the current device index to the specified value for the duration

    of the context, and automatically restores the previous device index when exiting

    the context.



    Args:

        device (Optional[int]): a given device index to temporarily set. If None,

            no device index switching occurs.



    Examples:



        >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)

        >>> # Set device 0 as the current device temporarily

        >>> with torch.accelerator.device_index(0):

        ...     # Code here runs with device 0 as the current device

        ...     pass

        >>> # Original device is now restored

        >>> # No-op when None is passed

        >>> with torch.accelerator.device_index(None):

        ...     # No device switching occurs

        ...     pass

    """

    def __init__(self, device: Optional[int], /) -> None:
        self.idx = device
        self.prev_idx = -1

    def __enter__(self) -> None:
        if self.idx is not None:
            self.prev_idx = torch._C._accelerator_exchangeDevice(self.idx)

    def __exit__(self, *exc_info: object) -> None:
        if self.idx is not None:
            torch._C._accelerator_maybeExchangeDevice(self.prev_idx)