Commit ·
ae7295b
1
Parent(s): cf2c51f
Migrate to embedded metallib and update flake.lock
Browse filesSwitch from file-based METALLIB_PATH loading to the embedded metallib
pattern (EMBEDDED_METALLIB_HEADER / EMBEDDED_METALLIB_NAMESPACE) that
kernel-builder now uses. This embeds the compiled Metal shader directly
into the shared library.
Also generate flake.lock for reproducible nix builds.
Co-developed-by: Claude Code v2.1.50 (claude-opus-4-6)
- flake.lock +95 -0
- rotary-embedding-metal/rotary_embedding.mm +11 -22
flake.lock
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1765121682,
|
| 6 |
+
"narHash": "sha256-4VBOP18BFeiPkyhy9o4ssBNQEvfvv1kXkasAYd0+rrA=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "65f23138d8d09a92e30f1e5c87611b23ef451bf3",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-utils": {
|
| 19 |
+
"inputs": {
|
| 20 |
+
"systems": "systems"
|
| 21 |
+
},
|
| 22 |
+
"locked": {
|
| 23 |
+
"lastModified": 1731533236,
|
| 24 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 25 |
+
"owner": "numtide",
|
| 26 |
+
"repo": "flake-utils",
|
| 27 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 28 |
+
"type": "github"
|
| 29 |
+
},
|
| 30 |
+
"original": {
|
| 31 |
+
"owner": "numtide",
|
| 32 |
+
"repo": "flake-utils",
|
| 33 |
+
"type": "github"
|
| 34 |
+
}
|
| 35 |
+
},
|
| 36 |
+
"kernel-builder": {
|
| 37 |
+
"inputs": {
|
| 38 |
+
"flake-compat": "flake-compat",
|
| 39 |
+
"flake-utils": "flake-utils",
|
| 40 |
+
"nixpkgs": "nixpkgs"
|
| 41 |
+
},
|
| 42 |
+
"locked": {
|
| 43 |
+
"lastModified": 1769448133,
|
| 44 |
+
"narHash": "sha256-XOp8+8u7fmXn1f63mJ40dPj/OHPMKtL9o4q7y0CUZFU=",
|
| 45 |
+
"owner": "huggingface",
|
| 46 |
+
"repo": "kernel-builder",
|
| 47 |
+
"rev": "078351df6e0fddb4a1a41ba3ffb8b804f58c4c6a",
|
| 48 |
+
"type": "github"
|
| 49 |
+
},
|
| 50 |
+
"original": {
|
| 51 |
+
"owner": "huggingface",
|
| 52 |
+
"repo": "kernel-builder",
|
| 53 |
+
"type": "github"
|
| 54 |
+
}
|
| 55 |
+
},
|
| 56 |
+
"nixpkgs": {
|
| 57 |
+
"locked": {
|
| 58 |
+
"lastModified": 1766341660,
|
| 59 |
+
"narHash": "sha256-4yG6vx7Dddk9/zh45Y2KM82OaRD4jO3HA9r98ORzysA=",
|
| 60 |
+
"owner": "NixOS",
|
| 61 |
+
"repo": "nixpkgs",
|
| 62 |
+
"rev": "26861f5606e3e4d1400771b513cc63e5f70151a6",
|
| 63 |
+
"type": "github"
|
| 64 |
+
},
|
| 65 |
+
"original": {
|
| 66 |
+
"owner": "NixOS",
|
| 67 |
+
"ref": "nixos-unstable-small",
|
| 68 |
+
"repo": "nixpkgs",
|
| 69 |
+
"type": "github"
|
| 70 |
+
}
|
| 71 |
+
},
|
| 72 |
+
"root": {
|
| 73 |
+
"inputs": {
|
| 74 |
+
"kernel-builder": "kernel-builder"
|
| 75 |
+
}
|
| 76 |
+
},
|
| 77 |
+
"systems": {
|
| 78 |
+
"locked": {
|
| 79 |
+
"lastModified": 1681028828,
|
| 80 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 81 |
+
"owner": "nix-systems",
|
| 82 |
+
"repo": "default",
|
| 83 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 84 |
+
"type": "github"
|
| 85 |
+
},
|
| 86 |
+
"original": {
|
| 87 |
+
"owner": "nix-systems",
|
| 88 |
+
"repo": "default",
|
| 89 |
+
"type": "github"
|
| 90 |
+
}
|
| 91 |
+
}
|
| 92 |
+
},
|
| 93 |
+
"root": "root",
|
| 94 |
+
"version": 7
|
| 95 |
+
}
|
rotary-embedding-metal/rotary_embedding.mm
CHANGED
|
@@ -4,25 +4,19 @@
|
|
| 4 |
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
#import <Metal/Metal.h>
|
| 7 |
-
#include <dlfcn.h>
|
| 8 |
#include <string>
|
| 9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
|
| 11 |
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 12 |
}
|
| 13 |
|
| 14 |
-
static std::string getModuleDirectory() {
|
| 15 |
-
Dl_info dl_info;
|
| 16 |
-
if (dladdr((void *)getModuleDirectory, &dl_info)) {
|
| 17 |
-
std::string path(dl_info.dli_fname);
|
| 18 |
-
size_t pos = path.find_last_of('/');
|
| 19 |
-
if (pos != std::string::npos) {
|
| 20 |
-
return path.substr(0, pos);
|
| 21 |
-
}
|
| 22 |
-
}
|
| 23 |
-
return ".";
|
| 24 |
-
}
|
| 25 |
-
|
| 26 |
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
| 27 |
std::optional<torch::Tensor> key, int64_t head_size,
|
| 28 |
torch::Tensor &cos_sin_cache, bool is_neox) {
|
|
@@ -74,16 +68,11 @@ void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
|
| 74 |
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 75 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 76 |
|
| 77 |
-
// Load
|
| 78 |
-
std::string moduleDir = getModuleDirectory();
|
| 79 |
-
std::string metallibPath = moduleDir + "/" + METALLIB_PATH;
|
| 80 |
-
|
| 81 |
-
NSString *metallibPathStr =
|
| 82 |
-
[NSString stringWithUTF8String:metallibPath.c_str()];
|
| 83 |
-
NSURL *metallibURL = [NSURL fileURLWithPath:metallibPathStr];
|
| 84 |
NSError *error = nil;
|
| 85 |
-
id<MTLLibrary> lib =
|
| 86 |
-
|
|
|
|
| 87 |
error ? [NSString stringWithFormat:@": %@",
|
| 88 |
error.localizedDescription]
|
| 89 |
.UTF8String
|
|
|
|
| 4 |
|
| 5 |
#import <Foundation/Foundation.h>
|
| 6 |
#import <Metal/Metal.h>
|
|
|
|
| 7 |
#include <string>
|
| 8 |
|
| 9 |
+
// Include the auto-generated header with embedded metallib.
|
| 10 |
+
#ifdef EMBEDDED_METALLIB_HEADER
|
| 11 |
+
#include EMBEDDED_METALLIB_HEADER
|
| 12 |
+
#else
|
| 13 |
+
#error "EMBEDDED_METALLIB_HEADER not defined"
|
| 14 |
+
#endif
|
| 15 |
+
|
| 16 |
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor &tensor) {
|
| 17 |
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
| 18 |
}
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
void rotary_embedding(torch::Tensor &positions, torch::Tensor &query,
|
| 21 |
std::optional<torch::Tensor> key, int64_t head_size,
|
| 22 |
torch::Tensor &cos_sin_cache, bool is_neox) {
|
|
|
|
| 68 |
id<MTLCommandBuffer> cmdBuf = stream->commandBuffer();
|
| 69 |
TORCH_CHECK(cmdBuf, "Failed to get command buffer");
|
| 70 |
|
| 71 |
+
// Load embedded Metal library.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
NSError *error = nil;
|
| 73 |
+
id<MTLLibrary> lib =
|
| 74 |
+
EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
|
| 75 |
+
TORCH_CHECK(lib, "Failed to create Metal library from embedded data",
|
| 76 |
error ? [NSString stringWithFormat:@": %@",
|
| 77 |
error.localizedDescription]
|
| 78 |
.UTF8String
|