File size: 4,607 Bytes
20347e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
from typing import Optional, Tuple

import torch

from ._ops import ops

# Quant type constants (match bitsandbytes DataType_t)
FP4 = 1
NF4 = 2


def quantize_4bit(
    input: torch.Tensor,
    blocksize: int = 64,
    quant_type: int = NF4,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Blockwise 4-bit quantization using NF4 or FP4 codebook.

    Args:
        input: Input tensor on MPS device (float16, bfloat16, or float32).
        blocksize: Number of elements per quantization block (64 or 128).
        quant_type: FP4 (1) or NF4 (2).

    Returns:
        Tuple of (packed, absmax):
            packed: uint8 tensor of packed 4-bit values [numel/2].
            absmax: float32 tensor of per-block max absolute values.
    """
    return ops.bnb_quantize_4bit(input, blocksize, quant_type)


def dequantize_4bit(
    packed: torch.Tensor,
    absmax: torch.Tensor,
    blocksize: int = 64,
    quant_type: int = NF4,
    numel: int = -1,
    output_dtype: torch.dtype = torch.float16,
) -> torch.Tensor:
    """Blockwise 4-bit dequantization using NF4 or FP4 codebook.

    Args:
        packed: uint8 tensor of packed 4-bit values.
        absmax: float32 tensor of per-block max absolute values.
        blocksize: Number of elements per quantization block (64 or 128).
        quant_type: FP4 (1) or NF4 (2).
        numel: Number of elements in the original tensor.
               If -1, inferred as packed.numel() * 2.
        output_dtype: Output scalar type.

    Returns:
        Dequantized tensor.
    """
    if numel < 0:
        numel = packed.numel() * 2
    return ops.bnb_dequantize_4bit(
        packed, absmax, blocksize, quant_type, numel, output_dtype
    )


def gemv_4bit(
    x: torch.Tensor,
    w: torch.Tensor,
    absmax: torch.Tensor,
    output_features: int,
    blocksize: int = 64,
    quant_type: int = NF4,
) -> torch.Tensor:
    """Fused matrix-vector multiply with 4-bit quantized weights.

    Computes y = dequant(W) @ x, where W is blockwise NF4/FP4 quantized.

    Args:
        x: Input vector [..., K] on MPS device.
        w: Packed weight matrix [N, K/2] (uint8) on MPS device.
        absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
        output_features: Number of output features (N).
        blocksize: Quantization block size (64 or 128).
        quant_type: FP4 (1) or NF4 (2).

    Returns:
        Output tensor [..., N].
    """
    return ops.bnb_gemv_4bit(x, w, absmax, blocksize, quant_type, output_features)


def gemm_4bit(
    x: torch.Tensor,
    w: torch.Tensor,
    absmax: torch.Tensor,
    output_features: int,
    blocksize: int = 64,
    quant_type: int = NF4,
) -> torch.Tensor:
    """Fused matrix-matrix multiply with 4-bit quantized transposed weights.

    Computes Y = X @ dequant(W).T, where W is blockwise NF4/FP4 quantized.

    Args:
        x: Input matrix [..., M, K] on MPS device.
        w: Packed weight matrix [N, K/2] (uint8) on MPS device.
        absmax: Per-block scales [N, ceil(K/blocksize)] (float32).
        output_features: Number of output features (N).
        blocksize: Quantization block size (64 or 128).
        quant_type: FP4 (1) or NF4 (2).

    Returns:
        Output tensor [..., M, N].
    """
    return ops.bnb_gemm_4bit(x, w, absmax, blocksize, quant_type, output_features)


def linear_4bit(
    x: torch.Tensor,
    w: torch.Tensor,
    absmax: torch.Tensor,
    output_features: int,
    blocksize: int = 64,
    quant_type: int = NF4,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """4-bit quantized linear layer (auto-selects GEMV or GEMM).

    Args:
        x: Input tensor on MPS device.
        w: Packed weight [N, K/2] (uint8).
        absmax: Scales [N, ceil(K/blocksize)] (float32).
        output_features: N.
        blocksize: 64 or 128.
        quant_type: FP4 (1) or NF4 (2).
        bias: Optional bias [N].

    Returns:
        Output tensor.
    """
    input_1d = x.dim() == 1
    if input_1d or (x.dim() >= 2 and x.size(-2) == 1):
        x_flat = x.view(x.size(-1)) if input_1d else x.squeeze(-2)
        y = gemv_4bit(
            x_flat,
            w,
            absmax,
            output_features,
            blocksize,
            quant_type,
        )
        if input_1d:
            y = y.squeeze(0)
        elif x.dim() >= 2:
            y = y.unsqueeze(-2)
    else:
        y = gemm_4bit(x, w, absmax, output_features, blocksize, quant_type)

    if bias is not None:
        y = y + bias

    return y

__all__ = [
    "quantize_4bit",
    "dequantize_4bit",
    "gemv_4bit",
    "gemm_4bit",
    "linear_4bit",
]