robtaylor-chipflow commited on
Commit
ae7295b
·
1 Parent(s): cf2c51f

Migrate to embedded metallib and update flake.lock

Browse files

Switch from file-based METALLIB_PATH loading to the embedded metallib
pattern (EMBEDDED_METALLIB_HEADER / EMBEDDED_METALLIB_NAMESPACE) that
kernel-builder now uses. This embeds the compiled Metal shader directly
into the shared library.

Also generate flake.lock for reproducible nix builds.

Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)

flake.lock ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1765121682,
6
+ "narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-utils": {
19
+ "inputs": {
20
+ "systems": "systems"
21
+ },
22
+ "locked": {
23
+ "lastModified": 1731533236,
24
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
25
+ "owner": "numtide",
26
+ "repo": "flake-utils",
27
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
28
+ "type": "github"
29
+ },
30
+ "original": {
31
+ "owner": "numtide",
32
+ "repo": "flake-utils",
33
+ "type": "github"
34
+ }
35
+ },
36
+ "kernel-builder": {
37
+ "inputs": {
38
+ "flake-compat": "flake-compat",
39
+ "flake-utils": "flake-utils",
40
+ "nixpkgs": "nixpkgs"
41
+ },
42
+ "locked": {
43
+ "lastModified": 1769448133,
44
+ "narHash": "sha256-XOp8+8u7fmXn1f63mJ40dPj/OHPMKtL9o4q7y0CUZFU=",
45
+ "owner": "huggingface",
46
+ "repo": "kernel-builder",
47
+ "rev": "078351df6e0fddb4a1a41ba3ffb8b804f58c4c6a",
48
+ "type": "github"
49
+ },
50
+ "original": {
51
+ "owner": "huggingface",
52
+ "repo": "kernel-builder",
53
+ "type": "github"
54
+ }
55
+ },
56
+ "nixpkgs": {
57
+ "locked": {
58
+ "lastModified": 1766341660,
59
+ "narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
60
+ "owner": "NixOS",
61
+ "repo": "nixpkgs",
62
+ "rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
63
+ "type": "github"
64
+ },
65
+ "original": {
66
+ "owner": "NixOS",
67
+ "ref": "nixos-unstable-small",
68
+ "repo": "nixpkgs",
69
+ "type": "github"
70
+ }
71
+ },
72
+ "root": {
73
+ "inputs": {
74
+ "kernel-builder": "kernel-builder"
75
+ }
76
+ },
77
+ "systems": {
78
+ "locked": {
79
+ "lastModified": 1681028828,
80
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
81
+ "owner": "nix-systems",
82
+ "repo": "default",
83
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
84
+ "type": "github"
85
+ },
86
+ "original": {
87
+ "owner": "nix-systems",
88
+ "repo": "default",
89
+ "type": "github"
90
+ }
91
+ }
92
+ },
93
+ "root": "root",
94
+ "version": 7
95
+ }
rotary-embedding-metal/rotary_embedding.mm CHANGED
@@ -4,25 +4,19 @@
4
 
5
  #import <Foundation/Foundation.h>
6
  #import <Metal/Metal.h>
7
- #include <dlfcn.h>
8
  #include <string>
9
 
 
 
 
 
 
 
 
10
  static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
11
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
12
  }
13
 
14
- static std::string getModuleDirectory() {
15
- Dl_info dl_info;
16
- if (dladdr((void *)getModuleDirectory, &dl_info)) {
17
- std::string path(dl_info.dli_fname);
18
- size_t pos = path.find_last_of('/');
19
- if (pos != std::string::npos) {
20
- return path.substr(0, pos);
21
- }
22
- }
23
- return ".";
24
- }
25
-
26
  void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
27
  std::optional<torch::Tensor> key, int64_t head_size,
28
  torch::Tensor &cos_sin_cache, bool is_neox) {
@@ -74,16 +68,11 @@ void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
74
  id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
75
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
76
 
77
- // Load metallib.
78
- std::string moduleDir = getModuleDirectory();
79
- std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
80
-
81
- NSString *metallibPathStr =
82
- [NSString stringWithUTF8String:metallibPath.c_str()];
83
- NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
84
  NSError *error = nil;
85
- id<MTLLibrary> lib = [device newLibraryWithURL:metallibURL error:&error];
86
- TORCH_CHECK(lib, "Failed to load Metal library at ", metallibPath,
 
87
  error ? [NSString stringWithFormat:@": %@",
88
  error.localizedDescription]
89
  .UTF8String
 
4
 
5
  #import <Foundation/Foundation.h>
6
  #import <Metal/Metal.h>
 
7
  #include <string>
8
 
9
+ // Include the auto-generated header with embedded metallib.
10
+ #ifdef EMBEDDED_METALLIB_HEADER
11
+ #include EMBEDDED_METALLIB_HEADER
12
+ #else
13
+ #error "EMBEDDED_METALLIB_HEADER not defined"
14
+ #endif
15
+
16
  static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
17
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
18
  }
19
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
21
  std::optional<torch::Tensor> key, int64_t head_size,
22
  torch::Tensor &cos_sin_cache, bool is_neox) {
 
68
  id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
69
  TORCH_CHECK(cmdBuf, "Failed to get command buffer");
70
 
71
+ // Load embedded Metal library.
 
 
 
 
 
 
72
  NSError *error = nil;
73
+ id<MTLLibrary> lib =
74
+ EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
75
+ TORCH_CHECK(lib, "Failed to create Metal library from embedded data",
76
  error ? [NSString stringWithFormat:@": %@",
77
  error.localizedDescription]
78
  .UTF8String