| | #include <ATen/Tensor.h> |
| | #include "torch_binding.h" |
| | #include "registration.h" |
| |
|
| | TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { |
| | ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, " |
| | "Tensor! output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()"); |
| | ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch); |
| |
|
| | ops.def("bf16_f32_embeddings(Tensor token_ids, Tensor weight_bf16, Tensor! output, " |
| | "int threadgroup_size) -> ()"); |
| | ops.impl("bf16_f32_embeddings", torch::kMPS, &bf16_f32_embeddings_torch); |
| |
|
| | ops.def("f32_bf16w_rmsnorm(Tensor input, Tensor weight_bf16, Tensor! output, float epsilon) -> ()"); |
| | ops.impl("f32_bf16w_rmsnorm", torch::kMPS, &f32_bf16w_rmsnorm_torch); |
| |
|
| | ops.def("f32_bf16w_dense_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()"); |
| | ops.impl("f32_bf16w_dense_matmul_qkv", torch::kMPS, &f32_bf16w_dense_matmul_qkv_torch); |
| |
|
| | ops.def("f32_bf16w_dense_matmul_attn_output(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()"); |
| | ops.impl("f32_bf16w_dense_matmul_attn_output", torch::kMPS, &f32_bf16w_dense_matmul_attn_output_torch); |
| |
|
| | ops.def("f32_bf16w_dense_matmul_mlp_gate(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()"); |
| | ops.impl("f32_bf16w_dense_matmul_mlp_gate", torch::kMPS, &f32_bf16w_dense_matmul_mlp_gate_torch); |
| |
|
| | ops.def("f32_rope(Tensor! activations, float rope_base, float interpolation_scale, float yarn_offset, " |
| | "float yarn_scale, float yarn_multiplier, int num_tokens, int num_q_heads, int num_kv_heads, " |
| | "int attn_head_dim, int token_offset, int threadgroup_size) -> ()"); |
| | ops.impl("f32_rope", torch::kMPS, &f32_rope_torch); |
| |
|
| | ops.def("f32_bf16w_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, Tensor kv_cache, " |
| | "int kv_cache_offset_bytes, int num_tokens, int num_cols, int num_q_heads, int num_kv_heads, " |
| | "int attn_head_dim, int token_offset, int max_tokens, float rope_base, float interpolation_scale, " |
| | "float yarn_offset, float yarn_scale, float yarn_multiplier, int threadgroup_size) -> ()"); |
| | ops.impl("f32_bf16w_matmul_qkv", torch::kMPS, &f32_bf16w_matmul_qkv_torch); |
| |
|
| | ops.def("f32_sdpa(Tensor q, int q_offset_bytes, Tensor kv, int kv_offset_bytes, Tensor s_bf16, int s_offset_bytes, " |
| | "Tensor! output, int output_offset_bytes, int window, int kv_stride, int num_q_tokens, int num_kv_tokens, " |
| | "int num_q_heads, int num_kv_heads, int head_dim) -> ()"); |
| | ops.impl("f32_sdpa", torch::kMPS, &f32_sdpa_torch); |
| |
|
| | ops.def("f32_topk(Tensor scores, Tensor expert_ids, Tensor expert_scores, int num_tokens, int num_experts, " |
| | "int num_active_experts) -> ()"); |
| | ops.impl("f32_topk", torch::kMPS, &f32_topk_torch); |
| |
|
| | ops.def("expert_routing_metadata(Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, " |
| | "Tensor intra_expert_offsets, int num_tokens, int num_experts) -> ()"); |
| | ops.impl("expert_routing_metadata", torch::kMPS, &expert_routing_metadata_torch); |
| |
|
| | ops.def("f32_scatter(Tensor input, Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, " |
| | "Tensor intra_expert_offsets, Tensor! output, int num_channels, int num_tokens, " |
| | "int num_active_experts) -> ()"); |
| | ops.impl("f32_scatter", torch::kMPS, &f32_scatter_torch); |
| |
|
| | ops.def("f32_bf16w_matmul_add(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, " |
| | "int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()"); |
| | ops.impl("f32_bf16w_matmul_add", torch::kMPS, &f32_bf16w_matmul_add_torch); |
| | } |
| |
|
| | REGISTER_EXTENSION(TORCH_EXTENSION_NAME) |