theapemachine's picture
Add sparse transformer v19 with Triton-backed KNN scheduler and various backward modes. Includes utilities for synthetic data generation and model training. Implements chunked sparse updates and integrates with existing sparse linear layers.
bc1b8eb
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSDevice.h>
#include <c10/util/Exception.h>
#include <filesystem>
#include <dlfcn.h>
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <mutex>
namespace fs = std::filesystem;
namespace {
struct SparseLinearParams {
uint32_t N;
uint32_t In;
uint32_t Out;
uint32_t dummy;
};
static id<MTLLibrary> g_lib = nil;
static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
static std::mutex g_mutex;
static std::string metallib_path_for_this_module() {
Dl_info info;
if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
fs::path so_path(info.dli_fname);
return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
}
static void ensure_library_locked(id<MTLDevice> device) {
if (g_lib != nil) return;
std::string path = metallib_path_for_this_module();
TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
NSURL* url = [NSURL fileURLWithPath:ns_path];
NSError* err = nil;
g_lib = [device newLibraryWithURL:url error:&err];
if (g_lib == nil) {
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
}
}
static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
std::lock_guard<std::mutex> lock(g_mutex);
ensure_library_locked(device);
if (*pipeline != nil) return *pipeline;
NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn];
TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
NSError* err = nil;
*pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
if (*pipeline == nil) {
const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
}
return *pipeline;
}
static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
void* ctx = t.storage().data_ptr().get_context();
TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
return (__bridge id<MTLBuffer>)ctx;
}
static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
}
static void check_mps_float_contig(const at::Tensor& t, const char* name) {
TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel");
TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
}
} // namespace
std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) {
check_mps_float_contig(x2d, "x2d");
check_mps_float_contig(gy2d, "gy2d");
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
int64_t N = x2d.size(0);
int64_t In = x2d.size(1);
int64_t Out = active_mask.size(0);
TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size");
auto grad_w = at::zeros({Out, In}, x2d.options());
auto grad_b = at::zeros({Out}, x2d.options());
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
};
set_tensor(x2d, 0);
set_tensor(gy2d, 1);
set_tensor(active_mask, 2);
set_tensor(grad_w, 3);
set_tensor(grad_b, 4);
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};
[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];
MTLSize tg = MTLSizeMake(16, 16, 1);
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
return {grad_w, grad_b};
}
at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) {
check_mps_float_contig(gy2d, "gy2d");
check_mps_float_contig(weight, "weight");
TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
int64_t N = gy2d.size(0);
int64_t Out = gy2d.size(1);
int64_t In = weight.size(1);
TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
auto grad_x = at::zeros({N, In}, gy2d.options());
id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
};
set_tensor(gy2d, 0);
set_tensor(weight, 1);
set_tensor(active_mask, 2);
set_tensor(grad_x, 3);
SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
MTLSize tg = MTLSizeMake(16, 16, 1);
MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
return grad_x;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
}