BitPixelLM / model /bitlinear.py
BlakePeavy's picture
Upload BitPixelLM model artifacts
72e872c verified
"""
PixelArtGen β€” BitLinear 1.58-bit Layer & RMSNorm
Implementation of the core BitNet b1.58 components:
- RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)
- BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1}
References:
- "The Era of 1-bit LLMs" (Ma et al., 2024) β€” arXiv:2402.17764
- "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) β€” arXiv:2310.11453
- "RMSNorm" (Zhang & Sennrich, 2019) β€” arXiv:1910.07467
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class RMSNorm(nn.Module):
"""
Root Mean Square Layer Normalization.
Simpler and faster than LayerNorm β€” removes mean centering,
keeps only the re-scaling by root mean square.
RMSNorm(x) = x / RMS(x) * g
where RMS(x) = sqrt(mean(x^2))
Reference: arXiv:1910.07467
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x: torch.Tensor) -> torch.Tensor:
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x: torch.Tensor) -> torch.Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
def activation_quant(x: torch.Tensor) -> torch.Tensor:
"""
Per-token 8-bit activation quantization from BitNet b1.58.
Quantizes activations to [-127, 127] per-token using absmax scaling.
Symmetric quantization (no zero-point) as specified in the paper.
Args:
x: (..., d_model) float tensor
Returns:
Quantized tensor (still float for autograd compatibility), scale factor
"""
Qb = 127 # 8-bit signed: 2^(8-1) - 1
scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
x_quant = (x * Qb / scale).clamp(-Qb, Qb).round()
# STE: detach the rounding, keep gradients flowing
x_quant = x + (x_quant * scale / Qb - x).detach()
return x_quant
def weight_quant(w: torch.Tensor) -> tuple:
"""
Absmean ternary weight quantization from BitNet b1.58.
Quantizes weights to {-1, 0, +1} using absmean scaling:
1. Compute gamma = mean(|W|)
2. Scale: W_scaled = W / gamma
3. Round to nearest in {-1, 0, +1}
Args:
w: (out_features, in_features) weight matrix
Returns:
(quantized_weights, scale_factor)
"""
gamma = w.abs().mean().clamp(min=1e-5)
w_scaled = w / gamma
w_quant = w_scaled.clamp(-1, 1).round()
# STE: detach the rounding, keep gradients on the latent weights
w_quant = w + (w_quant * gamma - w).detach()
return w_quant, gamma
class BitLinear158(nn.Module):
"""
1.58-bit Linear Layer from BitNet b1.58.
Drop-in replacement for nn.Linear with:
- Ternary weights {-1, 0, +1} via absmean quantization
- 8-bit per-token activation quantization
- Built-in RMSNorm (absorbs the preceding LayerNorm)
- No bias (following BitNet b1.58 / LLaMA convention)
- Full-precision latent weights maintained for training (STE)
Forward pass:
1. RMSNorm the input
2. Quantize activations to 8-bit
3. Quantize weights to ternary
4. Matrix multiply (effectively integer addition)
5. Rescale output
During training, gradients flow through quantization via the
Straight-Through Estimator (STE) β€” the gradient of round()
is treated as the identity function.
Reference: arXiv:2402.17764
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Full-precision latent weight (master copy for training)
self.weight = nn.Parameter(torch.empty(out_features, in_features))
# Built-in RMSNorm (replaces the preceding LayerNorm)
self.rms_norm = RMSNorm(in_features)
# Initialize weights
self._init_weights()
def _init_weights(self):
"""Kaiming uniform initialization, same as nn.Linear."""
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (batch, seq_len, in_features)
Returns:
(batch, seq_len, out_features)
"""
# 1. Normalize input (built-in RMSNorm)
x = self.rms_norm(x)
# 2. Quantize activations to 8-bit per-token
x_q = activation_quant(x)
# 3. Quantize weights to ternary {-1, 0, +1}
w_q, w_scale = weight_quant(self.weight)
# 4. Matrix multiply with quantized weights and activations
# In theory this is integer addition; in practice we use float
# for autograd compatibility during training
output = F.linear(x_q, w_q)
return output
def extra_repr(self) -> str:
return f"in={self.in_features}, out={self.out_features}, bits=1.58"
class SwiGLU(nn.Module):
"""
SwiGLU activation for Feed-Forward Networks.
SwiGLU(x) = (Swish(xW1) βŠ™ xV) W2
Uses 3 linear projections instead of 2, but the hidden dim
is typically reduced by 2/3 to keep parameter count similar.
When used with BitLinear158, all three projections are ternary.
Reference: arXiv:2002.05202 (Shazeer, 2020)
"""
def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True):
super().__init__()
hidden_features = hidden_features or int(in_features * 8 / 3) # 2/3 of 4x expansion
# Round to nearest multiple of 8 for efficiency
hidden_features = ((hidden_features + 7) // 8) * 8
Linear = BitLinear158 if use_bitlinear else nn.Linear
self.w1 = Linear(in_features, hidden_features) # gate projection
self.v = Linear(in_features, hidden_features) # value projection
self.w2 = Linear(hidden_features, in_features) # output projection
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w2(F.silu(self.w1(x)) * self.v(x))
# ──── Testing ────────────────────────────────────────────────────
if __name__ == "__main__":
print("Testing BitLinear158 components...")
# Test RMSNorm
norm = RMSNorm(256)
x = torch.randn(2, 10, 256)
y = norm(x)
print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}")
# Test weight quantization
w = torch.randn(512, 256)
w_q, scale = weight_quant(w)
unique = torch.unique(w_q.detach())
print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}")
print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, "
f"0={((w_q.detach().round() == 0).sum().item())}, "
f"+1={((w_q.detach().round() == 1).sum().item())}")
# Test activation quantization
a = torch.randn(2, 10, 256)
a_q = activation_quant(a)
print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]")
# Test BitLinear158
layer = BitLinear158(256, 512)
x = torch.randn(2, 10, 256)
y = layer(x)
print(f"BitLinear158: {x.shape} -> {y.shape}")
# Test gradient flow (STE)
loss = y.sum()
loss.backward()
assert layer.weight.grad is not None, "Gradient did not flow through STE!"
print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})")
# Test SwiGLU
swiglu = SwiGLU(256, use_bitlinear=True)
x = torch.randn(2, 10, 256)
y = swiglu(x)
print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}")
total = sum(p.numel() for p in swiglu.parameters())
print(f" SwiGLU params: {total:,}")
# Parameter comparison
ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256))
ff_params = sum(p.numel() for p in ff_standard.parameters())
print(f" Standard FFN params: {ff_params:,}")
print(f" Ratio: {total / ff_params:.2f}x")
print("\nAll tests passed! βœ“")