grapheneaffiliates's picture
Upload python/bitlinear.py with huggingface_hub
2a50419 verified
"""
BitLinear: ternary {-1, 0, +1} linear layer with straight-through estimator.
Training: shadow float weights -> quantize forward -> STE backward
Inference: pure ternary weights -> matmul is add/sub only
Based on BitNet b1.58 (arxiv 2402.17764).
Drop-in replacement for nn.Linear. Use `use_bitlinear=True` in H4AttentionLayer
and H4TransformerBlock to swap all trainable projections to ternary.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def ternary_quantize(w):
"""Quantize weights to {-1, 0, +1} via absmean scaling.
scale = mean(|w|)
w_q = RoundClip(w / scale, -1, +1)
The absmean adapts the rounding boundary to each layer's weight
distribution. This is the canonical BitNet b1.58 method.
"""
scale = w.abs().mean() + 1e-8
w_scaled = w / scale
w_q = torch.clamp(torch.round(w_scaled), -1, 1)
return w_q, scale
def activation_quant_int8(x):
"""Per-token absmax quantization to int8 range [-127, 127].
Each token (last dim) gets its own scale factor.
"""
Q_b = 127.0
scale = x.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
x_q = torch.clamp(torch.round(x * Q_b / scale), -Q_b, Q_b)
return x_q, scale, Q_b
class BitLinear(nn.Module):
"""
Ternary linear layer. Drop-in replacement for nn.Linear.
Forward pass uses quantized weights via STE so gradients
flow to shadow float weights. Inference mode freezes to
pure ternary for integer-only compute.
"""
def __init__(self, in_features, out_features, bias=False):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features))
# Kaiming init scaled for ternary convergence
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
self.bias = nn.Parameter(torch.zeros(out_features)) if bias else None
self.register_buffer('_frozen_ternary', None)
self.register_buffer('_frozen_scale', None)
self._inference_mode = False
def forward(self, x):
if self._inference_mode and self._frozen_ternary is not None:
# Pure integer inference path
y = F.linear(x, self._frozen_ternary.float() * self._frozen_scale, self.bias)
return y
# QAT forward with straight-through estimator (STE)
#
# Weight STE: forward sees quantized weights, backward sees float shadow
w_q, w_scale = ternary_quantize(self.weight)
w_ste = self.weight + (w_q * w_scale - self.weight).detach()
# Activation STE: forward sees int8-quantized input, backward sees float
x_q, x_scale, Q_b = activation_quant_int8(x)
x_ste = x + (x_q * x_scale / Q_b - x).detach()
# Matmul through STE — gradients flow to self.weight and x
y = F.linear(x_ste, w_ste, self.bias)
return y
def freeze(self):
"""Lock to ternary for inference. After this, forward uses int path."""
w_q, w_s = ternary_quantize(self.weight.data)
self._frozen_ternary = w_q.to(torch.int8)
self._frozen_scale = w_s
self._inference_mode = True
def unfreeze(self):
"""Return to training mode with float shadow weights."""
self._inference_mode = False
@property
def ternary_stats(self):
"""Distribution of {-1, 0, +1} in current ternary quantization."""
w_q, _ = ternary_quantize(self.weight.data)
n = w_q.numel()
return {
'neg1': (w_q == -1).sum().item() / n,
'zero': (w_q == 0).sum().item() / n,
'pos1': (w_q == 1).sum().item() / n,
}
def extra_repr(self):
s = f'{self.in_features}, {self.out_features}, bias={self.bias is not None}'
if self._inference_mode:
s += ', frozen=True'
return s