|
|
#pragma once |
|
|
|
|
|
#include <ATen/core/Array.h> |
|
|
#include <ATen/core/TransformationHelper.h> |
|
|
#include <c10/util/Half.h> |
|
|
#include <c10/util/BFloat16.h> |
|
|
#include <c10/util/MathConstants.h> |
|
|
#include <c10/util/Optional.h> |
|
|
#include <c10/macros/Macros.h> |
|
|
|
|
|
#include <type_traits> |
|
|
#include <limits> |
|
|
#include <cmath> |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
namespace at { |
|
|
namespace { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct uniform_int_from_to_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) { |
|
|
range_ = range; |
|
|
base_ = base; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
if (( |
|
|
std::is_same<T, int64_t>::value || |
|
|
std::is_same<T, double>::value || |
|
|
std::is_same<T, float>::value || |
|
|
std::is_same<T, at::BFloat16>::value) && range_ >= 1ULL << 32) |
|
|
{ |
|
|
return transformation::uniform_int_from_to<T>(generator->random64(), range_, base_); |
|
|
} else { |
|
|
return transformation::uniform_int_from_to<T>(generator->random(), range_, base_); |
|
|
} |
|
|
} |
|
|
|
|
|
private: |
|
|
uint64_t range_; |
|
|
int64_t base_; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct uniform_int_full_range_distribution { |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
return transformation::uniform_int_full_range<T>(generator->random64()); |
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct uniform_int_distribution { |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
if (std::is_same<T, double>::value || std::is_same<T, int64_t>::value) { |
|
|
return transformation::uniform_int<T>(generator->random64()); |
|
|
} else { |
|
|
return transformation::uniform_int<T>(generator->random()); |
|
|
} |
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct uniform_real_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) { |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(from <= to); |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits<T>::max()); |
|
|
from_ = from; |
|
|
to_ = to; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){ |
|
|
if(std::is_same<T, double>::value) { |
|
|
return transformation::uniform_real<T>(generator->random64(), from_, to_); |
|
|
} else { |
|
|
return transformation::uniform_real<T>(generator->random(), from_, to_); |
|
|
} |
|
|
} |
|
|
|
|
|
private: |
|
|
T from_; |
|
|
T to_; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \ |
|
|
template <typename T> \ |
|
|
struct has_member_##member \ |
|
|
{ \ |
|
|
typedef char yes; \ |
|
|
typedef long no; \ |
|
|
template <typename U> static yes test(decltype(&U::member)); \ |
|
|
template <typename U> static no test(...); \ |
|
|
static constexpr bool value = sizeof(test<T>(0)) == sizeof(yes); \ |
|
|
} |
|
|
|
|
|
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample); |
|
|
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample); |
|
|
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample); |
|
|
DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample); |
|
|
|
|
|
#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \ |
|
|
\ |
|
|
template <typename RNG, typename ret_type, \ |
|
|
typename std::enable_if_t<( \ |
|
|
has_member_next_##TYPE##_normal_sample<RNG>::value && \ |
|
|
has_member_set_next_##TYPE##_normal_sample<RNG>::value \ |
|
|
), int> = 0> \ |
|
|
C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \ |
|
|
if (generator->next_##TYPE##_normal_sample()) { \ |
|
|
*ret = *(generator->next_##TYPE##_normal_sample()); \ |
|
|
generator->set_next_##TYPE##_normal_sample(c10::optional<TYPE>()); \ |
|
|
return true; \ |
|
|
} \ |
|
|
return false; \ |
|
|
} \ |
|
|
\ |
|
|
template <typename RNG, typename ret_type, \ |
|
|
typename std::enable_if_t<( \ |
|
|
!has_member_next_##TYPE##_normal_sample<RNG>::value || \ |
|
|
!has_member_set_next_##TYPE##_normal_sample<RNG>::value \ |
|
|
), int> = 0> \ |
|
|
C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* , ret_type* ) { \ |
|
|
return false; \ |
|
|
} \ |
|
|
\ |
|
|
template <typename RNG, typename ret_type, \ |
|
|
typename std::enable_if_t<( \ |
|
|
has_member_set_next_##TYPE##_normal_sample<RNG>::value \ |
|
|
), int> = 0> \ |
|
|
C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \ |
|
|
generator->set_next_##TYPE##_normal_sample(cache); \ |
|
|
} \ |
|
|
\ |
|
|
template <typename RNG, typename ret_type, \ |
|
|
typename std::enable_if_t<( \ |
|
|
!has_member_set_next_##TYPE##_normal_sample<RNG>::value \ |
|
|
), int> = 0> \ |
|
|
C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* , ret_type ) { \ |
|
|
} |
|
|
|
|
|
DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double); |
|
|
DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct normal_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) { |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); |
|
|
mean = mean_in; |
|
|
stdv = stdv_in; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline dist_acctype<T> operator()(RNG generator){ |
|
|
dist_acctype<T> ret; |
|
|
|
|
|
if (std::is_same<T, double>::value) { |
|
|
if (maybe_get_next_double_normal_sample(generator, &ret)) { |
|
|
return transformation::normal(ret, mean, stdv); |
|
|
} |
|
|
} else { |
|
|
if (maybe_get_next_float_normal_sample(generator, &ret)) { |
|
|
return transformation::normal(ret, mean, stdv); |
|
|
} |
|
|
} |
|
|
|
|
|
uniform_real_distribution<T> uniform(0.0, 1.0); |
|
|
const dist_acctype<T> u1 = uniform(generator); |
|
|
const dist_acctype<T> u2 = uniform(generator); |
|
|
const dist_acctype<T> r = ::sqrt(static_cast<T>(-2.0) * ::log(static_cast<T>(1.0)-u2)); |
|
|
const dist_acctype<T> theta = static_cast<T>(2.0) * c10::pi<T> * u1; |
|
|
if (std::is_same<T, double>::value) { |
|
|
maybe_set_next_double_normal_sample(generator, r * ::sin(theta)); |
|
|
} else { |
|
|
maybe_set_next_float_normal_sample(generator, r * ::sin(theta)); |
|
|
} |
|
|
ret = r * ::cos(theta); |
|
|
return transformation::normal(ret, mean, stdv); |
|
|
} |
|
|
|
|
|
private: |
|
|
T mean; |
|
|
T stdv; |
|
|
}; |
|
|
|
|
|
template <typename T> |
|
|
struct DiscreteDistributionType { using type = float; }; |
|
|
|
|
|
template <> struct DiscreteDistributionType<double> { using type = double; }; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct bernoulli_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline bernoulli_distribution(T p_in) { |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1); |
|
|
p = p_in; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
uniform_real_distribution<T> uniform(0.0, 1.0); |
|
|
return transformation::bernoulli<T>(uniform(generator), p); |
|
|
} |
|
|
|
|
|
private: |
|
|
T p; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct geometric_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline geometric_distribution(T p_in) { |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1); |
|
|
p = p_in; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
uniform_real_distribution<T> uniform(0.0, 1.0); |
|
|
return transformation::geometric<T>(uniform(generator), p); |
|
|
} |
|
|
|
|
|
private: |
|
|
T p; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct exponential_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline exponential_distribution(T lambda_in) { |
|
|
lambda = lambda_in; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
uniform_real_distribution<T> uniform(0.0, 1.0); |
|
|
return transformation::exponential<T>(uniform(generator), lambda); |
|
|
} |
|
|
|
|
|
private: |
|
|
T lambda; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct cauchy_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) { |
|
|
median = median_in; |
|
|
sigma = sigma_in; |
|
|
} |
|
|
|
|
|
template <typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator) { |
|
|
uniform_real_distribution<T> uniform(0.0, 1.0); |
|
|
return transformation::cauchy<T>(uniform(generator), median, sigma); |
|
|
} |
|
|
|
|
|
private: |
|
|
T median; |
|
|
T sigma; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
struct lognormal_distribution { |
|
|
|
|
|
C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) { |
|
|
TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); |
|
|
mean = mean_in; |
|
|
stdv = stdv_in; |
|
|
} |
|
|
|
|
|
template<typename RNG> |
|
|
C10_HOST_DEVICE inline T operator()(RNG generator){ |
|
|
normal_distribution<T> normal(mean, stdv); |
|
|
return transformation::log_normal<T>(normal(generator)); |
|
|
} |
|
|
|
|
|
private: |
|
|
T mean; |
|
|
T stdv; |
|
|
}; |
|
|
} |
|
|
} |
|
|
|