File size: 6,262 Bytes
c1af2fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#pragma once

#include <ATen/Context.h>
#include <ATen/core/Generator.h>
#include <ATen/core/TensorBase.h>
#include <ATen/cuda/PhiloxCudaState.h>
#include <atomic>
#include <memory>
#include <unordered_set>
namespace at {

namespace cuda {
struct CUDAGraph;
}

/**

 * Note [CUDA Graph-safe RNG states]

 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

 *

 * Strategy:

 * ~~~~~~~~~

 * (It helps to look at

 * cuda/detail/PhiloxCudaStateRaw.cuh and

 * cuda/detail/UnpackRaw.cuh

 * while you read this.)

 *

 * A CUDA graph containing multiple RNG ops behaves like a

 * single giant kernel from the perspective of ops external

 * to the graph.  During graph capture, logic in CUDAGeneratorImpl

 * records the total of all offset increments that occur in the

 * graphed region, and records the final total as the offset for

 * the entire graph.

 *

 * When the graph reruns, the logic that reruns it

 * increments this device's CUDA generator's offset

 * by that total.

 *

 * Meanwhile, within the graph, at capture time, instead of

 * populating PhiloxCudaStates with the uint64_t offset pulled

 * directly from the global state, PhiloxCudaState uses a pointer

 * to a one-element stream-local int64_t device tensor

 * holding an initial offset value, and a uint64_t holding an

 * intra-graph offset. (The intra-graph offset starts from zero

 * when capture begins.)  In each consumer kernel,

 * at::cuda::philox::unpack computes the offset to use for this kernel

 * as intra-graph offset + *initial offset.

 *

 * When the graph reruns, the logic that reruns it first

 * fill_s the initial offset tensor with this device's

 * CUDA generator's current offset.

 *

 * The control flow above ensures graphed execution is bitwise

 * identical to eager execution as long as RNG ops are enqueued

 * from a single thread, even if RNG ops and graphs containing

 * RNG ops are enqueued and run simultaneously on multiple streams.

 *

 * Usage:

 * ~~~~~~

 * PhiloxCudaState in this file, and unpack() in

 * cuda/CUDAGraphsUtils.cuh allow non-divergent use of

 * CUDAGeneratorImpl whether graph capture is underway or not.

 *

 * Each PhiloxCudaState instance should be used for one and only one

 * consumer kernel.

 *

 * Example (see e.g. native/cuda/Dropout.cu):

 *

 * #include <ATen/cuda/CUDAGeneratorImpl.h>

 * #include <ATen/cuda/CUDAGraphsUtils.cuh>

 *

 * __global__ void kernel(..., PhiloxCudaState philox_args) {

 *   auto seeds = at::cuda::philox::unpack(philox_args);

 *   IndexType idx = blockIdx.x * blockDim.x + threadIdx.x;

 *   curandStatePhilox4_32_10_t state;

 *   curand_init(std::get<0>(seeds), // seed

 *               idx,                // per-thread subsequence

 *               std::get<1>(seeds), // offset in subsequence

 *               &state);

 *   ...

 * }

 *

 * host_caller(...) {

 *   PhiloxCudaState rng_engine_inputs;

 *   {

 *     // See Note [Acquire lock when using random generators]

 *     std::lock_guard<std::mutex> lock(gen->mutex_);

 *

 *     // gen could be HostState or DevState here! No divergent code needed!

 *     rng_engine_inputs = gen->philox_cuda_state(offset_increment);

 *   }

 *   kernel<<<...>>>(..., rng_engine_inputs);

 * }

 *

 */

struct CUDAGeneratorState : public c10::intrusive_ptr_target {
  uint64_t seed_;
  uint64_t philox_offset_per_thread_;
  uint32_t offset_intragraph_;
  bool capturing_{};
  std::unordered_set<cuda::CUDAGraph*> registered_graphs_;
  at::TensorBase seed_extragraph_{};
  at::TensorBase offset_extragraph_{};

  CUDAGeneratorState(
      uint64_t seed = default_rng_seed_val,
      uint64_t philox_offset_per_thread = 0,
      uint32_t offset_intragraph = 0)
      : seed_(seed),
        philox_offset_per_thread_(philox_offset_per_thread),
        offset_intragraph_(offset_intragraph) {}

  void increase(uint64_t increment);

  void register_graph(cuda::CUDAGraph* graph);
  void unregister_graph(cuda::CUDAGraph* graph);

  void capture_prologue();
  // capture_epilogue returns the wholegraph_increment
  uint64_t capture_epilogue();
  void replay_prologue(uint64_t wholegraph_increment);
  c10::intrusive_ptr<CUDAGeneratorState> clone();
};

struct TORCH_CUDA_CPP_API CUDAGeneratorImpl : public c10::GeneratorImpl {
  // Constructors
  CUDAGeneratorImpl(DeviceIndex device_index = -1);
  CUDAGeneratorImpl(
      DeviceIndex device_index,
      c10::intrusive_ptr<CUDAGeneratorState> state_);
  ~CUDAGeneratorImpl() override = default;

  // CUDAGeneratorImpl methods
  std::shared_ptr<CUDAGeneratorImpl> clone() const;
  void set_current_seed(uint64_t seed) override;
  void set_offset(uint64_t offset) override;
  uint64_t get_offset() const override;
  uint64_t current_seed() const override;
  uint64_t seed() override;
  void set_state(const c10::TensorImpl& new_state) override;
  c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
  void graphsafe_set_state(

      const c10::intrusive_ptr<GeneratorImpl>& state) override;
  c10::intrusive_ptr<c10::GeneratorImpl> graphsafe_get_state() const override;

  void set_philox_offset_per_thread(uint64_t offset);
  uint64_t philox_offset_per_thread() const;

  void register_graph(cuda::CUDAGraph* graph);
  void unregister_graph(cuda::CUDAGraph* graph);

  // Generates a PhiloxCudaState with a specified increment, and increment
  // current state
  PhiloxCudaState philox_cuda_state(uint64_t increment);

  bool reset_rnn_state() {
    return !no_reset_rnn_state_.test_and_set();
  }

  // Temporarily accommodates call sites that use philox_engine_inputs.
  // Allows incremental refactor of call sites to use philox_cuda_state.
  std::pair<uint64_t, uint64_t> philox_engine_inputs(uint64_t increment);

  static c10::DeviceType device_type();

 private:
  CUDAGeneratorImpl* clone_impl() const override;

  c10::intrusive_ptr<CUDAGeneratorState> state_;
  std::atomic_flag no_reset_rnn_state_{};
};

namespace cuda::detail {

TORCH_CUDA_CPP_API const Generator& getDefaultCUDAGenerator(

    DeviceIndex device_index = -1);
TORCH_CUDA_CPP_API Generator createCUDAGenerator(DeviceIndex device_index = -1);

} // namespace cuda::detail
} // namespace at