File size: 6,939 Bytes
bc1b8eb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <torch/extension.h>
#include <ATen/ATen.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSDevice.h>
#include <c10/util/Exception.h>
#include <filesystem>
#include <dlfcn.h>
#import <Metal/Metal.h>
#import <Foundation/Foundation.h>
#include <mutex>

namespace fs = std::filesystem;

namespace {
struct SparseLinearParams {
    uint32_t N;
    uint32_t In;
    uint32_t Out;
    uint32_t dummy;
};

static id<MTLLibrary> g_lib = nil;
static id<MTLComputePipelineState> g_pipeline_grad_w = nil;
static id<MTLComputePipelineState> g_pipeline_grad_x = nil;
static std::mutex g_mutex;

static std::string metallib_path_for_this_module() {
    Dl_info info;
    if (dladdr((void*)&metallib_path_for_this_module, &info) == 0 || info.dli_fname == nullptr) return std::string();
    fs::path so_path(info.dli_fname);
    return (so_path.parent_path() / "sparse_linear_ops.metallib").string();
}

static void ensure_library_locked(id<MTLDevice> device) {
    if (g_lib != nil) return;
    std::string path = metallib_path_for_this_module();
    TORCH_CHECK(!path.empty(), "sparse_linear_ops: failed to locate extension path via dladdr");
    NSString* ns_path = [NSString stringWithUTF8String:path.c_str()];
    NSURL* url = [NSURL fileURLWithPath:ns_path];
    NSError* err = nil;
    g_lib = [device newLibraryWithURL:url error:&err];
    if (g_lib == nil) {
        const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
        TORCH_CHECK(false, "sparse_linear_ops: failed to load metallib at ", path, ": ", msg);
    }
}

static id<MTLComputePipelineState> ensure_pipeline(id<MTLDevice> device, id<MTLComputePipelineState>* pipeline, const char* fn_name) {
    std::lock_guard<std::mutex> lock(g_mutex);
    ensure_library_locked(device);
    if (*pipeline != nil) return *pipeline;
    NSString* ns_fn = [NSString stringWithUTF8String:fn_name];
    id<MTLFunction> fn =[g_lib newFunctionWithName:ns_fn];
    TORCH_CHECK(fn != nil, "sparse_linear_ops: function `", fn_name, "` not found in metallib");
    NSError* err = nil;
    *pipeline = [device newComputePipelineStateWithFunction:fn error:&err];
    if (*pipeline == nil) {
        const char* msg = err ? [[err localizedDescription] UTF8String] : "unknown error";
        TORCH_CHECK(false, "sparse_linear_ops: failed to create pipeline for ", fn_name, ": ", msg);
    }
    return *pipeline;
}

static inline id<MTLBuffer> storage_as_mtlbuffer(const at::Tensor& t) {
    void* ctx = t.storage().data_ptr().get_context();
    TORCH_CHECK(ctx != nullptr, "sparse_linear_ops: expected MPS tensor storage with MTLBuffer context");
    return (__bridge id<MTLBuffer>)ctx;
}

static inline NSUInteger storage_offset_bytes(const at::Tensor& t) {
    return (NSUInteger)(t.storage_offset() * (int64_t)t.element_size());
}

static void check_mps_float_contig(const at::Tensor& t, const char* name) {
    TORCH_CHECK(t.device().is_mps(), name, " must be on MPS");
    TORCH_CHECK(t.dtype() == at::kFloat, name, " must be float32 for v12 kernel");
    TORCH_CHECK(t.is_contiguous(), name, " must be contiguous");
}
} // namespace

std::vector<at::Tensor> sparse_linear_grad_wb(at::Tensor x2d, at::Tensor gy2d, at::Tensor active_mask) {
    check_mps_float_contig(x2d, "x2d");
    check_mps_float_contig(gy2d, "gy2d");
    TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
    TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
    TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
    
    int64_t N = x2d.size(0);
    int64_t In = x2d.size(1);
    int64_t Out = active_mask.size(0);
    TORCH_CHECK(gy2d.size(1) == Out, "gy2d width must equal active_mask size");

    auto grad_w = at::zeros({Out, In}, x2d.options());
    auto grad_b = at::zeros({Out}, x2d.options());

    id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
    id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_w, "sparse_linear_grad_w_float");
    at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
    
    id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];

    auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
    };
    set_tensor(x2d, 0);
    set_tensor(gy2d, 1);
    set_tensor(active_mask, 2);
    set_tensor(grad_w, 3);
    set_tensor(grad_b, 4);
    
    SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};
    [encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:5];

    MTLSize tg = MTLSizeMake(16, 16, 1);
    MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((Out + 15) / 16), 1);[encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
    
    return {grad_w, grad_b};
}

at::Tensor sparse_linear_grad_x(at::Tensor gy2d, at::Tensor weight, at::Tensor active_mask) {
    check_mps_float_contig(gy2d, "gy2d");
    check_mps_float_contig(weight, "weight");
    TORCH_CHECK(active_mask.device().is_mps(), "active_mask must be on MPS");
    TORCH_CHECK(active_mask.dtype() == at::kBool, "active_mask must be bool");
    TORCH_CHECK(active_mask.is_contiguous(), "active_mask must be contiguous");
    
    int64_t N = gy2d.size(0);
    int64_t Out = gy2d.size(1);
    int64_t In = weight.size(1);
    TORCH_CHECK(weight.size(0) == Out, "weight out_features must match gy2d width");
    
    auto grad_x = at::zeros({N, In}, gy2d.options());

    id<MTLDevice> device = (id<MTLDevice>)at::mps::MPSDevice::getInstance()->device();
    id<MTLComputePipelineState> pipeline = ensure_pipeline(device, &g_pipeline_grad_x, "sparse_linear_grad_x_float");
    at::mps::MPSStream* stream = at::mps::getCurrentMPSStream();
    
    id<MTLComputeCommandEncoder> encoder = (id<MTLComputeCommandEncoder>)stream->commandEncoder();[encoder setComputePipelineState:pipeline];
    
    auto set_tensor = [&](const at::Tensor& t, int idx) {[encoder setBuffer:storage_as_mtlbuffer(t) offset:storage_offset_bytes(t) atIndex:(NSUInteger)idx];
    };
    set_tensor(gy2d, 0);
    set_tensor(weight, 1);
    set_tensor(active_mask, 2);
    set_tensor(grad_x, 3);
    
    SparseLinearParams prm{(uint32_t)N, (uint32_t)In, (uint32_t)Out, 0};[encoder setBytes:&prm length:sizeof(SparseLinearParams) atIndex:4];
    
    MTLSize tg = MTLSizeMake(16, 16, 1);
    MTLSize grid = MTLSizeMake((NSUInteger)((In + 15) / 16), (NSUInteger)((N + 15) / 16), 1);
    [encoder dispatchThreadgroups:grid threadsPerThreadgroup:tg];
    
    return grad_x;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("sparse_linear_grad_wb", &sparse_linear_grad_wb, "Sparse active-row Linear dW/db (Metal/MPS, fp32)");
    m.def("sparse_linear_grad_x", &sparse_linear_grad_x, "Sparse active-row Linear dX (Metal/MPS, fp32)");
}