| #pragma once |
|
|
| #include <torch/torch.h> |
|
|
| #include <core/scalar_type.hpp> |
|
|
| void silu_and_mul(torch::Tensor &out, torch::Tensor &input); |
|
|
| void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices, |
| torch::Tensor &token_expert_indices, |
| torch::Tensor &gating_output); |
|
|
| void moe_sum(torch::Tensor &input, torch::Tensor &output); |
|
|
| void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, |
| int64_t block_size, torch::Tensor sorted_token_ids, |
| torch::Tensor experts_ids, |
| torch::Tensor num_tokens_post_pad); |
|
|
| #ifndef USE_ROCM |
| torch::Tensor marlin_gemm_moe( |
| const torch::Tensor &a, const torch::Tensor &b_q_weights, |
| const torch::Tensor &sorted_ids, const torch::Tensor &topk_weights, |
| const torch::Tensor &topk_ids, const torch::Tensor &b_scales, |
| torch::Tensor &b_zeros, const torch::Tensor &g_idx, |
| const torch::Tensor &perm, torch::Tensor &workspace, |
| vllm::ScalarTypeId const b_q_type_id, int64_t size_m, int64_t size_n, |
| int64_t size_k, bool is_k_full, int64_t num_experts, int64_t topk, |
| int64_t moe_block_size, bool replicate_input, bool apply_weights); |
| #endif |
|
|