Commit ·
2b765de
1
Parent(s): dd92cff
Use shared kernels-test-utils and set metal3.1 compatibility
Browse filesReplace inline device detection with kernels_test_utils.get_available_devices()
from the shared kernel-builder test utilities package. Set metal-std-version to
metal3.1 for macOS 14+ compatibility (was defaulting to metal4.0/macOS 26).
Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)
- build.toml +1 -0
- flake.lock +3 -3
- tests/test_rotary_embedding.py +3 -9
build.toml
CHANGED
|
@@ -10,6 +10,7 @@ src = [
|
|
| 10 |
|
| 11 |
[kernel.rotary_embedding_metal]
|
| 12 |
backend = "metal"
|
|
|
|
| 13 |
src = [
|
| 14 |
"rotary-embedding-metal/rotary_embedding.metal",
|
| 15 |
"rotary-embedding-metal/rotary_embedding.mm",
|
|
|
|
| 10 |
|
| 11 |
[kernel.rotary_embedding_metal]
|
| 12 |
backend = "metal"
|
| 13 |
+
metal-std-version = "metal3.1"
|
| 14 |
src = [
|
| 15 |
"rotary-embedding-metal/rotary_embedding.metal",
|
| 16 |
"rotary-embedding-metal/rotary_embedding.mm",
|
flake.lock
CHANGED
|
@@ -41,11 +41,11 @@
|
|
| 41 |
"rust-overlay": "rust-overlay"
|
| 42 |
},
|
| 43 |
"locked": {
|
| 44 |
-
"lastModified":
|
| 45 |
-
"narHash": "sha256-
|
| 46 |
"owner": "ChipFlow",
|
| 47 |
"repo": "kernels",
|
| 48 |
-
"rev": "
|
| 49 |
"type": "github"
|
| 50 |
},
|
| 51 |
"original": {
|
|
|
|
| 41 |
"rust-overlay": "rust-overlay"
|
| 42 |
},
|
| 43 |
"locked": {
|
| 44 |
+
"lastModified": 1773072978,
|
| 45 |
+
"narHash": "sha256-wTtMgTt1IMM5BFMh/lu+Y1jTw1P69aZcTr4fCNGvaw4=",
|
| 46 |
"owner": "ChipFlow",
|
| 47 |
"repo": "kernels",
|
| 48 |
+
"rev": "c220611160b60919af0c7c85438d82f3e3577aa2",
|
| 49 |
"type": "github"
|
| 50 |
},
|
| 51 |
"original": {
|
tests/test_rotary_embedding.py
CHANGED
|
@@ -7,17 +7,11 @@ for both NeoX (Llama/Mistral) and GPT-J rotation styles.
|
|
| 7 |
import pytest
|
| 8 |
import torch
|
| 9 |
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
def _is_mps_available() -> bool:
|
| 14 |
-
return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
|
| 15 |
|
|
|
|
| 16 |
|
| 17 |
-
|
| 18 |
-
DEVICES = ["mps"]
|
| 19 |
-
else:
|
| 20 |
-
DEVICES = [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
|
| 21 |
|
| 22 |
DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
| 23 |
HEAD_SIZES = [64, 128, 256]
|
|
|
|
| 7 |
import pytest
|
| 8 |
import torch
|
| 9 |
|
| 10 |
+
from kernels_test_utils import get_available_devices
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
+
import rotary_embedding as ops
|
| 13 |
|
| 14 |
+
DEVICES = get_available_devices()
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
| 17 |
HEAD_SIZES = [64, 128, 256]
|