File size: 2,295 Bytes
d1d4335
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
// Philox Counter based RNG implementation for Metal
// Borrowed from aten/src/ATen/core/PhiloxRNGEngine.h
// Which in turn borrowed from
// http://www.thesalmons.org/john/random123/papers/random123sc11.pdf
#pragma once
#include <metal_stdlib>

namespace c10 {
namespace metal {

namespace detail {

constexpr float uint32_to_uniform_float(uint32_t value) {
  // maximum value such that `MAX_INT * scale < 1.0` (with float rounding)
  constexpr float scale = 4.6566127342e-10;
  return static_cast<float>(value & 0x7FFFFFFF) * scale;
}

inline uint2 splitlong(ulong v) {
  return uint2(v >> 32, v & 0xffffffff);
}

} // namespace detail

namespace philox4 {

uint2 mulhilo(uint a, uint b) {
  auto rc = static_cast<ulong>(a) * b;
  return detail::splitlong(rc);
}
uint4 single_round(uint4 ctr, uint2 key) {
  constexpr uint kPhiloxSA = 0xD2511F53;
  constexpr uint kPhiloxSB = 0xCD9E8D57;
  auto rc0 = mulhilo(kPhiloxSA, ctr.x);
  auto rc1 = mulhilo(kPhiloxSB, ctr.z);
  return uint4(rc1.y ^ ctr.y ^ key.x, rc1.x, rc0.y ^ ctr.w ^ key.y, rc0.x);
}

uint4 multiple_rounds(uint4 ctr, uint2 key, uint rounds) {
  constexpr uint2 kPhilox10 = {0x9E3779B9, 0xBB67AE85};
  for (uint round = 0; round < rounds - 1; ++round) {
    ctr = single_round(ctr, key);
    key += kPhilox10;
  }
  return ctr;
}

uint4 rand(long seed, long index) {
  uint4 ctr = 0;
  ctr.zw = detail::splitlong(index);
  return multiple_rounds(ctr, detail::splitlong(seed), 10);
}

} // namespace philox4

float randn(long seed, long index) {
  auto value = philox4::rand(seed, index);
  float u1 = 1.0 - detail::uint32_to_uniform_float(value.x);
  float u2 = 1.0 - detail::uint32_to_uniform_float(value.y);
  return ::metal::sqrt(-2.0 * ::metal::log(u1)) *
      ::metal::cos(2.0 * M_PI_F * u2);
}

float rand(long seed, long index) {
  auto value = philox4::rand(seed, index);
  return detail::uint32_to_uniform_float(value.x);
}

long randint64(long seed, long index, long low, long high) {
  auto range = high - low;
  auto value = philox4::rand(seed, index);
  // TODO: Implement better algorithm for large ranges
  return low +
      static_cast<long>(detail::uint32_to_uniform_float(value.x) * range);
}

} // namespace metal
} // namespace c10