File size: 12,085 Bytes
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131f3b1
 
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131f3b1
 
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
// bitsandbytes MPS Metal kernels - ObjC++ dispatch
// Interfaces between PyTorch MPS tensors and Metal compute kernels.
// Uses the same dispatch pattern as kernels-community/activation, with
// get_command_buffer() moved inside dispatch_sync to avoid race conditions
// during model loading.

#include <torch/torch.h>

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>

#include <algorithm>
#include <iostream>
#include <sstream>
#include <unordered_map>

#ifdef EMBEDDED_METALLIB_HEADER
#include EMBEDDED_METALLIB_HEADER
#endif

// ============================================================================
// Metal helpers
// ============================================================================

static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& t) {
  return __builtin_bit_cast(id<MTLBuffer>, t.storage().data());
}

namespace {

static id<MTLLibrary> library = nil;

id<MTLLibrary> get_library() {
  if (library != nil)
    return library;
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
  NSError* error = nil;

#ifdef EMBEDDED_METALLIB_HEADER
  library = EMBEDDED_METALLIB_NAMESPACE::createLibrary(device, &error);
  if (library == nil) {
    std::cerr << "Failed to create Metal library from embedded header"
              << std::endl;
    if (error)
      std::cerr << "Error: " << [[error localizedDescription] UTF8String]
                << std::endl;
  }
#else
  library = [device newDefaultLibrary];
  if (library == nil) {
    std::cerr << "Failed to load Metal library" << std::endl;
    if (error)
      std::cerr << "Error: " << [[error localizedDescription] UTF8String]
                << std::endl;
  }
#endif
  return library;
}

id<MTLComputePipelineState> get_pipeline(const std::string& name) {
  static std::unordered_map<std::string, id<MTLComputePipelineState>> cache;
  auto it = cache.find(name);
  if (it != cache.end())
    return it->second;

  id<MTLLibrary> lib = get_library();
  if (!lib)
    return nil;

  id<MTLFunction> func =
      [lib newFunctionWithName:[NSString stringWithUTF8String:name.c_str()]];
  if (!func) {
    std::cerr << "Kernel not found: " << name << std::endl;
    return nil;
  }

  NSError* error = nil;
  id<MTLDevice> device = MTLCreateSystemDefaultDevice();
  id<MTLComputePipelineState> state =
      [device newComputePipelineStateWithFunction:func error:&error];
  if (!state) {
    std::cerr << "Failed to create pipeline for " << name << std::endl;
    return nil;
  }
  cache[name] = state;
  return state;
}

std::string type_str(torch::ScalarType type) {
  switch (type) {
    case torch::kFloat32:
      return "float";
    case torch::kFloat16:
      return "half";
    case torch::kBFloat16:
      return "bfloat16_t";
    default:
      throw std::runtime_error("Unsupported dtype for BnB MPS kernels");
  }
}

void set_tensor(
    id<MTLComputeCommandEncoder> enc,
    const torch::Tensor& t,
    int index) {
  [enc setBuffer:getMTLBufferStorage(t)
          offset:t.storage_offset() * t.element_size()
         atIndex:index];
}

} // namespace

// ============================================================================
// Public API: quantize_4bit
// ============================================================================

