""" PixelArtGen — BitLinear 1.58-bit Layer & RMSNorm Implementation of the core BitNet b1.58 components: - RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019) - BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1} References: - "The Era of 1-bit LLMs" (Ma et al., 2024) — arXiv:2402.17764 - "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) — arXiv:2310.11453 - "RMSNorm" (Zhang & Sennrich, 2019) — arXiv:1910.07467 """ import torch import torch.nn as nn import torch.nn.functional as F import math class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization. Simpler and faster than LayerNorm — removes mean centering, keeps only the re-scaling by root mean square. RMSNorm(x) = x / RMS(x) * g where RMS(x) = sqrt(mean(x^2)) Reference: arXiv:1910.07467 """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x: torch.Tensor) -> torch.Tensor: return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x: torch.Tensor) -> torch.Tensor: output = self._norm(x.float()).type_as(x) return output * self.weight def activation_quant(x: torch.Tensor) -> torch.Tensor: """ Per-token 8-bit activation quantization from BitNet b1.58. Quantizes activations to [-127, 127] per-token using absmax scaling. Symmetric quantization (no zero-point) as specified in the paper. Args: x: (..., d_model) float tensor Returns: Quantized tensor (still float for autograd compatibility), scale factor """ Qb = 127 # 8-bit signed: 2^(8-1) - 1 scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) x_quant = (x * Qb / scale).clamp(-Qb, Qb).round() # STE: detach the rounding, keep gradients flowing x_quant = x + (x_quant * scale / Qb - x).detach() return x_quant def weight_quant(w: torch.Tensor) -> tuple: """ Absmean ternary weight quantization from BitNet b1.58. Quantizes weights to {-1, 0, +1} using absmean scaling: 1. Compute gamma = mean(|W|) 2. Scale: W_scaled = W / gamma 3. Round to nearest in {-1, 0, +1} Args: w: (out_features, in_features) weight matrix Returns: (quantized_weights, scale_factor) """ gamma = w.abs().mean().clamp(min=1e-5) w_scaled = w / gamma w_quant = w_scaled.clamp(-1, 1).round() # STE: detach the rounding, keep gradients on the latent weights w_quant = w + (w_quant * gamma - w).detach() return w_quant, gamma class BitLinear158(nn.Module): """ 1.58-bit Linear Layer from BitNet b1.58. Drop-in replacement for nn.Linear with: - Ternary weights {-1, 0, +1} via absmean quantization - 8-bit per-token activation quantization - Built-in RMSNorm (absorbs the preceding LayerNorm) - No bias (following BitNet b1.58 / LLaMA convention) - Full-precision latent weights maintained for training (STE) Forward pass: 1. RMSNorm the input 2. Quantize activations to 8-bit 3. Quantize weights to ternary 4. Matrix multiply (effectively integer addition) 5. Rescale output During training, gradients flow through quantization via the Straight-Through Estimator (STE) — the gradient of round() is treated as the identity function. Reference: arXiv:2402.17764 """ def __init__(self, in_features: int, out_features: int): super().__init__() self.in_features = in_features self.out_features = out_features # Full-precision latent weight (master copy for training) self.weight = nn.Parameter(torch.empty(out_features, in_features)) # Built-in RMSNorm (replaces the preceding LayerNorm) self.rms_norm = RMSNorm(in_features) # Initialize weights self._init_weights() def _init_weights(self): """Kaiming uniform initialization, same as nn.Linear.""" nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (batch, seq_len, in_features) Returns: (batch, seq_len, out_features) """ # 1. Normalize input (built-in RMSNorm) x = self.rms_norm(x) # 2. Quantize activations to 8-bit per-token x_q = activation_quant(x) # 3. Quantize weights to ternary {-1, 0, +1} w_q, w_scale = weight_quant(self.weight) # 4. Matrix multiply with quantized weights and activations # In theory this is integer addition; in practice we use float # for autograd compatibility during training output = F.linear(x_q, w_q) return output def extra_repr(self) -> str: return f"in={self.in_features}, out={self.out_features}, bits=1.58" class SwiGLU(nn.Module): """ SwiGLU activation for Feed-Forward Networks. SwiGLU(x) = (Swish(xW1) ⊙ xV) W2 Uses 3 linear projections instead of 2, but the hidden dim is typically reduced by 2/3 to keep parameter count similar. When used with BitLinear158, all three projections are ternary. Reference: arXiv:2002.05202 (Shazeer, 2020) """ def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True): super().__init__() hidden_features = hidden_features or int(in_features * 8 / 3) # 2/3 of 4x expansion # Round to nearest multiple of 8 for efficiency hidden_features = ((hidden_features + 7) // 8) * 8 Linear = BitLinear158 if use_bitlinear else nn.Linear self.w1 = Linear(in_features, hidden_features) # gate projection self.v = Linear(in_features, hidden_features) # value projection self.w2 = Linear(hidden_features, in_features) # output projection def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.v(x)) # ──── Testing ──────────────────────────────────────────────────── if __name__ == "__main__": print("Testing BitLinear158 components...") # Test RMSNorm norm = RMSNorm(256) x = torch.randn(2, 10, 256) y = norm(x) print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}") # Test weight quantization w = torch.randn(512, 256) w_q, scale = weight_quant(w) unique = torch.unique(w_q.detach()) print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}") print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, " f"0={((w_q.detach().round() == 0).sum().item())}, " f"+1={((w_q.detach().round() == 1).sum().item())}") # Test activation quantization a = torch.randn(2, 10, 256) a_q = activation_quant(a) print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]") # Test BitLinear158 layer = BitLinear158(256, 512) x = torch.randn(2, 10, 256) y = layer(x) print(f"BitLinear158: {x.shape} -> {y.shape}") # Test gradient flow (STE) loss = y.sum() loss.backward() assert layer.weight.grad is not None, "Gradient did not flow through STE!" print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})") # Test SwiGLU swiglu = SwiGLU(256, use_bitlinear=True) x = torch.randn(2, 10, 256) y = swiglu(x) print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}") total = sum(p.numel() for p in swiglu.parameters()) print(f" SwiGLU params: {total:,}") # Parameter comparison ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256)) ff_params = sum(p.numel() for p in ff_standard.parameters()) print(f" Standard FFN params: {ff_params:,}") print(f" Ratio: {total / ff_params:.2f}x") print("\nAll tests passed! ✓")