|
|
#pragma once |
|
|
|
|
|
|
|
|
#ifdef _MSC_VER |
|
|
#define _USE_MATH_DEFINES |
|
|
#include <math.h> |
|
|
#endif |
|
|
|
|
|
#include <stdint.h> |
|
|
|
|
|
#ifdef __CUDACC__ |
|
|
#include <cuda.h> |
|
|
#endif |
|
|
|
|
|
#include <ATen/core/Array.h> |
|
|
#include <c10/macros/Macros.h> |
|
|
#include <c10/util/Exception.h> |
|
|
#include <c10/util/Half.h> |
|
|
#include <cmath> |
|
|
|
|
|
namespace at { |
|
|
|
|
|
|
|
|
namespace detail { |
|
|
|
|
|
typedef at::detail::Array<uint32_t, 4> UINT4; |
|
|
typedef at::detail::Array<uint32_t, 2> UINT2; |
|
|
typedef at::detail::Array<double, 2> DOUBLE2; |
|
|
typedef at::detail::Array<float, 2> FLOAT2; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class philox_engine { |
|
|
public: |
|
|
|
|
|
C10_HOST_DEVICE inline explicit philox_engine(uint64_t seed = 67280421310721, |
|
|
uint64_t subsequence = 0, |
|
|
uint64_t offset = 0) { |
|
|
|
|
|
reset_state(seed, subsequence); |
|
|
incr_n(offset); |
|
|
} |
|
|
|
|
|
C10_HOST_DEVICE inline void reset_state(uint64_t seed = 67280421310721, |
|
|
uint64_t subsequence = 0) { |
|
|
key_[0] = static_cast<uint32_t>(seed); |
|
|
key_[1] = static_cast<uint32_t>(seed >> 32); |
|
|
counter_ = detail::UINT4(0); |
|
|
counter_[2] = static_cast<uint32_t>(subsequence); |
|
|
counter_[3] = static_cast<uint32_t>(subsequence >> 32); |
|
|
STATE = 0; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_HOST_DEVICE inline uint32_t operator()(int32_t n_rounds = 10) { |
|
|
if(STATE == 0) { |
|
|
detail::UINT4 counter = counter_; |
|
|
detail::UINT2 key = key_; |
|
|
output_ = rand(counter, key, n_rounds); |
|
|
incr(); |
|
|
} |
|
|
uint32_t ret = output_[STATE]; |
|
|
STATE = (STATE + 1) & 3; |
|
|
return ret; |
|
|
} |
|
|
|
|
|
inline float randn(uint32_t n_rounds) { |
|
|
#ifdef __CUDA_ARCH__ |
|
|
AT_ASSERT(false, "Unsupported invocation of randn on CUDA"); |
|
|
#endif |
|
|
reset_state(); |
|
|
detail::UINT4 counter = counter_; |
|
|
detail::UINT2 key = key_; |
|
|
detail::UINT4 i = rand(counter, key, n_rounds); |
|
|
detail::FLOAT2 prenorm; |
|
|
prenorm[0] = 1 - uint32_to_uniform_float(i[0]); |
|
|
prenorm[1] = 1 - uint32_to_uniform_float(i[1]); |
|
|
detail::FLOAT2 ret = normalize_pair_uniform(prenorm); |
|
|
return ret[0]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_HOST_DEVICE inline void incr_n(uint64_t n) { |
|
|
uint32_t nlo = static_cast<uint32_t>(n); |
|
|
uint32_t nhi = static_cast<uint32_t>(n >> 32); |
|
|
counter_[0] += nlo; |
|
|
|
|
|
if (counter_[0] < nlo) { |
|
|
nhi++; |
|
|
|
|
|
|
|
|
|
|
|
counter_[1] += nhi; |
|
|
if(nhi != 0) { |
|
|
if (nhi <= counter_[1]) { |
|
|
return; |
|
|
} |
|
|
} |
|
|
} else { |
|
|
|
|
|
|
|
|
|
|
|
counter_[1] += nhi; |
|
|
if (nhi <= counter_[1]) { |
|
|
return; |
|
|
} |
|
|
} |
|
|
if (++counter_[2]) |
|
|
return; |
|
|
++counter_[3]; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
C10_HOST_DEVICE inline void incr() { |
|
|
if (++counter_[0]) |
|
|
return; |
|
|
if (++counter_[1]) |
|
|
return; |
|
|
if (++counter_[2]) { |
|
|
return; |
|
|
} |
|
|
++counter_[3]; |
|
|
} |
|
|
|
|
|
private: |
|
|
detail::UINT4 counter_; |
|
|
detail::UINT4 output_; |
|
|
detail::UINT2 key_; |
|
|
uint32_t STATE; |
|
|
|
|
|
C10_HOST_DEVICE inline uint32_t mulhilo32(uint32_t a, uint32_t b, |
|
|
uint32_t *result_high) { |
|
|
#ifdef __CUDA_ARCH__ |
|
|
*result_high = __umulhi(a, b); |
|
|
return a*b; |
|
|
#else |
|
|
const uint64_t product = static_cast<uint64_t>(a) * b; |
|
|
*result_high = static_cast<uint32_t>(product >> 32); |
|
|
return static_cast<uint32_t>(product); |
|
|
#endif |
|
|
} |
|
|
|
|
|
C10_HOST_DEVICE inline detail::UINT4 single_round(detail::UINT4 ctr, detail::UINT2 in_key) { |
|
|
uint32_t hi0; |
|
|
uint32_t hi1; |
|
|
uint32_t lo0 = mulhilo32(kPhiloxSA, ctr[0], &hi0); |
|
|
uint32_t lo1 = mulhilo32(kPhiloxSB, ctr[2], &hi1); |
|
|
detail::UINT4 ret; |
|
|
ret[0] = hi1 ^ ctr[1] ^ in_key[0]; |
|
|
ret[1] = lo1; |
|
|
ret[2] = hi0 ^ ctr[3] ^ in_key[1]; |
|
|
ret[3] = lo0; |
|
|
return ret; |
|
|
} |
|
|
|
|
|
C10_HOST_DEVICE constexpr float uint32_to_uniform_float(uint32_t value) { |
|
|
|
|
|
constexpr float scale = 4.6566127342e-10; |
|
|
return static_cast<float>(value & 0x7FFFFFFF) * scale; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
C10_HOST_DEVICE inline detail::UINT4 rand(detail::UINT4& counter, detail::UINT2& key, uint32_t n_rounds) { |
|
|
for (uint32_t round = 0; round < (n_rounds - 1); round++) { |
|
|
counter = single_round(counter, key); |
|
|
key[0] += (kPhilox10A); key[1] += (kPhilox10B); |
|
|
} |
|
|
return single_round(counter, key); |
|
|
} |
|
|
|
|
|
inline detail::FLOAT2 normalize_pair_uniform(detail::FLOAT2 in) { |
|
|
|
|
|
float u1 = in[0]; |
|
|
|
|
|
constexpr float two_pi = 2.0 * M_PI; |
|
|
|
|
|
float mag = std::sqrt(-2.0 * std::log(u1)); |
|
|
|
|
|
detail::FLOAT2 ret; |
|
|
|
|
|
ret[0] = mag * std::cos(two_pi); |
|
|
ret[1] = mag * std::sin(two_pi); |
|
|
return ret; |
|
|
} |
|
|
|
|
|
|
|
|
static const uint32_t kPhilox10A = 0x9E3779B9; |
|
|
static const uint32_t kPhilox10B = 0xBB67AE85; |
|
|
static const uint32_t kPhiloxSA = 0xD2511F53; |
|
|
static const uint32_t kPhiloxSB = 0xCD9E8D57; |
|
|
}; |
|
|
|
|
|
typedef philox_engine Philox4_32; |
|
|
|
|
|
} |
|
|
|