| std::vector<torch::Tensor> | |
| get_mla_metadata( | |
| torch::Tensor &seqlens_k, | |
| const int64_t num_heads_per_head_k, | |
| const int64_t num_heads_k | |
| ); | |
| std::vector<torch::Tensor> | |
| mha_fwd_kvcache_mla( | |
| torch::Tensor &q, | |
| const torch::Tensor &kcache, | |
| // TODO: fix for optional | |
| // std::optional<torch::Tensor> &vcache_, | |
| const torch::Tensor &vcache_, | |
| const int64_t head_size_v, | |
| const torch::Tensor &seqlens_k, | |
| const torch::Tensor &block_table, | |
| // TODO:should be float | |
| const double softmax_scale, | |
| // TODO: fix for mutable bool | |
| const bool is_causal_, | |
| const torch::Tensor &tile_scheduler_metadata, | |
| const torch::Tensor &num_splits, | |
| // TODO: remove when resolved | |
| const int64_t unknown_param = 0 | |
| ); |