| | #pragma once |
| |
|
| | #include <c10/core/CachingDeviceAllocator.h> |
| | #include <c10/core/DeviceType.h> |
| | #include <c10/macros/Macros.h> |
| |
|
| | #include <ATen/detail/MTIAHooksInterface.h> |
| | #include <optional> |
| |
|
| | namespace at::accelerator { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | TORCH_API std::optional<c10::DeviceType> getAccelerator(bool checked = false); |
| |
|
| | |
| | TORCH_API bool isAccelerator(c10::DeviceType device_type); |
| |
|
| | |
| | template < |
| | typename... T, |
| | typename = std::enable_if_t<(std::is_same_v<T, c10::DeviceType> && ...)>> |
| | inline bool isAcceleratorExcluded( |
| | c10::DeviceType device_type, |
| | c10::DeviceType first_excluded, |
| | T... rest_excluded) { |
| | if constexpr (sizeof...(rest_excluded) > 0) { |
| | return device_type != first_excluded && |
| | isAcceleratorExcluded(device_type, rest_excluded...); |
| | } else { |
| | return device_type != first_excluded && isAccelerator(device_type); |
| | } |
| | } |
| |
|
| | |
| | |
| | TORCH_API c10::DeviceIndex deviceCount(); |
| |
|
| | |
| | TORCH_API void setDeviceIndex(c10::DeviceIndex device_index); |
| |
|
| | |
| | TORCH_API c10::DeviceIndex getDeviceIndex(); |
| |
|
| | |
| | |
| | TORCH_API void setCurrentStream(c10::Stream stream); |
| |
|
| | |
| | TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index); |
| |
|
| | |
| | |
| | TORCH_API void synchronizeDevice(c10::DeviceIndex device_index); |
| |
|
| | |
| | |
| | TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); |
| |
|
| | |
| | |
| | |
| | TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); |
| |
|
| | TORCH_API inline void emptyCache() { |
| | const auto device_type = getAccelerator(true).value(); |
| | at::getDeviceAllocator(device_type)->emptyCache(); |
| | } |
| |
|
| | TORCH_API inline at::CachingDeviceAllocator::DeviceStats getDeviceStats( |
| | c10::DeviceIndex device_index) { |
| | const auto device_type = getAccelerator(true).value(); |
| | return at::getDeviceAllocator(device_type)->getDeviceStats(device_index); |
| | } |
| |
|
| | TORCH_API inline void resetAccumulatedStats(c10::DeviceIndex device_index) { |
| | const auto device_type = getAccelerator(true).value(); |
| | at::getDeviceAllocator(device_type)->resetAccumulatedStats(device_index); |
| | } |
| |
|
| | TORCH_API inline void resetPeakStats(c10::DeviceIndex device_index) { |
| | const auto device_type = getAccelerator(true).value(); |
| | at::getDeviceAllocator(device_type)->resetPeakStats(device_index); |
| | } |
| |
|
| | } |
| |
|
| | namespace at { |
| | |
| | using at::accelerator::getAccelerator; |
| | using at::accelerator::isAccelerator; |
| | } |
| |
|