| #pragma once |
|
|
| #include <torch/torch.h> |
|
|
| void f32_bf16w_matmul_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output, |
| int64_t num_tokens, |
| int64_t num_cols, |
| int64_t num_rows, |
| int64_t threadgroup_size); |
|
|
| void bf16_f32_embeddings_torch(const at::Tensor& token_ids, |
| const at::Tensor& weight_bf16, |
| at::Tensor& output, |
| int64_t threadgroup_size); |
|
|
| void f32_bf16w_rmsnorm_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| at::Tensor& output, |
| double epsilon); |
|
|
| void f32_bf16w_dense_matmul_qkv_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output); |
|
|
| void f32_bf16w_dense_matmul_attn_output_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output); |
|
|
| void f32_bf16w_dense_matmul_mlp_gate_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output); |
|
|
| void f32_rope_torch(at::Tensor& activations, |
| double rope_base, |
| double interpolation_scale, |
| double yarn_offset, |
| double yarn_scale, |
| double yarn_multiplier, |
| int64_t num_tokens, |
| int64_t num_q_heads, |
| int64_t num_kv_heads, |
| int64_t attn_head_dim, |
| int64_t token_offset, |
| int64_t threadgroup_size); |
|
|
| void f32_bf16w_matmul_qkv_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output, |
| at::Tensor& kv_cache, |
| int64_t kv_cache_offset_bytes, |
| int64_t num_tokens, |
| int64_t num_cols, |
| int64_t num_q_heads, |
| int64_t num_kv_heads, |
| int64_t attn_head_dim, |
| int64_t token_offset, |
| int64_t max_tokens, |
| double rope_base, |
| double interpolation_scale, |
| double yarn_offset, |
| double yarn_scale, |
| double yarn_multiplier, |
| int64_t threadgroup_size); |
|
|
| void f32_sdpa_torch(const at::Tensor& q, |
| int64_t q_offset_bytes, |
| const at::Tensor& kv, |
| int64_t kv_offset_bytes, |
| const at::Tensor& s_bf16, |
| int64_t s_offset_bytes, |
| at::Tensor& output, |
| int64_t output_offset_bytes, |
| int64_t window, |
| int64_t kv_stride, |
| int64_t num_q_tokens, |
| int64_t num_kv_tokens, |
| int64_t num_q_heads, |
| int64_t num_kv_heads, |
| int64_t head_dim); |
|
|
| void f32_topk_torch(const at::Tensor& scores, |
| at::Tensor& expert_ids, |
| at::Tensor& expert_scores, |
| int64_t num_tokens, |
| int64_t num_experts, |
| int64_t num_active_experts); |
|
|
| void expert_routing_metadata_torch(const at::Tensor& expert_ids, |
| const at::Tensor& expert_scores, |
| at::Tensor& expert_offsets, |
| at::Tensor& intra_expert_offsets, |
| int64_t num_tokens, |
| int64_t num_experts); |
|
|
| void f32_scatter_torch(const at::Tensor& input, |
| const at::Tensor& expert_ids, |
| const at::Tensor& expert_scores, |
| const at::Tensor& expert_offsets, |
| const at::Tensor& intra_expert_offsets, |
| at::Tensor& output, |
| int64_t num_channels, |
| int64_t num_tokens, |
| int64_t num_active_experts); |
|
|
| void f32_bf16w_matmul_add_torch(const at::Tensor& input, |
| const at::Tensor& weight_bf16, |
| const at::Tensor& bias_bf16, |
| at::Tensor& output, |
| int64_t num_tokens, |
| int64_t num_cols, |
| int64_t num_rows, |
| int64_t threadgroup_size); |