#include #include #include #include #include #include #include #import #import #include namespace fs = std::filesystem; namespace { struct SparseLinearParams { uint32_t N; uint32_t In; uint32_t Out; uint32_t dummy; }; static id g_lib = nil; static id g_pipeline_grad_w = nil; static id g_pipeline_grad_x = nil; static std::mutex g_mutex; static std::string metallib_path_for_this_module() { Dl_info info; if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string(); fs::path so_path(info.dli_fname); return (so_path.parent_path() / "sparse_linear_ops.metallib").string(); } static void ensure_library_locked(id device) { if (g_lib != nil) return; std::string path = metallib_path_for_this_module(); TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr"); NSString* ns_path = [NSString stringWithUTF8String:path.c_str()]; NSURL* url = [NSURL fileURLWithPath:ns_path]; NSError* err = nil; g_lib = [device newLibraryWithURL:url error:&err]; if (g_lib == nil) { const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg); } } static id ensure_pipeline(id device, id* pipeline, const char* fn_name) { std::lock_guard lock(g_mutex); ensure_library_locked(device); if (*pipeline != nil) return *pipeline; NSString* ns_fn = [NSString stringWithUTF8String:fn_name]; id fn =[g_lib newFunctionWithName:ns_fn]; TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib"); NSError* err = nil; *pipeline = [device newComputePipelineStateWithFunction:fn error:&err]; if (*pipeline == nil) { const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg); } return *pipeline; } static inline id storage_as_mtlbuffer(const at::Tensor& t) { void* ctx = t.storage().data_ptr().get_context(); TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context"); return (__bridge id)ctx; } static inline NSUInteger storage_offset_bytes(const at::Tensor& t) { return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size()); } static void check_mps_float_contig(const at::Tensor& t, const char* name) { TORCH_CHECK(t.device().is_mps(), name, " must be on MPS"); TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel"); TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); } } // namespace std::vector sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) { check_mps_float_contig(x2d, "x2d"); check_mps_float_contig(gy2d, "gy2d"); TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS"); TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool"); TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); int64_t N = x2d.size(0); int64_t In = x2d.size(1); int64_t Out = active_mask.size(0); TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size"); auto grad_w = at::zeros({Out, In}, x2d.options()); auto grad_b = at::zeros({Out}, x2d.options()); id device = (id)at::mps::MPSDevice::getInstance()->device(); id pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float"); at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); id encoder = (id)stream->commandEncoder();[encoder setComputePipelineState:pipeline]; auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx]; }; set_tensor(x2d, 0); set_tensor(gy2d, 1); set_tensor(active_mask, 2); set_tensor(grad_w, 3); set_tensor(grad_b, 4); SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0}; [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5]; MTLSize tg = MTLSizeMake(16, 16, 1); MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; return {grad_w, grad_b}; } at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) { check_mps_float_contig(gy2d, "gy2d"); check_mps_float_contig(weight, "weight"); TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS"); TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool"); TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); int64_t N = gy2d.size(0); int64_t Out = gy2d.size(1); int64_t In = weight.size(1); TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width"); auto grad_x = at::zeros({N, In}, gy2d.options()); id device = (id)at::mps::MPSDevice::getInstance()->device(); id pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float"); at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); id encoder = (id)stream->commandEncoder();[encoder setComputePipelineState:pipeline]; auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx]; }; set_tensor(gy2d, 0); set_tensor(weight, 1); set_tensor(active_mask, 2); set_tensor(grad_x, 3); SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4]; MTLSize tg = MTLSizeMake(16, 16, 1); MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1); [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; return grad_x; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)"); m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)"); }