File size: 3,731 Bytes
95d28ad 9ffd725 95d28ad 9ffd725 95d28ad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | #include <ATen/Tensor.h>
#include "torch_binding.h"
#include "registration.h"
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, "
"Tensor! output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);
ops.def("bf16_f32_embeddings(Tensor token_ids, Tensor weight_bf16, Tensor! output, "
"int threadgroup_size) -> ()");
ops.impl("bf16_f32_embeddings", torch::kMPS, &bf16_f32_embeddings_torch);
ops.def("f32_bf16w_rmsnorm(Tensor input, Tensor weight_bf16, Tensor! output, float epsilon) -> ()");
ops.impl("f32_bf16w_rmsnorm", torch::kMPS, &f32_bf16w_rmsnorm_torch);
ops.def("f32_bf16w_dense_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
ops.impl("f32_bf16w_dense_matmul_qkv", torch::kMPS, &f32_bf16w_dense_matmul_qkv_torch);
ops.def("f32_bf16w_dense_matmul_attn_output(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
ops.impl("f32_bf16w_dense_matmul_attn_output", torch::kMPS, &f32_bf16w_dense_matmul_attn_output_torch);
ops.def("f32_bf16w_dense_matmul_mlp_gate(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
ops.impl("f32_bf16w_dense_matmul_mlp_gate", torch::kMPS, &f32_bf16w_dense_matmul_mlp_gate_torch);
ops.def("f32_rope(Tensor! activations, float rope_base, float interpolation_scale, float yarn_offset, "
"float yarn_scale, float yarn_multiplier, int num_tokens, int num_q_heads, int num_kv_heads, "
"int attn_head_dim, int token_offset, int threadgroup_size) -> ()");
ops.impl("f32_rope", torch::kMPS, &f32_rope_torch);
ops.def("f32_bf16w_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, Tensor kv_cache, "
"int kv_cache_offset_bytes, int num_tokens, int num_cols, int num_q_heads, int num_kv_heads, "
"int attn_head_dim, int token_offset, int max_tokens, float rope_base, float interpolation_scale, "
"float yarn_offset, float yarn_scale, float yarn_multiplier, int threadgroup_size) -> ()");
ops.impl("f32_bf16w_matmul_qkv", torch::kMPS, &f32_bf16w_matmul_qkv_torch);
ops.def("f32_sdpa(Tensor q, int q_offset_bytes, Tensor kv, int kv_offset_bytes, Tensor s_bf16, int s_offset_bytes, "
"Tensor! output, int output_offset_bytes, int window, int kv_stride, int num_q_tokens, int num_kv_tokens, "
"int num_q_heads, int num_kv_heads, int head_dim) -> ()");
ops.impl("f32_sdpa", torch::kMPS, &f32_sdpa_torch);
ops.def("f32_topk(Tensor scores, Tensor expert_ids, Tensor expert_scores, int num_tokens, int num_experts, "
"int num_active_experts) -> ()");
ops.impl("f32_topk", torch::kMPS, &f32_topk_torch);
ops.def("expert_routing_metadata(Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
"Tensor intra_expert_offsets, int num_tokens, int num_experts) -> ()");
ops.impl("expert_routing_metadata", torch::kMPS, &expert_routing_metadata_torch);
ops.def("f32_scatter(Tensor input, Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
"Tensor intra_expert_offsets, Tensor! output, int num_channels, int num_tokens, "
"int num_active_experts) -> ()");
ops.impl("f32_scatter", torch::kMPS, &f32_scatter_torch);
ops.def("f32_bf16w_matmul_add(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, "
"int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
ops.impl("f32_bf16w_matmul_add", torch::kMPS, &f32_bf16w_matmul_add_torch);
}
REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |