File size: 8,723 Bytes
281d8ba
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
// csrc/moe.cpp

#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAStream.h>
#include <torch/torch.h>

// Forward declarations for existing functions
void sort_cuda(torch::Tensor x,
               int64_t end_bit,
               torch::Tensor x_out,
               torch::Tensor iota_out);

void bincount_cumsum_cuda(torch::Tensor input,
                          torch::Tensor &output,
                          int64_t minlength);

torch::Tensor index_select_out_cuda(torch::Tensor out,
                                    torch::Tensor in,
                                    torch::Tensor idx_int32);

void gather_cuda(torch::Tensor const &x,
                 torch::Tensor const &indices,
                 torch::Tensor const &bins,
                 torch::Tensor &output,
                 int64_t E,
                 int64_t C,
                 int64_t top_k);

void scatter_cuda(torch::Tensor const &src,
                  torch::Tensor const &indices,
                  torch::Tensor const &bins,
                  torch::Tensor const &weights,
                  torch::Tensor &y,
                  int64_t T,
                  int64_t E,
                  int64_t C,
                  int64_t top_k);

torch::Tensor batch_mm(torch::Tensor x,
                       torch::Tensor weights,
                       torch::Tensor batch_sizes,
                       torch::Tensor output,
                       bool trans_b = false);

