| #pragma once |
|
|
| #include <torch/torch.h> |
|
|
| void gather_cuda(torch::Tensor const &x, |
| torch::Tensor const &indices, |
| torch::Tensor const &bins, |
| torch::Tensor &output, |
| int64_t E, |
| int64_t C, |
| int64_t top_k); |
|
|
| void scatter_cuda(torch::Tensor const &src, |
| torch::Tensor const &indices, |
| torch::Tensor const &bins, |
| torch::Tensor const &weights, |
| torch::Tensor &y, |
| int64_t T, |
| int64_t E, |
| int64_t C, |
| int64_t top_k); |
|
|
| void sort_cuda(torch::Tensor x, |
| int64_t end_bit, |
| torch::Tensor x_out, |
| torch::Tensor iota_out); |
|
|
| void bincount_cumsum_cuda(torch::Tensor input, |
| torch::Tensor &output, |
| int64_t minlength); |
|
|
| torch::Tensor index_select_out_cuda(torch::Tensor out, |
| torch::Tensor in, |
| torch::Tensor idx_int32); |
|
|
| torch::Tensor |
| batch_mm(torch::Tensor x, |
| torch::Tensor weights, |
| torch::Tensor batch_sizes, |
| torch::Tensor output, |
| bool trans_b = false |
| ); |
|
|
| torch::Tensor experts_cuda( |
| torch::Tensor hidden_states, |
| torch::Tensor router_indices, |
| torch::Tensor routing_weights, |
| torch::Tensor gate_up_proj, |
| torch::Tensor gate_up_proj_bias, |
| torch::Tensor down_proj, |
| torch::Tensor down_proj_bias, |
| int64_t expert_capacity, |
| int64_t num_experts, |
| int64_t top_k |
| ); |
|
|
| std::vector<torch::Tensor> experts_backward_cuda( |
| const torch::Tensor &grad_out, |
| const torch::Tensor &hidden_states, |
| const torch::Tensor &router_indices, |
| const torch::Tensor &routing_weights, |
| const torch::Tensor |
| &gate_up_proj, |
| const torch::Tensor |
| &gate_up_proj_bias, |
| const torch::Tensor &down_proj, |
| const torch::Tensor &down_proj_bias, |
| int64_t expert_capacity, |
| int64_t num_experts, |
| int64_t top_k |
| ); |
|
|