File size: 3,731 Bytes
95d28ad
 
 
 
 
9ffd725
 
95d28ad
9ffd725
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95d28ad
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
#include <ATen/Tensor.h>
#include "torch_binding.h"
#include "registration.h"

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("f32_bf16w_matmul(Tensor input, Tensor weight_bf16, Tensor bias_bf16, "
          "Tensor! output, int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
  ops.impl("f32_bf16w_matmul", torch::kMPS, &f32_bf16w_matmul_torch);

  ops.def("bf16_f32_embeddings(Tensor token_ids, Tensor weight_bf16, Tensor! output, "
          "int threadgroup_size) -> ()");
  ops.impl("bf16_f32_embeddings", torch::kMPS, &bf16_f32_embeddings_torch);

  ops.def("f32_bf16w_rmsnorm(Tensor input, Tensor weight_bf16, Tensor! output, float epsilon) -> ()");
  ops.impl("f32_bf16w_rmsnorm", torch::kMPS, &f32_bf16w_rmsnorm_torch);

  ops.def("f32_bf16w_dense_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
  ops.impl("f32_bf16w_dense_matmul_qkv", torch::kMPS, &f32_bf16w_dense_matmul_qkv_torch);

  ops.def("f32_bf16w_dense_matmul_attn_output(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
  ops.impl("f32_bf16w_dense_matmul_attn_output", torch::kMPS, &f32_bf16w_dense_matmul_attn_output_torch);

  ops.def("f32_bf16w_dense_matmul_mlp_gate(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output) -> ()");
  ops.impl("f32_bf16w_dense_matmul_mlp_gate", torch::kMPS, &f32_bf16w_dense_matmul_mlp_gate_torch);

  ops.def("f32_rope(Tensor! activations, float rope_base, float interpolation_scale, float yarn_offset, "
          "float yarn_scale, float yarn_multiplier, int num_tokens, int num_q_heads, int num_kv_heads, "
          "int attn_head_dim, int token_offset, int threadgroup_size) -> ()");
  ops.impl("f32_rope", torch::kMPS, &f32_rope_torch);

  ops.def("f32_bf16w_matmul_qkv(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, Tensor kv_cache, "
          "int kv_cache_offset_bytes, int num_tokens, int num_cols, int num_q_heads, int num_kv_heads, "
          "int attn_head_dim, int token_offset, int max_tokens, float rope_base, float interpolation_scale, "
          "float yarn_offset, float yarn_scale, float yarn_multiplier, int threadgroup_size) -> ()");
  ops.impl("f32_bf16w_matmul_qkv", torch::kMPS, &f32_bf16w_matmul_qkv_torch);

  ops.def("f32_sdpa(Tensor q, int q_offset_bytes, Tensor kv, int kv_offset_bytes, Tensor s_bf16, int s_offset_bytes, "
          "Tensor! output, int output_offset_bytes, int window, int kv_stride, int num_q_tokens, int num_kv_tokens, "
          "int num_q_heads, int num_kv_heads, int head_dim) -> ()");
  ops.impl("f32_sdpa", torch::kMPS, &f32_sdpa_torch);

  ops.def("f32_topk(Tensor scores, Tensor expert_ids, Tensor expert_scores, int num_tokens, int num_experts, "
          "int num_active_experts) -> ()");
  ops.impl("f32_topk", torch::kMPS, &f32_topk_torch);

  ops.def("expert_routing_metadata(Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
          "Tensor intra_expert_offsets, int num_tokens, int num_experts) -> ()");
  ops.impl("expert_routing_metadata", torch::kMPS, &expert_routing_metadata_torch);

  ops.def("f32_scatter(Tensor input, Tensor expert_ids, Tensor expert_scores, Tensor expert_offsets, "
          "Tensor intra_expert_offsets, Tensor! output, int num_channels, int num_tokens, "
          "int num_active_experts) -> ()");
  ops.impl("f32_scatter", torch::kMPS, &f32_scatter_torch);

  ops.def("f32_bf16w_matmul_add(Tensor input, Tensor weight_bf16, Tensor bias_bf16, Tensor! output, "
          "int num_tokens, int num_cols, int num_rows, int threadgroup_size) -> ()");
  ops.impl("f32_bf16w_matmul_add", torch::kMPS, &f32_bf16w_matmul_add_torch);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)