std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
    at::Tensor input,
    int64_t blocksize,
    int64_t quant_type) {
  TORCH_CHECK(input.is_mps(), "Input must be on MPS device");
  TORCH_CHECK(
      blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512,
      "Only blocksize 64, 128, 256, and 512 are supported");
  TORCH_CHECK(
      quant_type == 1 || quant_type == 2,
      "quant_type must be 1 (FP4) or 2 (NF4)");

  int n = static_cast<int>(input.numel());
  int num_blocks =
      (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);
  int packed_size = (n + 1) / 2;

  auto absmax =
      torch::empty({num_blocks}, input.options().dtype(torch::kFloat32));
  auto packed =
      torch::empty({packed_size}, input.options().dtype(torch::kUInt8));

  std::stringstream ss;
  ss << "bnb_quantize_blockwise_" << type_str(input.scalar_type()) << "_bs_"
     << blocksize << "_qt_" << quant_type;

  auto pipeline = get_pipeline(ss.str());
  TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());

  @autoreleasepool {
    dispatch_sync(torch::mps::get_dispatch_queue(), ^{
      @autoreleasepool {
        id<MTLCommandBuffer> commandBuffer =
            torch::mps::get_command_buffer();
        TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");

        id<MTLComputeCommandEncoder> encoder =
            [commandBuffer computeCommandEncoder];
        TORCH_CHECK(encoder, "Failed to create compute encoder");

        [encoder setComputePipelineState:pipeline];

        int idx = 0;
        set_tensor(encoder, input, idx++);
        set_tensor(encoder, absmax, idx++);
        set_tensor(encoder, packed, idx++);
        [encoder setBytes:&n length:sizeof(int) atIndex:idx++];

        NSUInteger threads_per_tg = pipeline.threadExecutionWidth;
        MTLSize grid = MTLSizeMake(num_blocks, 1, 1);
        MTLSize tg = MTLSizeMake(threads_per_tg, 1, 1);
        [encoder dispatchThreads:grid threadsPerThreadgroup:tg];
        [encoder endEncoding];

        torch::mps::commit();
      }
    });
  }

  return std::make_tuple(packed, absmax);
}

// ============================================================================
// Public API: dequantize_blockwise
// ============================================================================

at::Tensor bnb_dequantize_4bit(
    at::Tensor packed,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t numel,
    torch::ScalarType output_dtype) {
  TORCH_CHECK(packed.is_mps(), "packed must be on MPS device");
  TORCH_CHECK(absmax.is_mps(), "absmax must be on MPS device");
  TORCH_CHECK(
      blocksize == 64 || blocksize == 128 || blocksize == 256 || blocksize == 512,
      "Only blocksize 64, 128, 256, and 512 are supported");

  int n = static_cast<int>(numel);
  int num_blocks =
      (n + static_cast<int>(blocksize) - 1) / static_cast<int>(blocksize);

  auto output = torch::empty({n}, packed.options().dtype(output_dtype));

  std::stringstream ss;
  ss << "bnb_dequantize_blockwise_" << type_str(output_dtype) << "_bs_"
     << blocksize << "_qt_" << quant_type;

  auto pipeline = get_pipeline(ss.str());
  TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());

  @autoreleasepool {
    dispatch_sync(torch::mps::get_dispatch_queue(), ^{
      @autoreleasepool {
        id<MTLCommandBuffer> commandBuffer =
            torch::mps::get_command_buffer();
        TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");

        id<MTLComputeCommandEncoder> encoder =
            [commandBuffer computeCommandEncoder];
        TORCH_CHECK(encoder, "Failed to create compute encoder");

        [encoder setComputePipelineState:pipeline];

        int idx = 0;
        set_tensor(encoder, packed, idx++);
        set_tensor(encoder, absmax, idx++);
        set_tensor(encoder, output, idx++);
        [encoder setBytes:&n length:sizeof(int) atIndex:idx++];

        NSUInteger max_tg = pipeline.maxTotalThreadsPerThreadgroup;
        NSUInteger desired = (blocksize + 1) / 2;
        NSUInteger tg_size =
            std::min(max_tg, std::max(static_cast<NSUInteger>(1), desired));
        if (tg_size < pipeline.threadExecutionWidth) {
          tg_size = std::min(pipeline.threadExecutionWidth, max_tg);
        }

        MTLSize grid = MTLSizeMake(tg_size * num_blocks, 1, 1);
        MTLSize tg = MTLSizeMake(tg_size, 1, 1);
        [encoder dispatchThreads:grid threadsPerThreadgroup:tg];
        [encoder endEncoding];

        torch::mps::commit();
      }
    });
  }

  return output;
}

// ============================================================================
// Public API: GEMV (matrix-vector multiply)
// y = dequant(W) @ x
// ============================================================================

