Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb | namespace fs = std::filesystem; | |
| namespace { | |
| struct SparseLinearParams { | |
| uint32_t N; | |
| uint32_t In; | |
| uint32_t Out; | |
| uint32_t dummy; | |
| }; | |
| static id<MTLLibrary> g_lib = nil; | |
| static id<MTLComputePipelineState> g_pipeline_grad_w = nil; | |
| static id<MTLComputePipelineState> g_pipeline_grad_x = nil; | |
| static std::mutex g_mutex; | |
| static std::string metallib_path_for_this_module() { | |
| Dl_info info; | |
| if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string(); | |
| fs::path so_path(info.dli_fname); | |
| return (so_path.parent_path() / "sparse_linear_ops.metallib").string(); | |
| } | |
| static void ensure_library_locked(id<MTLDevice> device) { | |
| if (g_lib != nil) return; | |
| std::string path = metallib_path_for_this_module(); | |
| TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr"); | |
| NSString* ns_path = [NSString stringWithUTF8String:path.c_str()]; | |
| NSURL* url = [NSURL fileURLWithPath:ns_path]; | |
| NSError* err = nil; | |
| g_lib = [device newLibraryWithURL:url error:&err]; | |
| if (g_lib == nil) { | |
| const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; | |
| TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg); | |
| } | |
| } | |
| static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) { | |
| std::lock_guard<std::mutex> lock(g_mutex); | |
| ensure_library_locked(device); | |
| if (*pipeline != nil) return *pipeline; | |
| NSString* ns_fn = [NSString stringWithUTF8String:fn_name]; | |
| id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn]; | |
| TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib"); | |
| NSError* err = nil; | |
| *pipeline = [device newComputePipelineStateWithFunction:fn error:&err]; | |
| if (*pipeline == nil) { | |
| const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error"; | |
| TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg); | |
| } | |
| return *pipeline; | |
| } | |
| static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) { | |
| void* ctx = t.storage().data_ptr().get_context(); | |
| TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context"); | |
| return (__bridge id<MTLBuffer>)ctx; | |
| } | |
| static inline NSUInteger storage_offset_bytes(const at::Tensor& t) { | |
| return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size()); | |
| } | |
| static void check_mps_float_contig(const at::Tensor& t, const char* name) { | |
| TORCH_CHECK(t.device().is_mps(), name, " must be on MPS"); | |
| TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel"); | |
| TORCH_CHECK(t.is_contiguous(), name, " must be contiguous"); | |
| } | |
| } // namespace | |
| std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) { | |
| check_mps_float_contig(x2d, "x2d"); | |
| check_mps_float_contig(gy2d, "gy2d"); | |
| TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS"); | |
| TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool"); | |
| TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); | |
| int64_t N = x2d.size(0); | |
| int64_t In = x2d.size(1); | |
| int64_t Out = active_mask.size(0); | |
| TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size"); | |
| auto grad_w = at::zeros({Out, In}, x2d.options()); | |
| auto grad_b = at::zeros({Out}, x2d.options()); | |
| id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device(); | |
| id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float"); | |
| at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); | |
| id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline]; | |
| auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx]; | |
| }; | |
| set_tensor(x2d, 0); | |
| set_tensor(gy2d, 1); | |
| set_tensor(active_mask, 2); | |
| set_tensor(grad_w, 3); | |
| set_tensor(grad_b, 4); | |
| SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0}; | |
| [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5]; | |
| MTLSize tg = MTLSizeMake(16, 16, 1); | |
| MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; | |
| return {grad_w, grad_b}; | |
| } | |
| at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) { | |
| check_mps_float_contig(gy2d, "gy2d"); | |
| check_mps_float_contig(weight, "weight"); | |
| TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS"); | |
| TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool"); | |
| TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous"); | |
| int64_t N = gy2d.size(0); | |
| int64_t Out = gy2d.size(1); | |
| int64_t In = weight.size(1); | |
| TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width"); | |
| auto grad_x = at::zeros({N, In}, gy2d.options()); | |
| id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device(); | |
| id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float"); | |
| at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); | |
| id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline]; | |
| auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx]; | |
| }; | |
| set_tensor(gy2d, 0); | |
| set_tensor(weight, 1); | |
| set_tensor(active_mask, 2); | |
| set_tensor(grad_x, 3); | |
| SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4]; | |
| MTLSize tg = MTLSizeMake(16, 16, 1); | |
| MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1); | |
| [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg]; | |
| return grad_x; | |
| } | |
| PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |
| m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)"); | |
| m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)"); | |
| } |