TaoNet-pico-T1 / bitlinear.py
Lobakkang's picture
Upload TaoNet model to HuggingFace Hub
2981407 verified
"""
BitLinear - Simplified for training stability.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class RMSNorm(nn.Module):
"""Root Mean Square Layer Normalization."""
def __init__(self, dim, eps=1e-6):
super().__init__()
self.eps = eps
def forward(self, x):
rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
return (x / rms)
class TernaryQuantize(torch.autograd.Function):
"""Ternary quantization with straight-through estimator."""
@staticmethod
def forward(ctx, w):
scale = 1.0 / w.abs().mean().clamp_(min=1e-5)
u = (w * scale).round().clamp_(-1, 1) / scale
return u
@staticmethod
def backward(ctx, grad_output):
return grad_output
class ActivationQuantize(torch.autograd.Function):
"""INT8 activation quantization."""
@staticmethod
def forward(ctx, x):
scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp_(min=1e-5)
y = (x * scale).round().clamp_(-128, 127) / scale
return y
@staticmethod
def backward(ctx, grad_output):
return grad_output
class BitLinear(nn.Linear):
"""
Linear layer with ternary weight quantization.
No internal normalization - caller handles it (Pre-Norm architecture).
"""
def __init__(self, in_features, out_features, bias=True):
super().__init__(in_features, out_features)
# Gentler initialization for ternary stability
nn.init.normal_(self.weight, mean=0.0, std=0.02)
self.rmsnorm = RMSNorm(in_features)
def forward(self, x):
w = self.weight # a weight tensor with shape [d, k]
x_norm = self.rmsnorm(x)
# A trick for implementing Straight−Through−Estimator (STE) using detach()
x_quant = x_norm + (ActivationQuantize.apply(x_norm) - x_norm).detach()
w_quant = w + (TernaryQuantize.apply(w) - w).detach()
y = F.linear(x_quant, w_quant)
return self.rmsnorm(y)
def get_inference_params(self):
"""Export for FPGA deployment."""
with torch.no_grad():
scale = self.weight.abs().mean(dim=-1, keepdim=True).clamp(min=1e-5)
w_ternary = (self.weight / scale).round().clamp(-1, 1).to(torch.int8)
return {
'weight_ternary': w_ternary,
'weight_scale': scale.squeeze()
}