at::Tensor bnb_gemv_4bit(
    at::Tensor x,
    at::Tensor w,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t output_features) {
  TORCH_CHECK(
      x.is_mps() && w.is_mps() && absmax.is_mps(),
      "All tensors must be on MPS device");

  int K = static_cast<int>(x.size(-1));
  int N = static_cast<int>(output_features);

  auto out_sizes = x.sizes().vec();
  out_sizes.back() = N;
  auto y = torch::zeros(out_sizes, x.options());

  std::stringstream ss;
  ss << "bnb_qmv_" << type_str(x.scalar_type()) << "_bs_" << blocksize
     << "_qt_" << quant_type;

  auto pipeline = get_pipeline(ss.str());
  TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());

  @autoreleasepool {
    dispatch_sync(torch::mps::get_dispatch_queue(), ^{
      @autoreleasepool {
        id<MTLCommandBuffer> commandBuffer =
            torch::mps::get_command_buffer();
        TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");

        id<MTLComputeCommandEncoder> encoder =
            [commandBuffer computeCommandEncoder];
        TORCH_CHECK(encoder, "Failed to create compute encoder");

        [encoder setComputePipelineState:pipeline];

        int idx = 0;
        set_tensor(encoder, w, idx++);
        set_tensor(encoder, absmax, idx++);
        set_tensor(encoder, x, idx++);
        set_tensor(encoder, y, idx++);
        [encoder setBytes:&K length:sizeof(int) atIndex:idx++];
        [encoder setBytes:&N length:sizeof(int) atIndex:idx++];

        int rows_per_tg = 8;
        int grid_y = (N + rows_per_tg - 1) / rows_per_tg;

        [encoder dispatchThreadgroups:MTLSizeMake(1, grid_y, 1)
                threadsPerThreadgroup:MTLSizeMake(32 * 2, 1, 1)];
        [encoder endEncoding];

        torch::mps::commit();
      }
    });
  }

  return y;
}

// ============================================================================
// Public API: GEMM (matrix-matrix multiply with transposed weight)
// Y = X @ dequant(W).T
// ============================================================================

at::Tensor bnb_gemm_4bit(
    at::Tensor x,
    at::Tensor w,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t output_features) {
  TORCH_CHECK(
      x.is_mps() && w.is_mps() && absmax.is_mps(),
      "All tensors must be on MPS device");
  TORCH_CHECK(x.dim() >= 2, "Input must be at least 2D for GEMM");

  int K = static_cast<int>(x.size(-1));
  int M = static_cast<int>(x.size(-2));
  int N = static_cast<int>(output_features);

  auto out_sizes = x.sizes().vec();
  out_sizes.back() = N;
  auto y = torch::zeros(out_sizes, x.options());

  std::stringstream ss;
  ss << "bnb_qmm_t_" << type_str(x.scalar_type()) << "_bs_" << blocksize
     << "_qt_" << quant_type;

  auto pipeline = get_pipeline(ss.str());
  TORCH_CHECK(pipeline, "Kernel not found: ", ss.str());

  @autoreleasepool {
    dispatch_sync(torch::mps::get_dispatch_queue(), ^{
      @autoreleasepool {
        id<MTLCommandBuffer> commandBuffer =
            torch::mps::get_command_buffer();
        TORCH_CHECK(commandBuffer, "Failed to get MPS command buffer");

        id<MTLComputeCommandEncoder> encoder =
            [commandBuffer computeCommandEncoder];
        TORCH_CHECK(encoder, "Failed to create compute encoder");

        [encoder setComputePipelineState:pipeline];

        int idx = 0;
        set_tensor(encoder, w, idx++);
        set_tensor(encoder, absmax, idx++);
        set_tensor(encoder, x, idx++);
        set_tensor(encoder, y, idx++);
        [encoder setBytes:&K length:sizeof(int) atIndex:idx++];
        [encoder setBytes:&N length:sizeof(int) atIndex:idx++];
        [encoder setBytes:&M length:sizeof(int) atIndex:idx++];

        int grid_x = (N + 31) / 32;
        int grid_y = (M + 31) / 32;

        [encoder dispatchThreadgroups:MTLSizeMake(grid_x, grid_y, 1)
                threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
        [encoder endEncoding];

        torch::mps::commit();
      }
    });
  }

  return y;
}