File size: 8,410 Bytes
72e872c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
"""

PixelArtGen β€” BitLinear 1.58-bit Layer & RMSNorm



Implementation of the core BitNet b1.58 components:

- RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)

- BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1}



References:

- "The Era of 1-bit LLMs" (Ma et al., 2024) β€” arXiv:2402.17764

- "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) β€” arXiv:2310.11453

- "RMSNorm" (Zhang & Sennrich, 2019) β€” arXiv:1910.07467

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import math


class RMSNorm(nn.Module):
    """

    Root Mean Square Layer Normalization.

    

    Simpler and faster than LayerNorm β€” removes mean centering,

    keeps only the re-scaling by root mean square.

    

    RMSNorm(x) = x / RMS(x) * g

    where RMS(x) = sqrt(mean(x^2))

    

    Reference: arXiv:1910.07467

    """

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x: torch.Tensor) -> torch.Tensor:
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output = self._norm(x.float()).type_as(x)
        return output * self.weight


def activation_quant(x: torch.Tensor) -> torch.Tensor:
    """

    Per-token 8-bit activation quantization from BitNet b1.58.

    

    Quantizes activations to [-127, 127] per-token using absmax scaling.

    Symmetric quantization (no zero-point) as specified in the paper.

    

    Args:

        x: (..., d_model) float tensor

    Returns:

        Quantized tensor (still float for autograd compatibility), scale factor

    """
    Qb = 127  # 8-bit signed: 2^(8-1) - 1
    scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
    x_quant = (x * Qb / scale).clamp(-Qb, Qb).round()
    # STE: detach the rounding, keep gradients flowing
    x_quant = x + (x_quant * scale / Qb - x).detach()
    return x_quant


def weight_quant(w: torch.Tensor) -> tuple:
    """

    Absmean ternary weight quantization from BitNet b1.58.

    

    Quantizes weights to {-1, 0, +1} using absmean scaling:

    1. Compute gamma = mean(|W|) 

    2. Scale: W_scaled = W / gamma

    3. Round to nearest in {-1, 0, +1}

    

    Args:

        w: (out_features, in_features) weight matrix

    Returns:

        (quantized_weights, scale_factor)

    """
    gamma = w.abs().mean().clamp(min=1e-5)
    w_scaled = w / gamma
    w_quant = w_scaled.clamp(-1, 1).round()
    # STE: detach the rounding, keep gradients on the latent weights
    w_quant = w + (w_quant * gamma - w).detach()
    return w_quant, gamma


class BitLinear158(nn.Module):
    """

    1.58-bit Linear Layer from BitNet b1.58.

    

    Drop-in replacement for nn.Linear with:

    - Ternary weights {-1, 0, +1} via absmean quantization

    - 8-bit per-token activation quantization

    - Built-in RMSNorm (absorbs the preceding LayerNorm)

    - No bias (following BitNet b1.58 / LLaMA convention)

    - Full-precision latent weights maintained for training (STE)

    

    Forward pass:

        1. RMSNorm the input

        2. Quantize activations to 8-bit 

        3. Quantize weights to ternary

        4. Matrix multiply (effectively integer addition)

        5. Rescale output

    

    During training, gradients flow through quantization via the

    Straight-Through Estimator (STE) β€” the gradient of round() 

    is treated as the identity function.

    

    Reference: arXiv:2402.17764

    """

    def __init__(self, in_features: int, out_features: int):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Full-precision latent weight (master copy for training)
        self.weight = nn.Parameter(torch.empty(out_features, in_features))

        # Built-in RMSNorm (replaces the preceding LayerNorm)
        self.rms_norm = RMSNorm(in_features)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Kaiming uniform initialization, same as nn.Linear."""
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """

        Args:

            x: (batch, seq_len, in_features)

        Returns:

            (batch, seq_len, out_features)

        """
        # 1. Normalize input (built-in RMSNorm)
        x = self.rms_norm(x)

        # 2. Quantize activations to 8-bit per-token
        x_q = activation_quant(x)

        # 3. Quantize weights to ternary {-1, 0, +1}
        w_q, w_scale = weight_quant(self.weight)

        # 4. Matrix multiply with quantized weights and activations
        # In theory this is integer addition; in practice we use float
        # for autograd compatibility during training
        output = F.linear(x_q, w_q)

        return output

    def extra_repr(self) -> str:
        return f"in={self.in_features}, out={self.out_features}, bits=1.58"


class SwiGLU(nn.Module):
    """

    SwiGLU activation for Feed-Forward Networks.

    

    SwiGLU(x) = (Swish(xW1) βŠ™ xV) W2

    

    Uses 3 linear projections instead of 2, but the hidden dim

    is typically reduced by 2/3 to keep parameter count similar.

    

    When used with BitLinear158, all three projections are ternary.

    

    Reference: arXiv:2002.05202 (Shazeer, 2020)

    """

    def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True):
        super().__init__()
        hidden_features = hidden_features or int(in_features * 8 / 3)  # 2/3 of 4x expansion
        # Round to nearest multiple of 8 for efficiency
        hidden_features = ((hidden_features + 7) // 8) * 8

        Linear = BitLinear158 if use_bitlinear else nn.Linear

        self.w1 = Linear(in_features, hidden_features)   # gate projection
        self.v = Linear(in_features, hidden_features)     # value projection
        self.w2 = Linear(hidden_features, in_features)    # output projection

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w2(F.silu(self.w1(x)) * self.v(x))


# ──── Testing ────────────────────────────────────────────────────

if __name__ == "__main__":
    print("Testing BitLinear158 components...")

    # Test RMSNorm
    norm = RMSNorm(256)
    x = torch.randn(2, 10, 256)
    y = norm(x)
    print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}")

    # Test weight quantization
    w = torch.randn(512, 256)
    w_q, scale = weight_quant(w)
    unique = torch.unique(w_q.detach())
    print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}")
    print(f"  Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, "
          f"0={((w_q.detach().round() == 0).sum().item())}, "
          f"+1={((w_q.detach().round() == 1).sum().item())}")

    # Test activation quantization
    a = torch.randn(2, 10, 256)
    a_q = activation_quant(a)
    print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]")

    # Test BitLinear158
    layer = BitLinear158(256, 512)
    x = torch.randn(2, 10, 256)
    y = layer(x)
    print(f"BitLinear158: {x.shape} -> {y.shape}")

    # Test gradient flow (STE)
    loss = y.sum()
    loss.backward()
    assert layer.weight.grad is not None, "Gradient did not flow through STE!"
    print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})")

    # Test SwiGLU
    swiglu = SwiGLU(256, use_bitlinear=True)
    x = torch.randn(2, 10, 256)
    y = swiglu(x)
    print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}")
    total = sum(p.numel() for p in swiglu.parameters())
    print(f"  SwiGLU params: {total:,}")

    # Parameter comparison
    ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256))
    ff_params = sum(p.numel() for p in ff_standard.parameters())
    print(f"  Standard FFN params: {ff_params:,}")
    print(f"  Ratio: {total / ff_params:.2f}x")

    print("\nAll tests passed! βœ“")