| """
|
| PixelArtGen β BitLinear 1.58-bit Layer & RMSNorm
|
|
|
| Implementation of the core BitNet b1.58 components:
|
| - RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019)
|
| - BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1}
|
|
|
| References:
|
| - "The Era of 1-bit LLMs" (Ma et al., 2024) β arXiv:2402.17764
|
| - "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) β arXiv:2310.11453
|
| - "RMSNorm" (Zhang & Sennrich, 2019) β arXiv:1910.07467
|
| """
|
|
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| import math
|
|
|
|
|
| class RMSNorm(nn.Module):
|
| """
|
| Root Mean Square Layer Normalization.
|
|
|
| Simpler and faster than LayerNorm β removes mean centering,
|
| keeps only the re-scaling by root mean square.
|
|
|
| RMSNorm(x) = x / RMS(x) * g
|
| where RMS(x) = sqrt(mean(x^2))
|
|
|
| Reference: arXiv:1910.07467
|
| """
|
|
|
| def __init__(self, dim: int, eps: float = 1e-6):
|
| super().__init__()
|
| self.eps = eps
|
| self.weight = nn.Parameter(torch.ones(dim))
|
|
|
| def _norm(self, x: torch.Tensor) -> torch.Tensor:
|
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| output = self._norm(x.float()).type_as(x)
|
| return output * self.weight
|
|
|
|
|
| def activation_quant(x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Per-token 8-bit activation quantization from BitNet b1.58.
|
|
|
| Quantizes activations to [-127, 127] per-token using absmax scaling.
|
| Symmetric quantization (no zero-point) as specified in the paper.
|
|
|
| Args:
|
| x: (..., d_model) float tensor
|
| Returns:
|
| Quantized tensor (still float for autograd compatibility), scale factor
|
| """
|
| Qb = 127
|
| scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5)
|
| x_quant = (x * Qb / scale).clamp(-Qb, Qb).round()
|
|
|
| x_quant = x + (x_quant * scale / Qb - x).detach()
|
| return x_quant
|
|
|
|
|
| def weight_quant(w: torch.Tensor) -> tuple:
|
| """
|
| Absmean ternary weight quantization from BitNet b1.58.
|
|
|
| Quantizes weights to {-1, 0, +1} using absmean scaling:
|
| 1. Compute gamma = mean(|W|)
|
| 2. Scale: W_scaled = W / gamma
|
| 3. Round to nearest in {-1, 0, +1}
|
|
|
| Args:
|
| w: (out_features, in_features) weight matrix
|
| Returns:
|
| (quantized_weights, scale_factor)
|
| """
|
| gamma = w.abs().mean().clamp(min=1e-5)
|
| w_scaled = w / gamma
|
| w_quant = w_scaled.clamp(-1, 1).round()
|
|
|
| w_quant = w + (w_quant * gamma - w).detach()
|
| return w_quant, gamma
|
|
|
|
|
| class BitLinear158(nn.Module):
|
| """
|
| 1.58-bit Linear Layer from BitNet b1.58.
|
|
|
| Drop-in replacement for nn.Linear with:
|
| - Ternary weights {-1, 0, +1} via absmean quantization
|
| - 8-bit per-token activation quantization
|
| - Built-in RMSNorm (absorbs the preceding LayerNorm)
|
| - No bias (following BitNet b1.58 / LLaMA convention)
|
| - Full-precision latent weights maintained for training (STE)
|
|
|
| Forward pass:
|
| 1. RMSNorm the input
|
| 2. Quantize activations to 8-bit
|
| 3. Quantize weights to ternary
|
| 4. Matrix multiply (effectively integer addition)
|
| 5. Rescale output
|
|
|
| During training, gradients flow through quantization via the
|
| Straight-Through Estimator (STE) β the gradient of round()
|
| is treated as the identity function.
|
|
|
| Reference: arXiv:2402.17764
|
| """
|
|
|
| def __init__(self, in_features: int, out_features: int):
|
| super().__init__()
|
| self.in_features = in_features
|
| self.out_features = out_features
|
|
|
|
|
| self.weight = nn.Parameter(torch.empty(out_features, in_features))
|
|
|
|
|
| self.rms_norm = RMSNorm(in_features)
|
|
|
|
|
| self._init_weights()
|
|
|
| def _init_weights(self):
|
| """Kaiming uniform initialization, same as nn.Linear."""
|
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| """
|
| Args:
|
| x: (batch, seq_len, in_features)
|
| Returns:
|
| (batch, seq_len, out_features)
|
| """
|
|
|
| x = self.rms_norm(x)
|
|
|
|
|
| x_q = activation_quant(x)
|
|
|
|
|
| w_q, w_scale = weight_quant(self.weight)
|
|
|
|
|
|
|
|
|
| output = F.linear(x_q, w_q)
|
|
|
| return output
|
|
|
| def extra_repr(self) -> str:
|
| return f"in={self.in_features}, out={self.out_features}, bits=1.58"
|
|
|
|
|
| class SwiGLU(nn.Module):
|
| """
|
| SwiGLU activation for Feed-Forward Networks.
|
|
|
| SwiGLU(x) = (Swish(xW1) β xV) W2
|
|
|
| Uses 3 linear projections instead of 2, but the hidden dim
|
| is typically reduced by 2/3 to keep parameter count similar.
|
|
|
| When used with BitLinear158, all three projections are ternary.
|
|
|
| Reference: arXiv:2002.05202 (Shazeer, 2020)
|
| """
|
|
|
| def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True):
|
| super().__init__()
|
| hidden_features = hidden_features or int(in_features * 8 / 3)
|
|
|
| hidden_features = ((hidden_features + 7) // 8) * 8
|
|
|
| Linear = BitLinear158 if use_bitlinear else nn.Linear
|
|
|
| self.w1 = Linear(in_features, hidden_features)
|
| self.v = Linear(in_features, hidden_features)
|
| self.w2 = Linear(hidden_features, in_features)
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| return self.w2(F.silu(self.w1(x)) * self.v(x))
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| print("Testing BitLinear158 components...")
|
|
|
|
|
| norm = RMSNorm(256)
|
| x = torch.randn(2, 10, 256)
|
| y = norm(x)
|
| print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}")
|
|
|
|
|
| w = torch.randn(512, 256)
|
| w_q, scale = weight_quant(w)
|
| unique = torch.unique(w_q.detach())
|
| print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}")
|
| print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, "
|
| f"0={((w_q.detach().round() == 0).sum().item())}, "
|
| f"+1={((w_q.detach().round() == 1).sum().item())}")
|
|
|
|
|
| a = torch.randn(2, 10, 256)
|
| a_q = activation_quant(a)
|
| print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]")
|
|
|
|
|
| layer = BitLinear158(256, 512)
|
| x = torch.randn(2, 10, 256)
|
| y = layer(x)
|
| print(f"BitLinear158: {x.shape} -> {y.shape}")
|
|
|
|
|
| loss = y.sum()
|
| loss.backward()
|
| assert layer.weight.grad is not None, "Gradient did not flow through STE!"
|
| print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})")
|
|
|
|
|
| swiglu = SwiGLU(256, use_bitlinear=True)
|
| x = torch.randn(2, 10, 256)
|
| y = swiglu(x)
|
| print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}")
|
| total = sum(p.numel() for p in swiglu.parameters())
|
| print(f" SwiGLU params: {total:,}")
|
|
|
|
|
| ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256))
|
| ff_params = sum(p.numel() for p in ff_standard.parameters())
|
| print(f" Standard FFN params: {ff_params:,}")
|
| print(f" Ratio: {total / ff_params:.2f}x")
|
|
|
| print("\nAll tests passed! β")
|
|
|