| |
| |
|
|
| #pragma once |
|
|
| #include "ck/ck.hpp" |
|
|
| namespace ck { |
|
|
| |
| |
| template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<float, T>{}, bool> = false> |
| __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) |
| { |
| uint32_t x = *(reinterpret_cast<uint32_t*>(&val)); |
| uint32_t drop_bits = uint32_t(x) & 0xFFFFu; |
| drop_bits ^= x >> 16; |
| drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); |
| drop_bits *= 0x7000149; |
| |
| |
| |
| uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); |
| return rng; |
| } |
|
|
| |
| template <typename T, uint32_t seed_t, std::enable_if_t<std::is_same<_Float16, T>{}, bool> = false> |
| __host__ __device__ uint32_t prand_generator(index_t id, T val, uint32_t seed = seed_t) |
| { |
| uint16_t x = *(reinterpret_cast<uint16_t*>(&val)); |
| uint32_t drop_bits = uint32_t(x) & 0xFFFFu; |
| drop_bits = ((drop_bits & 31) << 11) | (drop_bits >> 5); |
| drop_bits *= 0x7000149; |
| |
| |
| |
| uint32_t rng = (drop_bits ^ 0x13371337 ^ (id * 229791) ^ seed); |
| return rng; |
| } |
|
|
| |
| template < |
| typename T, |
| uint32_t seed_t, |
| std::enable_if_t<!(std::is_same<float, T>{} || std::is_same<_Float16, T>{}), bool> = false> |
| __host__ __device__ uint32_t prand_generator(int id, T val, uint32_t seed = seed_t) |
| { |
| std::ignore = id; |
| std::ignore = val; |
| std::ignore = seed; |
|
|
| return 0; |
| } |
|
|
| } |
|
|