|
|
#pragma once |
|
|
|
|
|
#include <c10/hip/HIPStream.h> |
|
|
|
|
|
|
|
|
|
|
|
namespace c10 { namespace hip { |
|
|
|
|
|
|
|
|
|
|
|
class HIPStreamMasqueradingAsCUDA { |
|
|
public: |
|
|
|
|
|
enum Unchecked { UNCHECKED }; |
|
|
|
|
|
explicit HIPStreamMasqueradingAsCUDA(Stream stream) |
|
|
: HIPStreamMasqueradingAsCUDA(UNCHECKED, stream) { |
|
|
|
|
|
TORCH_CHECK(stream.device().is_cuda() ); |
|
|
} |
|
|
|
|
|
explicit HIPStreamMasqueradingAsCUDA(Unchecked, Stream stream) |
|
|
|
|
|
: stream_( |
|
|
HIPStream( |
|
|
Stream( |
|
|
Stream::UNSAFE, |
|
|
Device(DeviceType::HIP, stream.device_index()), |
|
|
stream.id()) |
|
|
) |
|
|
) {} |
|
|
|
|
|
|
|
|
explicit HIPStreamMasqueradingAsCUDA(HIPStream stream) : stream_(stream) {} |
|
|
|
|
|
bool operator==(const HIPStreamMasqueradingAsCUDA& other) const noexcept { |
|
|
return stream_ == other.stream_; |
|
|
} |
|
|
|
|
|
bool operator!=(const HIPStreamMasqueradingAsCUDA& other) const noexcept { |
|
|
return stream_ != other.stream_; |
|
|
} |
|
|
|
|
|
operator hipStream_t() const { return stream_.stream(); } |
|
|
|
|
|
operator Stream() const { |
|
|
|
|
|
return Stream(Stream::UNSAFE, device(), id()); |
|
|
} |
|
|
|
|
|
DeviceIndex device_index() const { return stream_.device_index(); } |
|
|
|
|
|
Device device() const { |
|
|
|
|
|
return Device(DeviceType::CUDA, stream_.device_index()); |
|
|
} |
|
|
|
|
|
StreamId id() const { return stream_.id(); } |
|
|
bool query() const { return stream_.query(); } |
|
|
void synchronize() const { stream_.synchronize(); } |
|
|
int priority() const { return stream_.priority(); } |
|
|
hipStream_t stream() const { return stream_.stream(); } |
|
|
|
|
|
Stream unwrap() const { |
|
|
|
|
|
return Stream(Stream::UNSAFE, device(), id()); |
|
|
} |
|
|
|
|
|
uint64_t pack() const noexcept { |
|
|
|
|
|
return unwrap().pack(); |
|
|
} |
|
|
|
|
|
static HIPStreamMasqueradingAsCUDA unpack(uint64_t bits) { |
|
|
|
|
|
return HIPStreamMasqueradingAsCUDA(Stream::unpack(bits)); |
|
|
} |
|
|
|
|
|
static std::tuple<int, int> priority_range() { return HIPStream::priority_range(); } |
|
|
|
|
|
|
|
|
HIPStream hip_stream() const { return stream_; } |
|
|
|
|
|
private: |
|
|
HIPStream stream_; |
|
|
}; |
|
|
|
|
|
HIPStreamMasqueradingAsCUDA |
|
|
inline getStreamFromPoolMasqueradingAsCUDA(const bool isHighPriority = false, DeviceIndex device = -1) { |
|
|
return HIPStreamMasqueradingAsCUDA(getStreamFromPool(isHighPriority, device)); |
|
|
} |
|
|
|
|
|
HIPStreamMasqueradingAsCUDA |
|
|
inline getStreamFromExternalMasqueradingAsCUDA(hipStream_t ext_stream, DeviceIndex device) { |
|
|
return HIPStreamMasqueradingAsCUDA(getStreamFromExternal(ext_stream, device)); |
|
|
} |
|
|
|
|
|
inline HIPStreamMasqueradingAsCUDA getDefaultHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { |
|
|
return HIPStreamMasqueradingAsCUDA(getDefaultHIPStream(device_index)); |
|
|
} |
|
|
|
|
|
inline HIPStreamMasqueradingAsCUDA getCurrentHIPStreamMasqueradingAsCUDA(DeviceIndex device_index = -1) { |
|
|
return HIPStreamMasqueradingAsCUDA(getCurrentHIPStream(device_index)); |
|
|
} |
|
|
|
|
|
inline void setCurrentHIPStreamMasqueradingAsCUDA(HIPStreamMasqueradingAsCUDA stream) { |
|
|
setCurrentHIPStream(stream.hip_stream()); |
|
|
} |
|
|
|
|
|
inline std::ostream& operator<<(std::ostream& stream, const HIPStreamMasqueradingAsCUDA& s) { |
|
|
stream << s.hip_stream() << " (masquerading as CUDA)"; |
|
|
return stream; |
|
|
} |
|
|
|
|
|
}} |
|
|
|
|
|
namespace std { |
|
|
template <> |
|
|
struct hash<c10::hip::HIPStreamMasqueradingAsCUDA> { |
|
|
size_t operator()(c10::hip::HIPStreamMasqueradingAsCUDA s) const noexcept { |
|
|
return std::hash<c10::Stream>{}(s.unwrap()); |
|
|
} |
|
|
}; |
|
|
} |
|
|
|