| | #pragma once |
| |
|
| | #include <c10/core/DeviceGuard.h> |
| | #include <c10/core/impl/DeviceGuardImplInterface.h> |
| | #include <c10/core/impl/GPUTrace.h> |
| | #include <c10/macros/Macros.h> |
| | #include <c10/util/Exception.h> |
| |
|
| | #include <c10/cuda/CUDACachingAllocator.h> |
| | #include <c10/cuda/CUDAException.h> |
| | #include <c10/cuda/CUDAFunctions.h> |
| | #include <c10/cuda/CUDAStream.h> |
| |
|
| | #include <cuda_runtime_api.h> |
| |
|
| | namespace c10 { |
| | namespace cuda { |
| | namespace impl { |
| |
|
| | struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { |
| | static constexpr DeviceType static_type = DeviceType::CUDA; |
| |
|
| | CUDAGuardImpl() {} |
| | explicit CUDAGuardImpl(DeviceType t) { |
| | TORCH_INTERNAL_ASSERT(t == DeviceType::CUDA); |
| | } |
| | DeviceType type() const override { |
| | return DeviceType::CUDA; |
| | } |
| | Device exchangeDevice(Device d) const override { |
| | TORCH_INTERNAL_ASSERT(d.is_cuda()); |
| | Device old_device = getDevice(); |
| | if (old_device.index() != d.index()) { |
| | C10_CUDA_CHECK(cudaSetDevice(d.index())); |
| | } |
| | return old_device; |
| | } |
| | Device getDevice() const override { |
| | int device; |
| | C10_CUDA_CHECK(cudaGetDevice(&device)); |
| | return Device(DeviceType::CUDA, device); |
| | } |
| | c10::optional<Device> uncheckedGetDevice() const noexcept { |
| | int device; |
| | const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); |
| | C10_CUDA_CHECK_WARN(err); |
| | if (err != cudaSuccess) { |
| | return c10::nullopt; |
| | } |
| | return Device(DeviceType::CUDA, device); |
| | } |
| | void setDevice(Device d) const override { |
| | TORCH_INTERNAL_ASSERT(d.is_cuda()); |
| | Device current_device = getDevice(); |
| | if (current_device != d) { |
| | C10_CUDA_CHECK(cudaSetDevice(d.index())); |
| | } |
| | } |
| | void uncheckedSetDevice(Device d) const noexcept override { |
| | auto current_device = uncheckedGetDevice(); |
| | if (!current_device.has_value() || current_device.value() != d) { |
| | C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); |
| | } |
| | } |
| | Stream getStream(Device d) const noexcept override { |
| | return getCurrentCUDAStream(d.index()).unwrap(); |
| | } |
| | Stream getDefaultStream(Device d) const override { |
| | return getDefaultCUDAStream(d.index()); |
| | } |
| | Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false) |
| | const override { |
| | return getStreamFromPool(isHighPriority, d.index()); |
| | } |
| | |
| | Stream exchangeStream(Stream s) const noexcept override { |
| | CUDAStream cs(s); |
| | auto old_stream = getCurrentCUDAStream(s.device().index()); |
| | setCurrentCUDAStream(cs); |
| | return old_stream.unwrap(); |
| | } |
| | DeviceIndex deviceCount() const noexcept override { |
| | return device_count(); |
| | } |
| |
|
| | |
| | void createEvent(cudaEvent_t* cuda_event, const EventFlag flag) const { |
| | |
| | auto cuda_flag = cudaEventDefault; |
| | switch (flag) { |
| | case EventFlag::PYTORCH_DEFAULT: |
| | case EventFlag::CUDA_EVENT_DISABLE_TIMING: |
| | cuda_flag = cudaEventDisableTiming; |
| | break; |
| | case EventFlag::BACKEND_DEFAULT: |
| | case EventFlag::CUDA_EVENT_DEFAULT: |
| | cuda_flag = cudaEventDefault; |
| | break; |
| | default: |
| | TORCH_CHECK(false, "CUDA event received unknown flag"); |
| | } |
| |
|
| | C10_CUDA_CHECK(cudaEventCreateWithFlags(cuda_event, cuda_flag)); |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_event_creation( |
| | reinterpret_cast<uintptr_t>(cuda_event)); |
| | } |
| | } |
| |
|
| | void destroyEvent(void* event, const DeviceIndex device_index) |
| | const noexcept override { |
| | if (!event) |
| | return; |
| | auto cuda_event = static_cast<cudaEvent_t>(event); |
| | int orig_device; |
| | C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); |
| | C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_event_deletion( |
| | reinterpret_cast<uintptr_t>(cuda_event)); |
| | } |
| | C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); |
| | C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); |
| | } |
| |
|
| | void record( |
| | void** event, |
| | const Stream& stream, |
| | const DeviceIndex device_index, |
| | const EventFlag flag) const override { |
| | TORCH_CHECK( |
| | device_index == -1 || device_index == stream.device_index(), |
| | "Event device index ", |
| | device_index, |
| | " does not match recording stream's device index ", |
| | stream.device_index(), |
| | "."); |
| |
|
| | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(*event); |
| | CUDAStream cuda_stream{stream}; |
| |
|
| | |
| | const auto orig_device = getDevice(); |
| | setDevice(stream.device()); |
| |
|
| | |
| | if (!cuda_event) |
| | createEvent(&cuda_event, flag); |
| | C10_CUDA_CHECK(cudaEventRecord(cuda_event, cuda_stream)); |
| | |
| | *event = cuda_event; |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_event_record( |
| | reinterpret_cast<uintptr_t>(cuda_event), |
| | reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
| | } |
| |
|
| | |
| | setDevice(orig_device); |
| | } |
| |
|
| | void block(void* event, const Stream& stream) const override { |
| | if (!event) |
| | return; |
| | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
| | CUDAStream cuda_stream{stream}; |
| | const auto orig_device = getDevice(); |
| | setDevice(stream.device()); |
| | C10_CUDA_CHECK(cudaStreamWaitEvent( |
| | cuda_stream, |
| | cuda_event, |
| | 0)); |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_event_wait( |
| | reinterpret_cast<uintptr_t>(cuda_event), |
| | reinterpret_cast<uintptr_t>(cuda_stream.stream())); |
| | } |
| | setDevice(orig_device); |
| | } |
| |
|
| | |
| | bool queryEvent(void* event) const override { |
| | if (!event) |
| | return true; |
| | cudaEvent_t cuda_event = static_cast<cudaEvent_t>(event); |
| | const cudaError_t err = C10_CUDA_ERROR_HANDLED(cudaEventQuery(cuda_event)); |
| | if (err != cudaErrorNotReady) { |
| | C10_CUDA_CHECK(err); |
| | } else { |
| | |
| | (void)cudaGetLastError(); |
| | } |
| | return (err == cudaSuccess); |
| | } |
| |
|
| | |
| | bool queryStream(const Stream& stream) const override { |
| | CUDAStream cuda_stream{stream}; |
| | return cuda_stream.query(); |
| | } |
| |
|
| | void synchronizeStream(const Stream& stream) const override { |
| | CUDAStream cuda_stream{stream}; |
| | cuda_stream.synchronize(); |
| | } |
| |
|
| | void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream) |
| | const override { |
| | CUDAStream cuda_stream{stream}; |
| | CUDACachingAllocator::recordStream(data_ptr, cuda_stream); |
| | } |
| | }; |
| |
|
| | } |
| | } |
| | } |
| |
|