|
|
#pragma once
|
|
|
|
|
|
#include <c10/core/DeviceType.h>
|
|
|
#include <c10/macros/Macros.h>
|
|
|
|
|
|
#include <atomic>
|
|
|
#include <utility>
|
|
|
#include <variant>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_CLANG_DIAGNOSTIC_PUSH()
|
|
|
C10_CLANG_DIAGNOSTIC_IGNORE("-Wundefined-var-template")
|
|
|
|
|
|
namespace at::native {
|
|
|
|
|
|
enum class CPUCapability {
|
|
|
DEFAULT = 0,
|
|
|
#if defined(HAVE_VSX_CPU_DEFINITION)
|
|
|
VSX = 1,
|
|
|
#elif defined(HAVE_ZVECTOR_CPU_DEFINITION)
|
|
|
ZVECTOR = 1,
|
|
|
#elif defined(HAVE_SVE256_CPU_DEFINITION) && defined(HAVE_ARM_BF16_CPU_DEFINITION)
|
|
|
SVE256 = 1,
|
|
|
#else
|
|
|
AVX2 = 1,
|
|
|
AVX512 = 2,
|
|
|
#endif
|
|
|
NUM_OPTIONS
|
|
|
};
|
|
|
|
|
|
|
|
|
enum class ErrorType {
|
|
|
MissingDeviceKernel,
|
|
|
DeviceNotSupported
|
|
|
};
|
|
|
|
|
|
|
|
|
using DispatchResult = std::variant<void*, ErrorType>;
|
|
|
|
|
|
CPUCapability get_cpu_capability();
|
|
|
|
|
|
template <typename FnPtr, typename T>
|
|
|
struct DispatchStub;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API DispatchStubImpl {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DispatchResult try_get_call_ptr(
|
|
|
c10::DeviceType device_type
|
|
|
, void *DEFAULT
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, void *AVX512
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, void *AVX2
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, void *VSX
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, void *ZVECTOR
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, void *SVE256
|
|
|
#endif
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
DispatchResult try_choose_cpu_impl(
|
|
|
void *DEFAULT
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, void *AVX512
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, void *AVX2
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, void *VSX
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, void *ZVECTOR
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, void *SVE256
|
|
|
#endif
|
|
|
);
|
|
|
|
|
|
|
|
|
void* get_call_ptr(
|
|
|
c10::DeviceType device_type
|
|
|
, void *DEFAULT
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, void *AVX512
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, void *AVX2
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, void *VSX
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, void *ZVECTOR
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, void *SVE256
|
|
|
#endif
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void* choose_cpu_impl(
|
|
|
void *DEFAULT
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, void *AVX512
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, void *AVX2
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, void *VSX
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, void *ZVECTOR
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, void *SVE256
|
|
|
#endif
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(_MSC_VER) && defined(_DEBUG)
|
|
|
std::atomic<void*> cpu_dispatch_ptr;
|
|
|
void* cuda_dispatch_ptr;
|
|
|
void* hip_dispatch_ptr;
|
|
|
void* mps_dispatch_ptr;
|
|
|
void* mtia_dispatch_ptr;
|
|
|
#if defined(USE_XPU)
|
|
|
void* xpu_dispatch_ptr;
|
|
|
#endif
|
|
|
void* hpu_dispatch_ptr;
|
|
|
void* privateuse1_dispatch_ptr;
|
|
|
#else
|
|
|
std::atomic<void*> cpu_dispatch_ptr{nullptr};
|
|
|
void* cuda_dispatch_ptr = nullptr;
|
|
|
void* hip_dispatch_ptr = nullptr;
|
|
|
void* mps_dispatch_ptr = nullptr;
|
|
|
void* mtia_dispatch_ptr = nullptr;
|
|
|
#if defined(USE_XPU)
|
|
|
void* xpu_dispatch_ptr = nullptr;
|
|
|
#endif
|
|
|
void* hpu_dispatch_ptr = nullptr;
|
|
|
void* privateuse1_dispatch_ptr = nullptr;
|
|
|
#endif
|
|
|
};
|
|
|
|
|
|
template <typename rT, typename T, typename... Args>
|
|
|
struct DispatchStub<rT (*)(Args...), T> {
|
|
|
using FnPtr = rT (*) (Args...);
|
|
|
|
|
|
DispatchStub() = default;
|
|
|
DispatchStub(const DispatchStub&) = delete;
|
|
|
DispatchStub& operator=(const DispatchStub&) = delete;
|
|
|
|
|
|
private:
|
|
|
FnPtr get_call_ptr(const c10::DeviceType device_type) {
|
|
|
return reinterpret_cast<FnPtr>(
|
|
|
impl.get_call_ptr(device_type
|
|
|
, reinterpret_cast<void*>(DEFAULT)
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(AVX512)
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(AVX2)
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(VSX)
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(ZVECTOR)
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(SVE256)
|
|
|
#endif
|
|
|
)
|
|
|
);
|
|
|
}
|
|
|
|
|
|
public:
|
|
|
template <typename... ArgTypes>
|
|
|
rT operator()(c10::DeviceType device_type, ArgTypes&&... args) {
|
|
|
FnPtr call_ptr = get_call_ptr(device_type);
|
|
|
return (*call_ptr)(std::forward<ArgTypes>(args)...);
|
|
|
}
|
|
|
|
|
|
void set_cuda_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.cuda_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
#if defined(USE_XPU)
|
|
|
void set_xpu_dispatch_ptr(FnPtr fn_ptr){
|
|
|
impl.xpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
#endif
|
|
|
|
|
|
void set_hpu_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.hpu_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
void set_hip_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.hip_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
void set_mps_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.mps_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
void set_mtia_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.mtia_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
void set_privateuse1_dispatch_ptr(FnPtr fn_ptr) {
|
|
|
impl.privateuse1_dispatch_ptr = reinterpret_cast<void*>(fn_ptr);
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
bool is_device_supported(const c10::DeviceType device_type) {
|
|
|
auto result = impl.try_get_call_ptr(device_type
|
|
|
, reinterpret_cast<void*>(DEFAULT)
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(AVX512)
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(AVX2)
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(VSX)
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(ZVECTOR)
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
, reinterpret_cast<void*>(SVE256)
|
|
|
#endif
|
|
|
);
|
|
|
if (std::holds_alternative<ErrorType>(result)){
|
|
|
return false;
|
|
|
}
|
|
|
return true;
|
|
|
}
|
|
|
|
|
|
static TORCH_API FnPtr DEFAULT;
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
static TORCH_API FnPtr AVX512;
|
|
|
#endif
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
static TORCH_API FnPtr AVX2;
|
|
|
#endif
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
static TORCH_API FnPtr VSX;
|
|
|
#endif
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
static TORCH_API FnPtr ZVECTOR;
|
|
|
#endif
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
static TORCH_API FnPtr SVE256;
|
|
|
#endif
|
|
|
private:
|
|
|
DispatchStubImpl impl;
|
|
|
};
|
|
|
|
|
|
namespace {
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterCUDADispatch {
|
|
|
RegisterCUDADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
|
|
stub.set_cuda_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterXPUDispatch {
|
|
|
RegisterXPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
|
|
stub.set_xpu_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterHPUDispatch {
|
|
|
RegisterHPUDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value){
|
|
|
stub.set_hpu_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterMPSDispatch {
|
|
|
RegisterMPSDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
|
|
stub.set_mps_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterHIPDispatch {
|
|
|
RegisterHIPDispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
|
|
|
|
|
stub.set_cuda_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterMTIADispatch {
|
|
|
RegisterMTIADispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
|
|
stub.set_mtia_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
template <typename DispatchStub>
|
|
|
struct RegisterPRIVATEUSE1Dispatch {
|
|
|
RegisterPRIVATEUSE1Dispatch(DispatchStub &stub, typename DispatchStub::FnPtr value) {
|
|
|
stub.set_privateuse1_dispatch_ptr(value);
|
|
|
}
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#define DECLARE_DISPATCH(fn, name) \
|
|
|
struct name##_DECLARE_DISPATCH_type : DispatchStub<fn, name##_DECLARE_DISPATCH_type> { \
|
|
|
name##_DECLARE_DISPATCH_type() = default; \
|
|
|
name##_DECLARE_DISPATCH_type(const name##_DECLARE_DISPATCH_type&) = delete; \
|
|
|
name##_DECLARE_DISPATCH_type& operator=(const name##_DECLARE_DISPATCH_type&) = delete; \
|
|
|
name##_DECLARE_DISPATCH_type(name##_DECLARE_DISPATCH_type&&) = delete; \
|
|
|
name##_DECLARE_DISPATCH_type& operator=(name##_DECLARE_DISPATCH_type&&) = delete; \
|
|
|
~name##_DECLARE_DISPATCH_type() = default; \
|
|
|
}; \
|
|
|
extern TORCH_API struct name##_DECLARE_DISPATCH_type name;
|
|
|
|
|
|
#define DEFINE_DISPATCH(name) struct name##_DECLARE_DISPATCH_type name
|
|
|
|
|
|
#define REGISTER_ARCH_DISPATCH(name, arch, fn) \
|
|
|
template <> name##_DECLARE_DISPATCH_type::FnPtr TORCH_API DispatchStub<name##_DECLARE_DISPATCH_type::FnPtr, struct name##_DECLARE_DISPATCH_type>::arch = fn;
|
|
|
|
|
|
#ifdef HAVE_AVX512_CPU_DEFINITION
|
|
|
#define REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX512, fn)
|
|
|
#else
|
|
|
#define REGISTER_AVX512_DISPATCH(name, fn)
|
|
|
#endif
|
|
|
|
|
|
#ifdef HAVE_AVX2_CPU_DEFINITION
|
|
|
#define REGISTER_AVX2_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, AVX2, fn)
|
|
|
#else
|
|
|
#define REGISTER_AVX2_DISPATCH(name, fn)
|
|
|
#endif
|
|
|
|
|
|
#ifdef HAVE_VSX_CPU_DEFINITION
|
|
|
#define REGISTER_VSX_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, VSX, fn)
|
|
|
#else
|
|
|
#define REGISTER_VSX_DISPATCH(name, fn)
|
|
|
#endif
|
|
|
|
|
|
#ifdef HAVE_ZVECTOR_CPU_DEFINITION
|
|
|
#define REGISTER_ZVECTOR_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, ZVECTOR, fn)
|
|
|
#else
|
|
|
#define REGISTER_ZVECTOR_DISPATCH(name, fn)
|
|
|
#endif
|
|
|
|
|
|
#ifdef HAVE_SVE256_CPU_DEFINITION
|
|
|
#define REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, SVE256, fn)
|
|
|
#else
|
|
|
#define REGISTER_SVE256_DISPATCH(name, fn)
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|
|
#define REGISTER_ALL_CPU_DISPATCH(name, fn) \
|
|
|
REGISTER_ARCH_DISPATCH(name, DEFAULT, fn) \
|
|
|
REGISTER_AVX512_DISPATCH(name, fn) \
|
|
|
REGISTER_AVX2_DISPATCH(name, fn) \
|
|
|
REGISTER_VSX_DISPATCH(name, fn) \
|
|
|
REGISTER_ZVECTOR_DISPATCH(name, fn) \
|
|
|
REGISTER_SVE256_DISPATCH(name, fn)
|
|
|
|
|
|
#define REGISTER_NO_CPU_DISPATCH(name) \
|
|
|
REGISTER_ALL_CPU_DISPATCH(name, nullptr)
|
|
|
|
|
|
#define REGISTER_CUDA_DISPATCH(name, fn) \
|
|
|
static RegisterCUDADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_XPU_DISPATCH(name, fn) \
|
|
|
static RegisterXPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_HPU_DISPATCH(name, fn) \
|
|
|
static RegisterHPUDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_HIP_DISPATCH(name, fn) \
|
|
|
static RegisterHIPDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_MPS_DISPATCH(name, fn) \
|
|
|
static RegisterMPSDispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_MTIA_DISPATCH(name, fn) \
|
|
|
static RegisterMTIADispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
#define REGISTER_PRIVATEUSE1_DISPATCH(name, fn) \
|
|
|
static RegisterPRIVATEUSE1Dispatch<struct name##_DECLARE_DISPATCH_type> name ## __register(name, fn);
|
|
|
|
|
|
|
|
|
|
|
|
#if defined(__CUDACC__)
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
|
|
#elif defined(__HIPCC__)
|
|
|
|
|
|
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_CUDA_DISPATCH(name, fn)
|
|
|
|
|
|
#elif defined(__OBJC__) && defined(USE_MPS)
|
|
|
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_MPS_DISPATCH(name, fn)
|
|
|
#elif defined(CPU_CAPABILITY)
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef CPU_CAPABILITY_AVX512
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, ((void*)(fn) ? nullptr : nullptr))
|
|
|
#else
|
|
|
#define REGISTER_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
|
|
#endif
|
|
|
#define ALSO_REGISTER_AVX512_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
|
|
#define ALSO_REGISTER_SVE256_DISPATCH(name, fn) REGISTER_ARCH_DISPATCH(name, CPU_CAPABILITY, fn)
|
|
|
#endif
|
|
|
}
|
|
|
|
|
|
C10_CLANG_DIAGNOSTIC_POP()
|
|
|
|