""" BitLinear - Simplified for training stability. """ import torch import torch.nn as nn import torch.nn.functional as F class RMSNorm(nn.Module): """Root Mean Square Layer Normalization.""" def __init__(self, dim, eps=1e-6): super().__init__() self.eps = eps def forward(self, x): rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) return (x / rms) class TernaryQuantize(torch.autograd.Function): """Ternary quantization with straight-through estimator.""" @staticmethod def forward(ctx, w): scale = 1.0 / w.abs().mean().clamp_(min=1e-5) u = (w * scale).round().clamp_(-1, 1) / scale return u @staticmethod def backward(ctx, grad_output): return grad_output class ActivationQuantize(torch.autograd.Function): """INT8 activation quantization.""" @staticmethod def forward(ctx, x): scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5) y = (x * scale).round().clamp_(-128, 127) / scale return y @staticmethod def backward(ctx, grad_output): return grad_output class BitLinear(nn.Linear): """ Linear layer with ternary weight quantization. No internal normalization - caller handles it (Pre-Norm architecture). """ def __init__(self, in_features, out_features, bias=True): super().__init__(in_features, out_features) # Gentler initialization for ternary stability nn.init.normal_(self.weight, mean=0.0, std=0.02) self.rmsnorm = RMSNorm(in_features) def forward(self, x): w = self.weight # a weight tensor with shape [d, k] x_norm = self.rmsnorm(x) # A trick for implementing Straight−Through−Estimator (STE) using detach() x_quant = x_norm + (ActivationQuantize.apply(x_norm) - x_norm).detach() w_quant = w + (TernaryQuantize.apply(w) - w).detach() y = F.linear(x_quant, w_quant) return self.rmsnorm(y) def get_inference_params(self): """Export for FPGA deployment.""" with torch.no_grad(): scale = self.weight.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5) w_ternary = (self.weight / scale).round().clamp(-1, 1).to(torch.int8) return { 'weight_ternary': w_ternary, 'weight_scale': scale.squeeze() }