torch::Tensor experts_cuda(
    torch::Tensor hidden_states,     // [B*S, H] - flattened hidden states
    torch::Tensor router_indices,    // [B*S, K] - expert indices per token
    torch::Tensor routing_weights,   // [B*S, E] or [B*S, K] - routing weights
    torch::Tensor gate_up_proj,      // [E, H, 2*H] - gate/up projection weights
    torch::Tensor gate_up_proj_bias, // [E, 2*H] - gate/up projection bias
    torch::Tensor down_proj,         // [E, H, H] - down projection weights
    torch::Tensor down_proj_bias,    // [E, H] - down projection bias
    int64_t expert_capacity,         // C - capacity per expert
    int64_t num_experts,             // E - number of experts
    int64_t top_k                    // K - top-k routing
) {
  // Input validation
  TORCH_CHECK(hidden_states.is_cuda(), "hidden_states must be on CUDA");
  TORCH_CHECK(router_indices.is_cuda(), "router_indices must be on CUDA");
  TORCH_CHECK(routing_weights.is_cuda(), "routing_weights must be on CUDA");
  TORCH_CHECK(gate_up_proj.is_cuda(), "gate_up_proj must be on CUDA");
  TORCH_CHECK(gate_up_proj_bias.is_cuda(), "gate_up_proj_bias must be on CUDA");
  TORCH_CHECK(down_proj.is_cuda(), "down_proj must be on CUDA");
  TORCH_CHECK(down_proj_bias.is_cuda(), "down_proj_bias must be on CUDA");

  TORCH_CHECK(hidden_states.ndimension() == 2,
              "hidden_states must be 2D [T, H]");
  TORCH_CHECK(router_indices.ndimension() == 2,
              "router_indices must be 2D [T, K]");
  TORCH_CHECK(routing_weights.ndimension() == 2,
              "routing_weights must be 2D [T, K]");
  TORCH_CHECK(gate_up_proj.ndimension() == 3,
              "gate_up_proj must be 3D [E, H, 2*H]");
  TORCH_CHECK(gate_up_proj_bias.ndimension() == 2,
              "gate_up_proj_bias must be 2D [E, 2*H]");
  TORCH_CHECK(down_proj.ndimension() == 3, "down_proj must be 3D [E, H, H]");
  TORCH_CHECK(down_proj_bias.ndimension() == 2,
              "down_proj_bias must be 2D [E, H]");

  const int64_t T = hidden_states.size(0); // Total tokens
  const int64_t H = hidden_states.size(1); // Hidden size
  const int64_t E = num_experts;
  const int64_t C = expert_capacity;
  const int64_t K = top_k;

  TORCH_CHECK(router_indices.size(0) == T && router_indices.size(1) == K);
  TORCH_CHECK(routing_weights.size(0) == T && (routing_weights.size(1) == K ||
                                               routing_weights.size(1) == E),
              "routing_weights must be [T, K] or [T, E]");
  TORCH_CHECK(gate_up_proj.size(0) == E && gate_up_proj.size(1) == H &&
              gate_up_proj.size(2) == 2 * H);
  TORCH_CHECK(gate_up_proj_bias.size(0) == E &&
              gate_up_proj_bias.size(1) == 2 * H);
  TORCH_CHECK(down_proj.size(0) == E && down_proj.size(1) == H &&
              down_proj.size(2) == H);
  TORCH_CHECK(down_proj_bias.size(0) == E && down_proj_bias.size(1) == H);

  // Ensure simple contiguity where helpful
  hidden_states = hidden_states.contiguous();
  router_indices = router_indices.contiguous();
  routing_weights = routing_weights.contiguous();

  // ALLOCATE

  auto device_opts = torch::TensorOptions()
                         .dtype(torch::kInt32)
                         .device(hidden_states.device());
  auto int64_opts = torch::TensorOptions()
                        .dtype(torch::kInt64)
                        .device(hidden_states.device());
  auto float_opts = torch::TensorOptions()
                        .dtype(hidden_states.dtype())
                        .device(hidden_states.device());

  // Buffers for sorting
  torch::Tensor flat_indices =
      router_indices.flatten().to(torch::kInt32, /*non_blocking=*/true);
  torch::Tensor sorted_values = torch::empty_like(flat_indices);
  torch::Tensor sorted_indices = torch::empty_like(flat_indices);

  // Buffer for bins - use int32 for smaller footprint
  torch::Tensor bins =
      torch::empty({E + 1},
                   device_opts); // Pre-allocate for bincount_cumsum result

  // Buffer for gathered tokens
  torch::Tensor x = torch::empty({E, C, H}, float_opts);

  // Buffer for expert token counts
  torch::Tensor expert_tokens = torch::empty({E}, device_opts);

  // Buffers for intermediate results
  torch::Tensor gate_up = torch::empty({E, C, 2 * H}, float_opts);

  // Final output buffer
  torch::Tensor output = torch::zeros_like(hidden_states);

  // COMPUTE

  // Sort tokens by expert
  sort_cuda(flat_indices, 32, sorted_values, sorted_indices);

  // Compute bins using bincount_cumsum
  bincount_cumsum_cuda(sorted_values, bins, E);

  // Gather tokens by expert
  // [T, H] -> [E, C, H]
  gather_cuda(hidden_states, sorted_indices, bins, x, E, C, K);

  if (E > 1) {
    expert_tokens.slice(0, 0, E - 1) =
        bins.slice(0, 1, E) - bins.slice(0, 0, E - 1);
    expert_tokens[E - 1] =
        (int32_t)(flat_indices.size(0) - bins[E - 1].item<int32_t>());
  } else {
    expert_tokens[0] = (int32_t)flat_indices.size(0);
  }
  // Clamp to expert capacity
  expert_tokens = torch::clamp(expert_tokens, 0, (int32_t)C);

  batch_mm(x, gate_up_proj, expert_tokens, gate_up, true);

  // add the gate bias to the output in-place
  gate_up.add_(gate_up_proj_bias.unsqueeze(1));

  // Compute GLU in-place, reusing gate_up buffer for output
  auto gate = gate_up.index({torch::indexing::Ellipsis,
                             torch::indexing::Slice(torch::indexing::None,
                                                    torch::indexing::None,
                                                    2)});
  auto up =
      gate_up.index({torch::indexing::Ellipsis,
                     torch::indexing::Slice(1, torch::indexing::None, 2)});

  const float limit = 7.0f;
  gate = gate.clamp(/*min=*/c10::nullopt, /*max=*/limit);
  up = up.clamp(/*min=*/-limit, /*max=*/limit);

  gate.mul_(torch::sigmoid(gate * 1.702f));
  up.add_(1).mul_(gate);

  // Down projection uses GLU result directly
  gate_up.resize_(0);
  batch_mm(up, down_proj, expert_tokens, gate_up, true);

  // add the down_bias in-place
  gate_up.add_(down_proj_bias.unsqueeze(1));

  // Stage allocations right before use
  torch::Tensor selected_weights = torch::empty({T * K}, float_opts);
  torch::Tensor weights_sorted = torch::empty({T * K}, float_opts);

  torch::Tensor selected_weights_2d =
      selected_weights.view({T, K}); // named lvalue view
  torch::Tensor flat_dense = routing_weights.view({T, E});
  torch::Tensor flat_router = router_indices.view({T, K});

  // gather_out(out&, self, dim, index, sparse_grad=false)
  at::gather_out(selected_weights_2d,
                 flat_dense,
                 /*dim=*/1,
                 flat_router,
                 /*sparse_grad=*/false);

  // Use int32 index select to avoid dtype conversion
  index_select_out_cuda(weights_sorted,                 // [T*K], float_opts
                        selected_weights.view({T * K}), // const&, ok as rvalue
                        sorted_indices // int32 indices, no conversion needed
  );

  // Scatter back to original positions with weights applied
  scatter_cuda(gate_up.view({E, C, H}),
               sorted_indices,
               bins,
               weights_sorted,
               output,
               T,
               E,
               C,
               K);

  return output;
}