grapheneaffiliates commited on
Commit
2a50419
·
verified ·
1 Parent(s): febd523

Upload python/bitlinear.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. python/bitlinear.py +113 -0
python/bitlinear.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ BitLinear: ternary {-1, 0, +1} linear layer with straight-through estimator.
3
+
4
+ Training: shadow float weights -> quantize forward -> STE backward
5
+ Inference: pure ternary weights -> matmul is add/sub only
6
+
7
+ Based on BitNet b1.58 (arxiv 2402.17764).
8
+
9
+ Drop-in replacement for nn.Linear. Use `use_bitlinear=True` in H4AttentionLayer
10
+ and H4TransformerBlock to swap all trainable projections to ternary.
11
+ """
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ import math
17
+
18
+
19
+ def ternary_quantize(w):
20
+ """Quantize weights to {-1, 0, +1} via absmean scaling.
21
+
22
+ scale = mean(|w|)
23
+ w_q = RoundClip(w / scale, -1, +1)
24
+
25
+ The absmean adapts the rounding boundary to each layer's weight
26
+ distribution. This is the canonical BitNet b1.58 method.
27
+ """
28
+ scale = w.abs().mean() + 1e-8
29
+ w_scaled = w / scale
30
+ w_q = torch.clamp(torch.round(w_scaled), -1, 1)
31
+ return w_q, scale
32
+
33
+
34
+ def activation_quant_int8(x):
35
+ """Per-token absmax quantization to int8 range [-127, 127].
36
+
37
+ Each token (last dim) gets its own scale factor.
38
+ """
39
+ Q_b = 127.0
40
+ scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
41
+ x_q = torch.clamp(torch.round(x * Q_b / scale), -Q_b, Q_b)
42
+ return x_q, scale, Q_b
43
+
44
+
45
+ class BitLinear(nn.Module):
46
+ """
47
+ Ternary linear layer. Drop-in replacement for nn.Linear.
48
+
49
+ Forward pass uses quantized weights via STE so gradients
50
+ flow to shadow float weights. Inference mode freezes to
51
+ pure ternary for integer-only compute.
52
+ """
53
+
54
+ def __init__(self, in_features, out_features, bias=False):
55
+ super().__init__()
56
+ self.in_features = in_features
57
+ self.out_features = out_features
58
+ self.weight = nn.Parameter(torch.empty(out_features, in_features))
59
+ # Kaiming init scaled for ternary convergence
60
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
61
+ self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
62
+
63
+ self.register_buffer('_frozen_ternary', None)
64
+ self.register_buffer('_frozen_scale', None)
65
+ self._inference_mode = False
66
+
67
+ def forward(self, x):
68
+ if self._inference_mode and self._frozen_ternary is not None:
69
+ # Pure integer inference path
70
+ y = F.linear(x, self._frozen_ternary.float() * self._frozen_scale, self.bias)
71
+ return y
72
+
73
+ # QAT forward with straight-through estimator (STE)
74
+ #
75
+ # Weight STE: forward sees quantized weights, backward sees float shadow
76
+ w_q, w_scale = ternary_quantize(self.weight)
77
+ w_ste = self.weight + (w_q * w_scale - self.weight).detach()
78
+
79
+ # Activation STE: forward sees int8-quantized input, backward sees float
80
+ x_q, x_scale, Q_b = activation_quant_int8(x)
81
+ x_ste = x + (x_q * x_scale / Q_b - x).detach()
82
+
83
+ # Matmul through STE — gradients flow to self.weight and x
84
+ y = F.linear(x_ste, w_ste, self.bias)
85
+ return y
86
+
87
+ def freeze(self):
88
+ """Lock to ternary for inference. After this, forward uses int path."""
89
+ w_q, w_s = ternary_quantize(self.weight.data)
90
+ self._frozen_ternary = w_q.to(torch.int8)
91
+ self._frozen_scale = w_s
92
+ self._inference_mode = True
93
+
94
+ def unfreeze(self):
95
+ """Return to training mode with float shadow weights."""
96
+ self._inference_mode = False
97
+
98
+ @property
99
+ def ternary_stats(self):
100
+ """Distribution of {-1, 0, +1} in current ternary quantization."""
101
+ w_q, _ = ternary_quantize(self.weight.data)
102
+ n = w_q.numel()
103
+ return {
104
+ 'neg1': (w_q == -1).sum().item() / n,
105
+ 'zero': (w_q == 0).sum().item() / n,
106
+ 'pos1': (w_q == 1).sum().item() / n,
107
+ }
108
+
109
+ def extra_repr(self):
110
+ s = f'{self.in_features}, {self.out_features}, bias={self.bias is not None}'
111
+ if self._inference_mode:
112
+ s += ', frozen=True'
113
+ return s