|
|
"""
|
|
|
BitLinear - Simplified for training stability.
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
class RMSNorm(nn.Module):
|
|
|
"""Root Mean Square Layer Normalization."""
|
|
|
|
|
|
def __init__(self, dim, eps=1e-6):
|
|
|
super().__init__()
|
|
|
self.eps = eps
|
|
|
|
|
|
def forward(self, x):
|
|
|
rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
|
|
|
return (x / rms)
|
|
|
|
|
|
|
|
|
class TernaryQuantize(torch.autograd.Function):
|
|
|
"""Ternary quantization with straight-through estimator."""
|
|
|
|
|
|
@staticmethod
|
|
|
def forward(ctx, w):
|
|
|
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
|
|
|
u = (w * scale).round().clamp_(-1, 1) / scale
|
|
|
return u
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_output):
|
|
|
return grad_output
|
|
|
|
|
|
|
|
|
class ActivationQuantize(torch.autograd.Function):
|
|
|
"""INT8 activation quantization."""
|
|
|
|
|
|
@staticmethod
|
|
|
def forward(ctx, x):
|
|
|
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
|
|
|
y = (x * scale).round().clamp_(-128, 127) / scale
|
|
|
return y
|
|
|
|
|
|
@staticmethod
|
|
|
def backward(ctx, grad_output):
|
|
|
return grad_output
|
|
|
|
|
|
|
|
|
class BitLinear(nn.Linear):
|
|
|
"""
|
|
|
Linear layer with ternary weight quantization.
|
|
|
|
|
|
No internal normalization - caller handles it (Pre-Norm architecture).
|
|
|
"""
|
|
|
|
|
|
def __init__(self, in_features, out_features, bias=True):
|
|
|
super().__init__(in_features, out_features)
|
|
|
|
|
|
|
|
|
nn.init.normal_(self.weight, mean=0.0, std=0.02)
|
|
|
self.rmsnorm = RMSNorm(in_features)
|
|
|
|
|
|
def forward(self, x):
|
|
|
w = self.weight
|
|
|
x_norm = self.rmsnorm(x)
|
|
|
|
|
|
x_quant = x_norm + (ActivationQuantize.apply(x_norm) - x_norm).detach()
|
|
|
w_quant = w + (TernaryQuantize.apply(w) - w).detach()
|
|
|
y = F.linear(x_quant, w_quant)
|
|
|
|
|
|
return self.rmsnorm(y)
|
|
|
|
|
|
def get_inference_params(self):
|
|
|
"""Export for FPGA deployment."""
|
|
|
with torch.no_grad():
|
|
|
scale = self.weight.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5)
|
|
|
w_ternary = (self.weight / scale).round().clamp(-1, 1).to(torch.int8)
|
|
|
|
|
|
return {
|
|
|
'weight_ternary': w_ternary,
|
|
|
'weight_scale': scale.squeeze()
|
|
|
} |