File size: 6,939 Bytes
bc1b8eb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | #include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSDevice.h>
#include <c10/util/Exception.h>
#include <filesystem>
#include <dlfcn.h>
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <mutex>
namespace fs = std::filesystem;
namespace {
struct SparseLinearParams {
uint32_t N;
uint32_t In;
uint32_t Out;
uint32_t dummy;
};
static id<MTLLibrary> g_lib = nil;
static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
static std::mutex g_mutex;
static std::string metallib_path_for_this_module() {
Dl_info info;
if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
fs::path so_path(info.dli_fname);
return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
}
static void ensure_library_locked(id<MTLDevice> device) {
if (g_lib != nil) return;
std::string path = metallib_path_for_this_module();
TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
NSURL* url = [NSURL fileURLWithPath:ns_path];
NSError* err = nil;
g_lib = [device newLibraryWithURL:url error:&err];
if (g_lib == nil) {
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
}
}
static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
std::lock_guard<std::mutex> lock(g_mutex);
ensure_library_locked(device);
if (*pipeline != nil) return *pipeline;
NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn];
TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
NSError* err = nil;
*pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
if (*pipeline == nil) {
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
}
return *pipeline;
}
static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
void* ctx = t.storage().data_ptr().get_context();
TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
return (__bridge id<MTLBuffer>)ctx;
}
static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
}
static void check_mps_float_contig(const at::Tensor& t, const char* name) {
TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel");
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
}
} // namespace
std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) {
check_mps_float_contig(x2d, "x2d");
check_mps_float_contig(gy2d, "gy2d");
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
int64_t N = x2d.size(0);
int64_t In = x2d.size(1);
int64_t Out = active_mask.size(0);
TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size");
auto grad_w = at::zeros({Out, In}, x2d.options());
auto grad_b = at::zeros({Out}, x2d.options());
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
};
set_tensor(x2d, 0);
set_tensor(gy2d, 1);
set_tensor(active_mask, 2);
set_tensor(grad_w, 3);
set_tensor(grad_b, 4);
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};
[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
MTLSize tg = MTLSizeMake(16, 16, 1);
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
return {grad_w, grad_b};
}
at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) {
check_mps_float_contig(gy2d, "gy2d");
check_mps_float_contig(weight, "weight");
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
int64_t N = gy2d.size(0);
int64_t Out = gy2d.size(1);
int64_t In = weight.size(1);
TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
auto grad_x = at::zeros({N, In}, gy2d.options());
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
};
set_tensor(gy2d, 0);
set_tensor(weight, 1);
set_tensor(active_mask, 2);
set_tensor(grad_x, 3);
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
MTLSize tg = MTLSizeMake(16, 16, 1);
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
return grad_x;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
} |