| TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { | |
| ops.def("rotary_embedding(Tensor positions, Tensor! query," | |
| " Tensor!? key, int head_size," | |
| " Tensor cos_sin_cache, bool is_neox) -> ()"); | |
| ops.impl("rotary_embedding", torch::kMPS, rotary_embedding); | |
| } | |
| REGISTER_EXTENSION(TORCH_EXTENSION_NAME) | |