luohoa97's picture
Deploy BitNet-Transformer Trainer
d5b7ee9 verified
"""BitLinear Layer — Ternary (1.58-bit) quantization for PyTorch.
Based on "The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits" (2024).
Weights are quantized to {-1, 0, 1} for extreme efficiency.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
def weight_quant(w):
"""Quantize weights to {-1, 0, 1} using absmean scaling."""
scale = w.abs().mean().clamp(min=1e-5)
w_q = torch.round(torch.clamp(w / scale, -1, 1))
# Straight-Through Estimator (STE)
return w + (w_q - w).detach()
def activation_quant(x):
"""Quantize activations to 8-bit using absmax scaling."""
scale = 127.0 / x.abs().max().clamp(min=1e-5)
x_q = torch.round(torch.clamp(x * scale, -128, 127))
# Straight-Through Estimator (STE)
return x + (x_q / scale - x).detach()
class BitLinear(nn.Linear):
"""Linear layer with ternary weights and 8-bit activations."""
def forward(self, x):
# Quantize weights
w_q = weight_quant(self.weight)
# Quantize activations (optional for small models, but good for BitNet adherence)
x_q = activation_quant(x)
# Perform linear operation (Matrix Multiplication)
# Note: In a real C++ kernel, this would be addition/subtraction only.
return F.linear(x_q, w_q, self.bias)
class BitRMSNorm(nn.Module):
"""Root Mean Square Layer Normalization (often used with BitNet)."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm_x = x.pow(2).mean(-1, keepdim=True)
x_normed = x * torch.rsqrt(norm_x + self.eps)
return x_normed * self.weight