Spaces:
Sleeping
Sleeping
| """ | |
| PixelArtGen β BitLinear 1.58-bit Layer & RMSNorm | |
| Implementation of the core BitNet b1.58 components: | |
| - RMSNorm: Root Mean Square Layer Normalization (Zhang & Sennrich, 2019) | |
| - BitLinear158: 1.58-bit linear layer with ternary weights {-1, 0, +1} | |
| References: | |
| - "The Era of 1-bit LLMs" (Ma et al., 2024) β arXiv:2402.17764 | |
| - "BitNet: Scaling 1-bit Transformers" (Wang et al., 2023) β arXiv:2310.11453 | |
| - "RMSNorm" (Zhang & Sennrich, 2019) β arXiv:1910.07467 | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import math | |
| class RMSNorm(nn.Module): | |
| """ | |
| Root Mean Square Layer Normalization. | |
| Simpler and faster than LayerNorm β removes mean centering, | |
| keeps only the re-scaling by root mean square. | |
| RMSNorm(x) = x / RMS(x) * g | |
| where RMS(x) = sqrt(mean(x^2)) | |
| Reference: arXiv:1910.07467 | |
| """ | |
| def __init__(self, dim: int, eps: float = 1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.weight = nn.Parameter(torch.ones(dim)) | |
| def _norm(self, x: torch.Tensor) -> torch.Tensor: | |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| output = self._norm(x.float()).type_as(x) | |
| return output * self.weight | |
| def activation_quant(x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Per-token 8-bit activation quantization from BitNet b1.58. | |
| Quantizes activations to [-127, 127] per-token using absmax scaling. | |
| Symmetric quantization (no zero-point) as specified in the paper. | |
| Args: | |
| x: (..., d_model) float tensor | |
| Returns: | |
| Quantized tensor (still float for autograd compatibility), scale factor | |
| """ | |
| Qb = 127 # 8-bit signed: 2^(8-1) - 1 | |
| scale = x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) | |
| x_quant = (x * Qb / scale).clamp(-Qb, Qb).round() | |
| # STE: detach the rounding, keep gradients flowing | |
| x_quant = x + (x_quant * scale / Qb - x).detach() | |
| return x_quant | |
| def weight_quant(w: torch.Tensor) -> tuple: | |
| """ | |
| Absmean ternary weight quantization from BitNet b1.58. | |
| Quantizes weights to {-1, 0, +1} using absmean scaling: | |
| 1. Compute gamma = mean(|W|) | |
| 2. Scale: W_scaled = W / gamma | |
| 3. Round to nearest in {-1, 0, +1} | |
| Args: | |
| w: (out_features, in_features) weight matrix | |
| Returns: | |
| (quantized_weights, scale_factor) | |
| """ | |
| gamma = w.abs().mean().clamp(min=1e-5) | |
| w_scaled = w / gamma | |
| w_quant = w_scaled.clamp(-1, 1).round() | |
| # STE: detach the rounding, keep gradients on the latent weights | |
| w_quant = w + (w_quant * gamma - w).detach() | |
| return w_quant, gamma | |
| class BitLinear158(nn.Module): | |
| """ | |
| 1.58-bit Linear Layer from BitNet b1.58. | |
| Drop-in replacement for nn.Linear with: | |
| - Ternary weights {-1, 0, +1} via absmean quantization | |
| - 8-bit per-token activation quantization | |
| - Built-in RMSNorm (absorbs the preceding LayerNorm) | |
| - No bias (following BitNet b1.58 / LLaMA convention) | |
| - Full-precision latent weights maintained for training (STE) | |
| Forward pass: | |
| 1. RMSNorm the input | |
| 2. Quantize activations to 8-bit | |
| 3. Quantize weights to ternary | |
| 4. Matrix multiply (effectively integer addition) | |
| 5. Rescale output | |
| During training, gradients flow through quantization via the | |
| Straight-Through Estimator (STE) β the gradient of round() | |
| is treated as the identity function. | |
| Reference: arXiv:2402.17764 | |
| """ | |
| def __init__(self, in_features: int, out_features: int): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| # Full-precision latent weight (master copy for training) | |
| self.weight = nn.Parameter(torch.empty(out_features, in_features)) | |
| # Built-in RMSNorm (replaces the preceding LayerNorm) | |
| self.rms_norm = RMSNorm(in_features) | |
| # Initialize weights | |
| self._init_weights() | |
| def _init_weights(self): | |
| """Kaiming uniform initialization, same as nn.Linear.""" | |
| nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Args: | |
| x: (batch, seq_len, in_features) | |
| Returns: | |
| (batch, seq_len, out_features) | |
| """ | |
| # 1. Normalize input (built-in RMSNorm) | |
| x = self.rms_norm(x) | |
| # 2. Quantize activations to 8-bit per-token | |
| x_q = activation_quant(x) | |
| # 3. Quantize weights to ternary {-1, 0, +1} | |
| w_q, w_scale = weight_quant(self.weight) | |
| # 4. Matrix multiply with quantized weights and activations | |
| # In theory this is integer addition; in practice we use float | |
| # for autograd compatibility during training | |
| output = F.linear(x_q, w_q) | |
| return output | |
| def extra_repr(self) -> str: | |
| return f"in={self.in_features}, out={self.out_features}, bits=1.58" | |
| class SwiGLU(nn.Module): | |
| """ | |
| SwiGLU activation for Feed-Forward Networks. | |
| SwiGLU(x) = (Swish(xW1) β xV) W2 | |
| Uses 3 linear projections instead of 2, but the hidden dim | |
| is typically reduced by 2/3 to keep parameter count similar. | |
| When used with BitLinear158, all three projections are ternary. | |
| Reference: arXiv:2002.05202 (Shazeer, 2020) | |
| """ | |
| def __init__(self, in_features: int, hidden_features: int = None, use_bitlinear: bool = True): | |
| super().__init__() | |
| hidden_features = hidden_features or int(in_features * 8 / 3) # 2/3 of 4x expansion | |
| # Round to nearest multiple of 8 for efficiency | |
| hidden_features = ((hidden_features + 7) // 8) * 8 | |
| Linear = BitLinear158 if use_bitlinear else nn.Linear | |
| self.w1 = Linear(in_features, hidden_features) # gate projection | |
| self.v = Linear(in_features, hidden_features) # value projection | |
| self.w2 = Linear(hidden_features, in_features) # output projection | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.w2(F.silu(self.w1(x)) * self.v(x)) | |
| # ββββ Testing ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| print("Testing BitLinear158 components...") | |
| # Test RMSNorm | |
| norm = RMSNorm(256) | |
| x = torch.randn(2, 10, 256) | |
| y = norm(x) | |
| print(f"RMSNorm: {x.shape} -> {y.shape}, mean={y.mean():.4f}, std={y.std():.4f}") | |
| # Test weight quantization | |
| w = torch.randn(512, 256) | |
| w_q, scale = weight_quant(w) | |
| unique = torch.unique(w_q.detach()) | |
| print(f"Weight quant: {w.shape}, unique values: {len(unique)}, scale: {scale:.4f}") | |
| print(f" Ternary distribution: -1={((w_q.detach().round() == -1).sum().item())}, " | |
| f"0={((w_q.detach().round() == 0).sum().item())}, " | |
| f"+1={((w_q.detach().round() == 1).sum().item())}") | |
| # Test activation quantization | |
| a = torch.randn(2, 10, 256) | |
| a_q = activation_quant(a) | |
| print(f"Activation quant: range [{a_q.min():.2f}, {a_q.max():.2f}]") | |
| # Test BitLinear158 | |
| layer = BitLinear158(256, 512) | |
| x = torch.randn(2, 10, 256) | |
| y = layer(x) | |
| print(f"BitLinear158: {x.shape} -> {y.shape}") | |
| # Test gradient flow (STE) | |
| loss = y.sum() | |
| loss.backward() | |
| assert layer.weight.grad is not None, "Gradient did not flow through STE!" | |
| print(f"STE gradient flow: OK (grad norm: {layer.weight.grad.norm():.4f})") | |
| # Test SwiGLU | |
| swiglu = SwiGLU(256, use_bitlinear=True) | |
| x = torch.randn(2, 10, 256) | |
| y = swiglu(x) | |
| print(f"SwiGLU (BitLinear): {x.shape} -> {y.shape}") | |
| total = sum(p.numel() for p in swiglu.parameters()) | |
| print(f" SwiGLU params: {total:,}") | |
| # Parameter comparison | |
| ff_standard = nn.Sequential(nn.Linear(256, 512), nn.GELU(), nn.Linear(512, 256)) | |
| ff_params = sum(p.numel() for p in ff_standard.parameters()) | |
| print(f" Standard FFN params: {ff_params:,}") | |
| print(f" Ratio: {total / ff_params:.2f}x") | |
| print("\nAll tests passed! β") | |