|
|
#pragma once
|
|
|
|
|
|
#include <ATen/BlasBackend.h>
|
|
|
#include <ATen/CPUGeneratorImpl.h>
|
|
|
#include <ATen/DeviceAccelerator.h>
|
|
|
#include <ATen/LinalgBackend.h>
|
|
|
#include <ATen/ROCmFABackend.h>
|
|
|
#include <ATen/SDPBackend.h>
|
|
|
#include <ATen/core/ATenGeneral.h>
|
|
|
#include <ATen/core/DeprecatedTypeProperties.h>
|
|
|
#include <ATen/core/Generator.h>
|
|
|
#include <ATen/core/LegacyTypeDispatch.h>
|
|
|
#include <ATen/detail/AcceleratorHooksInterface.h>
|
|
|
#include <ATen/detail/CUDAHooksInterface.h>
|
|
|
#include <ATen/detail/HIPHooksInterface.h>
|
|
|
#include <ATen/detail/HPUHooksInterface.h>
|
|
|
#include <ATen/detail/IPUHooksInterface.h>
|
|
|
#include <ATen/detail/MAIAHooksInterface.h>
|
|
|
#include <ATen/detail/MPSHooksInterface.h>
|
|
|
#include <ATen/detail/MTIAHooksInterface.h>
|
|
|
#include <ATen/detail/PrivateUse1HooksInterface.h>
|
|
|
#include <ATen/detail/XPUHooksInterface.h>
|
|
|
#include <c10/core/QEngine.h>
|
|
|
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
|
|
#include <c10/util/CallOnce.h>
|
|
|
#include <c10/util/Exception.h>
|
|
|
#include <c10/util/env.h>
|
|
|
#include <c10/util/irange.h>
|
|
|
|
|
|
#include <cstdint>
|
|
|
#include <mutex>
|
|
|
|
|
|
namespace at {
|
|
|
|
|
|
class Tensor;
|
|
|
|
|
|
enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM };
|
|
|
|
|
|
class TORCH_API Context {
|
|
|
public:
|
|
|
Context();
|
|
|
|
|
|
const Generator& defaultGenerator(Device device) {
|
|
|
c10::DeviceType device_type = device.type();
|
|
|
lazyInitDevice(device_type);
|
|
|
|
|
|
if (device_type == at::kCPU) {
|
|
|
return at::detail::getDefaultCPUGenerator();
|
|
|
} else {
|
|
|
return getAcceleratorHooksInterface(device_type)
|
|
|
.getDefaultGenerator(device.index());
|
|
|
}
|
|
|
}
|
|
|
|
|
|
const AcceleratorHooksInterface& getAcceleratorHooksInterface(
|
|
|
std::optional<c10::DeviceType> opt_device_type = std::nullopt) {
|
|
|
if (!opt_device_type.has_value()) {
|
|
|
opt_device_type = at::getAccelerator(true);
|
|
|
}
|
|
|
if (opt_device_type == at::kCUDA) {
|
|
|
return at::detail::getCUDAHooks();
|
|
|
} else if (opt_device_type == at::kXPU) {
|
|
|
return at::detail::getXPUHooks();
|
|
|
} else if (opt_device_type == at::kMPS) {
|
|
|
return at::detail::getMPSHooks();
|
|
|
} else if (opt_device_type == at::kPrivateUse1) {
|
|
|
return at::detail::getPrivateUse1Hooks();
|
|
|
} else if (opt_device_type == at::kMTIA) {
|
|
|
return at::detail::getMTIAHooks();
|
|
|
} else if (opt_device_type == at::kHIP) {
|
|
|
return at::detail::getHIPHooks();
|
|
|
} else if (opt_device_type == at::kHPU) {
|
|
|
return at::detail::getHPUHooks();
|
|
|
} else {
|
|
|
TORCH_CHECK(
|
|
|
false,
|
|
|
opt_device_type.has_value()
|
|
|
? c10::DeviceTypeName(opt_device_type.value())
|
|
|
: "None",
|
|
|
" device type not an accelerator.");
|
|
|
}
|
|
|
}
|
|
|
|
|
|
Device getDeviceFromPtr(void* data, c10::DeviceType device_type) {
|
|
|
lazyInitDevice(device_type);
|
|
|
|
|
|
if (device_type == at::kCPU) {
|
|
|
return c10::DeviceType::CPU;
|
|
|
} else {
|
|
|
return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data);
|
|
|
}
|
|
|
}
|
|
|
|
|
|
bool isPinnedPtr(
|
|
|
const void* data,
|
|
|
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
|
|
auto opt_device_type =
|
|
|
device_type.has_value() ? device_type : at::getAccelerator();
|
|
|
if (!opt_device_type.has_value() ||
|
|
|
!at::isAccelerator(
|
|
|
opt_device_type.value())) {
|
|
|
return false;
|
|
|
}
|
|
|
if (!init_[static_cast<int8_t>(opt_device_type.value())].test_once()) {
|
|
|
|
|
|
return false;
|
|
|
}
|
|
|
return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data);
|
|
|
}
|
|
|
|
|
|
Allocator* getPinnedMemoryAllocator(
|
|
|
std::optional<c10::DeviceType> device_type = std::nullopt) {
|
|
|
auto opt_device_type =
|
|
|
device_type.has_value() ? device_type : at::getAccelerator();
|
|
|
if (opt_device_type) {
|
|
|
lazyInitDevice(opt_device_type.value());
|
|
|
}
|
|
|
return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator();
|
|
|
}
|
|
|
|
|
|
void lazyInitDevice(c10::DeviceType device_type) {
|
|
|
if (device_type != at::kCPU) {
|
|
|
c10::call_once(init_[static_cast<int8_t>(device_type)], [&] {
|
|
|
getAcceleratorHooksInterface(device_type).init();
|
|
|
});
|
|
|
}
|
|
|
}
|
|
|
|
|
|
static bool hasOpenMP();
|
|
|
static bool hasMKL();
|
|
|
static bool hasKleidiAI();
|
|
|
static bool hasLAPACK();
|
|
|
static bool hasMKLDNN();
|
|
|
static bool hasMAGMA() {
|
|
|
return detail::getCUDAHooks().hasMAGMA();
|
|
|
}
|
|
|
static bool hasCUDA() {
|
|
|
return detail::getCUDAHooks().hasCUDA();
|
|
|
}
|
|
|
static bool hasMTIA() {
|
|
|
return detail::getMTIAHooks().hasMTIA();
|
|
|
}
|
|
|
static bool hasCUDART() {
|
|
|
return detail::getCUDAHooks().hasCUDART();
|
|
|
}
|
|
|
static long versionCUDART() {
|
|
|
return detail::getCUDAHooks().versionCUDART();
|
|
|
}
|
|
|
static bool hasCuDNN() {
|
|
|
return detail::getCUDAHooks().hasCuDNN();
|
|
|
}
|
|
|
static long versionCuDNN() {
|
|
|
return detail::getCUDAHooks().versionCuDNN();
|
|
|
}
|
|
|
static bool hasCuSOLVER() {
|
|
|
return detail::getCUDAHooks().hasCuSOLVER();
|
|
|
}
|
|
|
static bool hasCuBLASLt() {
|
|
|
return detail::getCUDAHooks().hasCuBLASLt();
|
|
|
}
|
|
|
static bool hasROCM() {
|
|
|
return detail::getCUDAHooks().hasROCM();
|
|
|
}
|
|
|
static bool hasHIP() {
|
|
|
return detail::getHIPHooks().hasHIP();
|
|
|
}
|
|
|
static bool hasMPS() {
|
|
|
return detail::getMPSHooks().hasMPS();
|
|
|
}
|
|
|
static bool hasIPU() {
|
|
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU);
|
|
|
}
|
|
|
static bool hasXLA() {
|
|
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA);
|
|
|
}
|
|
|
static bool hasXPU() {
|
|
|
return detail::getXPUHooks().hasXPU();
|
|
|
}
|
|
|
static bool hasLazy() {
|
|
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy);
|
|
|
}
|
|
|
static bool hasMAIA() {
|
|
|
return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA);
|
|
|
}
|
|
|
static bool hasHPU() {
|
|
|
return detail::getHPUHooks().hasHPU();
|
|
|
}
|
|
|
|
|
|
static const at::cuda::NVRTC& getNVRTC() {
|
|
|
return detail::getCUDAHooks().nvrtc();
|
|
|
}
|
|
|
|
|
|
static bool setFlushDenormal(bool on);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool userEnabledCuDNN() const;
|
|
|
void setUserEnabledCuDNN(bool e);
|
|
|
bool userEnabledMkldnn() const;
|
|
|
void setUserEnabledMkldnn(bool e);
|
|
|
bool benchmarkCuDNN() const;
|
|
|
void setBenchmarkCuDNN(bool);
|
|
|
int benchmarkLimitCuDNN() const;
|
|
|
void setBenchmarkLimitCuDNN(int);
|
|
|
bool deterministicCuDNN() const;
|
|
|
void setDeterministicCuDNN(bool);
|
|
|
bool deterministicMkldnn() const;
|
|
|
void setDeterministicMkldnn(bool);
|
|
|
bool userEnabledNNPACK() const;
|
|
|
void setUserEnabledNNPACK(bool e);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void setSDPPriorityOrder(const std::vector<int64_t>& order);
|
|
|
std::array<at::SDPBackend, at::num_sdp_backends> sDPPriorityOrder();
|
|
|
|
|
|
void setSDPUseFlash(bool);
|
|
|
bool userEnabledFlashSDP() const;
|
|
|
|
|
|
void setSDPUseMemEfficient(bool);
|
|
|
bool userEnabledMemEfficientSDP() const;
|
|
|
|
|
|
void setSDPUseMath(bool);
|
|
|
bool userEnabledMathSDP() const;
|
|
|
|
|
|
void setSDPUseCuDNN(bool);
|
|
|
bool userEnabledCuDNNSDP() const;
|
|
|
|
|
|
void setAllowFP16BF16ReductionMathSDP(bool);
|
|
|
bool allowFP16BF16ReductionMathSDP() const;
|
|
|
|
|
|
void setSDPUseOverrideable(bool);
|
|
|
bool userEnabledOverrideableSDP() const;
|
|
|
|
|
|
at::LinalgBackend linalgPreferredBackend() const;
|
|
|
void setLinalgPreferredBackend(at::LinalgBackend);
|
|
|
|
|
|
at::BlasBackend blasPreferredBackend();
|
|
|
void setBlasPreferredBackend(at::BlasBackend);
|
|
|
|
|
|
at::ROCmFABackend getROCmFAPreferredBackend() const;
|
|
|
void setROCmFAPreferredBackend(at::ROCmFABackend);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool deterministicAlgorithms() const;
|
|
|
bool deterministicAlgorithmsWarnOnly() const;
|
|
|
void setDeterministicAlgorithms(bool, bool);
|
|
|
bool deterministicFillUninitializedMemory() const;
|
|
|
void setDeterministicFillUninitializedMemory(bool);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
static void alertNotDeterministic(std::string_view const& caller);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void alertCuBLASConfigNotDeterministic() const;
|
|
|
|
|
|
void setFloat32MatmulPrecision(const std::string& s);
|
|
|
bool allowTF32CuDNN() const;
|
|
|
void setAllowTF32CuDNN(bool);
|
|
|
bool allowTF32OneDNN() const;
|
|
|
void setAllowTF32OneDNN(bool);
|
|
|
bool allowTF32CuBLAS() const;
|
|
|
void setAllowTF32CuBLAS(bool);
|
|
|
Float32MatmulPrecision float32MatmulPrecision() const;
|
|
|
void setFloat32MatmulPrecision(Float32MatmulPrecision p);
|
|
|
bool allowFP16ReductionCuBLAS() const;
|
|
|
void setAllowFP16ReductionCuBLAS(bool);
|
|
|
bool allowBF16ReductionCuBLAS() const;
|
|
|
void setAllowBF16ReductionCuBLAS(bool);
|
|
|
bool allowFP16AccumulationCuBLAS() const;
|
|
|
void setAllowFP16AccumulationCuBLAS(bool);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::optional<int32_t> _SMCarveout_EXPERIMENTAL() const;
|
|
|
void _setSMCarveout_EXPERIMENTAL(std::optional<int32_t>);
|
|
|
|
|
|
at::QEngine qEngine() const;
|
|
|
void setQEngine(at::QEngine e);
|
|
|
static const std::vector<at::QEngine>& supportedQEngines();
|
|
|
static bool isXNNPACKAvailable();
|
|
|
void setCheckSparseTensorInvariants(bool e);
|
|
|
bool checkSparseTensorInvariants() const;
|
|
|
|
|
|
|
|
|
|
|
|
void setReleaseWeightsWhenPrepacking(bool e);
|
|
|
bool releaseWeightsWhenPrepacking() const;
|
|
|
|
|
|
void setDisplayVmapFallbackWarnings(bool enabled);
|
|
|
bool areVmapFallbackWarningsEnabled() const;
|
|
|
|
|
|
bool isDefaultMobileCPUAllocatorSet();
|
|
|
void setDefaultMobileCPUAllocator();
|
|
|
void unsetDefaultMobileCPUAllocator();
|
|
|
bool allowFP16ReductionCPU() const;
|
|
|
void setAllowFP16ReductionCPU(bool);
|
|
|
|
|
|
|
|
|
void lazyInitCUDA() {
|
|
|
TORCH_WARN_DEPRECATION(
|
|
|
"lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.")
|
|
|
lazyInitDevice(at::kCUDA);
|
|
|
}
|
|
|
void lazyInitHIP() {
|
|
|
TORCH_WARN_DEPRECATION(
|
|
|
"lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.")
|
|
|
lazyInitDevice(at::kHIP);
|
|
|
}
|
|
|
void lazyInitXPU() {
|
|
|
TORCH_WARN_DEPRECATION(
|
|
|
"lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.")
|
|
|
lazyInitDevice(at::kXPU);
|
|
|
}
|
|
|
void lazyInitMTIA() {
|
|
|
TORCH_WARN_DEPRECATION(
|
|
|
"lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.")
|
|
|
lazyInitDevice(at::kMTIA);
|
|
|
}
|
|
|
void lazyInitPrivateUse1() {
|
|
|
TORCH_WARN_DEPRECATION(
|
|
|
"lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.")
|
|
|
lazyInitDevice(at::kPrivateUse1);
|
|
|
}
|
|
|
|
|
|
private:
|
|
|
static bool checkCuBLASConfigDeterministic();
|
|
|
std::array<c10::once_flag, at::COMPILE_TIME_MAX_DEVICE_TYPES> init_;
|
|
|
bool enabled_cudnn = true;
|
|
|
bool deterministic_cudnn = false;
|
|
|
bool deterministic_mkldnn = false;
|
|
|
bool _deterministic_algorithms = false;
|
|
|
bool _deterministic_algorithms_warn_only = false;
|
|
|
bool _deterministic_fill_uninitialized_memory = true;
|
|
|
std::array<at::SDPBackend, at::num_sdp_backends> sdp_priority_order = {
|
|
|
at::SDPBackend::flash_attention,
|
|
|
at::SDPBackend::efficient_attention,
|
|
|
at::SDPBackend::math,
|
|
|
at::SDPBackend::cudnn_attention};
|
|
|
bool enabled_flashSDP = true;
|
|
|
bool enabled_mem_efficientSDP = true;
|
|
|
bool enabled_mathSDP = true;
|
|
|
bool enabled_cudnnSDP = true;
|
|
|
bool enabled_overrideable = true;
|
|
|
bool allow_fp16_bf16_reduction_mathSDP = false;
|
|
|
bool benchmark_cudnn = false;
|
|
|
Float32MatmulPrecision float32_matmul_precision =
|
|
|
c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true
|
|
|
? at::Float32MatmulPrecision::HIGH
|
|
|
: at::Float32MatmulPrecision::HIGHEST;
|
|
|
int benchmark_limit_cudnn = 10;
|
|
|
bool allow_tf32_cudnn = true;
|
|
|
bool allow_fp16_reduction_cublas = true;
|
|
|
bool allow_bf16_reduction_cublas = true;
|
|
|
bool allow_fp16_accumulation_cublas = false;
|
|
|
std::optional<int32_t> sm_carveout = std::nullopt;
|
|
|
bool enabled_mkldnn = true;
|
|
|
bool allow_tf32_onednn = false;
|
|
|
bool enabled_nnpack = true;
|
|
|
at::LinalgBackend linalg_preferred_backend =
|
|
|
(c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true ||
|
|
|
c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true)
|
|
|
? at::LinalgBackend::Cusolver
|
|
|
: at::LinalgBackend::Default;
|
|
|
at::BlasBackend blas_preferred_backend =
|
|
|
(c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true ||
|
|
|
c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true)
|
|
|
? at::BlasBackend::Cublaslt
|
|
|
: at::BlasBackend::Default;
|
|
|
at::ROCmFABackend rocm_fa_preferred_backend =
|
|
|
c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true
|
|
|
? at::ROCmFABackend::Ck
|
|
|
: at::ROCmFABackend::Default;
|
|
|
#ifdef C10_MOBILE
|
|
|
bool release_original_weights = true;
|
|
|
#else
|
|
|
bool release_original_weights = false;
|
|
|
#endif
|
|
|
bool display_vmap_fallback_warnings_ = false;
|
|
|
std::optional<at::QEngine> quantized_engine = std::nullopt;
|
|
|
bool enable_sparse_tensor_invariant_checks = false;
|
|
|
bool allow_fp16_reduction_cpu = false;
|
|
|
|
|
|
Allocator* prev_allocator_ptr_{nullptr};
|
|
|
};
|
|
|
|
|
|
TORCH_API Context& globalContext();
|
|
|
|
|
|
inline void init() {
|
|
|
globalContext();
|
|
|
}
|
|
|
|
|
|
TORCH_API Allocator* getCPUAllocator();
|
|
|
|
|
|
inline DeprecatedTypeProperties& getDeprecatedTypeProperties(
|
|
|
Backend p,
|
|
|
ScalarType s) {
|
|
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
|
|
p, s);
|
|
|
}
|
|
|
|
|
|
inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
|
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
|
|
Backend::CPU, s);
|
|
|
}
|
|
|
|
|
|
inline DeprecatedTypeProperties& CUDA(ScalarType s) {
|
|
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
|
|
Backend::CUDA, s);
|
|
|
}
|
|
|
|
|
|
inline DeprecatedTypeProperties& HIP(ScalarType s) {
|
|
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
|
|
Backend::HIP, s);
|
|
|
}
|
|
|
|
|
|
inline DeprecatedTypeProperties& MPS(ScalarType s) {
|
|
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
|
|
Backend::MPS, s);
|
|
|
}
|
|
|
|
|
|
inline bool hasCUDA() {
|
|
|
return globalContext().hasCUDA();
|
|
|
}
|
|
|
|
|
|
inline bool hasMTIA() {
|
|
|
return globalContext().hasMTIA();
|
|
|
}
|
|
|
|
|
|
inline bool hasHIP() {
|
|
|
return globalContext().hasHIP();
|
|
|
}
|
|
|
|
|
|
inline bool hasIPU() {
|
|
|
return globalContext().hasIPU();
|
|
|
}
|
|
|
|
|
|
inline bool hasXLA() {
|
|
|
return globalContext().hasXLA();
|
|
|
}
|
|
|
|
|
|
inline bool hasMPS() {
|
|
|
return globalContext().hasMPS();
|
|
|
}
|
|
|
|
|
|
inline bool hasMAIA() {
|
|
|
return globalContext().hasMAIA();
|
|
|
}
|
|
|
|
|
|
inline bool hasXPU() {
|
|
|
return globalContext().hasXPU();
|
|
|
}
|
|
|
|
|
|
inline bool hasHPU() {
|
|
|
return globalContext().hasHPU();
|
|
|
}
|
|
|
|
|
|
|
|
|
inline size_t getNumGPUs() {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (hasCUDA() && hasHIP()) {
|
|
|
TORCH_CHECK(
|
|
|
false,
|
|
|
"Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades "
|
|
|
"to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually "
|
|
|
"means HIP. Rebuild PyTorch with one or the other disabled.");
|
|
|
} else if (hasCUDA()) {
|
|
|
return detail::getCUDAHooks().deviceCount();
|
|
|
} else if (hasHIP()) {
|
|
|
return detail::getHIPHooks().getNumGPUs();
|
|
|
} else {
|
|
|
return 0;
|
|
|
}
|
|
|
}
|
|
|
|
|
|
inline bool hasOpenMP() {
|
|
|
return globalContext().hasOpenMP();
|
|
|
}
|
|
|
|
|
|
inline bool hasMKL() {
|
|
|
return globalContext().hasMKL();
|
|
|
}
|
|
|
|
|
|
inline bool hasKleidiAI() {
|
|
|
return globalContext().hasKleidiAI();
|
|
|
}
|
|
|
|
|
|
inline bool hasLAPACK() {
|
|
|
return globalContext().hasLAPACK();
|
|
|
}
|
|
|
|
|
|
inline bool hasMAGMA() {
|
|
|
return globalContext().hasMAGMA();
|
|
|
}
|
|
|
|
|
|
inline bool hasMKLDNN() {
|
|
|
return globalContext().hasMKLDNN();
|
|
|
}
|
|
|
|
|
|
inline void manual_seed(uint64_t seed) {
|
|
|
{
|
|
|
auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU);
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(gen.mutex());
|
|
|
gen.set_current_seed(seed);
|
|
|
}
|
|
|
|
|
|
const auto opt_device_type = at::getAccelerator();
|
|
|
if (!opt_device_type.has_value()) {
|
|
|
return;
|
|
|
}
|
|
|
const auto num_gpus = globalContext()
|
|
|
.getAcceleratorHooksInterface(opt_device_type)
|
|
|
.deviceCount();
|
|
|
for (const auto i : c10::irange(num_gpus)) {
|
|
|
auto gen = globalContext().defaultGenerator(
|
|
|
Device(opt_device_type.value(), static_cast<c10::DeviceIndex>(i)));
|
|
|
{
|
|
|
|
|
|
std::lock_guard<std::mutex> lock(gen.mutex());
|
|
|
gen.set_current_seed(seed);
|
|
|
}
|
|
|
}
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
struct TORCH_API NoTF32Guard {
|
|
|
NoTF32Guard();
|
|
|
NoTF32Guard(NoTF32Guard&& other) = delete;
|
|
|
NoTF32Guard(const NoTF32Guard&) = delete;
|
|
|
NoTF32Guard& operator=(const NoTF32Guard&) = delete;
|
|
|
NoTF32Guard& operator=(NoTF32Guard&&) = delete;
|
|
|
~NoTF32Guard();
|
|
|
static bool should_disable_tf32();
|
|
|
|
|
|
private:
|
|
|
bool changed = false;
|
|
|
};
|
|
|
|
|
|
struct TORCH_API ROCmBackwardPassGuard {
|
|
|
ROCmBackwardPassGuard();
|
|
|
ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;
|
|
|
ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete;
|
|
|
ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete;
|
|
|
ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete;
|
|
|
~ROCmBackwardPassGuard();
|
|
|
static bool is_backward_pass();
|
|
|
};
|
|
|
|
|
|
}
|
|
|
|