File size: 1,721 Bytes
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
#pragma once

#include <ATen/ATen.h>
#include <tuple>

// ============================================================================
// Blockwise 4-bit quantization (NF4/FP4)
// ============================================================================

// Quantize and return both packed tensor and absmax
std::tuple<at::Tensor, at::Tensor> bnb_quantize_4bit(
    at::Tensor input,
    int64_t blocksize,
    int64_t quant_type);

// ============================================================================
// Blockwise 4-bit dequantization
// ============================================================================

// Dequantize packed 4-bit tensor back to output_dtype
at::Tensor bnb_dequantize_4bit(
    at::Tensor packed,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t numel,
    c10::ScalarType output_dtype);

// ============================================================================
// Fused GEMV: y = dequant(W) @ x
// W: [N, K/2] packed, absmax: [N, K_groups], x: [..., K], y: [..., N]
// ============================================================================

at::Tensor bnb_gemv_4bit(
    at::Tensor x,
    at::Tensor w,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t output_features);

// ============================================================================
// Fused GEMM: Y = X @ dequant(W).T
// X: [M, K], W: [N, K/2] packed, absmax: [N, K_groups], Y: [M, N]
// ============================================================================

at::Tensor bnb_gemm_4bit(
    at::Tensor x,
    at::Tensor w,
    at::Tensor absmax,
    int64_t blocksize,
    int64_t quant_type,
    int64_t output_features);