| | #pragma once |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | #include <c10/core/Device.h> |
| | #include <c10/core/impl/GPUTrace.h> |
| | #include <c10/cuda/CUDAException.h> |
| | #include <c10/cuda/CUDAMacros.h> |
| | #include <cuda_runtime_api.h> |
| | namespace c10::cuda { |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | C10_CUDA_API DeviceIndex device_count() noexcept; |
| |
|
| | |
| | C10_CUDA_API DeviceIndex device_count_ensure_non_zero(); |
| |
|
| | C10_CUDA_API DeviceIndex current_device(); |
| |
|
| | C10_CUDA_API void set_device(DeviceIndex device, const bool force = false); |
| |
|
| | C10_CUDA_API void device_synchronize(); |
| |
|
| | C10_CUDA_API void warn_or_error_on_sync(); |
| |
|
| | |
| | C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count); |
| |
|
| | C10_CUDA_API cudaError_t GetDevice(DeviceIndex* device); |
| |
|
| | C10_CUDA_API cudaError_t |
| | SetDevice(DeviceIndex device, const bool force = false); |
| |
|
| | C10_CUDA_API cudaError_t MaybeSetDevice(DeviceIndex device); |
| |
|
| | C10_CUDA_API DeviceIndex ExchangeDevice(DeviceIndex device); |
| |
|
| | C10_CUDA_API DeviceIndex MaybeExchangeDevice(DeviceIndex device); |
| |
|
| | C10_CUDA_API void SetTargetDevice(); |
| |
|
| | enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; |
| |
|
| | |
| | |
| | |
| | |
| | class WarningState { |
| | public: |
| | void set_sync_debug_mode(SyncDebugMode l) { |
| | sync_debug_mode = l; |
| | } |
| |
|
| | SyncDebugMode get_sync_debug_mode() { |
| | return sync_debug_mode; |
| | } |
| |
|
| | private: |
| | SyncDebugMode sync_debug_mode = SyncDebugMode::L_DISABLED; |
| | }; |
| |
|
| | C10_CUDA_API __inline__ WarningState& warning_state() { |
| | static WarningState warning_state_; |
| | return warning_state_; |
| | } |
| | |
| | |
| | C10_CUDA_API void __inline__ memcpy_and_sync( |
| | void* dst, |
| | const void* src, |
| | int64_t nbytes, |
| | cudaMemcpyKind kind, |
| | cudaStream_t stream) { |
| | if (C10_UNLIKELY( |
| | warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
| | warn_or_error_on_sync(); |
| | } |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_stream_synchronization( |
| | c10::kCUDA, reinterpret_cast<uintptr_t>(stream)); |
| | } |
| | #if defined(USE_ROCM) && USE_ROCM |
| | |
| | |
| | |
| | hipStreamCaptureStatus captureStatus; |
| | C10_CUDA_CHECK(hipStreamGetCaptureInfo(stream, &captureStatus, nullptr)); |
| | if (C10_LIKELY(captureStatus == hipStreamCaptureStatusNone)) { |
| | C10_CUDA_CHECK(hipMemcpyWithStream(dst, src, nbytes, kind, stream)); |
| | } else { |
| | C10_CUDA_CHECK(hipErrorStreamCaptureUnsupported); |
| | } |
| | #else |
| | C10_CUDA_CHECK(cudaMemcpyAsync(dst, src, nbytes, kind, stream)); |
| | C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
| | #endif |
| | } |
| |
|
| | C10_CUDA_API void __inline__ stream_synchronize(cudaStream_t stream) { |
| | if (C10_UNLIKELY( |
| | warning_state().get_sync_debug_mode() != SyncDebugMode::L_DISABLED)) { |
| | warn_or_error_on_sync(); |
| | } |
| | const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); |
| | if (C10_UNLIKELY(interp)) { |
| | (*interp)->trace_gpu_stream_synchronization( |
| | c10::kCUDA, reinterpret_cast<uintptr_t>(stream)); |
| | } |
| | C10_CUDA_CHECK(cudaStreamSynchronize(stream)); |
| | } |
| |
|
| | C10_CUDA_API bool hasPrimaryContext(DeviceIndex device_index); |
| | C10_CUDA_API std::optional<DeviceIndex> getDeviceIndexWithPrimaryContext(); |
| |
|
| | } |
| |
|