File size: 1,749 Bytes
d5b7ee9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
"""BitLinear Layer — Ternary (1.58-bit) quantization for PyTorch.

Based on "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (2024).
Weights are quantized to {-1, 0, 1} for extreme efficiency.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F


def weight_quant(w):
    """Quantize weights to {-1, 0, 1} using absmean scaling."""
    scale = w.abs().mean().clamp(min=1e-5)
    w_q = torch.round(torch.clamp(w / scale, -1, 1))
    # Straight-Through Estimator (STE)
    return w + (w_q - w).detach()


def activation_quant(x):
    """Quantize activations to 8-bit using absmax scaling."""
    scale = 127.0 / x.abs().max().clamp(min=1e-5)
    x_q = torch.round(torch.clamp(x * scale, -128, 127))
    # Straight-Through Estimator (STE)
    return x + (x_q / scale - x).detach()


class BitLinear(nn.Linear):
    """Linear layer with ternary weights and 8-bit activations."""

    def forward(self, x):
        # Quantize weights
        w_q = weight_quant(self.weight)
        
        # Quantize activations (optional for small models, but good for BitNet adherence)
        x_q = activation_quant(x)
        
        # Perform linear operation (Matrix Multiplication)
        # Note: In a real C++ kernel, this would be addition/subtraction only.
        return F.linear(x_q, w_q, self.bias)


class BitRMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (often used with BitNet)."""

    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        norm_x = x.pow(2).mean(-1, keepdim=True)
        x_normed = x * torch.rsqrt(norm_x + self.eps)
        return x_normed * self.weight