robtaylor-chipflow commited on
Commit
2b765de
·
1 Parent(s): dd92cff

Use shared kernels-test-utils and set metal3.1 compatibility

Browse files

Replace inline device detection with kernels_test_utils.get_available_devices()
from the shared kernel-builder test utilities package. Set metal-std-version to
metal3.1 for macOS 14+ compatibility (was defaulting to metal4.0/macOS 26).

Co-developed-by: Claude Code v2.1.58 (claude-opus-4-6)

Files changed (3) hide show
  1. build.toml +1 -0
  2. flake.lock +3 -3
  3. tests/test_rotary_embedding.py +3 -9
build.toml CHANGED
@@ -10,6 +10,7 @@ src = [
10
 
11
  [kernel.rotary_embedding_metal]
12
  backend = "metal"
 
13
  src = [
14
  "rotary-embedding-metal/rotary_embedding.metal",
15
  "rotary-embedding-metal/rotary_embedding.mm",
 
10
 
11
  [kernel.rotary_embedding_metal]
12
  backend = "metal"
13
+ metal-std-version = "metal3.1"
14
  src = [
15
  "rotary-embedding-metal/rotary_embedding.metal",
16
  "rotary-embedding-metal/rotary_embedding.mm",
flake.lock CHANGED
@@ -41,11 +41,11 @@
41
  "rust-overlay": "rust-overlay"
42
  },
43
  "locked": {
44
- "lastModified": 1772650055,
45
- "narHash": "sha256-6R8dJEPH+uHJyvr3nZPZ/xFwULzR4UCsLQGSjLRsxQE=",
46
  "owner": "ChipFlow",
47
  "repo": "kernels",
48
- "rev": "f85b1d195c115acdb3f92c061a0dafcc0f9bfe79",
49
  "type": "github"
50
  },
51
  "original": {
 
41
  "rust-overlay": "rust-overlay"
42
  },
43
  "locked": {
44
+ "lastModified": 1773072978,
45
+ "narHash": "sha256-wTtMgTt1IMM5BFMh/lu+Y1jTw1P69aZcTr4fCNGvaw4=",
46
  "owner": "ChipFlow",
47
  "repo": "kernels",
48
+ "rev": "c220611160b60919af0c7c85438d82f3e3577aa2",
49
  "type": "github"
50
  },
51
  "original": {
tests/test_rotary_embedding.py CHANGED
@@ -7,17 +7,11 @@ for both NeoX (Llama/Mistral) and GPT-J rotation styles.
7
  import pytest
8
  import torch
9
 
10
- import rotary_embedding as ops
11
-
12
-
13
- def _is_mps_available() -> bool:
14
- return hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
15
 
 
16
 
17
- if _is_mps_available():
18
- DEVICES = ["mps"]
19
- else:
20
- DEVICES = [f"cuda:{i}" for i in range(max(1, torch.cuda.device_count()))]
21
 
22
  DTYPES = [torch.float32, torch.float16, torch.bfloat16]
23
  HEAD_SIZES = [64, 128, 256]
 
7
  import pytest
8
  import torch
9
 
10
+ from kernels_test_utils import get_available_devices
 
 
 
 
11
 
12
+ import rotary_embedding as ops
13
 
14
+ DEVICES = get_available_devices()
 
 
 
15
 
16
  DTYPES = [torch.float32, torch.float16, torch.bfloat16]
17
  HEAD_SIZES = [64, 128, 256]